In [1]:
import torch
from torchtext.legacy.data import Field, TabularDataset, BucketIterator, Iterator, Pipeline, RawField
from torchtext import datasets

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy.datasets import Multi30k
from torchtext.legacy.data import Field, BucketIterator

import spacy
import numpy as np
import os

import random
import math
import time

from torchtext.data.utils import get_tokenizer

tokenizer = get_tokenizer("basic_english")


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


## Load Data

In [None]:
# create Field objects
SOURCE = Field(init_token = '<sos>', 
            eos_token = '<eos>', 
            tokenize = tokenizer,
            lower=True, 
            batch_first = True)
TARGET = Field(init_token = '<sos>', 
            eos_token = '<eos>', 
            tokenize = tokenizer,
            lower=True, 
               batch_first = True)

# create tuples representing the columns
fields = [
  ('SQL', SOURCE),
  ('text', TARGET),
]

# load the dataset in json format
train_ds, valid_ds, test_ds = TabularDataset.splits(
   path = './',
   train = 'train_sql_text_no_pad.csv',
   validation = 'dev_sql_text_no_pad.csv',
   test = 'test_sql_text_no_pad.csv',
   format = 'csv',
   fields = fields,
   skip_header = True
)

# check an example
print(vars(train_ds[0]))

In [None]:
from torchtext.vocab import GloVe, vocab
glove_vectors = GloVe()
glove_vocab = vocab(glove_vectors.stoi)

SOURCE.build_vocab(train_ds, min_freq = 25,  vectors=glove_vectors)
TARGET.build_vocab(train_ds, min_freq = 25,  vectors=glove_vectors)

In [None]:
# determine what device to use
device = torch.device(
  'cuda' if torch.cuda.is_available() else 'cpu'
)
#device = "cpu"
# create iterators for train/valid/test datasets
train_it, valid_it, test_it = BucketIterator.splits(
  (train_ds, valid_ds, test_ds),
  sort_key = lambda x: len(x.SQL),
  sort = False,
  batch_size = 64,
  device = device
)

# iterate over training
for batch in train_it:
    break

## Model

In [None]:
class Encoder(nn.Module):
    def __init__(self, 
                 input_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim,
                 dropout, 
                 device,
                 max_length = 100):
        super().__init__()

        self.device = device
        
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim,
                                                  dropout, 
                                                  device) 
                                     for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = src.shape[0]
        src_len = src.shape[1]
        
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        
        #pos = [batch size, src len]
        
        src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
        
        #src = [batch size, src len, hid dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        #src = [batch size, src len, hid dim]
            
        return src
    
class EncoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim,  
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len, hid dim]
        #src_mask = [batch size, 1, 1, src len] 
                
        #self attention
        _src, _ = self.self_attention(src, src, src, src_mask)
        
        #dropout, residual connection and layer norm
        src = self.self_attn_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        #positionwise feedforward
        _src = self.positionwise_feedforward(src)
        
        #dropout, residual and layer norm
        src = self.ff_layer_norm(src + self.dropout(_src))
        
        #src = [batch size, src len, hid dim]
        
        return src
    
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value, mask = None):
        
        batch_size = query.shape[0]
        
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
                
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        #Q = [batch size, query len, hid dim]
        #K = [batch size, key len, hid dim]
        #V = [batch size, value len, hid dim]
                
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        #Q = [batch size, n heads, query len, head dim]
        #K = [batch size, n heads, key len, head dim]
        #V = [batch size, n heads, value len, head dim]
                
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        #energy = [batch size, n heads, query len, key len]
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim = -1)
                
        #attention = [batch size, n heads, query len, key len]
                
        x = torch.matmul(self.dropout(attention), V)
        
        #x = [batch size, n heads, query len, head dim]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        #x = [batch size, query len, n heads, head dim]
        
        x = x.view(batch_size, -1, self.hid_dim)
        
        #x = [batch size, query len, hid dim]
        
        x = self.fc_o(x)
        
        #x = [batch size, query len, hid dim]
        
        return x, attention
    
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        #x = [batch size, seq len, hid dim]
        
        x = self.dropout(torch.relu(self.fc_1(x)))
        
        #x = [batch size, seq len, pf dim]
        
        x = self.fc_2(x)
        
        #x = [batch size, seq len, hid dim]
        
        return x

class Decoder(nn.Module):
    def __init__(self, 
                 output_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device,
                 max_length = 100):
        super().__init__()
        
        self.device = device
        
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim, 
                                                  dropout, 
                                                  device)
                                     for _ in range(n_layers)])
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
                
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
                            
        #pos = [batch size, trg len]
            
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
                
        #trg = [batch size, trg len, hid dim]
        
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        output = self.fc_out(trg)
        
        #output = [batch size, trg len, output dim]
            
        return output, attention
    
class DecoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len, hid dim]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        #self attention
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        
        #dropout, residual connection and layer norm
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
            
        #trg = [batch size, trg len, hid dim]
            
        #encoder attention
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        
        #dropout, residual connection and layer norm
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
                    
        #trg = [batch size, trg len, hid dim]
        
        #positionwise feedforward
        _trg = self.positionwise_feedforward(trg)
        
        #dropout, residual and layer norm
        trg = self.ff_layer_norm(trg + self.dropout(_trg))
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return trg, attention
    
class Seq2Seq(nn.Module):
    def __init__(self, 
                 encoder, 
                 decoder, 
                 src_pad_idx, 
                 trg_pad_idx, 
                 device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        
        #src = [batch size, src len]
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        #src_mask = [batch size, 1, 1, src len]

        return src_mask
    
    def make_trg_mask(self, trg):
        
        #trg = [batch size, trg len]
        
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        
        #trg_pad_mask = [batch size, 1, 1, trg len]
        
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
        
        #trg_sub_mask = [trg len, trg len]
            
        trg_mask = trg_pad_mask & trg_sub_mask
        
        #trg_mask = [batch size, 1, trg len, trg len]
        
        return trg_mask

    def forward(self, src, trg):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len]
                
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        #src_mask = [batch size, 1, 1, src len]
        #trg_mask = [batch size, 1, trg len, trg len]
        
        enc_src = self.encoder(src, src_mask)
        
        #enc_src = [batch size, src len, hid dim]
                
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return output, attention
    
SRC_PAD_IDX = SOURCE.vocab.stoi[SOURCE.pad_token]
TRG_PAD_IDX = TARGET.vocab.stoi[TARGET.pad_token]

class PretrainedEncoder(nn.Module):
    def __init__(self, 
                 input_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim,
                 dropout, 
                 device,
                 max_length = 100,
                 src_field=SOURCE,
                 trg_field=TARGET):
        super().__init__()

        self.device = device
        
        #self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.tok_embedding = nn.Embedding.from_pretrained(src_field.vocab.vectors, freeze=False)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim,
                                                  dropout, 
                                                  device) 
                                     for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = src.shape[0]
        src_len = src.shape[1]
        
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        
        #pos = [batch size, src len]
        
        src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
        
        #src = [batch size, src len, hid dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        #src = [batch size, src len, hid dim]
            
        return src
    
class AttentionPointerDecoder(nn.Module):
    def __init__(self, 
                 output_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device,
                 copy=True,
                 source_field = SOURCE,
                 target_field = TARGET,
                 src_pad_idx=SRC_PAD_IDX, 
                 trg_pad_idx=TRG_PAD_IDX, 
                 max_length = 100):
        super().__init__()
        
        self.device = device
        
        #self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim, 
                                                  dropout, 
                                                  device)
                                     for _ in range(n_layers)])
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        self.copy = copy
        self.output_dim = output_dim
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.source_field = source_field
        self.target_field = target_field
        
    def forward(self, trg, enc_src, trg_mask, src_mask, src):
        
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        untouched_src = src.clone()
        untounched_trg = trg.clone()
        
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
                            
        #pos = [batch size, trg len]
            
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
                
        #trg = [batch size, trg len, hid dim]
        
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        output = self.fc_out(trg)
        
        #output = [batch size, trg len, output dim]
        # attention copy kicks in
        if self.copy:
            # sum each head attention
            alpha = attention.sum(dim=1) # bsz x out_seq_len x in_seq_len # attention

            out_seq_len = alpha.shape[1]
            in_seq_len = alpha.shape[2]
            # mask input tokens that does not correspond to output tokens to -inf
            mask = torch.zeros_like(alpha, requires_grad=False)
            mask[torch.where(untounched_trg == self.trg_pad_idx)] = float('-inf') #  bsz x out_seq_len x in_seq_len
            
            mask = mask.permute(0, 2, 1)
            mask[torch.where(untouched_src == self.src_pad_idx)] = float('-inf') #  bsz x in_seq_len x out_seq_len

            mask = mask.permute(0, 2, 1) #  bsz x out_seq_len x in_seq_len

            masked_alpha = alpha + mask
            
            #print(alpha.max())
            #print(alpha.min())
            concated = torch.cat((output, alpha), dim=2) # bsz x out_seq_len x (in_seq_len + len(output_types))

            #concated = torch.nn.functional.softmax(concated, dim=2) # normalize
            
            #concated = torch.nn.functional.softmax(concated, dim=2) # normalize

            normalized_input = concated[:,:,self.output_dim:] # bsz x out_seq_len x in_seq_len         probabilities for copy[]

            normalized_output = concated[:,:,:self.output_dim] # bsz x out_seq_len x len(output_types) 

            mapped_input = torch.zeros_like(normalized_output)

            ## replaced by scatter axis?
            ## replaced by scatter axis?
            # scatter_add
            # dim: the axis starts to index
            # indexes
            # values
            # xid: what is the input locations in the output vocabs
            #pred = (g * pred).scatter_add(2, xids, (1 - g) * dists)
            ## prepare x_id
            src_in_str = np.asarray(self.source_field.vocab.itos)[untouched_src.cpu().data.int().numpy()]
            src_to_trg_indices = [[self.target_field.vocab.stoi[e_word] for e_word in e_row] for e_row in src_in_str]
            src_to_trg_tensor = torch.Tensor(src_to_trg_indices).long().to(self.device)
            
            bsz = untouched_src.shape[0]
            
            xid = src_to_trg_tensor.view(bsz, 1, in_seq_len).repeat(1, out_seq_len, 1)
            ###
            mapped_input = mapped_input.scatter_add(2, xid, normalized_input)
            final_output = normalized_output + mapped_input
            output = final_output
        
        #
        return output, attention

