In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import time
import random


In [2]:
# building additional dictionaries

tag_idx = 0
word_idx = 0
char_idx = 0

word2idx = dict()
tag2idx = dict()
idx2tag = dict()
char2idx = dict()


st = time.time()
with open("corpus.train", "r") as file:
    for line in file:
        pairs = line.rstrip("\n").rstrip("\r").split(" ")
    
        for pair in pairs:
            spliting_pair =  pair.split("/")
            word = "/".join(spliting_pair[:len(spliting_pair)-1]).lower()
            tag = spliting_pair[len(spliting_pair)-1]
            
            
            if word not in word2idx.keys():
                word2idx[word] = word_idx
                word_idx+=1
                
            if tag not in tag2idx.keys():
                tag2idx[tag] = tag_idx
                tag_idx+=1
            
            for char in word:
                if char not in char2idx.keys():
                    char2idx[char] = char_idx
                    char_idx+=1
                    
for key in tag2idx:
    idx2tag[tag2idx[key]] = key

#Unknown word index
word2idx["<UNKNOWN>"] = word_idx
word_idx+=1

#Unknown char index
char2idx['\r'] = char_idx
char_idx+=1


print(tag2idx)
print(idx2tag)
print(char2idx)

{'IN': 0, 'DT': 1, 'NNP': 2, 'CD': 3, 'NN': 4, '``': 5, "''": 6, 'POS': 7, '-LRB-': 8, 'VBN': 9, 'NNS': 10, 'VBP': 11, ',': 12, 'CC': 13, '-RRB-': 14, 'VBD': 15, 'RB': 16, 'TO': 17, '.': 18, 'VBZ': 19, 'NNPS': 20, 'PRP': 21, 'PRP$': 22, 'VB': 23, 'JJ': 24, 'MD': 25, 'VBG': 26, 'RBR': 27, ':': 28, 'WP': 29, 'WDT': 30, 'JJR': 31, 'PDT': 32, 'RBS': 33, 'WRB': 34, 'JJS': 35, '$': 36, 'RP': 37, 'FW': 38, 'EX': 39, 'SYM': 40, '#': 41, 'LS': 42, 'UH': 43, 'WP$': 44}
{0: 'IN', 1: 'DT', 2: 'NNP', 3: 'CD', 4: 'NN', 5: '``', 6: "''", 7: 'POS', 8: '-LRB-', 9: 'VBN', 10: 'NNS', 11: 'VBP', 12: ',', 13: 'CC', 14: '-RRB-', 15: 'VBD', 16: 'RB', 17: 'TO', 18: '.', 19: 'VBZ', 20: 'NNPS', 21: 'PRP', 22: 'PRP$', 23: 'VB', 24: 'JJ', 25: 'MD', 26: 'VBG', 27: 'RBR', 28: ':', 29: 'WP', 30: 'WDT', 31: 'JJR', 32: 'PDT', 33: 'RBS', 34: 'WRB', 35: 'JJS', 36: '$', 37: 'RP', 38: 'FW', 39: 'EX', 40: 'SYM', 41: '#', 42: 'LS', 43: 'UH', 44: 'WP$'}
{'i': 0, 'n': 1, 'a': 2, 'o': 3, 'c': 4, 't': 5, '.': 6, '1': 7, '9': 8,

In [3]:
class CustomDataset(Dataset):
    def __init__(self, filename):        
        X, y = [], []
        with open(filename, "r") as file:
           
            for line in file:
                pairs = line.rstrip("\n").rstrip("\r").split(" ")
                
                X1, y1 = [], []
                
                for pair in pairs:
                    spliting_pair =  pair.split("/")
                    
                    word = "/".join(spliting_pair[:len(spliting_pair)-1]).lower()
                    tag = spliting_pair[len(spliting_pair)-1]
                    
                    X1.append(word)
                    y1.append(tag)
                
                X.append(" ".join(X1))
                y.append(" ".join(y1))
                            
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.X)
   
    def __getitem__(self, idx):    
        return self.X[idx], self.y[idx]

In [4]:
# buildig neural network

