In [1]:
import torch
import torch.nn as nn
from preprocessing import CharLevelDataset
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from ELMO import ELMo

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
path = '/scratch/sanika/cleaned_marathi.txt'

In [22]:
# read from characters_marathi.txt
with open('characters_marathi.txt', 'r') as f:
    characters = f.read().splitlines()

# add comma, space, . , newline, tab
characters += [' ', ',', '.', '\n', '\t', '-', '?', '!', ':', ';', '(', ')', '।', '॥', '“', '”']

In [None]:
# read from 'scratch/sanika/content/sample_data/fulldataset_dedup_final.txt
with open('/scratch/sanika/content/sample_data/fulldataset_dedup_final.txt', 'r', encoding='utf-8') as f:
    # split at newline
    new_lines = []
    for i, line in enumerate(f):
        # only add if all characters part of characters
        if all([c in characters for c in line]):
            new_lines.append(line)
        else:
            print(line)

In [29]:
new_lines[0]

'तो घरी आला.\n'

In [4]:
dataset = CharLevelDataset(path, 0, 1085998)
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=dataset.collate_fn)

In [6]:
# save the vocabulary
vocab = dataset.word_vocab
torch.save(vocab, 'word_vocab.pt')
torch.save(dataset.character_vocab, 'char_vocab.pt')
print(dataset.word_vocab.num_words)
print(dataset.character_vocab.num_chars)

50807
80


In [18]:
def train(model, len_vocab, word_vocab, character_vocabulary):
    current = 0
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(5):
        model.train()
        total_loss = 0
        total_items = 0
        correct = 0

        for i in range(10):
            batch_loss = 0
            total_items_batch = 0
            correct_batch = 0
            dataset = CharLevelDataset(path=path, start=current, end=current+100000, character_vocab=character_vocabulary, word_vocab=word_vocab)
            dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=dataset.collate_fn)
            current += 100000
            for batch in tqdm(dataloader):
                sentences, targets = batch
                optimizer.zero_grad()
                sentences = sentences.to(device)
                targets = targets.to(device)
                forward_output, backward_output, final_embeddings = model(sentences)
                # now, to calculate loss we need to calculate the probability of each word in the vocabulary for forward and backward separately
                # calculate the loss for the forward part
                # one hot encode the targets (shifted by 1)
                targets_forward = targets[:, 1:]
                targets_forward = torch.nn.functional.one_hot(targets_forward, num_classes=len_vocab)
                targets_forward = targets_forward.float()
                # shift the target by 1 to the other side for backward
                targets_backward = targets[:, :-1] 
                targets_backward = torch.nn.functional.one_hot(targets_backward, num_classes=len_vocab)
                targets_backward = targets_backward.float()
                # calculate the probabilities
                forward_output = forward_output[:, :-1, :]
                backward_output = backward_output[:, 1:, :]
                # calculate the loss per word
                loss = criterion(forward_output, targets_forward) + criterion(backward_output, targets_backward)
                loss.backward()
                optimizer.step()
                batch_loss += loss.item()
                total_loss += loss.item()
                

                # calculate accuracy, prediction is correct if it is within the top k predictions
                # return top k predictions for k = 20
                _, topk = torch.topk(forward_output, 20)
                targets_forward = torch.argmax(targets_forward, dim=2)
                targets_forward = targets_forward.unsqueeze(2)
                total_items_batch += targets_forward.shape[0]*targets_forward.shape[1]
                # check if target is in top k
                correct_predictions = targets_forward == topk
                # correct += targets_forward == topk
                correct_batch += torch.sum(correct_predictions)
                correct += torch.sum(correct_predictions)
                total_items += targets_forward.shape[0]*targets_forward.shape[1]
                
                _, topk = torch.topk(backward_output, 20)
                targets_backward = torch.argmax(targets_backward, dim=2)
                targets_backward = targets_backward.unsqueeze(2)
                total_items_batch += targets_backward.shape[0]*targets_backward.shape[1]
                # check if target is in top k
                correct_predictions = targets_backward == topk
                # correct += targets_backward == topk
                correct_batch += torch.sum(correct_predictions)
                correct += torch.sum(correct_predictions)
                total_items += targets_backward.shape[0]*targets_forward.shape[1]

            print(f'Batch Loss {i}: {batch_loss}')
            print(f'Batch Accuracy {i}: {correct_batch/total_items_batch}')

        # save model
        torch.save(model.state_dict(), f'elmo_epoch_{epoch}.pt')
        print(f'Epoch {epoch} Loss: {total_loss} Accuracy: {correct/total_items}')

In [8]:
print(dataset.word_vocab.num_words)

50807


In [19]:
model = ELMo(cnn_config = {'character_embedding_size': 16, 
                           'num_filters': 32, 
                           'kernel_size': 5, 
                           'max_word_length': 10, 
                           'char_vocab_size': dataset.character_vocab.num_chars}, 
             elmo_config = {'num_layers': 3,
                            'word_embedding_dim': 150,
                            'vocab_size': dataset.word_vocab.num_words}, 
             char_vocab_size = dataset.character_vocab.num_chars).to(device)


In [20]:
train(model, dataset.word_vocab.num_words, dataset.word_vocab, dataset.character_vocab)

100%|██████████| 782/782 [02:21<00:00,  5.51it/s]


Batch Loss 0: 1.2178260030923411
Batch Accuracy 0: 0.5567686557769775


100%|██████████| 782/782 [02:20<00:00,  5.55it/s]


Batch Loss 1: 1.1368769069667906
Batch Accuracy 1: 0.749113917350769


100%|██████████| 782/782 [02:20<00:00,  5.57it/s]


Batch Loss 2: 1.1107494086027145
Batch Accuracy 2: 0.7631421089172363


 99%|█████████▊| 771/782 [02:16<00:01,  5.63it/s]


KeyboardInterrupt: 

In [None]:
# save model as .pt
torch.save(model.state_dict(), 'model_elmo_marathi.pt')

In [None]:
# save character and word vocab as .pt
torch.save(dataset.character_vocab, 'character_vocab_marathi.pt')
torch.save(dataset.word_vocab, 'vocab_marathi.pt')
print(dataset.character_vocab.num_chars)

In [21]:
# print model layers
print(model)

ELMo(
  (char_cnn): CharCNN(
    (char_embedding): Embedding(126, 16)
    (conv_layers): ModuleList(
      (0): Conv1d(16, 32, kernel_size=(5,), stride=(1,))
      (1): Conv1d(16, 32, kernel_size=(5,), stride=(1,))
      (2): Conv1d(16, 32, kernel_size=(5,), stride=(1,))
      (3): Conv1d(16, 32, kernel_size=(5,), stride=(1,))
      (4): Conv1d(16, 32, kernel_size=(5,), stride=(1,))
      (5): Conv1d(16, 32, kernel_size=(5,), stride=(1,))
    )
    (fc): Linear(in_features=192, out_features=300, bias=True)
  )
  (forward_lstm): LSTM(300, 150)
  (backward_lstm): LSTM(300, 150)
  (forward_lstms): ModuleList(
    (0): LSTM(300, 150)
    (1): LSTM(300, 150)
    (2): LSTM(300, 150)
  )
  (backward_lstms): ModuleList(
    (0): LSTM(300, 150)
    (1): LSTM(300, 150)
    (2): LSTM(300, 150)
  )
  (fc): Linear(in_features=300, out_features=36377, bias=True)
)


: 