class AttentionPointerDecoderV2(nn.Module):
    def __init__(self, 
                 output_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device,
                 copy=True,
                 source_field = SOURCE,
                 target_field = TARGET,
                 src_pad_idx=SRC_PAD_IDX, 
                 trg_pad_idx=TRG_PAD_IDX, 
                 max_length = 100):
        super().__init__()
        
        self.device = device
        
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim, 
                                                  dropout, 
                                                  device)
                                     for _ in range(n_layers)])
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        self.copy = copy
        self.output_dim = output_dim
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.source_field = source_field
        self.target_field = target_field
        #self.fc_nhead_to_one = nn.Linear(n_heads, 1)
        self.fc_nhead_to_one = nn.Sequential(
          nn.Linear(n_heads, 64),
          nn.Tanh(),
          nn.Linear(64, 64),
          nn.Tanh(),
          nn.Linear(64, 1)
        )
        self.n_heads = n_heads
        
    def forward(self, trg, enc_src, trg_mask, src_mask, src):
        
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        untouched_src = src.clone()
        untounched_trg = trg.clone()
        
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
                            
        #pos = [batch size, trg len]
            
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
                
        #trg = [batch size, trg len, hid dim]
        
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        output = self.fc_out(trg)
        
        #output = [batch size, trg len, output dim]
        # attention copy kicks in
        if self.copy:
            # sum each head attention
            in_att = attention.permute(1, 0, 2, 3)
            in_att = in_att.view(self.n_heads, untounched_trg.shape[0], untounched_trg.shape[1] * untouched_src.shape[1])
            in_att = in_att.reshape(self.n_heads, untounched_trg.shape[0] * untounched_trg.shape[1] * untouched_src.shape[1])
            in_att = in_att.permute(1, 0)
            
            alpha = self.fc_nhead_to_one(in_att)
            alpha = alpha.permute(1, 0)
            alpha = alpha.view(1,  untounched_trg.shape[0] * untounched_trg.shape[1] * untouched_src.shape[1])
            alpha = alpha.view(1,  untounched_trg.shape[0] * untounched_trg.shape[1],  untouched_src.shape[1])
            alpha = alpha.view(1,  untounched_trg.shape[0], untounched_trg.shape[1],  untouched_src.shape[1])
            alpha = alpha.permute(1, 0, 2, 3)
            alpha = alpha.view(untounched_trg.shape[0], untounched_trg.shape[1], untouched_src.shape[1])
            #alpha = attention.sum(dim=1) # bsz x out_seq_len x in_seq_len # attention
            
            out_seq_len = alpha.shape[1]
            in_seq_len = alpha.shape[2]
            # mask input tokens that does not correspond to output tokens to -inf
            mask = torch.zeros_like(alpha, requires_grad=False)
            mask[torch.where(untounched_trg == self.trg_pad_idx)] = float('-inf') #  bsz x out_seq_len x in_seq_len
            
            mask = mask.permute(0, 2, 1)
            mask[torch.where(untouched_src == self.src_pad_idx)] = float('-inf') #  bsz x in_seq_len x out_seq_len

            mask = mask.permute(0, 2, 1) #  bsz x out_seq_len x in_seq_len

            masked_alpha = alpha + mask
            
            #print(alpha.max())
            #print(alpha.min())
            concated = torch.cat((output, alpha), dim=2) # bsz x out_seq_len x (in_seq_len + len(output_types))

            #concated = torch.nn.functional.softmax(concated, dim=2) # normalize
            
            #concated = torch.nn.functional.softmax(concated, dim=2) # normalize

            normalized_input = concated[:,:,self.output_dim:] # bsz x out_seq_len x in_seq_len         probabilities for copy[]

            normalized_output = concated[:,:,:self.output_dim] # bsz x out_seq_len x len(output_types) 

            mapped_input = torch.zeros_like(normalized_output)

            ## replaced by scatter axis?
            ## replaced by scatter axis?
            # scatter_add
            # dim: the axis starts to index
            # indexes
            # values
            # xid: what is the input locations in the output vocabs
            #pred = (g * pred).scatter_add(2, xids, (1 - g) * dists)
            ## prepare x_id
            src_in_str = np.asarray(self.source_field.vocab.itos)[untouched_src.cpu().data.int().numpy()]
            src_to_trg_indices = [[self.target_field.vocab.stoi[e_word] for e_word in e_row] for e_row in src_in_str]
            src_to_trg_tensor = torch.Tensor(src_to_trg_indices).long().to(self.device)
            
            bsz = untouched_src.shape[0]
            
            xid = src_to_trg_tensor.view(bsz, 1, in_seq_len).repeat(1, out_seq_len, 1)
            ###
            mapped_input = mapped_input.scatter_add(2, xid, normalized_input)
            final_output = normalized_output + mapped_input
            output = final_output
        
        #
        return output, attention


