# Load Dependency



In [None]:
# Make sure the glove.840B.300d.zip is in the project directory under .vector_cache
# Only use this in Colab if you have the zip file and want to copy over

from google.colab import drive

drive.mount('/content/drive')
!cp -r drive/MyDrive/CompSci-590-NLP/Final\ Project/.vector_cache .
!ls .vector_cache

Mounted at /content/drive
glove.840B.300d.zip


In [None]:
# This only needs to be executed once unless you don't see the WikiSQL folder on the left

!git clone https://github.com/salesforce/WikiSQL
!pip install -r WikiSQL/requirements.txt
!tar xvjf WikiSQL/data.tar.bz2

!pip install graph4nlp-cu110

# Comment back this if need customized setup, meanwhile comment out last line
# !pip3 install PyYAML
# !pip3 install nltk
# !pip3 install scipy
# !git clone https://github.com/graph4ai/graph4nlp.git
# !graph4nlp/configure
# !python graph4nlp/setup.py install

In [1]:
# import all libraries

import json
import tqdm
import numpy as np
import os
from copy import deepcopy
from random import sample
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from collections import Counter

from graph4nlp.pytorch.data.data import to_batch, from_batch
from graph4nlp.pytorch.data import GraphData
from graph4nlp.pytorch.data.dataset import Text2TextDataItem as DataItem
from graph4nlp.pytorch.modules.utils.vocab_utils import VocabModel
from graph4nlp.pytorch.modules.config import get_basic_args
from graph4nlp.pytorch.models.graph2seq import Graph2Seq
from graph4nlp.pytorch.models.graph2seq_loss import Graph2SeqLoss
from graph4nlp.pytorch.modules.evaluation.base import EvaluationMetricBase

from graph4nlp.pytorch.modules.utils.padding_utils import pad_2d_vals_no_size

from graph4nlp.pytorch.modules.utils.copy_utils import prepare_ext_vocab

from google.colab import files



os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

embed_dim = 300

Using backend: pytorch


In [None]:
print(torch.__version__)
print(torch.version.cuda)

1.10.0+cu111
11.1


# Data Processing

## Load tables

Load tables for later lookups when parsing queries. Only need to be run once.

In [2]:
# Store table in a dict: id --> json
f_train = open('data/train.tables.jsonl')
train_tables = {json.loads(e)['id']:json.loads(e) for e in f_train}
f_train.close()

f_dev = open('data/dev.tables.jsonl')
dev_tables = {json.loads(e)['id']:json.loads(e) for e in f_dev}
f_dev.close()

f_test = open('data/test.tables.jsonl')
test_tables = {json.loads(e)['id']:json.loads(e) for e in f_test}
f_test.close()


## Parse Query Graph

This section define the function for parsing queries to graphs. Only need to be run once.

In [3]:
# build the graph given a query

def _parse_query_graph(q, table):
    g = GraphData()
    cnt = 0

    # node for select
    g.add_nodes(1)
    g.node_attributes[cnt]['token'] = 'SELECT'
    cnt += 1

    # node for agg
    if q['agg']:
        g.add_nodes(1)
        g.node_attributes[cnt]['token'] = agg_ops[q['agg']]
        g.add_edge(cnt - 1, cnt)
        cnt += 1
        

    # selected column
    g.add_nodes(1)
    g.node_attributes[cnt]['token'] = table['header'][q['sel']]
    g.add_edge(cnt - 1, cnt)
    cnt += 1


    # FROM node
    g.add_nodes(1)
    g.node_attributes[cnt]['token'] = 'FROM'
    g.add_edge(cnt - 1, cnt)
    cnt += 1

    # table node
    g.add_nodes(1)
    g.node_attributes[cnt]['token'] = 'table'
    g.add_edge(cnt - 1, cnt)
    cnt += 1

    # WHERE node
    g.add_nodes(1)
    g.node_attributes[cnt]['token'] = 'WHERE'
    g.add_edge(cnt - 1, cnt)
    cnt += 1

    # need 'and' node or not
    if len(q['conds']) > 1:
        and_idx = cnt

        g.add_nodes(1)
        g.node_attributes[cnt]['token'] = 'AND'
        g.add_edge(cnt - 1, cnt)
        cnt += 1

        for cond in q['conds']:
            g.add_nodes(1)
            g.node_attributes[cnt]['token'] = table['header'][cond[0]]
            g.add_edge(and_idx, cnt)
            cnt += 1

            g.add_nodes(1)
            g.node_attributes[cnt]['token'] = cond_ops[cond[1]]
            g.add_edge(cnt - 1, cnt)
            cnt += 1

            g.add_nodes(1)
            g.node_attributes[cnt]['token'] = str(cond[2])
            g.add_edge(cnt - 1, cnt)
            cnt += 1
    elif len(q['conds']) == 1:
        # order: col, op, const
        g.add_nodes(1)
        g.node_attributes[cnt]['token'] = table['header'][q['conds'][0][0]]
        g.add_edge(cnt - 1, cnt)
        cnt += 1

        g.add_nodes(1)
        g.node_attributes[cnt]['token'] = cond_ops[q['conds'][0][1]]
        g.add_edge(cnt - 1, cnt)
        cnt += 1

        g.add_nodes(1)
        g.node_attributes[cnt]['token'] = str(q['conds'][0][2])
        g.add_edge(cnt - 1, cnt)
        cnt += 1

    g.node_features['token_id'] = torch.zeros(cnt, 1, dtype=torch.long)

    return g



def _parse_query_json(q, table):
    query = 'SELECT {agg} {sel} FROM table'.format(
                agg=agg_ops[q['agg']],
                sel=table['header'][q['sel']],
            ) if q['agg'] else \
            'SELECT {sel} FROM table'.format(
                sel=table['header'][q['sel']],
            )
    if q['conds']:
        query += ' WHERE ' + ' AND '.join(['{} {} {}'.format(table['header'][i], cond_ops[o], v) for i, o, v in q['conds']])
    return query


## Vocab and Graph for Train & Dev

This section build the vocab and parse data into train, dev and test set.

In [4]:
SHARE_VOCAB = False

# Build training dataset and vocab
train_data_set = []

agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
cond_ops = ['=', '>', '<', 'OP']


