In [1]:
import datetime

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

In [2]:
data_dir = './data/'
train_lang = 'en'

In [3]:
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

class DatasetSeq(Dataset):
    def __init__(self, data_dir, train_lang='en'):
        #open file
        with open(data_dir + train_lang + '.train', 'r') as f:
            train = f.read().split('\n\n')

        # delete extra tag markup
        train = [x for x in train if not '_ ' in x]
	    #init vocabs of tokens for encoding {<str> token: <int> id}
        self.target_vocab = {} # {p: 1, a: 2, r: 3, pu: 4}
        self.word_vocab = {} # {cat: 1, sat: 2, on: 3, mat: 4, '.': 5}
        self.char_vocab = {} # {c: 1, a: 2, t: 3, ' ': 4, s: 5}
	    
        # Cat sat on mat. -> [1, 2, 3, 4, 5]
        # p    a  r  p pu -> [1, 2, 3, 1, 4]
        # chars  -> [1, 2, 3, 4, 5, 2, 3, 4]

	    #init encoded sequences lists (processed data)
        self.encoded_sequences = []
        self.encoded_targets = []
        self.encoded_char_sequences = []
        # n=1 because first value is padding
        n_word = 1
        n_target = 1
        n_char = 1
        for line in train:
            sequence = []
            target = []
            chars = []
            for item in line.split('\n'):
                if item != '':
                    word, label = item.split(' ')

                    if self.word_vocab.get(word) is None:
                        self.word_vocab[word] = n_word
                        n_word += 1
                    if self.target_vocab.get(label) is None:
                        self.target_vocab[label] = n_target
                        n_target += 1
                    for char in word:
                        if self.char_vocab.get(char) is None:
                            self.char_vocab[char] = n_char
                            n_char += 1
                    sequence.append(self.word_vocab[word])
                    target.append(self.target_vocab[label])
                    chars.append([self.char_vocab[char] for char in word])
            self.encoded_sequences.append(sequence)
            self.encoded_targets.append(target)
            self.encoded_char_sequences.append(chars)

    def __len__(self):
        return len(self.encoded_sequences)

    def __getitem__(self, index):
        return {
            'data': self.encoded_sequences[index], # [1, 2, 3, 4, 6] len=5
            'char': self.encoded_char_sequences[index],# [[1,2,3], [4,5], [1,2], [2,6,5,4], []] len=5
            'target': self.encoded_targets[index], #  (1)
        }

In [4]:
dataset = DatasetSeq(data_dir)

In [5]:
#padding
# seq1 = [1, 2, 3, 4]
# seq2 = [9, 7, 6, 4, 3, 7, 5]
# pad seq1 equal seq2
# seq1 = [1, 2, 3, 4, 0, 0, 0]
# concat(seq1, seq2) [[1, 2, 3, 4, 0, 0, 0],
#                     [9, 7, 6, 4, 3, 7, 5]]

In [6]:
def collate_fn(batch):
    data = []
    target = []
    for item in batch:
        data.append(torch.as_tensor(item['data']))
        target.append(torch.as_tensor(item['target']))
    # pad different length sequences
    data = pad_sequence(data, batch_first=True, padding_value=0)
    target = pad_sequence(target, batch_first=True, padding_value=0)

    return {'data': data, 'target': target}

In [7]:
#hyper params
vocab_size = len(dataset.word_vocab) + 1
n_classes = len(dataset.target_vocab) + 1
n_chars = len(dataset.char_vocab) + 1
#TODO try to use other model parameters
emb_dim = 256
hidden = 256
n_epochs = 5 #10
cuda_device = -1
batch_size = 64
device = f'cuda:{cuda_device}' if cuda_device != -1 else 'cpu'

### RNN

In [8]:
class RNNPredictor(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes):
        super().__init__()
        #TODO try to use other RNN archicetures, f.e. RNN and LSTM
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        # batch_first = False: T x B x Vec
        # batch_first = True: B x T x Vec
        self.rnn = nn.RNN(emb_dim, hidden_dim, batch_first=True) 
        self.clf = nn.Linear(hidden_dim, n_classes)
        self.do = nn.Dropout(0.1)
        self.hidden_dim = hidden_dim

    def forward(self, x):
        emb = self.word_emb(x)
        hidden, _ = self.rnn(emb)

        return self.clf(self.do(hidden))

In [9]:
model = RNNPredictor(vocab_size, emb_dim, hidden, n_classes).to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [10]:
start = datetime.datetime.now()

for epoch in range(n_epochs):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    for i, batch in enumerate(dataloader):
        optim.zero_grad()

        predict = model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes), batch['target'].to(device).view(-1))
        loss.backward()
        optim.step()
        
        if i % 100 == 0:
            print(f'epoch: {epoch:02d}, step: {i:03d}, loss: {loss.item():.16f}')
   
    torch.save(model.state_dict(), f'./data/rnn_chkpt_{epoch}.pth')
    
    end = datetime.datetime.now() - start