class AttentionPointerDecoderV3(nn.Module):
    def __init__(self, 
                 output_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device,
                 copy=True,
                 source_field = SOURCE,
                 target_field = TARGET,
                 src_pad_idx=SRC_PAD_IDX, 
                 trg_pad_idx=TRG_PAD_IDX, 
                 max_length = 100):
        super().__init__()
        
        self.device = device
        
        self.tok_embedding = nn.Embedding.from_pretrained(target_field.vocab.vectors, freeze=False)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim, 
                                                  dropout, 
                                                  device)
                                     for _ in range(n_layers)])
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        self.copy = copy
        self.output_dim = output_dim
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.source_field = source_field
        self.target_field = target_field
        #self.fc_nhead_to_one = nn.Linear(n_heads, 1)
        self.fc_nhead_to_one = nn.Sequential(
          nn.Linear(n_heads, 64),
          nn.Tanh(),
          nn.Linear(64, 64),
          nn.Tanh(),
          nn.Linear(64, 1)
        )
        self.n_heads = n_heads
        
    def forward(self, trg, enc_src, trg_mask, src_mask, src):
        
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        untouched_src = src.clone()
        untounched_trg = trg.clone()
        
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
                            
        #pos = [batch size, trg len]
            
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
                
        #trg = [batch size, trg len, hid dim]
        
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        
        #trg = [batch size, trg len, hid dim]
        #attention = [batch size, n heads, trg len, src len]
        
        output = self.fc_out(trg)
        
        #output = [batch size, trg len, output dim]
        # attention copy kicks in
        if self.copy:
            # sum each head attention
            in_att = attention.permute(1, 0, 2, 3)
            in_att = in_att.view(self.n_heads, untounched_trg.shape[0], untounched_trg.shape[1] * untouched_src.shape[1])
            in_att = in_att.reshape(self.n_heads, untounched_trg.shape[0] * untounched_trg.shape[1] * untouched_src.shape[1])
            in_att = in_att.permute(1, 0)
            
            alpha = self.fc_nhead_to_one(in_att)
            alpha = alpha.permute(1, 0)
            alpha = alpha.view(1,  untounched_trg.shape[0] * untounched_trg.shape[1] * untouched_src.shape[1])
            alpha = alpha.view(1,  untounched_trg.shape[0] * untounched_trg.shape[1],  untouched_src.shape[1])
            alpha = alpha.view(1,  untounched_trg.shape[0], untounched_trg.shape[1],  untouched_src.shape[1])
            alpha = alpha.permute(1, 0, 2, 3)
            alpha = alpha.view(untounched_trg.shape[0], untounched_trg.shape[1], untouched_src.shape[1])
            #alpha = attention.sum(dim=1) # bsz x out_seq_len x in_seq_len # attention
            
            out_seq_len = alpha.shape[1]
            in_seq_len = alpha.shape[2]
            # mask input tokens that does not correspond to output tokens to -inf
            mask = torch.zeros_like(alpha, requires_grad=False)
            mask[torch.where(untounched_trg == self.trg_pad_idx)] = float('-inf') #  bsz x out_seq_len x in_seq_len
            
            mask = mask.permute(0, 2, 1)
            mask[torch.where(untouched_src == self.src_pad_idx)] = float('-inf') #  bsz x in_seq_len x out_seq_len

            mask = mask.permute(0, 2, 1) #  bsz x out_seq_len x in_seq_len

            masked_alpha = alpha + mask
            
            #print(alpha.max())
            #print(alpha.min())
            concated = torch.cat((output, alpha), dim=2) # bsz x out_seq_len x (in_seq_len + len(output_types))

            #concated = torch.nn.functional.softmax(concated, dim=2) # normalize
            
            #concated = torch.nn.functional.softmax(concated, dim=2) # normalize

            normalized_input = concated[:,:,self.output_dim:] # bsz x out_seq_len x in_seq_len         probabilities for copy[]

            normalized_output = concated[:,:,:self.output_dim] # bsz x out_seq_len x len(output_types) 

            mapped_input = torch.zeros_like(normalized_output)

            ## replaced by scatter axis?
            ## replaced by scatter axis?
            # scatter_add
            # dim: the axis starts to index
            # indexes
            # values
            # xid: what is the input locations in the output vocabs
            #pred = (g * pred).scatter_add(2, xids, (1 - g) * dists)
            ## prepare x_id
            src_in_str = np.asarray(self.source_field.vocab.itos)[untouched_src.cpu().data.int().numpy()]
            src_to_trg_indices = [[self.target_field.vocab.stoi[e_word] for e_word in e_row] for e_row in src_in_str]
            src_to_trg_tensor = torch.Tensor(src_to_trg_indices).long().to(self.device)
            
            bsz = untouched_src.shape[0]
            
            xid = src_to_trg_tensor.view(bsz, 1, in_seq_len).repeat(1, out_seq_len, 1)
            ###
            mapped_input = mapped_input.scatter_add(2, xid, normalized_input)
            final_output = normalized_output + mapped_input
            output = final_output
        
        #
        return output, attention

