In [50]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
torch.random.manual_seed(13)
import sys
sys.path.append('..')
from ELMO import ELMo
import wandb
import torch
import json

In [51]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [52]:
# read the json file
with open('wiki-cloze/mr.json', encoding='utf-8') as f:
    data = json.load(f)

In [53]:
from indicnlp.tokenize import indic_tokenize
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from tqdm import tqdm
import re
from typing import List, Tuple, Optional
from torch.nn.utils.rnn import pad_sequence
torch.manual_seed(13)
from collections import Counter
from torchtext.vocab import build_vocab_from_iterator
import sys
sys.path.append('..')
from preprocessing import tokenize, convert_to_oov, CharLevelVocab, WordLevelVocab


OUT_OF_VOCAB = '<OOV>'
PAD_TAG = '<PAD>'
START_TAG = '<BOS>'
END_TAG = '<EOS>'



class QuestionAnsweringDataset(Dataset):
    def __init__(self, question, answers, correct_answer, word_vocab: WordLevelVocab, char_vocab: CharLevelVocab, max_seq_length=50, max_word_length=10):
        self.question = tokenize(question)
        self.answers = answers
        self.correct_answer = correct_answer
        self.word_vocab = word_vocab
        self.char_vocab = char_vocab
        self.max_seq_length = max_seq_length
        self.max_word_length = max_word_length

    def __len__(self):
        return len(self.question)
    
    def __getitem__(self, idx):
        # i have 4 answers. I need to return 4 sentences with the <MASK> token replaced with the 4 different answers
        question = self.question[idx]
        answers = self.answers[idx]
        # replace mask with <> token
        question = [word if word != 'MASK' else '<>' for word in question]


        onehot = []
        for i in range(4):
            if answers[i] == self.correct_answer[idx]:
                onehot.append(1)
            else:
                onehot.append(0)


        return [torch.tensor([self.char_vocab.char_to_index(char) for char in word], dtype=torch.long) for word in question], \
               [torch.tensor([self.char_vocab.char_to_index(char) for char in word], dtype=torch.long) for word in answers], \
               torch.tensor(onehot, dtype=torch.float32)

    def collate_fn(self, batch):
        questions, answers, onehot = zip(*batch)
        bos_token = []
        for c in START_TAG:
            bos_token.append(self.char_vocab.char_to_index(c))
        bos_token = torch.tensor(bos_token, dtype=torch.long)
        eos_token = []
        for c in END_TAG:
            eos_token.append(self.char_vocab.char_to_index(c))
        eos_token = torch.tensor(eos_token, dtype=torch.long)
        pad_token = []
        for c in PAD_TAG:
            pad_token.append(self.char_vocab.char_to_index(c))
        pad_token = torch.tensor(pad_token, dtype=torch.long)

        middle_token = []
        for c in "<>":
            middle_token.append(self.char_vocab.char_to_index(c))
        middle_token = torch.tensor(middle_token, dtype=torch.long)

        questions = [[bos_token] + sentence for sentence in questions]

        questions = [sentence[:self.max_seq_length] + [pad_token] * (self.max_seq_length - len(sentence)) for sentence in questions]

        # add the answers to the questions
        for i in range(len(questions)):
            for j in range(len(answers[i])):
                questions[i].append(middle_token)
                questions[i].append(answers[i][j])
        
        # add the end token
        for i in range(len(questions)):
            questions[i].append(eos_token)

        for i in range(len(questions)):
            for j in range(len(questions[i])):
                questions[i][j] = torch.cat([questions[i][j][:self.max_word_length], torch.tensor([self.char_vocab.char_to_index(PAD_TAG)]*(self.max_word_length - len(questions[i][j])), dtype=torch.long)])

        questions = torch.stack([torch.stack(sentence) for sentence in questions])
        
        onehot = torch.stack(onehot)
        return questions, onehot



