In [95]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Transformer as T

from constants import *
from model import SinusoidalEmbedding

INS_ID = 0
CPY_ID = 1
PAD_ID = 2

class DependencyEvolver(nn.Module):
    def __init__(
        self,
        d_model=512, dim_feedforward=2048, nhead=8, dropout=0.1, N=64,
        encoder_layers=6, decoder_layers=6,
        tok_v=VOCAB_SIZE, rel_v=0, pos_v=0,
        pad_token_id=PAD_TOKEN_ID
    ):
        super().__init__()
        
        codec_params = {
            'd_model': d_model,
            'dim_feedforward': dim_feedforward,
            'nhead': nhead,
            'dropout': dropout,
            'batch_first': True
        }
       
        encoder_layer = nn.TransformerEncoderLayer(**codec_params)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=encoder_layers)
        decoder_layer = nn.TransformerDecoderLayer(**codec_params)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=decoder_layers)
       
        self.rel_offset = tok_v
        self.pos_offset = tok_v + rel_v
        self.embedding = nn.Embedding(tok_v + rel_v + pos_v, d_model)
        self.positional_embedding = SinusoidalEmbedding(d_model, max_len=N)
        
        self.done = nn.Parameter(torch.zeros(d_model))
        self.plh = nn.Parameter(torch.zeros(d_model))
        
        self.pad_token_id = pad_token_id
        self.causal_mask = T.generate_square_subsequent_mask(N)
        
        self.op_head = nn.Linear(d_model, 2)
        self.idx_head = nn.Linear(d_model, 512)
        self.tok_head = nn.Linear(d_model, tok_v)
        self.rel_head = nn.Linear(d_model, rel_v)
        self.pos_head = nn.Linear(d_model, pos_v)
       
    def t1(self, memory, tgt_op, tgt_cpy, is_leaf):
        B, _, D = memory.shape
        device = memory.device
       
        memory = self.positional_embedding(memory, d=-1)
        permuted_memory = memory[torch.arange(B, device=device).unsqueeze(1), tgt_cpy]
        ins = tgt_op.eq(INS_ID).unsqueeze(-1).expand(-1, -1, D)
        tgt = torch.where(ins, self.plh.expand_as(ins), permuted_memory)
        tgt[is_leaf] += self.done
        tgt = self.positional_embedding(tgt, d=1)
        
        return tgt
    
    def t2(self, memory, tgt_op, tgt_par):
        B, _, D = memory.shape
        device = memory.device
       
        ins = tgt_op.eq(INS_ID).unsqueeze(-1).expand(-1, -1, D)
        permuted_memory = memory[torch.arange(B, device=device).unsqueeze(1), tgt_par]
        tgt = memory + torch.where(ins, permuted_memory, 0)
        
        return tgt
    
    def t3(self, memory, tgt_pos):
        pass
        
    def forward(
        self,
        input_ids,
        tgt_op, tgt_cpy, tgt_par,
        tgt_pos, tgt_rel, tgt_tok,
        is_leaf,
        src=None
    ):
        src_pad_mask = input_ids.eq(self.pad_token_id)
        tgt_pad_mask = tgt_par.eq(0)
        
        src = self.embedding(input_ids) if src is None else src
           
        m1 = self.encoder(src, src_key_padding_mask=src_pad_mask)
        t1 = self.t1(m1, tgt_op, tgt_cpy, is_leaf)
        x1 = self.decoder(
            t1, m1,
            memory_key_padding_mask=src_pad_mask,
            tgt_key_padding_mask=tgt_pad_mask,
            tgt_mask=self.causal_mask
        )
       
        m2 = self.encoder(t1, src_key_padding_mask=tgt_pad_mask)
        t2 = self.t2(m2, tgt_op, tgt_par)
        x2 = self.decoder(
            t1, m2,
            memory_key_padding_mask=tgt_pad_mask,
            tgt_key_padding_mask=tgt_pad_mask,
            tgt_mask=self.causal_mask
        )
       


In [96]:
model = DependencyEvolver()

'''
1. BOS (tok:rel:pos) tok:rel:pos EOS
2. BOS CPY(1) INS INS CPY(2) INS EOS
'''

B, N, D = 1, 64, 512
input_ids = torch.randint(0, 10, (B, N))

tgt_op = torch.full((B, N), PAD_ID)
tgt_op[:, 1:6] = torch.tensor([CPY_ID, INS_ID, INS_ID, CPY_ID, INS_ID])

tgt_cpy = torch.full((B, N), 0)
tgt_cpy[:, 1] = 1
tgt_cpy[:, 4] = 2
tgt_cpy[:, 6] = 3

tgt_par = torch.full((B, N), 0)
tgt_par[:, 2] = 4
tgt_par[:, 3] = 4
tgt_par[:, 5] = 4

is_leaf = torch.full((B, N), False)
is_leaf[:, 4] = True

model.forward(input_ids, tgt_op, tgt_cpy, tgt_par, None, None, None, is_leaf)