class AttentionPointerSeq2Seq(nn.Module):
    def __init__(self, 
                 encoder, 
                 decoder, 
                 src_pad_idx, 
                 trg_pad_idx, 
                 device,
                 copy = True,
                 output_dim = len(TARGET.vocab),
                 source_field = SOURCE,
                 target_field = TARGET
                ):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        self.copy = copy
        self.output_dim = output_dim
        self.source_field = source_field
        self.target_field = target_field
        
    def make_src_mask(self, src):
        
        #src = [batch size, src len]
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        #src_mask = [batch size, 1, 1, src len]

        return src_mask
    
    def make_trg_mask(self, trg):
        
        #trg = [batch size, trg len]
        
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        
        #trg_pad_mask = [batch size, 1, 1, trg len]
        
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
        
        #trg_sub_mask = [trg len, trg len]
            
        trg_mask = trg_pad_mask & trg_sub_mask
        
        #trg_mask = [batch size, 1, trg len, trg len]
        
        return trg_mask

    def forward(self, src, trg):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len]
                
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        #src_mask = [batch size, 1, 1, src len]
        #trg_mask = [batch size, 1, trg len, trg len]
        
        enc_src = self.encoder(src, src_mask)
        
        #enc_src = [batch size, src len, hid dim]
                
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask, src)

        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return output, attention
    


class PretrainedReconstructorEncoder(nn.Module):
    def __init__(self, 
                 input_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim,
                 dropout, 
                 device,
                 max_length = 100,
                 src_field=SOURCE,
                 trg_field=TARGET):
        super().__init__()

        self.device = device
        
        #self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        #self.tok_embedding = nn.Embedding.from_pretrained(src_field.vocab.vectors, freeze=False)
        self.tok_embedding = nn.Linear(input_dim, hid_dim)
        
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim,
                                                  dropout, 
                                                  device) 
                                     for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        self.hid_dim = hid_dim
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len, src_vocab_dim]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = src.shape[0]
        src_len = src.shape[1]
        
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        
        #pos = [batch size, src len]
        src = self.tok_embedding(src.view(batch_size * src_len, src.shape[2]))
        src =  src.view(batch_size, src_len, self.hid_dim)
        src = self.dropout((src * self.scale) + self.pos_embedding(pos))
        
        #src = [batch size, src len, hid dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        #src = [batch size, src len, hid dim]
            
        return src
    
class AttentionPointerReconstructorSeq2Seq(nn.Module):
    def __init__(self, 
                 encoder, 
                 decoder, 
                 src_pad_idx, 
                 trg_pad_idx, 
                 device,
                 copy = True,
                 output_dim = len(TARGET.vocab),
                 source_field = SOURCE,
                 target_field = TARGET
                ):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        self.copy = copy
        self.output_dim = output_dim
        self.source_field = source_field
        self.target_field = target_field
        
    def make_src_mask(self, src):
        
        #src = [batch size, src len]
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        #src_mask = [batch size, 1, 1, src len]

        return src_mask
    
    def make_trg_mask(self, trg):
        
        #trg = [batch size, trg len]
        
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        
        #trg_pad_mask = [batch size, 1, 1, trg len]
        
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
        
        #trg_sub_mask = [trg len, trg len]
            
        trg_mask = trg_pad_mask & trg_sub_mask
        
        #trg_mask = [batch size, 1, trg len, trg len]
        
        return trg_mask

    def forward(self, src, src_tensor, trg):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len]
                
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        #src_mask = [batch size, 1, 1, src len]
        #trg_mask = [batch size, 1, trg len, trg len]
        
        enc_src = self.encoder(src_tensor, src_mask)
        
        #enc_src = [batch size, src len, hid dim]
                
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask, src)

        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return output, attention

