In [None]:
import argparse
import torch
from solver import Solver
from utils import *
from models import *
from dataset import *
from copy import deepcopy as cc

In [None]:
args = argparse.Namespace()
args.no_cuda = True
args.model_dir ='train_model'
args.seq_length = 50
args.batch_size = 3
args.num_step = 10
args.data_dir ='data_dir'
args.load = False
args.train= True
args.test = False
args.valid_path ='data/valid.txt'
args.train_path ='data/train.txt'
args.test_path ='data/test.txt'

In [None]:
solver = Solver(args)

In [None]:
data_yielder = solver.data_utils.train_data_yielder()
batch = data_yielder.__next__()
print(batch['input'].size(), batch['input_mask'].size())

In [None]:
op = solver.model(inputs = batch['input'], mask = batch['input_mask'])

- --
- --
- --

In [None]:
class Config():
    "Configuration for BERT model"
    vocab_size: int = 30522 # Size of Vocabulary
    dim: int = 768 # Dimension of Hidden Layer in Transformer Encoder
    n_layers: int = 12 # Numher of Hidden Layers
    n_heads: int = 12 # Numher of Heads in Multi-Headed Attention Layers
    dim_ff: int = 768*4 # Dimension of Intermediate Layers in Positionwise Feedforward Net
    activ_fn: str = "gelu" # Non-linear Activation Function Type in Hidden Layers
    p_drop_hidden: float = 0.1 # Probability of Dropout of various Hidden Layers
    p_drop_attn: float = 0.1 # Probability of Dropout of Attention Layers
    max_len: int = 512 # Maximum Length for Positional Embeddings
    n_segments: int = 2 # Number of Sentence Segments
    layer_norm_eps: int = 1e-12 # eps value for the LayerNorms
    output_attentions : bool = False # Weather to output the attention scores

config = Config()

In [None]:
class BertEmbeddings(nn.Module):
    "The embedding module from word, position and token_type embeddings."
    def __init__(self, cfg):
        super().__init__()
        self.word_embeddings = nn.Embedding(cfg.vocab_size, cfg.dim, padding_idx=0) # token embedding
        self.position_embeddings = nn.Embedding(cfg.max_len, cfg.dim) # position embedding
        self.token_type_embeddings = nn.Embedding(cfg.n_segments, cfg.dim) # segment(token type) embedding

        self.LayerNorm = nn.LayerNorm(cfg.dim, eps=cfg.layer_norm_eps)
        self.dropout = nn.Dropout(cfg.p_drop_hidden)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
        pos = pos.unsqueeze(0).expand_as(x) # (S,) -> (B, S)

        e = self.word_embeddings(x) + self.position_embeddings(pos) + self.token_type_embeddings(seg)
        e = self.LayerNorm(e)
        e = self.dropout(e)
        return e

In [None]:
# X -> self_attn -> X [calculated by self attention]
class BertSelfAttention(nn.Module):
    """ Multi-Headed Dot Product Attention """
    def __init__(self, cfg):
        super().__init__()
        self.query = nn.Linear(cfg.dim, cfg.dim)
        self.key = nn.Linear(cfg.dim, cfg.dim)
        self.value = nn.Linear(cfg.dim, cfg.dim)
        self.dropout = nn.Dropout(cfg.p_drop_attn)
        self.n_heads = cfg.n_heads

    def forward(self, x, attention_mask = None, output_attentions = False):

        B, S, D = x.shape
        H = self.n_heads
        W = int( D/H )
        assert W * H == D

        q, k, v = self.query(x), self.key(x), self.value(x)
        q, k, v = q.reshape((B, S, H, W)), k.reshape((B, S, H, W)), v.reshape((B, S, H, W))
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        attn_scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        
        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(attention_mask[:, None, None, :] == 0, -1e9)
        attn_scores = self.dropout(F.softmax(attn_scores, dim=-1))

        hidden_states = (attn_scores @ v).transpose(1, 2).contiguous()
        hidden_states = hidden_states.reshape(B, S, D)
        return (hidden_states, attn_scores) if output_attentions else (hidden_states,)

- $ a = \begin{matrix}  0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 0 \\ \end{matrix} $ &#8594
$ b = \begin{matrix}  1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ \end{matrix} $ &#8594 
$ c = \begin{matrix}  0 & 0 & 0 & 0 \\ 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ \end{matrix} $ &#8594
$ tri\_matrix = \begin{matrix}  1 & 1 & 1 & 1 \\ 0 & 1 & 1 & 1 \\ 0 & 0 & 1 & 1 \\ 0 & 0 & 0 & 1 \\ \end{matrix} $
$ (a + c) = \begin{matrix} 0 & 1 & 0 & 0 \\ 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 0 & 0 & 1 & 0 \\ \end{matrix} $