print(f'time: {end}')

epoch: 00, step: 000, loss: 2.8783037662506104
epoch: 00, step: 100, loss: 0.3520188927650452
epoch: 00, step: 200, loss: 0.2416084706783295
epoch: 00, step: 300, loss: 0.2162329405546188
epoch: 01, step: 000, loss: 0.1988760977983475
epoch: 01, step: 100, loss: 0.1368701159954071
epoch: 01, step: 200, loss: 0.1138945966959000
epoch: 01, step: 300, loss: 0.1934471130371094
epoch: 02, step: 000, loss: 0.1269728839397430
epoch: 02, step: 100, loss: 0.0996958985924721
epoch: 02, step: 200, loss: 0.1021485775709152
epoch: 02, step: 300, loss: 0.1236945837736130
epoch: 03, step: 000, loss: 0.1009382158517838
epoch: 03, step: 100, loss: 0.1115887165069580
epoch: 03, step: 200, loss: 0.0929358825087547
epoch: 03, step: 300, loss: 0.0899294540286064
epoch: 04, step: 000, loss: 0.0669453814625740
epoch: 04, step: 100, loss: 0.0569904074072838
epoch: 04, step: 200, loss: 0.0868021994829178
epoch: 04, step: 300, loss: 0.0921812132000923
time: 0:02:34.407387


In [11]:
#example
phrase = 'He ran quickly after the red bus and caught it'
words = phrase.split(' ')
tokens = [dataset.word_vocab[w] for w in words]

start = datetime.datetime.now()
with torch.no_grad():
    model.eval()
    predict = model(torch.tensor(tokens).unsqueeze(0).to(device)) # 1 x T x N_classes
    labels = torch.argmax(predict, dim=-1).squeeze().cpu().detach().tolist()
    end = datetime.datetime.now() - start

target_labels = list(dataset.target_vocab.keys())

print([target_labels[l-1] for l in labels])
print(f'time: {end}')

['PRON', 'VERB', 'ADV', 'ADP', 'DET', 'ADJ', 'NOUN', 'CCONJ', 'VERB', 'PRON']
time: 0:00:00.035584


### GRU

In [12]:
class GRUPredictor(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes):
        super().__init__()
        #TODO try to use other RNN archicetures, f.e. RNN and LSTM
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        # batch_first = False: T x B x Vec
        # batch_first = True: B x T x Vec
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True) 
        self.clf = nn.Linear(hidden_dim, n_classes)
        self.do = nn.Dropout(0.1)
        self.hidden_dim = hidden_dim

    def forward(self, x):
        emb = self.word_emb(x)
        hidden, _ = self.rnn(emb)

        return self.clf(self.do(hidden))

In [13]:
model = GRUPredictor(vocab_size, emb_dim, hidden, n_classes).to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [14]:
start = datetime.datetime.now()

for epoch in range(n_epochs):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    for i, batch in enumerate(dataloader):
        optim.zero_grad()

        predict = model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes), batch['target'].to(device).view(-1))
        loss.backward()
        optim.step()
        
        if i % 100 == 0:
            print(f'epoch: {epoch:02d}, step: {i:03d}, loss: {loss.item():.16f}')
   
    torch.save(model.state_dict(), f'./data/gru_chkpt_{epoch}.pth')
    
    end = datetime.datetime.now() - start

print(f'time: {end}')

epoch: 00, step: 000, loss: 2.9530749320983887
epoch: 00, step: 100, loss: 0.3745443522930145
epoch: 00, step: 200, loss: 0.2296777069568634
epoch: 00, step: 300, loss: 0.1896043270826340
epoch: 01, step: 000, loss: 0.1815935224294662
epoch: 01, step: 100, loss: 0.1371265947818756
epoch: 01, step: 200, loss: 0.1430202573537827
epoch: 01, step: 300, loss: 0.1426171660423279
epoch: 02, step: 000, loss: 0.1110012009739876
epoch: 02, step: 100, loss: 0.1385727524757385
epoch: 02, step: 200, loss: 0.1213863268494606
epoch: 02, step: 300, loss: 0.0888922661542892
epoch: 03, step: 000, loss: 0.1266777813434601
epoch: 03, step: 100, loss: 0.0427642539143562
epoch: 03, step: 200, loss: 0.1140308678150177
epoch: 03, step: 300, loss: 0.0906578674912453
epoch: 04, step: 000, loss: 0.0654601976275444
epoch: 04, step: 100, loss: 0.0656056106090546
epoch: 04, step: 200, loss: 0.0940924882888794
epoch: 04, step: 300, loss: 0.0726933404803276
time: 0:04:40.834782


In [15]:
#example
phrase = 'He ran quickly after the red bus and caught it'
words = phrase.split(' ')
tokens = [dataset.word_vocab[w] for w in words]