with open('data/train.jsonl') as f_train:
    for line in f_train:
        qjson = json.loads(line)
        tid = qjson['table_id']

        if qjson['question'][-1] in ['?', '!', '.']:
            qjson['question'] = qjson['question'][:-1]

        query = _parse_query_json(qjson['sql'], train_tables[tid])
        g = _parse_query_graph(qjson['sql'], train_tables[tid])
        g.graph_attributes['query'] = query
        d = DataItem(query, qjson['question'], None, share_vocab=SHARE_VOCAB)
        d.graph = g

        train_data_set.append(d)


vocab_model = VocabModel(data_set=train_data_set, 
                         tokenizer=None, 
                         lower_case=False, 
                         max_word_vocab_size=None, 
                         min_word_vocab_freq=10, 
                         pretrained_word_emb_name='840B', 
                         word_emb_size=embed_dim, 
                         share_vocab=SHARE_VOCAB)


# fill initial embedding for each node
for d in train_data_set:
    for i in range(len(d.graph.node_attributes)):
        token_id = vocab_model.in_word_vocab.getIndex(d.graph.node_attributes[i]['token'])
        d.graph.node_features['token_id'][i][0] = token_id


# Build dev dataset
dev_dataset = []

with open('data/dev.jsonl') as f_dev:
    for line in f_dev:
        qjson = json.loads(line)
        tid = qjson['table_id']
        query = _parse_query_json(qjson['sql'], dev_tables[tid])
        g = _parse_query_graph(qjson['sql'], dev_tables[tid])
        g.graph_attributes['query'] = query
        d = DataItem(query, qjson['question'], None, share_vocab=SHARE_VOCAB)
        d.graph = g
        dev_dataset.append(d)

        for i in range(len(d.graph.node_attributes)):
            token_id = vocab_model.in_word_vocab.getIndex(d.graph.node_attributes[i]['token'])
            d.graph.node_features['token_id'][i][0] = token_id


# Build test dataset
test_dataset = []

with open('data/test.jsonl') as f_test:
  for line in f_test:
    qjson = json.loads(line)
    tid = qjson['table_id']
    query = _parse_query_json(qjson['sql'], test_tables[tid])
    g = _parse_query_graph(qjson['sql'], test_tables[tid])
    g.graph_attributes['query'] = query
    d = DataItem(query, qjson['question'], None, share_vocab=True)
    d.graph = g
    test_dataset.append(d)

    for i in range(len(d.graph.node_attributes)):
      token_id = vocab_model.in_word_vocab.getIndex(d.graph.node_attributes[i]['token'])
      d.graph.node_features['token_id'][i][0] = token_id
#       embed = vocab_model.in_word_vocab.embeddings[token_id]
#       d.graph.node_features['node_feat'][i] = torch.Tensor(embed)




Building vocabs...




Pretrained word embeddings hit ratio: 0.9030159668835009
Using pretrained word embeddings
Pretrained word embeddings hit ratio: 0.9116675839295542
Using pretrained word embeddings
[ Using separate word vocabs for input & output text ]
[ Initialized input word embeddings: (3382, 300) ]
[ Initialized output word embeddings: (3634, 300) ]


In [5]:
# dataset class
class SQL2TextDataset(torch.utils.data.Dataset):
    def __init__(self, glist, vocab_model):
        self.glist = glist
        self.vocab_model = vocab_model

    def __len__(self):
        return len(self.glist)

    def __getitem__(self, i):

        x = self.glist[i]
        y = self.vocab_model.out_word_vocab.to_index_sequence(self.glist[i].output_text) + [2] # sos, ..., eos
        return x, y

    @staticmethod
    def collate_fn(data_list):
        # print(data_list)
        graph_list = [d[0].graph for d in data_list]
        query_list = [d[0].graph.graph_attributes['query'] for d in data_list]
        graph_data = to_batch(graph_list)

        np_output = [np.array(d[1], dtype=np.int32) for d in data_list]
        output_str = [deepcopy(d[0].output_text.strip()) for d in data_list]
        output_pad = pad_2d_vals_no_size(np_output)
      
        output_idx = torch.from_numpy(output_pad).long()
        return {'graph': graph_data, 'output_idx': output_idx, 'output_str': output_str, 'query': query_list}

# Graph2Seq Model Setting

This section defines the setting for Graph2Seq model. You need to run it whenever you decide to change the setting, and also rerun the SQL2Text Model block.


In [6]:
opt = {
  "graph_construction_args": {
    "graph_construction_share": {
      "graph_type": "node_emb",
      "root_dir": None,
      "topology_subdir": "NodeEmbGraph",
      "share_vocab": SHARE_VOCAB
    },
    "graph_construction_private": {
      "lower_case": False
    },
    "node_embedding": {
      "input_size": 300,
      "hidden_size": 300,
      "word_dropout": 0,
      "rnn_dropout": 0.3,
      "fix_bert_emb": False,
      "fix_word_emb": False,
      "embedding_style": {
        "single_token_item": True,
        "emb_strategy": "w2v",
        "num_rnn_layers": 1,
        "bert_model_name": None,
        "bert_lower_case": None
      },
      "sim_metric_type": "weighted_cosine",
      "num_heads": 1,
      "top_k_neigh": 8,
    #   "epsilon_neigh": None,
    #   "smoothness_ratio": 0.1,
    #   "connectivity_ratio": 0.05,
    #   "sparsity_ratio": 0.1
    }
  },
  "graph_embedding_args": {
    "graph_embedding_share": {
      "num_layers": 3,
      "input_size": 300,
      "hidden_size": 300,
      "output_size": 300,
      "direction_option": "bi_sep",
      "feat_drop": 0.3
    },
    "graph_embedding_private": {
      "heads": [
        10, 10, 10
      ],
      "attn_drop": 0.2,
      "negative_slope": 0.2,
      "residual": False,
      "activation": "relu",
      "allow_zero_in_degree": False
    }
  },
  "decoder_args": {
    "rnn_decoder_share": {
      "rnn_type": "lstm",
      "input_size": 300,
      "hidden_size": 512,
      "rnn_emb_input_size": 300,
      "use_copy": False,
      "use_coverage": True,
      "graph_pooling_strategy": "max",
      "attention_type": "uniform",
      "fuse_strategy": "average",
      "dropout": 0.4
    },
    "rnn_decoder_private": {
      "max_decoder_step": 50,
      "node_type_num": None,
      "tgt_emb_as_output_layer": False,
      "teacher_forcing_rate": 1
    }
  },
  "graph_construction_name": "node_emb",
  "graph_embedding_name": "graphsage",
  "decoder_name": "stdrnn"
}

# Reconstruction Seq2Seq Model

Define the reconstruction model. Only need to be run once.

