In [1]:
import torch
import torch.nn as nn
from preprocessing import NextWordDataset
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
path = '/scratch/sanika/content/sample_data/fulldataset_dedup_final.txt'

In [4]:
# make the character embedding and convolutional layer with max pooling
class CharCNN(nn.Module):
    def __init__(self, character_embedding_size, num_filters, kernel_size, max_word_length, char_vocab_size, word_embedding_dim, device=None):
        super(CharCNN, self).__init__()
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.char_embedding = nn.Embedding(char_vocab_size, character_embedding_size).to(device)
        self.conv_layers = nn.ModuleList([nn.Conv1d(character_embedding_size, num_filters, kernel_size).to(device) for _ in range(max_word_length - kernel_size + 1)])
        self.fc = nn.Linear(num_filters * (max_word_length - kernel_size + 1), word_embedding_dim).to(device)
        self.device = device

    def forward(self, x):
        # x is a batch of words. Each word is a list of characters (batch_size, max_word_length)
        # first, we convert the characters to embeddings
        x = x.to(self.device)
        x = self.char_embedding(x) # (batch_size, max_word_length, character_embedding_size)
        # print(x.shape)
        x = x.permute(0, 2, 1) # (batch_size, character_embedding_size, max_word_length)

        # now we run the convolutional layers
        x = [conv(x) for conv in self.conv_layers]
        
        # now we max pool
        x = [torch.max(torch.relu(conv), dim=2)[0] for conv in x]

        # now we concatenate the results
        x = torch.cat(x, dim=1) # (batch_size, num_filters * (max_word_length - kernel_size + 1))
        
        # finally, we run the fully connected layer
        x = self.fc(x)
        return x


# ELMo part
class ELMo(nn.Module):
    def __init__(self, cnn_config, elmo_config, char_vocab_size):
        # input to this is a batch of sentences. Each sentence is a list of words. Each word is a list of characters.
        super(ELMo, self).__init__()
        # first, we convert the token to a representation using character embeddings
        self.char_cnn = CharCNN(cnn_config['character_embedding_size'], 
                                cnn_config['num_filters'], 
                                cnn_config['kernel_size'], 
                                cnn_config['max_word_length'], 
                                cnn_config['char_vocab_size'],
                                elmo_config['word_embedding_dim'],
                                device = device).to(device)
        self.forward_lstm = nn.LSTM(elmo_config['word_embedding_dim'], int(elmo_config['word_embedding_dim']/2), 
                                    1, bidirectional = False).to(device)
        self.backward_lstm = nn.LSTM(elmo_config['word_embedding_dim'], int(elmo_config['word_embedding_dim']/2),
                                    1, bidirectional = False).to(device)
        # based on the number of layers as passed in the argument, sequentially have that many layers
        self.forward_lstms = nn.ModuleList([self.forward_lstm for _ in range(elmo_config['num_layers'])])
        self.backward_lstms = nn.ModuleList([self.backward_lstm for _ in range(elmo_config['num_layers'])])
        self.num_layers = elmo_config['num_layers']
        self.fc = nn.Linear(elmo_config['word_embedding_dim'], elmo_config['vocab_size'])
        
    def forward(self, x):
        # character cnn
        # convert x to tensor
        x = torch.stack(x, dim=0)
        x = x.permute(1, 0, 2)
        x = [self.char_cnn(word) for word in x]
        # lstm1
        x = torch.stack(x, dim=1) 
        x = x.permute(1, 0, 2) 
        
        lstm_output = x
        for i in range(self.num_layers):
            forward_lstm_output, _ = self.forward_lstms[i](lstm_output)
            backward_lstm_output, _ = self.backward_lstms[i](torch.flip(lstm_output, [1]))
            backward_lstm_output = torch.flip(backward_lstm_output, [1])
            lstm_output = torch.cat((forward_lstm_output, backward_lstm_output), dim = 2)
        
        x = torch.mean(lstm_output, dim = 1)
        x = self.fc(x)
        return x, lstm_output
     

In [5]:
dataset = NextWordDataset(path, 0, 10000)
words, targets, vocab, character_vocab = dataset.format(dataset.sentences, 5, True)