In [None]:
## SQL2TextV3
device = "cuda"

INPUT_DIM = len(SOURCE.vocab)
OUTPUT_DIM = len(TARGET.vocab)
#HID_DIM = 256
HID_DIM = SOURCE.vocab.vectors.shape[1]
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 10
DEC_HEADS = 10
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1
SRC_PAD_IDX = SOURCE.vocab.stoi[SOURCE.pad_token]
TRG_PAD_IDX = TARGET.vocab.stoi[TARGET.pad_token]

enc = PretrainedEncoder(INPUT_DIM, 
                          HID_DIM, 
                          ENC_LAYERS, 
                          ENC_HEADS, 
                          ENC_PF_DIM, 
                          ENC_DROPOUT, 
                          device,
                          max_length = 100,
                          src_field=SOURCE,
                          trg_field=TARGET)

dec = AttentionPointerDecoderV3(OUTPUT_DIM, 
                              HID_DIM, 
                              DEC_LAYERS, 
                              DEC_HEADS, 
                              DEC_PF_DIM, 
                              DEC_DROPOUT, 
                              device,
                              copy=True,
                              source_field = SOURCE,
                              target_field = TARGET,
                              src_pad_idx=SRC_PAD_IDX, 
                              trg_pad_idx=TRG_PAD_IDX, 
                              max_length = 100)


model2 = AttentionPointerSeq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device, 
                                copy = True, output_dim=OUTPUT_DIM,
                               source_field = SOURCE,
                               target_field = TARGET).to(device)
## text2SQLv3
device = "cuda"

INPUT_DIM = len(TARGET.vocab)
OUTPUT_DIM = len(SOURCE.vocab)
#HID_DIM = 256
HID_DIM = SOURCE.vocab.vectors.shape[1]
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 10
DEC_HEADS = 10
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1
SRC_PAD_IDX = SOURCE.vocab.stoi[SOURCE.pad_token]
TRG_PAD_IDX = TARGET.vocab.stoi[TARGET.pad_token]

enc = PretrainedReconstructorEncoder(INPUT_DIM, 
                          HID_DIM, 
                          ENC_LAYERS, 
                          ENC_HEADS, 
                          ENC_PF_DIM, 
                          ENC_DROPOUT, 
                          device,
                          max_length = 100,
                          src_field=TARGET,
                          trg_field=SOURCE)

dec = AttentionPointerDecoderV3(OUTPUT_DIM, 
                              HID_DIM, 
                              DEC_LAYERS, 
                              DEC_HEADS, 
                              DEC_PF_DIM, 
                              DEC_DROPOUT, 
                              device,
                              copy=True,
                              source_field = TARGET,
                              target_field = SOURCE,
                              src_pad_idx=TRG_PAD_IDX, 
                              trg_pad_idx=SRC_PAD_IDX, 
                              max_length = 100)


reconstructor_model2 = AttentionPointerReconstructorSeq2Seq(enc, dec, TRG_PAD_IDX, SRC_PAD_IDX, device, 
                                copy = True, output_dim=OUTPUT_DIM,
                               source_field = TARGET,
                               target_field = SOURCE).to(device)

## Training and evaluate