In [7]:
SRC_PAD_IDX = vocab_model.in_word_vocab.PAD
TRG_PAD_IDX = vocab_model.out_word_vocab.PAD
SOURCE = vocab_model.in_word_vocab
TARGET = vocab_model.out_word_vocab

class EncoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim,  
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len, hid dim]
        #src_mask = [batch size, 1, 1, src len] 
                
        #self attention
        _src, _ = self.self_attention(src, src, src, src_mask)
        
        #dropout, residual connection and layer norm
        src = self.self_attn_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        #positionwise feedforward
        _src = self.positionwise_feedforward(src)
        
        #dropout, residual and layer norm
        src = self.ff_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        return src


class DecoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len, hid dim]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        #self attention
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        
        #dropout, residual connection and layer norm
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
            
        #trg = [batch size, trg len, hid dim]
            
        #encoder attention
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        
        #dropout, residual connection and layer norm
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
                    
        #trg = [batch size, trg len, hid dim]
        
        #positionwise feedforward
        _trg = self.positionwise_feedforward(trg)
        
        #dropout, residual and layer norm
        trg = self.ff_layer_norm(trg + self.dropout(_trg))
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return trg, attention



class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value, mask = None):
        
        batch_size = query.shape[0]
        
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
                
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        #Q = [batch size, query len, hid dim]
        #K = [batch size, key len, hid dim]
        #V = [batch size, value len, hid dim]
                
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        #Q = [batch size, n heads, query len, head dim]
        #K = [batch size, n heads, key len, head dim]
        #V = [batch size, n heads, value len, head dim]
                
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        #energy = [batch size, n heads, query len, key len]
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim = -1)
                
        #attention = [batch size, n heads, query len, key len]
                
        x = torch.matmul(self.dropout(attention), V)
        
        #x = [batch size, n heads, query len, head dim]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        #x = [batch size, query len, n heads, head dim]
        
        x = x.view(batch_size, -1, self.hid_dim)
        
        #x = [batch size, query len, hid dim]
        
        x = self.fc_o(x)
        
        #x = [batch size, query len, hid dim]
        
        return x, attention
    

class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        #x = [batch size, seq len, hid dim]
        
        x = self.dropout(torch.relu(self.fc_1(x)))
        
        #x = [batch size, seq len, pf dim]
        
        x = self.fc_2(x)
        
        #x = [batch size, seq len, hid dim]
        
        return x



class AttentionPointerDecoderV3(nn.Module):
    def __init__(self, 
                 output_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device,
                 copy=True,
                 source_field = SOURCE,
                 target_field = TARGET,
                 src_pad_idx=SRC_PAD_IDX, 
                 trg_pad_idx=TRG_PAD_IDX, 
                 max_length = 100):
        super().__init__()
        
        self.device = device
        
        # self.tok_embedding = nn.Embedding.from_pretrained(target_field.vocab.vectors, freeze=False)
        self.tok_embedding = nn.Embedding.from_pretrained(torch.from_numpy(target_field.embeddings), freeze=False)
        
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim, 
                                                  dropout, 
                                                  device)
                                     for _ in range(n_layers)])
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        self.copy = copy
        self.output_dim = output_dim
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.source_field = source_field
        self.target_field = target_field
        #self.fc_nhead_to_one = nn.Linear(n_heads, 1)
        self.fc_nhead_to_one = nn.Sequential(
          nn.Linear(n_heads, 64),
          nn.Tanh(),
          nn.Linear(64, 64),
          nn.Tanh(),
          nn.Linear(64, 1)
        )
        self.n_heads = n_heads

        # adapted
        self.tp_itos = [self.source_field.getWord(i) for i in range(self.source_field.get_vocab_size())]
        
    def forward(self, trg, enc_src, trg_mask, src_mask, src):
        
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        untouched_src = src.clone()
        untounched_trg = trg.clone()
        
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
                            
        #pos = [batch size, trg len]
        # print(self.target_field.embeddings.shape)
        # print(trg)
        # for x in trg:
        #     for y in x:
        #         if y >= 5511:
        #             print(y, 'out of range')
        # self.tok_embedding(trg)
            
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
                
        #trg = [batch size, trg len, hid dim]
        
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        output = self.fc_out(trg)
        
        #output = [batch size, trg len, output dim]
        # attention copy kicks in
        if self.copy:
            # sum each head attention
            in_att = attention.permute(1, 0, 2, 3)
            in_att = in_att.view(self.n_heads, untounched_trg.shape[0], untounched_trg.shape[1] * untouched_src.shape[1])
            in_att = in_att.reshape(self.n_heads, untounched_trg.shape[0] * untounched_trg.shape[1] * untouched_src.shape[1])
            in_att = in_att.permute(1, 0)
            
            alpha = self.fc_nhead_to_one(in_att)
            alpha = alpha.permute(1, 0)
            alpha = alpha.view(1,  untounched_trg.shape[0] * untounched_trg.shape[1] * untouched_src.shape[1])
            alpha = alpha.view(1,  untounched_trg.shape[0] * untounched_trg.shape[1],  untouched_src.shape[1])
            alpha = alpha.view(1,  untounched_trg.shape[0], untounched_trg.shape[1],  untouched_src.shape[1])
            alpha = alpha.permute(1, 0, 2, 3)
            alpha = alpha.view(untounched_trg.shape[0], untounched_trg.shape[1], untouched_src.shape[1])
            #alpha = attention.sum(dim=1) # bsz x out_seq_len x in_seq_len # attention
            
            out_seq_len = alpha.shape[1]
            in_seq_len = alpha.shape[2]
            # mask input tokens that does not correspond to output tokens to -inf
            mask = torch.zeros_like(alpha, requires_grad=False)
            mask[torch.where(untounched_trg == self.trg_pad_idx)] = float('-inf') #  bsz x out_seq_len x in_seq_len
            
            mask = mask.permute(0, 2, 1)
            mask[torch.where(untouched_src == self.src_pad_idx)] = float('-inf') #  bsz x in_seq_len x out_seq_len

            mask = mask.permute(0, 2, 1) #  bsz x out_seq_len x in_seq_len

            masked_alpha = alpha + mask
            
            #print(alpha.max())
            #print(alpha.min())
            concated = torch.cat((output, alpha), dim=2) # bsz x out_seq_len x (in_seq_len + len(output_types))

            #concated = torch.nn.functional.softmax(concated, dim=2) # normalize
            
            #concated = torch.nn.functional.softmax(concated, dim=2) # normalize

            normalized_input = concated[:,:,self.output_dim:] # bsz x out_seq_len x in_seq_len         probabilities for copy[]

            normalized_output = concated[:,:,:self.output_dim] # bsz x out_seq_len x len(output_types) 

            mapped_input = torch.zeros_like(normalized_output)

            ## replaced by scatter axis?
            ## replaced by scatter axis?
            # scatter_add
            # dim: the axis starts to index
            # indexes
            # values
            # xid: what is the input locations in the output vocabs
            #pred = (g * pred).scatter_add(2, xids, (1 - g) * dists)
            ## prepare x_id

            # src_in_str = np.asarray(self.source_field.vocab.itos)[untouched_src.cpu().data.int().numpy()]
            # src_to_trg_indices = [[self.target_field.vocab.stoi[e_word] for e_word in e_row] for e_row in src_in_str]

            src_in_str = np.asarray(self.tp_itos)[untouched_src.cpu().data.int().numpy()]
            src_to_trg_indices = [[self.target_field.getIndex(e_word) for e_word in e_row] for e_row in src_in_str]

            src_to_trg_tensor = torch.Tensor(src_to_trg_indices).long().to(self.device)
            
            bsz = untouched_src.shape[0]
            
            xid = src_to_trg_tensor.view(bsz, 1, in_seq_len).repeat(1, out_seq_len, 1)
            ###
            mapped_input = mapped_input.scatter_add(2, xid, normalized_input)
            final_output = normalized_output + mapped_input
            output = final_output
        
        return output, attention