start = datetime.datetime.now()
with torch.no_grad():
    model.eval()
    predict = model(torch.tensor(tokens).unsqueeze(0).to(device)) # 1 x T x N_classes
    labels = torch.argmax(predict, dim=-1).squeeze().cpu().detach().tolist()
    end = datetime.datetime.now() - start

target_labels = list(dataset.target_vocab.keys())

print([target_labels[l-1] for l in labels])
print(f'time: {end}')

['PRON', 'VERB', 'ADV', 'SCONJ', 'DET', 'ADJ', 'NOUN', 'CCONJ', 'VERB', 'PRON']
time: 0:00:00.021163


### LSTM

In [16]:
class LSTMPredictor(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes):
        super().__init__()
        #TODO try to use other RNN archicetures, f.e. RNN and LSTM
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        # batch_first = False: T x B x Vec
        # batch_first = True: B x T x Vec
        self.rnn = nn.LSTM(emb_dim, hidden_dim, batch_first=True) 
        self.clf = nn.Linear(hidden_dim, n_classes)
        self.do = nn.Dropout(0.1)
        self.hidden_dim = hidden_dim

    def forward(self, x):
        emb = self.word_emb(x)
        hidden, _ = self.rnn(emb)

        return self.clf(self.do(hidden))    

In [17]:
model = LSTMPredictor(vocab_size, emb_dim, hidden, n_classes).to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [18]:
start = datetime.datetime.now()

for epoch in range(n_epochs):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    for i, batch in enumerate(dataloader):
        optim.zero_grad()

        predict = model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes), batch['target'].to(device).view(-1))
        loss.backward()
        optim.step()
        
        if i % 100 == 0:
            print(f'epoch: {epoch:02d}, step: {i:03d}, loss: {loss.item():.16f}')
   
    torch.save(model.state_dict(), f'./data/lstm_chkpt_{epoch}.pth')
    
    end = datetime.datetime.now() - start
    
print(f'time: {end}')

epoch: 00, step: 000, loss: 2.9133541584014893
epoch: 00, step: 100, loss: 0.4188470840454102
epoch: 00, step: 200, loss: 0.2069870233535767
epoch: 00, step: 300, loss: 0.1259828656911850
epoch: 01, step: 000, loss: 0.2121656835079193
epoch: 01, step: 100, loss: 0.1377781778573990
epoch: 01, step: 200, loss: 0.1373562812805176
epoch: 01, step: 300, loss: 0.1905425339937210
epoch: 02, step: 000, loss: 0.1031825914978981
epoch: 02, step: 100, loss: 0.0835773050785065
epoch: 02, step: 200, loss: 0.1040356904268265
epoch: 02, step: 300, loss: 0.1029168292880058
epoch: 03, step: 000, loss: 0.0933668762445450
epoch: 03, step: 100, loss: 0.0724321752786636
epoch: 03, step: 200, loss: 0.0848787724971771
epoch: 03, step: 300, loss: 0.0636043921113014
epoch: 04, step: 000, loss: 0.0543968081474304
epoch: 04, step: 100, loss: 0.0681346356868744
epoch: 04, step: 200, loss: 0.0870344415307045
epoch: 04, step: 300, loss: 0.0513709969818592
time: 0:05:34.413929


In [19]:
#example
phrase = 'He ran quickly after the red bus and caught it'
words = phrase.split(' ')
tokens = [dataset.word_vocab[w] for w in words]

start = datetime.datetime.now()
with torch.no_grad():
    model.eval()
    predict = model(torch.tensor(tokens).unsqueeze(0).to(device)) # 1 x T x N_classes
    labels = torch.argmax(predict, dim=-1).squeeze().cpu().detach().tolist()
    end = datetime.datetime.now() - start

target_labels = list(dataset.target_vocab.keys())

print([target_labels[l-1] for l in labels])
print(f'time: {end}')

['PRON', 'VERB', 'ADV', 'SCONJ', 'DET', 'ADJ', 'NOUN', 'CCONJ', 'VERB', 'PRON']
time: 0:00:00.133802


**Вывод:**

Модель GRU оказалась самой оптимальной по качеству, скорости работы и скорости обучения.

Количество эпох: 3  
Размер бача: 128  

| Модель | Время обучения | Время работы | Loss |
| - | - | - | - |
| RNN | 0:00:57.377293 | 0:00:00.010970 | 0.1060799211263657 |
| GRU | 0:02:30.678112 | 0:00:00.010934 | 0.0839845910668373 |
| LSTM | 0:02:49.642584 | 0:00:00.013226 | 0.1057852804660797 |

Количество эпох: 5  
Размер бача: 64  

| Модель | Время обучения | Время работы | Loss |
| - | - | - | - |
| RNN | 0:02:34.407387 | 0:00:00.035584 | 0.0921812132000923 |
| GRU | 0:04:40.834782 | 0:00:00.021163 | 0.0726933404803276 |
| LSTM | 0:05:34.413929 | 0:00:00.133802 | 0.0513709969818592 |