In [None]:
def cyclic_trainv2(model, reconstructor, iterator, optimizer, criterion, clip):
    
    model.train()
    reconstructor.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        src = batch.SQL
        trg = batch.text
        
        optimizer.zero_grad()
        
        output_text, _ = model(src, trg[:,:-1])
        
        text_pred = output_text.argmax(dim=2)
        
        output_text_prob = torch.nn.functional.softmax(output_text, dim=2)
        output_sql, _ = reconstructor(text_pred, output_text_prob, src[:, :-1])
        
        #output = [batch size, trg len - 1, output dim]
        #trg = [batch size, trg len]
            
        output_text_dim = output_text.shape[-1]
        
        output_sql_dim = output_sql.shape[-1]
        
        output_text = output_text.contiguous().view(-1, output_text_dim)
        
        output_sql = output_sql.contiguous().view(-1, output_sql_dim)
        
        text_trg = trg[:,1:].contiguous().view(-1)
        sql_trg = src[:,1:].contiguous().view(-1)        
        #output = [batch size * trg len - 1, output dim]
        #trg = [batch size * trg len - 1]
            
        loss = criterion(output_text, text_trg) + criterion(output_sql, sql_trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += 0.5* loss.item()
        
    return epoch_loss / len(iterator)

def cyclic_evaluatev2(model, reconstructor, iterator, criterion):
    
    model.eval()
    
    epoch_sql2text_loss = 0
    epoch_text_2sql_loss = 0
    epoch_total_loss = 0

    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch.SQL
            trg = batch.text

            output_text, _ = model(src, trg[:,:-1])
            
            #one_hot trg
            one_hot_trg = torch.nn.functional.one_hot(trg.view(trg.shape[0] * trg.shape[1]),  num_classes=len(TARGET.vocab))
            one_hot_trg = one_hot_trg.view(trg.shape[0], trg.shape[1], len(TARGET.vocab)).float()
            output_sql, _ = reconstructor(trg, one_hot_trg, src[:, :-1])

            #output = [batch size, trg len - 1, output dim]
            #trg = [batch size, trg len]

            output_text_dim = output_text.shape[-1]

            output_sql_dim = output_sql.shape[-1]

            output_text = output_text.contiguous().view(-1, output_text_dim)

            output_sql = output_sql.contiguous().view(-1, output_sql_dim)

            text_trg = trg[:,1:].contiguous().view(-1)
            sql_trg = src[:,1:].contiguous().view(-1)        
            #output = [batch size * trg len - 1, output dim]
            #trg = [batch size * trg len - 1]

            sql2text_loss = criterion(output_text, text_trg)
            text_2sql_loss = criterion(output_sql, sql_trg)
            total_loss = criterion(output_text, text_trg) + criterion(output_sql, sql_trg)
            
            
            epoch_sql2text_loss += sql2text_loss.item()
            epoch_text_2sql_loss += text_2sql_loss.item()
            epoch_total_loss += 0.5*  total_loss.item()

    return epoch_sql2text_loss / len(iterator), epoch_text_2sql_loss / len(iterator), epoch_total_loss / len(iterator)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
LEARNING_RATE = 0.00035

optimizer = torch.optim.Adam(list(model2.parameters()) + list(reconstructor_model2.parameters()), lr = LEARNING_RATE, weight_decay=1e-5)

criterion = nn.CrossEntropyLoss(ignore_index = 1)


In [None]:
## cyclicV2
N_EPOCHS = 20
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = cyclic_trainv2(model2, reconstructor_model2, train_it, optimizer, criterion, CLIP)
    valid_sql2text_loss,  valid_text_2sql_loss, valid_total_loss = cyclic_evaluatev2(model2, reconstructor_model2, valid_it, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_sql2text_loss < best_valid_loss:
        best_valid_loss = valid_sql2text_loss
        torch.save(model2.state_dict(), '/datacommons/carin/fk43/CS590/models/freq25_cyclicv2_glove_trained_fc_mapped_pointer_transformer-model5.pt')
        torch.save(reconstructor_model2.state_dict(), '/datacommons/carin/fk43/CS590/models/cyclicv2_trained_simple_reconstructor-model4.pt')

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Validation SQL2Textt. Loss: {valid_sql2text_loss:.3f} |  Validation SQL2Text PPL: {math.exp(valid_sql2text_loss):7.3f}')
    print(f'\t Validation Text2SQL. Loss: {valid_text_2sql_loss:.3f} |  Validation Text2SQL PPL: {math.exp(valid_text_2sql_loss):7.3f}')
    print(f'\t Validation Total. Loss: {valid_total_loss:.3f} |  Validation Total PPL: {math.exp(valid_total_loss):7.3f}')

In [None]:
model = model2
reconstructor_model = reconstructor_model2

In [None]:
import os 
model_dir = ''

model.load_state_dict(torch.load(os.path.join(model_dir, 'freq25_cyclicv2_glove_trained_fc_mapped_pointer_transformer-model5.pt')))


In [None]:
## cyclic trained

test_loss = evaluate(model, test_it, criterion)
print(f'\t Test. Loss: {test_loss:.3f} |  Test. PPL: {math.exp(test_loss):7.3f}')

## Sequence-to-sequence Predict on Test Set

In [None]:
def copy_decode(src_tensor, src_field, trg_field, model, device, max_len=50):
    assert isinstance(src_tensor, torch.Tensor)

    model.eval()
    src_mask = model.make_src_mask(src_tensor)

    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)
    # enc_src = [batch_sz, src_len, hid_dim]

    trg_indexes = [[trg_field.vocab.stoi[trg_field.init_token]] for _ in range(len(src_tensor))]
    # Even though some examples might have been completed by producing a <eos> token
    # we still need to feed them through the model because other are not yet finished
    # and all examples act as a batch. Once every single sentence prediction encounters
    # <eos> token, then we can stop predicting.
    translations_done = [0] * len(src_tensor)
    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask, src_tensor)
        pred_tokens = output.argmax(2)[:,-1]
        for i, pred_token_i in enumerate(pred_tokens):
            trg_indexes[i].append(pred_token_i)
            if pred_token_i == trg_field.vocab.stoi[trg_field.eos_token]:
                translations_done[i] = 1
        if all(translations_done):
            break

    # Iterate through each predicted example one by one;
    # Cut-off the portion including the after the <eos> token
    pred_sentences = []
    for trg_sentence in trg_indexes:
        pred_sentence = []
        for i in range(1, len(trg_sentence)):
            if trg_sentence[i] == trg_field.vocab.stoi[trg_field.eos_token]:
                break
            pred_sentence.append(trg_field.vocab.itos[trg_sentence[i]])
        pred_sentences.append(pred_sentence)

    return pred_sentences, attention