class PretrainedReconstructorEncoder(nn.Module):
    def __init__(self, 
                 input_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim,
                 dropout, 
                 device,
                 max_length = 100,
                 src_field=SOURCE,
                 trg_field=TARGET):
        super().__init__()

        self.device = device
        
        #self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        #self.tok_embedding = nn.Embedding.from_pretrained(src_field.vocab.vectors, freeze=False)
        self.tok_embedding = nn.Linear(input_dim, hid_dim)
        
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim,
                                                  dropout, 
                                                  device) 
                                     for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        self.hid_dim = hid_dim
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len, src_vocab_dim]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = src.shape[0]
        src_len = src.shape[1]
        
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        
        #pos = [batch size, src len]
        src = self.tok_embedding(src.view(batch_size * src_len, src.shape[2]))
        src =  src.view(batch_size, src_len, self.hid_dim)
        src = self.dropout((src * self.scale) + self.pos_embedding(pos))
        
        #src = [batch size, src len, hid dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        #src = [batch size, src len, hid dim]
            
        return src
    


class AttentionPointerReconstructorSeq2Seq(nn.Module):
    def __init__(self, 
                 encoder, 
                 decoder, 
                 src_pad_idx, 
                 trg_pad_idx, 
                 device,
                 copy = True,
                 output_dim = TARGET.get_vocab_size(),
                 source_field = SOURCE,
                 target_field = TARGET
                ):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        self.copy = copy
        self.output_dim = output_dim
        self.source_field = source_field
        self.target_field = target_field
        
    def make_src_mask(self, src):
        
        #src = [batch size, src len]
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        #src_mask = [batch size, 1, 1, src len]

        return src_mask
    
    def make_trg_mask(self, trg):
        
        #trg = [batch size, trg len]
        
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        
        #trg_pad_mask = [batch size, 1, 1, trg len]
        
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
        
        #trg_sub_mask = [trg len, trg len]
            
        trg_mask = trg_pad_mask & trg_sub_mask
        
        #trg_mask = [batch size, 1, trg len, trg len]
        
        return trg_mask

    def forward(self, src, src_tensor, trg):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len]
                
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        #src_mask = [batch size, 1, 1, src len]
        #trg_mask = [batch size, 1, trg len, trg len]
        
        enc_src = self.encoder(src_tensor, src_mask)
        
        #enc_src = [batch size, src len, hid dim]
                
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask, src)

        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return output, attention

# SQL2Text Model

The SQL2Text Model contains the following function:

1. train(epochs, batch_sz)
2. evaluate(batch_sz, split='dev')
3. train_with_reconstruct(epochs, batch_sz)
4. evaluate_with_reconstruct(batch_sz, split='dev')
5. translate_sample(batch_sz, sample_sz, random=False, split='test')
6. translate_to_file(filename, batch_sz, split='test')
7. translate_sample_post_copy(batch_sz, sample_sz, random=False, split='test')
8. translate_to_file_post_copy(filename, batch_sz, split='test')

Please use these functions accordingly. 

In [None]:
# This block only needs to be run once unless settings are tweaked.

agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
cond_ops = ['=', '>', '<', 'OP']
syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']

all_sql_syms = agg_ops + cond_ops + syms
all_sql_syms = [e_sym.lower() for e_sym in all_sql_syms]

# post copy using attention
def post_copy_processing(src, pred, attention):
    unk_locs = np.where(np.asarray(pred) == '<unk>')[0]
    refined_sentence = deepcopy(pred)
    exclude_idx = np.isin(np.asarray(src), np.asarray(all_sql_syms))
    excluded_src = np.asarray(src)[~exclude_idx]
    excluded_attention = attention[:, ~exclude_idx]

    for e_unk_idx in unk_locs:
        this_unk_attention = excluded_attention[e_unk_idx, :]
        best_matched_inp_idx = this_unk_attention.argmax().cpu().data.numpy()
        best_matched_inp = excluded_src[best_matched_inp_idx]
        refined_sentence[e_unk_idx] = best_matched_inp
        # set already matched to -inf
        excluded_attention[:, best_matched_inp_idx] = 0
    return ' '.join(refined_sentence)


# Convert a list of token to a string, stop at first <eos>
def wordid2str(word_ids, vocab):
    ret = []
    assert len(word_ids.shape) == 2, print(word_ids.shape)
    for i in range(word_ids.shape[0]):
        id_list = word_ids[i, :]
        ret_inst = []
        for j in range(id_list.shape[0]):
            if id_list[j] == vocab.EOS:
                break
            token = vocab.getWord(id_list[j])
            ret_inst.append(token)
        ret.append(" ".join(ret_inst))
    return ret 


# Convert a list of token to a string
def wordid2str_all(word_ids, vocab):
    ret = []
    assert len(word_ids.shape) == 2, print(word_ids.shape)
    for i in range(word_ids.shape[0]):
        id_list = word_ids[i, :]
        ret_inst = []
        for j in range(id_list.shape[0]):
            # if id_list[j] == vocab.EOS:
            #     break
            token = vocab.getWord(id_list[j])
            ret_inst.append(token)
        ret.append(" ".join(ret_inst))
    return ret 