- $ SCORE =  \begin{matrix}  
\langle Q_0, K_0 \rangle & \dots & \langle Q_0, K_3 \rangle\\ \vdots & \ddots & \\ \langle Q_3, K_0 \rangle &  & \langle Q_3, K_3 \rangle \\ 
\end{matrix}
= \begin{matrix} S_{00} & \dots & S_{03} \\ \vdots & \ddots & \\S_{30} & & S_{33} \\ \end{matrix} $

- $ SCORE  \xrightarrow[- \infty]{MASK(a+c)} 
\begin{matrix} 
-\infty & S_{01} & -\infty & -\infty \\ S_{10} & -\infty & S_{12} & -\infty \\ -\infty & S_{21} & -\infty & S_{23} \\ -\infty & -\infty & S_{32} & -\infty \\ \end{matrix} 
\xrightarrow[]{SOFTMAAX(dim = 0)} 
\begin{matrix} 
0 & P_{01} & 0 & 0 \\ P_{10} & 0 & P_{12} & 0 \\ 0 & P_{21} & 0 & P_{23} \\ 0 & 0 & P_{32} & 0 \\ \end{matrix}  
\xrightarrow[]{Score^T\dot (Score+\epsilon)} 
\begin{matrix} 
0 & \sqrt{\langle P_{01}, P_{01} +\epsilon \rangle} & 0 & 0 \\ 
\sqrt{\langle P_{10},P_{01} +\epsilon \rangle} & 0 & \sqrt{\langle P_{12},P_{21} +\epsilon \rangle} & 0 \\ 
0 & \sqrt{\langle P_{21},P_{12} +\epsilon \rangle }& 0 & \sqrt{\langle P_{23},P_{32} +\epsilon \rangle} \\ 
0 & 0 & \sqrt{\langle P_{32},P_{23} +\epsilon \rangle} & 0 \\ \end{matrix}$

- $\hat{A^l} \sim
\begin{matrix} 
0 & \sqrt{\langle P_{01}, P_{01}\rangle} & 0 & 0 \\ 
\sqrt{\langle P_{10}P_{01} \rangle} & 0 & \sqrt{\langle P_{12}P_{21} \rangle} & 0 \\ 
0 & \sqrt{\langle P_{21}P_{12} \rangle }& 0 & \sqrt{\langle P_{23}P_{32} \rangle} \\ 
0 & 0 & \sqrt{\langle P_{32}P_{23} \rangle}& 0 \\ \end{matrix}
\equiv
\begin{matrix} 
0 & \hat{a_0} & 0 & 0  \\ 
\hat{a_0} & 0 & \hat{a_1} & 0 \\ 
0 & \hat{a_1} & 0 & \hat{a_2} \\ 
0 & 0 & \hat{a_2} & 0 \\ \end{matrix}$

- $A^l = A^{l-1} + (1 - A^{l-1})\hat{A^l} = 
\begin{matrix} 0 & a_0 & 0 & 0  \\ a_0 & 0 & a_1 & 0 \\ 0 & a_1 & 0 & a_2 \\ 0 & 0 & a_2 & 0 \\ \end{matrix}
\xrightarrow[0]{MASK(a)} 
\begin{matrix} 0 & log(a_0 + \epsilon) & 0 & 0  \\ 0 & 0 & log(a_1 + \epsilon) & 0 \\ 0 & 0 & 0 & log(a_2 + \epsilon) \\ 0 & 0 & 0 & 0 \\ \end{matrix}
\xrightarrow[]{C.tri\_matrix}
\begin{matrix} 
0 & log(a_0 + \epsilon) & log(a_0+ \epsilon) & log(a_0+ \epsilon)  \\ 
0 & 0 & log(a_1+ \epsilon) & log(a_1+ \epsilon) \\ 
0 & 0 & 0 & log(a_2 + \epsilon) \\ 
0 & 0 & 0 & 0 \\ \end{matrix}
\xrightarrow[tri\_matrix^T.C]{log()} \sim
\begin{matrix} 
0 &  log(a_0 + \epsilon) & log(a_0)+log(a_1) & log(a_0)+log(a_1)+log(a_2)  \\ 
0 & 0 & log(a_1) & log(a_1)+log(a_2) \\ 
0 & 0 & 0 & log(a_2) \\ 
0 & 0 & 0 & 0 \\ 
\end{matrix}$