In [54]:
questions = [data['cloze_data'][i]['question'] for i in range(len(data['cloze_data']))]
options = [data['cloze_data'][i]['options'] for i in range(len(data['cloze_data']))]
answers = [data['cloze_data'][i]['answer'] for i in range(len(data['cloze_data']))]

In [55]:
# load word and character vocabulary
char_vocab = torch.load('../ELMo/char_vocab_marathi.pt')
word_vocab = torch.load('../ELMo/word_vocab_marathi.pt')

In [56]:
dataset = QuestionAnsweringDataset(questions, options, answers, word_vocab, char_vocab, max_seq_length=30, max_word_length=10)
# split into train, val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

In [57]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=dataset.collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=dataset.collate_fn)

In [58]:
class QuestionAnswering(nn.Module):
    def __init__(self, elmo, embedding_dim, num_classes):
        super(QuestionAnswering, self).__init__()
        self.elmo = elmo
        self.fc = nn.Linear(embedding_dim, embedding_dim//2)
        self.fc2 = nn.Linear(embedding_dim//2, embedding_dim//4)
        self.fc3 = nn.Linear(embedding_dim//4, num_classes)
        self.relu = nn.ReLU()
        self.lambdas = nn.Parameter(torch.rand(3))

        for param in self.elmo.parameters():
            param.requires_grad = False

    def forward(self, questions):
        _, _, questions = self.elmo(questions)
      
        encoding = torch.zeros_like(questions[0])

        for i in range(3):
            encoding += self.lambdas[i] * questions[i]
        selected_encodings = [encoding[:,-8,:], encoding[:,-6,:], encoding[:,-4,:], encoding[:,-2,:]]
        # stack
        selected_encodings = torch.stack(selected_encodings, dim=1)
        # take mean of the embeddings
        selected_encodings = torch.mean(selected_encodings, dim=1)
        # take mean of second last, 4th last, 6th last and 8th last layer
        x = self.fc(selected_encodings)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

In [59]:
elmo = ELMo(cnn_config = {'character_embedding_size': 16, 
                           'num_filters': 32, 
                           'kernel_size': 5, 
                           'max_word_length': 10, 
                           'char_vocab_size': char_vocab.num_chars}, 
             elmo_config = {'num_layers': 3,
                            'word_embedding_dim': 150,
                            'vocab_size': word_vocab.num_words}, 
             char_vocab_size = char_vocab.num_chars).to(device)

In [60]:
elmo.load_state_dict(torch.load('../ELMo/elmo_epoch_3.pt'))

<All keys matched successfully>

In [61]:
question_answering = QuestionAnswering(elmo, 300, 4).to(device)

In [62]:
def train_question(model, train_loader, val_loader, optimizer, criterion, epochs):
    model.to(device)
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        predictions = []
        targets = []
        for batch in tqdm(train_loader):
            # q1, q2, q3, q4, target = batch
            optimizer.zero_grad()
            # q1 = q1.to(device)
            # q2 = q2.to(device)
            # q3 = q3.to(device)
            # q4 = q4.to(device)
            # output = model(q1, q2, q3, q4)
            questions, target = batch
            questions = questions.to(device)
            target = target.to(device)
            output = model(questions)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            predictions.extend(torch.argmax(output, dim=1).tolist())
            targets.extend(torch.argmax(target, dim=1).tolist())
        print('Epoch:', epoch, 'Loss:', total_loss)
        print('Accuracy:', sum([1 for i in range(len(predictions)) if predictions[i] == targets[i]])/len(predictions))

        with torch.no_grad():
            model.eval()
            total_loss = 0
            predictions = []
            targets = []
            for batch in tqdm(val_loader):
                # q1, q2, q3, q4, target = batch
                # q1 = q1.to(device)
                # q2 = q2.to(device)
                # q3 = q3.to(device)
                # q4 = q4.to(device)
                questions, target = batch
                questions = questions.to(device)
                target = target.to(device)
                output = model(questions)
                target = target.to(device)
                loss = criterion(output, target)
                total_loss += loss.item()
                predictions.extend(torch.argmax(output, dim=1).tolist())
                targets.extend(torch.argmax(target, dim=1).tolist())
            print('Val Loss:', total_loss)
            print('Val Accuracy:', sum([1 for i in range(len(predictions)) if predictions[i] == targets[i]])/len(predictions))
        

In [63]:
optimizer = torch.optim.Adam(question_answering.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [64]:
train_question(question_answering, train_dataloader, val_dataloader, optimizer, criterion, 10)

100%|██████████| 285/285 [00:28<00:00,  9.99it/s]


Epoch: 0 Loss: 397.0885375738144
Accuracy: 0.24472295514511874


100%|██████████| 72/72 [00:06<00:00, 10.39it/s]


Val Loss: 99.90960466861725
Val Accuracy: 0.24626209322779244


100%|██████████| 285/285 [00:28<00:00, 10.02it/s]


Epoch: 1 Loss: 395.5011662244797
Accuracy: 0.25395778364116095


100%|██████████| 72/72 [00:06<00:00, 10.29it/s]


Val Loss: 99.89568102359772
Val Accuracy: 0.2519788918205805


100%|██████████| 285/285 [00:28<00:00, 10.05it/s]


Epoch: 2 Loss: 395.18937361240387
Accuracy: 0.2615435356200528


100%|██████████| 72/72 [00:07<00:00, 10.23it/s]


Val Loss: 100.15773093700409
Val Accuracy: 0.23702726473175023


100%|██████████| 285/285 [00:28<00:00,  9.97it/s]


Epoch: 3 Loss: 394.9826751947403
Accuracy: 0.26824978012313105


100%|██████████| 72/72 [00:07<00:00, 10.20it/s]


Val Loss: 100.15262377262115
Val Accuracy: 0.23878627968337732


100%|██████████| 285/285 [00:28<00:00, 10.07it/s]


Epoch: 4 Loss: 394.55707013607025
Accuracy: 0.27341688654353563


100%|██████████| 72/72 [00:06<00:00, 10.31it/s]


Val Loss: 100.10006392002106
Val Accuracy: 0.23922603342128407


100%|██████████| 285/285 [00:28<00:00, 10.06it/s]


Epoch: 5 Loss: 393.7376916408539
Accuracy: 0.2773746701846966


100%|██████████| 72/72 [00:07<00:00, 10.15it/s]


Val Loss: 100.96801340579987
Val Accuracy: 0.23746701846965698


100%|██████████| 285/285 [00:28<00:00, 10.02it/s]


Epoch: 6 Loss: 393.3480746746063
Accuracy: 0.28924802110817943


100%|██████████| 72/72 [00:07<00:00, 10.07it/s]


Val Loss: 100.24395418167114
Val Accuracy: 0.2704485488126649


100%|██████████| 285/285 [00:28<00:00, 10.07it/s]


Epoch: 7 Loss: 392.3604700565338
Accuracy: 0.28770888302550574


100%|██████████| 72/72 [00:06<00:00, 10.39it/s]


Val Loss: 100.40632259845734
Val Accuracy: 0.2546174142480211


100%|██████████| 285/285 [00:28<00:00, 10.01it/s]


Epoch: 8 Loss: 390.7692667245865
Accuracy: 0.3005716798592788


100%|██████████| 72/72 [00:07<00:00, 10.12it/s]


Val Loss: 100.5695880651474
Val Accuracy: 0.2598944591029024


100%|██████████| 285/285 [00:28<00:00, 10.04it/s]


Epoch: 9 Loss: 389.617115855217
Accuracy: 0.3016710642040457


100%|██████████| 72/72 [00:07<00:00, 10.25it/s]

Val Loss: 101.78613924980164
Val Accuracy: 0.2546174142480211





Question answering is not a valid task, because all of the questions in the dataset are extremely specific and related to places. The only learning the model does is by overfitting.