class SQL2TextModel:
    def __init__(self, 
                 opt, 
                 vocab_model, 
                 train_dataset, 
                 dev_dataset, 
                 test_dataset=None):
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.device = torch.device("cpu")

        # Data related
        self.train_data = train_dataset
        self.dev_data = dev_dataset
        self.test_data = test_dataset
        self.vocab_model = vocab_model

        # Graph2Seq
        self.opt = opt
        self.enable_copy = self.opt['decoder_args']['rnn_decoder_share']['use_copy']
        self.use_coverage = self.opt['decoder_args']['rnn_decoder_share']['use_coverage']
        self.network = Graph2Seq.from_args(opt=self.opt, vocab_model=vocab_model).to(self.device)
        self.loss = Graph2SeqLoss(ignore_index=self.vocab_model.out_word_vocab.PAD, 
                                  use_coverage=self.use_coverage, 
                                  coverage_weight=0.3)


        self.CLIP = 5

        # Reconstructor
        INPUT_DIM = self.vocab_model.out_word_vocab.get_vocab_size()
        OUTPUT_DIM = self.vocab_model.in_word_vocab.get_vocab_size()
        HID_DIM = self.vocab_model.in_word_vocab.embeddings.shape[1]
        ENC_LAYERS = 3
        DEC_LAYERS = 3
        ENC_HEADS = 10
        DEC_HEADS = 10
        ENC_PF_DIM = 512
        DEC_PF_DIM = 512
        ENC_DROPOUT = 0.1
        DEC_DROPOUT = 0.1
        SRC_PAD_IDX = self.vocab_model.in_word_vocab.PAD
        TRG_PAD_IDX = self.vocab_model.out_word_vocab.PAD

        self.enc = PretrainedReconstructorEncoder(INPUT_DIM, 
                                                  HID_DIM, 
                                                  ENC_LAYERS, 
                                                  ENC_HEADS, 
                                                  ENC_PF_DIM, 
                                                  ENC_DROPOUT, 
                                                  self.device,
                                                  max_length = 100,
                                                  src_field=self.vocab_model.out_word_vocab,
                                                  trg_field=self.vocab_model.in_word_vocab)

        self.dec = AttentionPointerDecoderV3(OUTPUT_DIM, 
                                             HID_DIM, 
                                             DEC_LAYERS, 
                                             DEC_HEADS, 
                                             DEC_PF_DIM, 
                                             DEC_DROPOUT, 
                                             self.device,
                                             copy=True,
                                             source_field = self.vocab_model.out_word_vocab,
                                             target_field = self.vocab_model.in_word_vocab,
                                             src_pad_idx=TRG_PAD_IDX, 
                                             trg_pad_idx=SRC_PAD_IDX, 
                                             max_length = 100)
        
        self.reconstructor = AttentionPointerReconstructorSeq2Seq(self.enc, self.dec, 
                                                                  TRG_PAD_IDX, SRC_PAD_IDX, 
                                                                  self.device, 
                                                                  copy = True, output_dim=OUTPUT_DIM,
                                                                  source_field = self.vocab_model.out_word_vocab,
                                                                  target_field = self.vocab_model.in_word_vocab).to(self.device)

        self.recon_loss = nn.CrossEntropyLoss(ignore_index = 1)

    def train(self, epochs, batch_sz):

        dataset = SQL2TextDataset(self.train_data, self.vocab_model)
        loader = DataLoader(dataset, batch_size=batch_sz, shuffle=True, collate_fn=SQL2TextDataset.collate_fn)

        best_perplexity = float('inf')

        parameters = [p for p in self.network.parameters() if p.requires_grad]
        learning_rate = 1e-3
        optimizer = optim.Adam(parameters, lr=learning_rate)


        for epoch in range(epochs):
            loss_collect = []
            self.network.train()

            for d in tqdm.notebook.tqdm(loader, leave=False):
                # print(batch_graph, y)
                batch_graph, y_idx, y_str = d['graph'], d['output_idx'], d['output_str']

                optimizer.zero_grad()

                x, y = batch_graph.to(self.device), y_idx.to(self.device)

                oov_dict = None
                if self.enable_copy:
                    oov_dict, y = prepare_ext_vocab(x, self.vocab_model, gt_str=y_str, device=self.device)

                prob, enc_attn_weights, coverage_vectors = self.network(x, y, oov_dict=oov_dict)
                loss = self.loss(logits=prob, label=y, enc_attn_weights=enc_attn_weights, coverage_vectors=coverage_vectors)
                loss_collect.append(loss.item())

                loss.backward()

                torch.nn.utils.clip_grad_norm_(
                    [p for group in optimizer.param_groups for p in group['params']], self.CLIP)

                optimizer.step()

            perplexity = self.evaluate(batch_sz=batch_sz)

            print('Epoch {}: \n    Total loss: {:.3f} \n    Dev set perplexity: {:.3f}'.format(epoch, np.sum(loss_collect), perplexity))
            if perplexity <= best_perplexity:
                print("Best model saved, epoch {}".format(epoch))
                torch.save(self.network.state_dict(), 'sql2text_best.pt')
            else:
                print(f'Perplexity is {perplexity}, it stops to drop at Epoch {epoch}, early stop.')
                break

            if abs(perplexity - best_perplexity) < 0.3:
                learning_rate /= 2
                for g in optimizer.param_groups:
                    g['lr'] = learning_rate
                print(f'Update learning rate to be {learning_rate} since the difference between perplexity and best perplexity is smaller than 0.3')

            best_perplexity = min(best_perplexity, perplexity)

        return best_perplexity


    def evaluate(self, batch_sz, split='dev'):
    
        self.network.eval()

        dataset = SQL2TextDataset(self.dev_data, self.vocab_model) if split == 'dev' else SQL2TextDataset(self.test_data, self.vocab_model)
        loader = DataLoader(dataset, batch_size=batch_sz, shuffle=True, collate_fn=SQL2TextDataset.collate_fn)

        bsz = []
        perplexity = 0

        with torch.no_grad():

            for d in tqdm.notebook.tqdm(loader, leave=False):
                batch_graph, y_idx = d['graph'], d['output_idx']

                bsz.append(len(y_idx))

                x, y = batch_graph.to(self.device), y_idx.to(self.device)

                oov_dict = None
                if self.enable_copy:
                    oov_dict = prepare_ext_vocab(batch_graph=x, vocab=self.vocab_model, device=self.device)          

                prob, enc_attn_weights, coverage_vectors = self.network(x, y, oov_dict=oov_dict)
                loss = self.loss(logits=prob, label=y, enc_attn_weights=enc_attn_weights, coverage_vectors=coverage_vectors)

                # # print(enc_attn_weights.shape)
                # print(enc_attn_weights[0].shape)
                # print(len(enc_attn_weights), prob.shape[1])


                perplexity += loss * len(y_idx)

            perplexity /= sum(bsz)
            perplexity = torch.exp(perplexity)

        return perplexity.item()

    # train with reconstruction loss
    def train_with_reconstruct(self, epochs, batch_sz):
        dataset = SQL2TextDataset(self.train_data, self.vocab_model)
        loader = DataLoader(dataset, batch_size=batch_sz, shuffle=True, collate_fn=SQL2TextDataset.collate_fn)

        best_sql2text_loss = float('inf')

        parameters = list(self.network.parameters()) + list(self.reconstructor.parameters())
        learning_rate = 1e-3
        weight_decay = 1e-4
        optimizer = optim.Adam(parameters, lr=learning_rate, weight_decay=weight_decay)


        for epoch in range(epochs):
            loss_collect = []
            self.network.train()
            self.reconstructor.train()

            for d in tqdm.notebook.tqdm(loader, leave=False):
                # print(batch_graph, y)
                batch_graph, y_idx, y_str, sql = d['graph'], d['output_idx'], d['output_str'], d['query']

                optimizer.zero_grad()

                x, y = batch_graph.to(self.device), y_idx.to(self.device)

                oov_dict = None
                if self.enable_copy:
                    oov_dict, y = prepare_ext_vocab(x, self.vocab_model, gt_str=y_str, device=self.device)

                prob, enc_attn_weights, coverage_vectors = self.network(x, y, oov_dict=oov_dict)
                pred_text = prob.argmax(dim=-1)

                tp_src = [ [self.vocab_model.in_word_vocab.SOS] + self.vocab_model.in_word_vocab.to_index_sequence(s) + [self.vocab_model.in_word_vocab.EOS] for s in sql]
                tp_src = pad_2d_vals_no_size(np.array(tp_src), dtype=np.int32)
                src = torch.LongTensor(tp_src).to(self.device)
                output_sql, _ = self.reconstructor(pred_text, prob, src[:,:-1])

                output_sql_dim = output_sql.shape[-1]
                output_sql = output_sql.contiguous().view(-1, output_sql_dim)

                sql_trg = src[:,1:].contiguous().view(-1)  

                loss = self.loss(logits=prob, 
                                 label=y, 
                                 enc_attn_weights=enc_attn_weights, 
                                 coverage_vectors=coverage_vectors) + self.recon_loss(output_sql, sql_trg)

                loss_collect.append(loss.item() * 0.5)

                loss.backward()

                torch.nn.utils.clip_grad_norm_([p for group in optimizer.param_groups for p in group['params']], self.CLIP)

                optimizer.step()

            sql2text_loss, text2sql_loss, total_loss = self.evaluate_with_reconstruct(batch_sz=batch_sz)

            print(f'Epoch {epoch}')
            print('    Total train loss: {:.3f}'.format(np.sum(loss_collect)))
            print(f'    SQL2Text Dev set ppl: {sql2text_loss}')
            print(f'    Text2SQL Dev set ppl: {text2sql_loss}')
            print(f'    Total Dev set ppl: {total_loss} \n')
            
            if sql2text_loss <= best_sql2text_loss:
                print("Best model saved, epoch {}".format(epoch))
                torch.save(self.network.state_dict(), 'sql2text_best.pt')
                torch.save(self.reconstructor.state_dict(), 'text2sql_best.pt')
            else:
                print(f'SQL2Text Perplexity is {sql2text_loss}, it stops to drop at Epoch {epoch}, early stop.')
                break

            if abs(sql2text_loss - best_sql2text_loss) < 0.3:
                learning_rate /= 10
                for g in optimizer.param_groups:
                    g['lr'] = learning_rate
                print(f'Update learning rate to be {learning_rate} since the difference between perplexity and best perplexity is smaller than 0.3')

            best_sql2text_loss = min(best_sql2text_loss, sql2text_loss)

        return best_sql2text_loss


    # eval with reconstruction loss
    def evaluate_with_reconstruct(self, batch_sz, split='dev'):
        self.network.eval()

        dataset = SQL2TextDataset(self.dev_data, self.vocab_model) if split == 'dev' else SQL2TextDataset(self.test_data, self.vocab_model)
        loader = DataLoader(dataset, batch_size=batch_sz, shuffle=True, collate_fn=SQL2TextDataset.collate_fn)

        bsz = []
        sql2text_loss_total = 0
        text2sql_loss_total = 0
        total_loss = 0

        with torch.no_grad():

            for d in tqdm.notebook.tqdm(loader, leave=False):
                batch_graph, y_idx, _, sql = d['graph'], d['output_idx'], d['output_str'], d['query']

                bsz.append(len(y_idx))

                x, y = batch_graph.to(self.device), y_idx.to(self.device)

                oov_dict = None
                if self.enable_copy:
                    oov_dict = prepare_ext_vocab(batch_graph=x, vocab=self.vocab_model, device=self.device)          

                prob, enc_attn_weights, coverage_vectors = self.network(x, y, oov_dict=oov_dict)

                one_hot_trg = torch.nn.functional.one_hot(y.view(y.shape[0] * y.shape[1]),  num_classes=self.vocab_model.out_word_vocab.get_vocab_size())
                one_hot_trg = one_hot_trg.view(y.shape[0], y.shape[1], self.vocab_model.out_word_vocab.get_vocab_size()).float()

                tp_src = [ [self.vocab_model.in_word_vocab.SOS] + self.vocab_model.in_word_vocab.to_index_sequence(s) + [self.vocab_model.in_word_vocab.EOS] for s in sql]
                tp_src = pad_2d_vals_no_size(np.array(tp_src), dtype=np.int32)
                src = torch.LongTensor(tp_src).to(self.device)

                output_sql, _ = self.reconstructor(y, one_hot_trg, src[:,:-1])

                output_sql_dim = output_sql.shape[-1]

                output_sql = output_sql.contiguous().view(-1, output_sql_dim)

                text_trg = y[:,1:].contiguous().view(-1)
                sql_trg = src[:,1:].contiguous().view(-1)        
                #output = [batch size * trg len - 1, output dim]
                #trg = [batch size * trg len - 1]

                sql2text_loss = self.loss(logits=prob, label=y, enc_attn_weights=enc_attn_weights, coverage_vectors=coverage_vectors)
                text2sql_loss = self.recon_loss(output_sql, sql_trg)
                loss = self.loss(logits=prob, label=y, enc_attn_weights=enc_attn_weights, coverage_vectors=coverage_vectors) + self.recon_loss(output_sql, sql_trg)
                
                
                sql2text_loss_total += sql2text_loss.item() * len(y_idx)
                text2sql_loss_total += text2sql_loss.item() * len(y_idx)
                total_loss += loss.item() * 0.5 * len(y_idx)


            sql2text_loss_total /= sum(bsz)
            sql2text_loss_total = math.exp(sql2text_loss_total)
            
            text2sql_loss_total /= sum(bsz)
            text2sql_loss_total = math.exp(text2sql_loss_total)

            total_loss /= sum(bsz)
            total_loss = math.exp(total_loss)

        return sql2text_loss_total, text2sql_loss_total, total_loss
    

    # translate a sample of dataset
    def translate_sample(self, batch_sz, sample_sz, random=False, split='test'):

        assert batch_sz <= sample_sz, 'Batch size must be smaller or equal to the sample size.'
        assert split == 'test' or split == 'dev', 'Sample must come from dev or test set.'
        
        self.network.eval()

        datalist = self.test_data if split == 'test' else self.dev_data
        assert len(datalist) >= sample_sz, f'Sample size exceed the number of data in {split} set.'

        datalist = sample(datalist, sample_sz) if random else datalist[:sample_sz]
        dataset = SQL2TextDataset(datalist, self.vocab_model)
        loader = DataLoader(dataset, batch_size=batch_sz, shuffle=random, collate_fn=SQL2TextDataset.collate_fn)

        ex = 1

        with torch.no_grad():

            for d in tqdm.notebook.tqdm(loader, leave=False):

                batch_graph, y_idx, y_str, sql = d['graph'], d['output_idx'], d['output_str'], d['query']

                x, y = batch_graph.to(self.device), y_idx.to(self.device)
                
                if self.enable_copy:
                    oov_dict = prepare_ext_vocab(batch_graph=x, vocab=self.vocab_model, device=self.device)
                    ref_dict = oov_dict
                else:
                    oov_dict = None
                    ref_dict = self.vocab_model.out_word_vocab

                pred = self.network.translate(batch_graph=x, oov_dict=oov_dict, beam_size=4, topk=1)
                pred_ids = pred[:, 0, :]
                pred_str = wordid2str(pred_ids.detach().cpu(), ref_dict)

                for i in range(len(pred)):
                    print(f'Sample {ex}')
                    print(f'Original SQL: {sql[i]}')
                    print(f'Original text: {y_str[i]}')
                    print(f'Predicted text: {pred_str[i]} \n')
                    ex += 1

    # translate all dataset and save to file
    def translate_to_file(self, filename, batch_sz, split='test'):

        assert split == 'test' or split == 'dev', 'Sample must come from dev or test set.'
        
        self.network.eval()

        datalist = self.test_data if split == 'test' else self.dev_data
        dataset = SQL2TextDataset(datalist, self.vocab_model)
        loader = DataLoader(dataset, batch_size=batch_sz, shuffle=False, collate_fn=SQL2TextDataset.collate_fn)

        ex = 1

        with torch.no_grad():

            for d in tqdm.notebook.tqdm(loader, leave=False):

                batch_graph, y_idx, y_str, sql = d['graph'], d['output_idx'], d['output_str'], d['query']

                x, y = batch_graph.to(self.device), y_idx.to(self.device)
                
                if self.enable_copy:
                    oov_dict = prepare_ext_vocab(batch_graph=x, vocab=self.vocab_model, device=self.device)
                    ref_dict = oov_dict
                else:
                    oov_dict = None
                    ref_dict = self.vocab_model.out_word_vocab

                pred = self.network.translate(batch_graph=x, oov_dict=oov_dict, beam_size=4, topk=1)
                pred_ids = pred[:, 0, :]
                pred_str = wordid2str(pred_ids.detach().cpu(), ref_dict)

                for i in range(len(pred)):
                    f.write(f'Sample {ex}\n')
                    f.write(f'Original SQL: {sql[i]}\n')
                    f.write(f'Original text: {y_str[i]}\n')
                    f.write(f'Predicted text: {pred_str[i]} \n\n')
                    ex += 1

        f.close()


    def translate_sample_post_copy(self, batch_sz, sample_sz, random=False, split='test'):
        assert batch_sz <= sample_sz, 'Batch size must be smaller or equal to the sample size.'
        assert split == 'test' or split == 'dev', 'Sample must come from dev or test set.'
        
        self.network.eval()

        datalist = self.test_data if split == 'test' else self.dev_data
        assert len(datalist) >= sample_sz, f'Sample size exceed the number of data in {split} set.'

        datalist = sample(datalist, sample_sz) if random else datalist[:sample_sz]
        dataset = SQL2TextDataset(datalist, self.vocab_model)
        loader = DataLoader(dataset, batch_size=batch_sz, shuffle=random, collate_fn=SQL2TextDataset.collate_fn)

        ex = 1

        with torch.no_grad():

            for d in tqdm.notebook.tqdm(loader, leave=False):
                batch_graph, y_idx, y_str, sql = d['graph'], d['output_idx'], d['output_str'], d['query']

                x, y = batch_graph.to(self.device), y_idx.to(self.device)

                # print(x.node_features)
                # print(x.node_attributes)
                # print(from_batch(x)[0].node_attributes)
                oov_dict = None
                if self.enable_copy:
                    oov_dict = prepare_ext_vocab(batch_graph=x, vocab=self.vocab_model, device=self.device)          

                prob, enc_attn_weights, coverage_vectors = self.network(x, y, oov_dict=oov_dict)
                pred_ids = prob.argmax(dim=-1)
                pred_text = wordid2str_all(pred_ids.detach().cpu(), self.vocab_model.out_word_vocab)

                attn = torch.concat(enc_attn_weights, dim=0)
                attn = attn.permute(1, 0, 2)

                origin_g = from_batch(x)

                for i in range(len(pred_text)):
                    # print(len(sql[i].split()), len(pred_text[i].split()))
                    # print(attn[i].shape)
                    # print(pred_ids[i])
                    # print(attn[i].shape)
                    tp = [d['token'] for d in origin_g[i].node_attributes]
                    res = post_copy_processing(tp + ['#pad#'] * (attn[i].shape[1] - len(tp)), pred_text[i].split(), attn[i])
                    res = res[:res.find('</s>')] if res.find('</s>') else res
                    print(f'Sample {ex}')
                    print(f'Original SQL: {sql[i]}')
                    print(f'Original text: {y_str[i]}')
                    # print(f'Predicted text: {pred_text[i]}')
                    print(f'Predicted text after post copy: {res} \n')

                    ex += 1

    def translate_to_file_post_copy(self, filename, batch_sz, split='test'):
        assert split == 'test' or split == 'dev', 'Sample must come from dev or test set.'
        
        self.network.eval()

        datalist = self.test_data if split == 'test' else self.dev_data
        dataset = SQL2TextDataset(datalist, self.vocab_model)
        loader = DataLoader(dataset, batch_size=batch_sz, shuffle=False, collate_fn=SQL2TextDataset.collate_fn)

        ex = 1
        f = open(filename, "w")

        with torch.no_grad():

            for d in tqdm.notebook.tqdm(loader, leave=False):
                batch_graph, y_idx, y_str, sql = d['graph'], d['output_idx'], d['output_str'], d['query']

                x, y = batch_graph.to(self.device), y_idx.to(self.device)

                oov_dict = None
                if self.enable_copy:
                    oov_dict = prepare_ext_vocab(batch_graph=x, vocab=self.vocab_model, device=self.device)          

                prob, enc_attn_weights, coverage_vectors = self.network(x, y, oov_dict=oov_dict)
                pred_ids = prob.argmax(dim=-1)
                pred_text = wordid2str_all(pred_ids.detach().cpu(), self.vocab_model.out_word_vocab)

                attn = torch.concat(enc_attn_weights, dim=0)
                attn = attn.permute(1, 0, 2)

                origin_g = from_batch(x)

                for i in range(len(pred_text)):
                    tp = [d['token'] for d in origin_g[i].node_attributes]
                    res = post_copy_processing(tp + ['#pad#'] * (attn[i].shape[1] - len(tp)), pred_text[i].split(), attn[i])
                    res = res[:res.find('</s>')] if res.find('</s>') else res
                    f.write(f'Sample {ex}\n')
                    f.write(f'Original SQL: {sql[i]}\n')
                    f.write(f'Original text: {y_str[i]}\n')
                    f.write(f'Predicted text after post copy: {res} \n\n')

                    ex += 1
        f.close()
    