In [6]:
from torch.utils.data import DataLoader

In [7]:
def create_dataloader(words, targets, batch_size):
    zipped = list(zip(words, targets))
    dataloader = {
        'train': DataLoader(zipped, batch_size = batch_size, shuffle = False),
    }
    return dataloader

In [8]:
dataloader = create_dataloader(words, targets, 128)

In [26]:
def train(model, len_vocab, word_vocab, character_vocabulary, dataloader):
    current = 0
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(10):
        model.train()
        total_loss = 0
        total_items = 0
        correct = 0

        for i in range(1):
            # dataset = NextWordDataset(path=path, start=current, end=current+100000, vocabulary=word_vocab, character_vocab=character_vocabulary)
            # words, targets = dataset.format(dataset.sentences, 5)
            # current += 100000
            # dataloader = create_dataloader(words, targets, 64)
            pbar = tqdm(enumerate(dataloader['train']), total=len(dataloader['train']))
            # reset pbar

            for j, (word, target) in enumerate(dataloader['train']):
                word = [w.to(device) for w in word]
                target = target.to(device)
                optimizer.zero_grad()
                output, _ = model(word)
                # one hot encode the target
                target = torch.nn.functional.one_hot(target, num_classes = len_vocab).float()
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                total_items += 1
                # calculate accuracy, prediction is correct if it is within the top k predictions
                # return top k predictions for k = 20
                _, topk = torch.topk(output, 20)
                # add 1 to correct if target is in top k
                target = torch.argmax(target, dim=1)
                correct += torch.sum(target.view(-1, 1) == topk)
                pbar.update(1)
            pbar.close()

        print(f'Epoch {epoch} Loss: {total_loss/total_items} Accuracy: {correct/total_items}')

In [24]:
model = ELMo(cnn_config = {'character_embedding_size': 16, 
                           'num_filters': 32, 
                           'kernel_size': 5, 
                           'max_word_length': 10, 
                           'char_vocab_size': len(character_vocab)}, 
             elmo_config = {'num_layers': 3,
                            'word_embedding_dim': 300,
                            'vocab_size': len(vocab)}, 
             char_vocab_size = len(character_vocab)).to(device)


In [27]:
train(model, len(vocab), vocab, character_vocab, dataloader)

  0%|          | 0/731 [00:00<?, ?it/s]

100%|██████████| 731/731 [04:17<00:00,  2.84it/s]


Epoch 0 Loss: 8.429014948259137 Accuracy: 18.803010940551758


100%|██████████| 731/731 [04:14<00:00,  2.87it/s]


Epoch 1 Loss: 7.934716470916685 Accuracy: 19.26812744140625


100%|██████████| 731/731 [04:14<00:00,  2.88it/s]


Epoch 2 Loss: 7.899834498824238 Accuracy: 19.051984786987305


100%|██████████| 731/731 [04:20<00:00,  2.80it/s]


Epoch 3 Loss: 7.905022729111762 Accuracy: 19.19562339782715


100%|██████████| 731/731 [04:20<00:00,  2.81it/s]


Epoch 4 Loss: 7.961161589165882 Accuracy: 19.391244888305664


100%|██████████| 731/731 [04:20<00:00,  2.80it/s]


Epoch 5 Loss: 8.026623571294113 Accuracy: 19.21887969970703


100%|██████████| 731/731 [04:21<00:00,  2.79it/s]


Epoch 6 Loss: 7.981884674888955 Accuracy: 19.3844051361084


100%|██████████| 731/731 [04:19<00:00,  2.82it/s]


Epoch 7 Loss: 8.013273296049617 Accuracy: 19.436389923095703


100%|██████████| 731/731 [04:19<00:00,  2.81it/s]


Epoch 8 Loss: 7.970953628400445 Accuracy: 19.482900619506836


100%|██████████| 731/731 [04:13<00:00,  2.88it/s]

Epoch 9 Loss: 7.966687541653781 Accuracy: 19.471956253051758





In [29]:
# save model as .pt
torch.save(model.state_dict(), 'model.pt')