def SQL_decoder(indices):
    sql = ' '.join(np.asarray(SOURCE.vocab.itos)[indices])
    return sql.replace(' <pad>', '').replace('<sos>', '').replace('<eos>', '')
def Text_decoder(indices):
    text = ' '.join(np.asarray(TARGET.vocab.itos)[indices])
    return text.replace(' <pad>', '')


In [None]:
def post_copy_processing_using_decoder(src, pred, attention, model):
    bsz = attention.shape[0]
    n_heads = attention.shape[1]
    out_seq_len = attention.shape[2]
    in_seq_len = attention.shape[3]
    
    in_att = attention.permute(1, 0, 2, 3)
    in_att = in_att.view(n_heads, bsz, out_seq_len * in_seq_len)
    in_att = in_att.reshape(n_heads, bsz * out_seq_len * in_seq_len)
    in_att = in_att.permute(1, 0)

    alpha = model.decoder.fc_nhead_to_one(in_att)
    alpha = alpha.permute(1, 0)
    alpha = alpha.view(1,  bsz * out_seq_len * in_seq_len)
    alpha = alpha.view(1,  bsz * out_seq_len,  in_seq_len)
    alpha = alpha.view(1,  bsz, out_seq_len,  in_seq_len)
    alpha = alpha.permute(1, 0, 2, 3)
    alpha = alpha.view(bsz, 1, out_seq_len, in_seq_len) 
    
    attention = alpha
    
    unk_locs = np.where(np.asarray(pred) == '<unk>')[0]
    refined_sentence = copy.deepcopy(pred)
    
    
    exclude_idx = np.isin(np.asarray(src), np.asarray(all_sql_syms+pred_text))
    excluded_src = np.asarray(src)[~exclude_idx]
    excluded_attention = attention[:, :, :, ~exclude_idx]
    
    for e_unk_idx in unk_locs:
        this_unk_attention = excluded_attention[0, :, e_unk_idx, :]
        best_matched_inp_idx = this_unk_attention.sum(dim=0).argmax().cpu().data.numpy()
        best_matched_inp = excluded_src[best_matched_inp_idx]
        refined_sentence[e_unk_idx] = best_matched_inp
        # set already matched to -inf
        excluded_attention[:, :, :, best_matched_inp_idx] = 0
    return ' '.join(refined_sentence)

In [None]:
import copy

myfile = open('Cyclic_Transformer_sql2text_results_on_test_set.txt', 'w')

show_num_of_sample = 200
src_field = SOURCE
trg_field = TARGET
reconstructor_model.eval()  
with torch.no_grad():
    for i, sample in tqdm.tqdm(enumerate(test_ds)):

        src = sample.SQL
        trg = sample.text
        src = ['<sos>'] + src + ['<eos>']
        src_tensor = torch.Tensor([SOURCE.vocab.stoi[e_word] for e_word in src]).to(device).view(1, -1).long()

        pred_text, attention = copy_decode(src_tensor, src_field, trg_field, model, device, max_len=100)
        pred_text = pred_text[0] + ['<eos>']
        #print(f'predicted trg = {pred_text}')
        ref_text = ' '.join(np.asarray(TARGET.vocab.itos)[[TARGET.vocab.stoi[e_word] for e_word in trg]])
        #print(f'Referenced trg = {ref_text}')
        post_pred_text = post_copy_processing_using_decoder(src, pred_text, attention, model)   
        
        sql = ' '.join(src)
        pred_text = ' '.join(pred_text)
        ref_text = ' '.join(trg)
#         print('-'*10 + 'Example {}'.format(i+1) + '-'*10)
#         print("Original SQL: " + sql.replace('<sos>', '').replace('<eos>', ''))
#         print("Original Text: " + ref_text)
#         #print("Predicted Text: " + pred_text.replace('<sos>', '').replace('<eos>', ''))
#         print("Predicted Pred:" + post_pred_text.replace('<sos>', '').replace('<eos>', ''))
#         print('   ')
        
        line_1 = '-'*10 + 'Example {}'.format(i+1) + '-'*10
        line_2 = "Original SQL: " + sql.replace('<sos>', '').replace('<eos>', '')
        line_3 = "Original Text: " + ref_text
        line_4 = "Predicted Pred:" + post_pred_text.replace('<sos>', '').replace('<eos>', '')
        myfile.write("%s\n" % line_1)
        myfile.write("%s\n" % line_2)
        myfile.write("%s\n" % line_3)
        myfile.write("%s\n" % line_4)
        myfile.write("%s\n" % '    ')
#         if (i+1) == show_num_of_sample:
#             break
myfile.close()