- $ \equiv 
\begin{matrix} 0 & log(a_0) & log(a_0a_1) & log(a_0a_1a_2)  \\ 
0 & 0 & log(a_1) & log(a_1a_2) \\ 
0 & 0 & 0 & log(a_2) \\ 
0 & 0 & 0 & 0 \\ 
\end{matrix}
\xrightarrow[]{exp()}
\begin{matrix} 0 & C_{01} & C_{02} & C_{02}  \\ 0 & 0 & C_{12} & C_{13} \\ 0 & 0 & 0 & C_{23} \\ 0 & 0 & 0 & 0 \\ \end{matrix}
\xrightarrow[]{ + C^T}
\begin{matrix} 0 & C_{01} & C_{02} & C_{02}  \\ 
C_{01} & 0 & C_{12} & C_{13} \\ 
C_{02} & C_{12} & 0 & C_{23} \\ 
C_{03} & C_{13} & C_{23} & 0 \\ 
\end{matrix}$

In [None]:
class GroupAttention(nn.Module):
    def __init__(self, cfg):
        super(GroupAttention, self).__init__()
        self.d_model = cfg.dim
        self.linear_key = nn.Linear(cfg.dim, cfg.dim)
        self.linear_query = nn.Linear(cfg.dim, cfg.dim)
        self.LayerNorm = nn.LayerNorm(cfg.dim, eps= cfg.layer_norm_eps)
        self.dropout = nn.Dropout(cfg.p_drop_attn)

    def forward(self, hidden_states, attention_mask, prior_A):
        B, S = hidden_states.size()[:2]

        context = self.LayerNorm(hidden_states)

        a = torch.from_numpy(np.diag(np.ones(S - 1, dtype=np.int32),1))
        b = torch.from_numpy(np.diag(np.ones(S, dtype=np.int32),0))
        c = torch.from_numpy(np.diag(np.ones(S - 1, dtype=np.int32),-1))
        tri_matrix = torch.from_numpy(np.triu(np.ones([S, S], dtype=np.float32),0))

        mask = attention_mask[:, None, :] & (a+c)[None, :, :]
        
        key = self.linear_key(context)
        query = self.linear_query(context)
        
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.d_model / 2)        
        scores = scores.masked_fill(mask == 0, -1e9)
        A = F.softmax(scores, dim=-1)
        A = torch.sqrt(A * A.transpose(-2,-1) + 1e-9)
        A = prior_A + (1. - prior_A)*A

        t = torch.log(A + 1e-9).masked_fill(a==0, 0).matmul(tri_matrix)
        C_prior = tri_matrix.matmul(t).exp().masked_fill((tri_matrix.int()-b)==0, 0)    

        C_prior = C_prior + C_prior.transpose(-2, -1) + torch.from_numpy(np.diag(np.ones(S)))
        
        return C_prior, A

In [None]:
dataset = tree2bert_dataset(data_path = './../Data/raw/seq2seq/train_short_prefix.txt.val',
                            max_seq_len = 256,
                            max_ev_len = 20)

In [None]:
fast_tokenizer = BertWordPieceTokenizer(vocab='./vocab/vocab.txt')
fast_tokenizer.add_special_tokens(['[et_sep]', '[ea_sep]', '[ds_sep]'])
fast_tokenizer.enable_truncation(max_length = 256)
fast_tokenizer.enable_padding(length=256)

emb_layer = BertEmbeddings(config)
att_layer = BertSelfAttention(config)
grp_att_layer = GroupAttention(config)

In [None]:
ip_text_tok = dataset[:2]['text_tok_src'].long()
ip_text_mask = dataset[:2]['text_mask_src'].long()
ip_event_loc = dataset[:2]['event_loc_src'].long()
ip_event_mask = dataset[:2]['event_mask_src'].long()

In [329]:
ip_text_A = torch.ones_like(ip_text_tok) # B x S

# Create diagonal for the A matrix
# all the [et_sep] and [PAD]s will have value of ~0 and rest will have value of ~1

# Create the vector for the diagonal elements of the A matrix 
# # all the [et_sep] and [PAD]s will have value of ~0 and rest will have value of ~1
#  B x S
A_initial = cc(ip_text_mask)*0.999 

# Set all the locations with [et_sep] tokens as 0.0
A_initial = A_initial.scatter(1, ip_event_loc.long(), torch.zeros_like(ip_event_loc).float())

# exclude the last token as this vector will be a diagonal with offset 1
A_initial = A_initial[:, :-1]

# Create a diagonal matrix with this A as the diagonal at offset 1
A_initial= [torch.diag(A_initial[i, :].float(), 1)[None, :, :] for i in range(A_initial.shape[0])]
A_initial = torch.cat(A_initial)

