In [1]:
import os
import re
import pandas as pd 
import numpy as np 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.autograd import  Variable 
from torch.utils.data import DataLoader, Dataset
from AttentionTransformer.TrainClassificationTransformer import * 
from AttentionTransformer.ClassificationDataset import *
from AttentionTransformer.utilities import count_model_parameters
import pickle 
from tqdm import tqdm_notebook, tqdm, trange, tnrange 
import torch.optim as optim
from tokenizers import BertWordPieceTokenizer
import logging
import sys
sys.path.append('../scripts/')

from ClassificationDatasetFromDict import *

In [2]:
SEED = 3007
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

def load_pickle(filepath):
    with open(filepath, 'rb') as fp:
        return pickle.load(fp)


In [3]:
tokenizer = BertWordPieceTokenizer('../data/bert-word-piece-custom-wikitext-vocab-10k-vocab.txt', lowercase = True, strip_accents = True)


data = load_pickle('../data/tokenized_questions_classes_subclasses_dict.pkl')



In [4]:
from AttentionTransformer.Encoder import * 
from AttentionTransformer.Decoder import *

In [14]:
class OneEncoderTwoDecoderTransformer(nn.Module):

    def __init__(
        self, vocab_size, pad_id, CLS_label_id, emb_dim = 512, dim_model = 512, dim_inner = 2048,
        layers = 6, heads = 8, dim_key = 64, dim_value = 64, dropout = 0.1, num_pos = 200
    ):

        super(OneEncoderTwoDecoderTransformer, self).__init__()

        self.pad_id = pad_id 
        self.encoder = Encoder(
            vocab_size, emb_dim, layers, heads, dim_key, dim_value, dim_model, dim_inner, pad_id, dropout = dropout, num_pos = num_pos
        )

        self.decoder1 = Decoder(
            vocab_size, emb_dim, layers, heads, dim_key, dim_value, dim_model, dim_inner, pad_id, dropout = dropout, num_pos = num_pos
        )

        self.decoder2 = Decoder(
            vocab_size, emb_dim, layers, heads, dim_key, dim_value, dim_model, dim_inner, pad_id, dropout = dropout, num_pos = num_pos
        )

        self.decoder1heads = nn.Linear(dim_model, 6)

        self.decoder2heads = nn.Linear(dim_model, 47)

        for parameter in self.parameters():

            if parameter.dim() > 1:

                nn.init.xavier_uniform_(parameter)

        assert dim_model == emb_dim, f'Dimensions of all the module objects must be same'

        self.cls_label_id = CLS_label_id

    def get_pad_mask(self, sequence, pad_id):

        return (sequence != pad_id).unsqueeze(-2)

    def get_subsequent_mask(self, sequence):

        batch_size, seq_length = sequence.size() 

        subsequent_mask = (
            1 - torch.triu(
                torch.ones((1, seq_length, seq_length), device=sequence.device), diagonal = 1
            )
        ).bool()

        return subsequent_mask

    def make_target_seq(self, batch_size):

        trg_tnsr = torch.zeros((batch_size, 1))
        trg_tnsr[trg_tnsr == 0] = self.cls_label_id
        return trg_tnsr.float()

    def get_decoder2_target(self, labels):

        tnsr = labels.float()
        return tnsr.unsqueeze(1)

    def forward(self, source_seq, classlabels):

        b, l = source_seq.size()
        targetdec1 = self.make_target_seq(b).to(source_seq.device)
        source_mask = self.get_pad_mask(source_seq, self.pad_id)
        targetdec1_mask = self.get_pad_mask(targetdec1, self.pad_id) & self.get_subsequent_mask(targetdec1)

        targetdec2 = self.get_decoder2_target(classlabels)
        targetdec2_mask = self.get_pad_mask(targetdec2, self.pad_id) & self.get_subsequent_mask(targetdec2)
        

        encoder_output = self.encoder(source_seq, source_mask)
        decoder_output_1 = self.decoder1(
            targetdec1, targetdec1_mask, encoder_output, source_mask
        )
        decoder_output_2 = self.decoder2(
            targetdec2, targetdec2_mask, encoder_output, source_mask
        )

        decoder_output_1 = decoder_output_1.view(decoder_output_1.size(0), -1)
        decoder_output_2 = decoder_output_2.view(decoder_output_2.size(0), -1)

        classheads1 = self.decoder1heads(decoder_output_1)
        classheads2 = self.decoder2heads(decoder_output_2)

        return classheads1, classheads2


In [9]:
dataset = ClassificationDatasetDict(data, 100)
dataloader = DataLoader(dataset, batch_size = 4, shuffle = True, num_workers=3, pin_memory=True)


In [10]:
d = next(iter(dataloader))

In [15]:
model = OneEncoderTwoDecoderTransformer(
    vocab_size = tokenizer.get_vocab_size(), 
    pad_id = 2,
    CLS_label_id = 2
)

In [17]:
count_model_parameters(model) / 1e6

84.691509

In [18]:
# model(d['source_seq'], d['class'])

(tensor([[ 0.0669, -1.5990,  0.7057,  0.4327,  1.2237,  2.1680],
         [-0.3351, -0.5885,  2.0430,  3.5518, -0.6105,  0.4860],
         [-1.4577, -0.8539,  2.8529,  0.5232,  0.8898, -1.9092],
         [-1.1834, -1.2539,  2.2092,  1.5802, -1.4747,  2.8295]],
        grad_fn=<AddmmBackward>),
 tensor([[ 0.0231, -0.9397, -1.3733,  2.1185,  0.1620,  0.3008, -1.9727,  1.0801,
          -1.2780,  0.0189,  0.2631,  0.6690, -0.4242, -0.0549,  4.2453,  0.3284,
           1.2134, -1.4837,  0.9786, -2.7903, -1.1144,  0.6317, -2.3783, -0.3723,
           1.2815,  0.9917,  1.5092,  2.2412, -0.3041, -1.2079,  1.2051,  0.8503,
          -1.4927,  0.9340,  1.8753,  3.8063, -1.1820,  0.8290, -2.3491,  2.4052,
          -2.0367,  0.1661,  0.7350,  1.4218,  0.5079, -3.6047, -1.2469],
         [ 2.2929, -1.6826, -2.1070,  0.1964,  1.6595,  0.0487, -2.3043, -2.3063,
          -1.1869, -1.4780,  0.2604, -0.3642, -0.7189,  1.3293,  4.5082,  1.4669,
           0.9167, -1.5544,  1.1655, -0.7214, -0.6193,  1