## Train with Reconstruction Loss

In [None]:
torch.cuda.empty_cache()
model_recon = SQL2TextModel(opt, vocab_model, train_data_set, dev_dataset, test_dataset)

print(model_recon.network.gnn_encoder)
print(model_recon.network.seq_decoder)

In [None]:
model_recon.train_with_reconstruct(epochs=10, batch_sz=64)
# model.evaluate_with_reconstruct(64)

In [None]:
# print a sample of translation
model.network.load_state_dict(torch.load('sql2text_best.pt'))
model.translate_sample(batch_sz=8, sample_sz=32, split='dev')

In [None]:
model.evaluate_with_reconstruct(64, 'test')

## Train with NLL Loss and Translate a Sample

In [None]:
torch.cuda.empty_cache()
model = SQL2TextModel(opt, vocab_model, train_data_set, dev_dataset, test_dataset)

print(model.network.gnn_encoder)
print(model.network.seq_decoder)

In [None]:
model.train(epochs=10, batch_sz=64)

In [None]:
model.network.load_state_dict(torch.load('sql2text_best.pt'))

# model.translate_sample(batch_sz=8, sample_sz=32, split='dev')
model.translate_sample_post_copy(batch_sz=8, sample_sz=32, split='dev')

In [None]:
model.evaluate(64, 'test')