# Add it's transpose, as the A matrix is supposed to be symmetric
A_initial += A_initial.transpose(-1, -2).contiguous()

In [330]:
A_initial[0, :6, :6]

tensor([[0.0000, 0.9990, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9990, 0.0000, 0.9990, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.9990, 0.0000, 0.9990, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.9990, 0.0000, 0.9990, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.9990, 0.0000, 0.9990],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.9990, 0.0000]])

In [355]:
hidden = emb_layer(x = ip_text_tok.long(), seg = torch.zeros_like(ip_text_tok).long())
hidden_event_subset = hidden[np.arange(hidden.shape[0])[:, None], ip_event_loc.long()]

In [358]:
d_model = config.dim
linear_key = nn.Linear(config.dim, config.dim)
linear_query = nn.Linear(config.dim, config.dim)
LayerNorm = nn.LayerNorm(config.dim, eps= config.layer_norm_eps)
dropout = nn.Dropout(config.p_drop_attn)

B, S = hidden_event_subset.size()[:2]

context = LayerNorm(hidden_event_subset)

a = torch.from_numpy(np.diag(np.ones(S - 1, dtype=np.int32),1))
b = torch.from_numpy(np.diag(np.ones(S, dtype=np.int32),0))
c = torch.from_numpy(np.diag(np.ones(S - 1, dtype=np.int32),-1))
tri_matrix = torch.from_numpy(np.triu(np.ones([S, S], dtype=np.float32),0))

a_full = torch.from_numpy(np.diag(np.ones(256 - 1, dtype=np.int32),1))
b_full = torch.from_numpy(np.diag(np.ones(256, dtype=np.int32),0))
c_full = torch.from_numpy(np.diag(np.ones(256 - 1, dtype=np.int32),-1))
tri_matrix_full = torch.from_numpy(np.triu(np.ones([256, 256], dtype=np.float32),0))

mask = ip_event_mask[:, None, :] & (a+c)[None, :, :]

key = linear_key(context)
query = linear_query(context)

scores = torch.matmul(query, key.transpose(-2, -1)) / (d_model / 2)        
scores = scores.masked_fill(mask == 0, -1e10)

A = F.softmax(scores, dim=-1)
A = torch.sqrt(A * A.transpose(-2,-1) + 1e-10)
A = torch.cat([torch.diagonal(A[batch], 1)[None, :] for batch in range(B)])

A_new = cc(ip_text_mask).float()
A_new = A_new.scatter(1, ip_event_loc[:, :-1].long(), A.float())

A_new = A_new[:, :-1]
A_new= [torch.diag(A_new[b, :], 1)[None, :, :] for b in range(A_new.shape[0])]
A_new = torch.cat(A_new)
A_new += A_new.transpose(-1, -2).contiguous()
A_new = A_initial + (1. - A_initial)*A_new

t = torch.log(A_new + 1e-10).masked_fill(a_full==0, 0).matmul(tri_matrix_full)
C_prior = tri_matrix_full.matmul(t).exp().masked_fill((tri_matrix_full.int()-b_full)==0, 0)    
C_prior = C_prior + C_prior.transpose(-2, -1) + torch.from_numpy(np.diag(np.ones(256)))



In [359]:
C_prior[0, 0, :12]

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 0.7112], dtype=torch.float64, grad_fn=<SliceBackward>)

In [336]:
ip_event_loc

tensor([[ 10,  30,  39,  59,  68,  94, 104, 106, 107, 108, 109, 110, 111, 112,
         113, 114, 115, 116, 117, 118],
        [ 13,  31,  44,  59,  70,  85,  87,  88,  89,  90,  91,  92,  93,  94,
          95,  96,  97,  98,  99, 100]])

In [353]:
fast_tokenizer.decode(ip_text_tok.numpy()[0, :11], skip_special_tokens = False)

'[CLS] some devices [ea_sep] can keep more efficiently [ea_sep] boots [et_sep]'

In [351]:
fast_tokenizer.decode?

[0;31mSignature:[0m
[0mfast_tokenizer[0m[0;34m.[0m[0mdecode[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mids[0m[0;34m:[0m[0mList[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mskip_special_tokens[0m[0;34m:[0m[0mUnion[0m[0;34m[[0m[0mbool[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0mstr[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Decode the given list of ids to a string sequence

Args:
    ids: List[unsigned int]:
        A list of ids to be decoded

    skip_special_tokens: (`optional`) boolean:
        Whether to remove all the special tokens from the output string

Returns:
    The decoded string
[0;31mFile:[0m      ~/miniconda3/envs/Pytorch/lib/python3.6/site-packages/tokenizers/implementations/base_tokenizer.py
[0;31mType:[0m      method