class POS_tag_net(nn.Module):
    def __init__(self, vocab_size, char_size, target_size, 
                 word_emb_dim = 32, char_emb_dim = 8, 
                 out_channels = 16, kernel_size = 3,
                 lstm_hidden_dim = 64):
        
        super().__init__()
        
        self.target_size = target_size
        self.word_emb_dim = word_emb_dim
        self.char_emb_dim = char_emb_dim
        self.lstm_hidden_dim = lstm_hidden_dim
        self.word_embedding = nn.Embedding(vocab_size, word_emb_dim)
        self.char_embedding = nn.Embedding(char_size, char_emb_dim)
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        
        self.conv1 = nn.Conv1d(self.char_emb_dim, out_channels=self.out_channels, kernel_size=self.kernel_size, padding=self.kernel_size//2)

        self.lstm = nn.LSTM(self.word_emb_dim+self.out_channels, 
                           self.lstm_hidden_dim,
                           bidirectional = True)
        
        self.fc1 = nn.Linear(self.lstm_hidden_dim*2, self.target_size)
    
    def forward(self, words_batch, chars_batch, max_len_word, max_len_sentence):
        
        right_words_batch = []
        for words in words_batch:
            embedding = self.word_embedding(torch.tensor(words))
            
            if embedding.shape[0] != max_sentence_len:
                embedding = torch.cat((embedding, torch.zeros(max_sentence_len-embedding.shape[0], self.word_emb_dim)))
            
            right_words_batch.append(embedding)
        
        right_words_batch = torch.stack(right_words_batch)
        
#         print(right_words_batch.shape)
        
        
        right_chars_batch = []
        
        for words in chars_batch:
            sentence = []
            for chars in words:
                c_embedding = self.char_embedding(torch.tensor(chars))
            
                if c_embedding.shape[0] != max_word_len:
                    c_embedding = torch.cat((c_embedding, torch.zeros(max_word_len-c_embedding.shape[0], self.char_emb_dim)))
                
                sentence.append(torch.t(c_embedding))
            
            while len(sentence)!=max_len_sentence:
                sentence.append(torch.t(torch.zeros(max_word_len, self.char_emb_dim)))
                
            sentence = torch.stack(sentence)
            right_chars_batch.append(sentence)
        
        right_chars_batch = torch.stack(right_chars_batch)
#         print(right_chars_batch.shape)
        
        right_chars_batch = right_chars_batch.view(-1, self.char_emb_dim, max_word_len)
#         print(right_chars_batch.shape)
        
        right_chars_batch = F.relu(self.conv1(right_chars_batch))
#         print("aaa", right_chars_batch.shape)
        
        right_chars_batch, _ = torch.max(right_chars_batch, dim=-1)
#         print("aaa", right_chars_batch.shape)
        
        right_chars_batch = right_chars_batch.view(-1, max_len_sentence, self.out_channels)
#         print("aaa", right_chars_batch.shape)
        
        lstm_input = torch.cat((right_words_batch, right_chars_batch), dim=-1)
        
        lstm_output, _ = self.lstm(lstm_input)
        
        
        out = F.log_softmax(self.fc1(lstm_output), dim = -1)
#         print(out.shape)
        return out

    def loss(self, Y_hat, Y):
        # TRICK 3 ********************************
        # before we calculate the negative log likelihood, we need to mask out the activations
        # this means we don't want to take into account padded items in the output vector
        # simplest way to think about this is to flatten ALL sequences into a REALLY long sequence
        # and calculate the loss on that.

        # flatten all the labels
        Y = Y.view(-1)

        Y_hat = Y_hat.view(-1, 45)

        mask = (Y > -1).float()

        # count how many tokens we have
        nb_tokens = int(torch.sum(mask).item())

        # pick the values for the label and zero out the rest with the mask
        Y_hat = Y_hat[range(Y_hat.shape[0]), Y] * mask

        # compute cross entropy loss which ignores all <PAD> tokens
        ce_loss = -torch.sum(Y_hat) / nb_tokens

        return ce_loss
        
    
my_net = POS_tag_net(len(word2idx.keys()), len(char2idx.keys()), len(tag2idx.keys()))

optimezer = optim.Adam(my_net.parameters(), lr = 0.001)                    



In [5]:
# Train neural network

EPOCHS = 1
start_time = time.time()

p_word_to_unknown = 0.02
p_char_to_unknown = 0.01



dataset = CustomDataset('corpus.train')
dataloader = DataLoader(dataset, batch_size = 32)


for epoch in range(EPOCHS):
    batch_n = 0
    for X, y in dataloader:
        # X - sentence, y - sentence tags
        
        X = [e.split(" ") for e in X]
        batch_y = [e.split(" ") for e in y]
                
        batch_tokens = X.copy()
        
        max_word_len = -1
        max_sentence_len = -1
        
        
        batch_char_X = []
        batch_words_X = []
        
        for sentence in X:
            if len(sentence) > max_sentence_len: 
                max_sentence_len = len(sentence)
            for word in sentence:
                if len(word) > max_word_len:
                    max_word_len = len(word)
        
        batch_norm_y = []
        
        for y in batch_y:
            y_tags = [tag2idx[e] for e in y]
            
            while len(y_tags) != max_sentence_len:
                y_tags.append(-1)
            
            
            batch_norm_y.append(y_tags)
        
        batch_norm_y = torch.tensor(batch_norm_y)
        
        for sentence in X:
          
            words_ids = []
            words_chars_ids = []
            
            for word in sentence:
                
                if p_word_to_unknown >= random.uniform(0, 1):
                    words_ids.append(word2idx["<UNKNOWN>"])
                else:
                    words_ids.append(word2idx[word])
                char_ids = []
                
                for c in word:
                    if p_char_to_unknown >= random.uniform(0, 1):
                        char_ids.append(char2idx['\r'])
                    else:
                        char_ids.append(char2idx[c])

                words_chars_ids.append(char_ids)
            
            batch_char_X.append(words_chars_ids)
            batch_words_X.append(words_ids)
           
        my_net.zero_grad()
        output = my_net(batch_words_X, batch_char_X, max_word_len, max_sentence_len)
        
        loss = my_net.loss(output, batch_norm_y)
        
      
        
        loss.backward()
        
        
        optimezer.step()
    
    
        batch_n+=1
        print(f"Epoch: {epoch} Batch: {batch_n}/{len(dataloader)} loss: {loss}")          
        
    print(f"======== Epoch_end: {epoch} loss: {loss} time: {time.time()-start_time} ========") 
    
    
    

Epoch: 0 Batch: 1/1186 loss: 3.820404529571533
Epoch: 0 Batch: 2/1186 loss: 3.78729510307312
Epoch: 0 Batch: 3/1186 loss: 3.7655272483825684
Epoch: 0 Batch: 4/1186 loss: 3.7398767471313477
Epoch: 0 Batch: 5/1186 loss: 3.7185237407684326
Epoch: 0 Batch: 6/1186 loss: 3.7019710540771484
Epoch: 0 Batch: 7/1186 loss: 3.6738228797912598
Epoch: 0 Batch: 8/1186 loss: 3.679391622543335
Epoch: 0 Batch: 9/1186 loss: 3.623492956161499
Epoch: 0 Batch: 10/1186 loss: 3.637249708175659
Epoch: 0 Batch: 11/1186 loss: 3.579315423965454
Epoch: 0 Batch: 12/1186 loss: 3.540592908859253
Epoch: 0 Batch: 13/1186 loss: 3.479952573776245
Epoch: 0 Batch: 14/1186 loss: 3.453953266143799
Epoch: 0 Batch: 15/1186 loss: 3.3783583641052246
Epoch: 0 Batch: 16/1186 loss: 3.3557252883911133
Epoch: 0 Batch: 17/1186 loss: 3.4183707237243652
Epoch: 0 Batch: 18/1186 loss: 3.3131446838378906
Epoch: 0 Batch: 19/1186 loss: 3.248081922531128
Epoch: 0 Batch: 20/1186 loss: 3.1276755332946777
Epoch: 0 Batch: 21/1186 loss: 3.06853127

Epoch: 0 Batch: 169/1186 loss: 1.5287188291549683
Epoch: 0 Batch: 170/1186 loss: 1.5277125835418701
Epoch: 0 Batch: 171/1186 loss: 1.5492793321609497
Epoch: 0 Batch: 172/1186 loss: 1.6274393796920776
Epoch: 0 Batch: 173/1186 loss: 1.457971215248108
Epoch: 0 Batch: 174/1186 loss: 1.514896035194397
Epoch: 0 Batch: 175/1186 loss: 1.4544639587402344
Epoch: 0 Batch: 176/1186 loss: 1.5065146684646606
Epoch: 0 Batch: 177/1186 loss: 1.4411065578460693
Epoch: 0 Batch: 178/1186 loss: 1.4912947416305542
Epoch: 0 Batch: 179/1186 loss: 1.3915094137191772
Epoch: 0 Batch: 180/1186 loss: 1.5430957078933716
Epoch: 0 Batch: 181/1186 loss: 1.448546290397644
Epoch: 0 Batch: 182/1186 loss: 1.4029890298843384
Epoch: 0 Batch: 183/1186 loss: 1.5664469003677368
Epoch: 0 Batch: 184/1186 loss: 1.4482345581054688
Epoch: 0 Batch: 185/1186 loss: 1.4787853956222534
Epoch: 0 Batch: 186/1186 loss: 1.4324078559875488
Epoch: 0 Batch: 187/1186 loss: 1.6121609210968018
Epoch: 0 Batch: 188/1186 loss: 1.4083739519119263
Epo

Epoch: 0 Batch: 334/1186 loss: 1.2469837665557861
Epoch: 0 Batch: 335/1186 loss: 1.2812048196792603
Epoch: 0 Batch: 336/1186 loss: 1.0647838115692139
Epoch: 0 Batch: 337/1186 loss: 1.0551412105560303
Epoch: 0 Batch: 338/1186 loss: 1.0435854196548462
Epoch: 0 Batch: 339/1186 loss: 1.1028565168380737
Epoch: 0 Batch: 340/1186 loss: 1.1234705448150635
Epoch: 0 Batch: 341/1186 loss: 1.1224489212036133
Epoch: 0 Batch: 342/1186 loss: 1.1543594598770142
Epoch: 0 Batch: 343/1186 loss: 1.1589972972869873
Epoch: 0 Batch: 344/1186 loss: 1.2261772155761719
Epoch: 0 Batch: 345/1186 loss: 1.2052335739135742
Epoch: 0 Batch: 346/1186 loss: 1.217176079750061
Epoch: 0 Batch: 347/1186 loss: 1.0590736865997314
Epoch: 0 Batch: 348/1186 loss: 1.1133091449737549
Epoch: 0 Batch: 349/1186 loss: 1.1278855800628662
Epoch: 0 Batch: 350/1186 loss: 1.1397182941436768
Epoch: 0 Batch: 351/1186 loss: 1.1889607906341553
Epoch: 0 Batch: 352/1186 loss: 1.1261130571365356
Epoch: 0 Batch: 353/1186 loss: 1.1087652444839478
E

Epoch: 0 Batch: 500/1186 loss: 0.7488059997558594
Epoch: 0 Batch: 501/1186 loss: 0.7359766364097595
Epoch: 0 Batch: 502/1186 loss: 0.786503791809082
Epoch: 0 Batch: 503/1186 loss: 0.7827017307281494
Epoch: 0 Batch: 504/1186 loss: 0.9035792946815491
Epoch: 0 Batch: 505/1186 loss: 1.0364521741867065
Epoch: 0 Batch: 506/1186 loss: 1.0465118885040283
Epoch: 0 Batch: 507/1186 loss: 0.9953327178955078
Epoch: 0 Batch: 508/1186 loss: 0.7948316335678101
Epoch: 0 Batch: 509/1186 loss: 0.9006295204162598
Epoch: 0 Batch: 510/1186 loss: 0.9541582465171814
Epoch: 0 Batch: 511/1186 loss: 0.9283533096313477
Epoch: 0 Batch: 512/1186 loss: 0.9761582612991333
Epoch: 0 Batch: 513/1186 loss: 0.8816893696784973
Epoch: 0 Batch: 514/1186 loss: 0.8196895718574524
Epoch: 0 Batch: 515/1186 loss: 0.8546434044837952
Epoch: 0 Batch: 516/1186 loss: 0.8404038548469543
Epoch: 0 Batch: 517/1186 loss: 0.8452087044715881
Epoch: 0 Batch: 518/1186 loss: 0.9295656085014343
Epoch: 0 Batch: 519/1186 loss: 0.6544380784034729
E

Epoch: 0 Batch: 665/1186 loss: 0.8294754028320312
Epoch: 0 Batch: 666/1186 loss: 0.8171490430831909
Epoch: 0 Batch: 667/1186 loss: 0.8295988440513611
Epoch: 0 Batch: 668/1186 loss: 0.7374513149261475
Epoch: 0 Batch: 669/1186 loss: 0.7949957251548767
Epoch: 0 Batch: 670/1186 loss: 0.7766125202178955
Epoch: 0 Batch: 671/1186 loss: 0.7071159482002258
Epoch: 0 Batch: 672/1186 loss: 0.8437638282775879
Epoch: 0 Batch: 673/1186 loss: 0.9995682835578918
Epoch: 0 Batch: 674/1186 loss: 0.7566387057304382
Epoch: 0 Batch: 675/1186 loss: 0.92465740442276
Epoch: 0 Batch: 676/1186 loss: 0.676298201084137
Epoch: 0 Batch: 677/1186 loss: 0.7193630337715149
Epoch: 0 Batch: 678/1186 loss: 0.6787469983100891
Epoch: 0 Batch: 679/1186 loss: 0.8419257998466492
Epoch: 0 Batch: 680/1186 loss: 0.8735013604164124
Epoch: 0 Batch: 681/1186 loss: 0.7213037610054016
Epoch: 0 Batch: 682/1186 loss: 0.768744945526123
Epoch: 0 Batch: 683/1186 loss: 0.6649228930473328
Epoch: 0 Batch: 684/1186 loss: 0.7435761094093323
Epoc

Epoch: 0 Batch: 831/1186 loss: 0.6773982644081116
Epoch: 0 Batch: 832/1186 loss: 0.6651740074157715
Epoch: 0 Batch: 833/1186 loss: 0.8221526741981506
Epoch: 0 Batch: 834/1186 loss: 0.8634192943572998
Epoch: 0 Batch: 835/1186 loss: 0.8047369122505188
Epoch: 0 Batch: 836/1186 loss: 0.7104342579841614
Epoch: 0 Batch: 837/1186 loss: 0.6944653987884521
Epoch: 0 Batch: 838/1186 loss: 0.7744914889335632
Epoch: 0 Batch: 839/1186 loss: 0.7004911303520203
Epoch: 0 Batch: 840/1186 loss: 0.7314794063568115
Epoch: 0 Batch: 841/1186 loss: 0.660953164100647
Epoch: 0 Batch: 842/1186 loss: 0.7046672701835632
Epoch: 0 Batch: 843/1186 loss: 0.7220805287361145
Epoch: 0 Batch: 844/1186 loss: 0.6429375410079956
Epoch: 0 Batch: 845/1186 loss: 0.6914911270141602
Epoch: 0 Batch: 846/1186 loss: 0.6488958597183228
Epoch: 0 Batch: 847/1186 loss: 0.6733807325363159
Epoch: 0 Batch: 848/1186 loss: 0.6356037259101868
Epoch: 0 Batch: 849/1186 loss: 0.6716345548629761
Epoch: 0 Batch: 850/1186 loss: 0.7138712406158447
E

Epoch: 0 Batch: 996/1186 loss: 0.5158782005310059
Epoch: 0 Batch: 997/1186 loss: 0.5501611828804016
Epoch: 0 Batch: 998/1186 loss: 0.6439496278762817
Epoch: 0 Batch: 999/1186 loss: 0.5919916033744812
Epoch: 0 Batch: 1000/1186 loss: 0.5637508630752563
Epoch: 0 Batch: 1001/1186 loss: 0.6497008800506592
Epoch: 0 Batch: 1002/1186 loss: 0.5812126994132996
Epoch: 0 Batch: 1003/1186 loss: 0.6350045204162598
Epoch: 0 Batch: 1004/1186 loss: 0.5963842868804932
Epoch: 0 Batch: 1005/1186 loss: 0.6273561716079712
Epoch: 0 Batch: 1006/1186 loss: 0.6541322469711304
Epoch: 0 Batch: 1007/1186 loss: 0.532993495464325
Epoch: 0 Batch: 1008/1186 loss: 0.5734787583351135
Epoch: 0 Batch: 1009/1186 loss: 0.7045409083366394
Epoch: 0 Batch: 1010/1186 loss: 0.5256373286247253
Epoch: 0 Batch: 1011/1186 loss: 0.5662253499031067
Epoch: 0 Batch: 1012/1186 loss: 0.6619151830673218
Epoch: 0 Batch: 1013/1186 loss: 0.7213365435600281
Epoch: 0 Batch: 1014/1186 loss: 0.5940865278244019
Epoch: 0 Batch: 1015/1186 loss: 0.61

Epoch: 0 Batch: 1158/1186 loss: 0.6512666940689087
Epoch: 0 Batch: 1159/1186 loss: 0.6948625445365906
Epoch: 0 Batch: 1160/1186 loss: 0.4584685266017914
Epoch: 0 Batch: 1161/1186 loss: 0.5490326285362244
Epoch: 0 Batch: 1162/1186 loss: 0.530508279800415
Epoch: 0 Batch: 1163/1186 loss: 0.6007548570632935
Epoch: 0 Batch: 1164/1186 loss: 0.5453662872314453
Epoch: 0 Batch: 1165/1186 loss: 0.5918939709663391
Epoch: 0 Batch: 1166/1186 loss: 0.7155564427375793
Epoch: 0 Batch: 1167/1186 loss: 0.5886008143424988
Epoch: 0 Batch: 1168/1186 loss: 0.6988006234169006
Epoch: 0 Batch: 1169/1186 loss: 0.7332509160041809
Epoch: 0 Batch: 1170/1186 loss: 0.6783777475357056
Epoch: 0 Batch: 1171/1186 loss: 0.5535977482795715
Epoch: 0 Batch: 1172/1186 loss: 0.5955678224563599
Epoch: 0 Batch: 1173/1186 loss: 0.47223493456840515
Epoch: 0 Batch: 1174/1186 loss: 0.48307639360427856
Epoch: 0 Batch: 1175/1186 loss: 0.6254425644874573
Epoch: 0 Batch: 1176/1186 loss: 0.6041740775108337
Epoch: 0 Batch: 1177/1186 loss

In [9]:
torch.save((word2idx, tag2idx, idx2tag, char2idx, my_net.state_dict()), 'my_model.pth')

In [6]:
out_file = open("corpus.out",'w')
with open("corpus.test", 'r') as f:
    for line in f:
        
        X = [line.rstrip("\n").rstrip("\r").lower().split(" ")]
        tokens = line.rstrip("\n").rstrip("\r").split(" ")
        
        max_word_len = -1
        max_sentence_len = -1
        
        
        batch_char_X = []
        batch_words_X = []
        
        for sentence in X:
            if len(sentence) > max_sentence_len: 
                max_sentence_len = len(sentence)
            for word in sentence:
                if len(word) > max_word_len:
                    max_word_len = len(word)
        
        
        for sentence in X:
          
            words_ids = []
            words_chars_ids = []
            
            for word in sentence:
                
                if word not in word2idx:
                    words_ids.append(word2idx["<UNKNOWN>"])
                else:
                    words_ids.append(word2idx[word])
                char_ids = []
                
                for c in word:
                    if c not in char2idx:
                        char_ids.append(char2idx['\r'])
                    else:
                        char_ids.append(char2idx[c])

                words_chars_ids.append(char_ids)
            
            batch_char_X.append(words_chars_ids)
            batch_words_X.append(words_ids)
        
        with torch.no_grad():
            tags = []
            output = my_net(batch_words_X, batch_char_X, max_word_len, max_sentence_len)
            output = output.view(max_sentence_len,-1)
            for e in output:
                tags.append(idx2tag[torch.argmax(e).item()])
                
            
        for token, tag in zip(tokens, tags):
            out_file.write(token+"/"+tag+" ")
        out_file.write('\n')
        
out_file.close()

In [7]:
!python tagger_eval.py corpus.out corpus.answer

Accuracy= 0.8069869573852978