# Predict and Write to File

Please change the file name before running it. Please also review the instruction under SQL2Text Model to know each functions can be called.

In [None]:
# change this accordingly
fname1 = 'drive/MyDrive/CompSci-590-NLP/Final Project/graph2seq_test_set_translate.txt'
fname2 = 'drive/MyDrive/CompSci-590-NLP/Final Project/graph2seq_test_set_translate_post_copy.txt'

model = SQL2TextModel(opt, vocab_model, train_data_set, dev_dataset, test_dataset)
model.network.load_state_dict(torch.load('sql2text_best.pt'))
# model.reconstructor.load_state_dict(torch.load('text2sql_best.py'))
model.translate_to_file(fname1, 64, split='test')
model.translate_to_file_post_copy(filename=fname2, batch_sz=64, split='test')

# Evaluation

This section is used for random evaluation.

In [1]:
target = []
g2s_pred = []
s2s_pred = []
sql = []

with open("sql2text-seq2seq.txt") as f1:
    offset1 = len('Original Text:')
    offset2 = len('Predicted Pred:')
    offset3 = len('Original SQL:')
    for line in f1:
        if line.startswith('Original Text'):
            s = line[offset1:].strip().split()
            target.append([s])
        elif line.startswith('Predicted Pred:'):
            s = line[offset2:].strip().split()
            s2s_pred.append(s)
        elif line.startswith('Original SQL:'):
            sql.append(line[offset3:].strip())
        

with open("sql2text-graph2seq.txt") as f2:
    offset1 = len('Original text:')
    offset2 = len('Predicted text after post copy:')
    for line in f2:
        # if line.startswith('Original text:'):
        #     s = line[offset1:].strip().split()
        #     target.append(s)
        if line.startswith('Predicted text after post copy:'):
            s = line[offset2:].strip().split()
            g2s_pred.append(s)
        
print(len(target), len(g2s_pred), len(s2s_pred))

15878 15878 15878


In [4]:
from torchtext.data.metrics import bleu_score as bs

# output the queries that Graph2Seq performs better than Seq2Seq
with open('graph-better.txt', 'w') as f:
    for i in range(len(target)):
        s1 = bs([s2s_pred[i]], [target[i]])
        s2 = bs([g2s_pred[i]], [target[i]])
        if s2 > s1:
            f.write(f'Sample {i}\n')
            f.write(f'SQL: {sql[i]}\n')
            f.write(f'graph2seq: {" ".join(g2s_pred[i])}\n')
            f.write(f'seq2seq: {" ".join(s2s_pred[i])}\n\n')

