In [1]:
import torch
import torch.nn as nn
from torch.optim import SGD 
import numpy as np

# Упражнение, для реализации "Ванильной" RNN
* Попробуем обучить сеть восстанавливать слово hello по первой букве. т.е. построим charecter-level модель

In [2]:
a = torch.ones((3,3))*3
b = torch.ones((3,3))*5

In [3]:
a @ b

tensor([[45., 45., 45.],
        [45., 45., 45.],
        [45., 45., 45.]])

In [4]:
a * b

tensor([[15., 15., 15.],
        [15., 15., 15.],
        [15., 15., 15.]])

In [5]:
word = 'ololoasdasddqweqw123456789'
# word = 'hello'

## Датасет. 
Позволяет:
* Закодировать символ при помощи one-hot
* Делать итератор по слову, которыей возвращает текущий символ и следующий как таргет

In [6]:
class WordDataSet:
    
    def __init__(self, word):
        self.chars2idx = {}
        self.indexs  = []
        for c in word: 
            if c not in self.chars2idx:
                self.chars2idx[c] = len(self.chars2idx)
                
            self.indexs.append(self.chars2idx[c])
            
        #print(self.chars2idx)
        #print(self.indexs)
            
        self.vec_size = len(self.chars2idx)
        self.seq_len  = len(word)
        
    def get_one_hot(self, idx):
        x = torch.zeros(self.vec_size)
        x[idx] = 1
        return x
    
    def __iter__(self):
        return zip(self.indexs[:-1], self.indexs[1:])
    
    def __len__(self):
        return self.seq_len
    
    def get_char_by_id(self, id):
        for c, i in self.chars2idx.items():
            if id == i: return c
        return None

## Реализация базовой RNN
<br/>
Скрытый элемент
$$ h_t= tanh⁡ (W_{ℎℎ} h_{t−1}+W_{xh} x_t) $$
Выход сети

$$ y_t = W_{hy} h_t $$

In [7]:
class VanillaRNN(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(VanillaRNN, self).__init__()        
        
        self.n_a = hidden_size
        self.n_x = in_size
        self.n_y = out_size
        
        self.x2hidden    = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden      = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.activation  = nn.Tanh()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        hidden = self.activation(self.x2hidden(x) + self.hidden(prev_hidden))
#         Версия без активации - может происходить gradient exploding
#         hidden = self.x2hidden(x) + self.hidden(prev_hidden)
        output = self.outweight(hidden)
        return output, hidden

## Инициализация переменных 

In [8]:
ds = WordDataSet(word=word)
rnn = VanillaRNN(in_size=ds.vec_size, hidden_size=3, out_size=ds.vec_size)

# Обучение

In [9]:
def train(net, lr=0.1, n_epochs=100, CLIP_GRAD=True, max_norm=5):
    optim  = SGD(net.parameters(), lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(n_epochs):
        hh = torch.zeros(net.n_a)
        loss = 0
        optim.zero_grad()
        for sample, next_sample in ds:
            x = ds.get_one_hot(sample).unsqueeze(0)
            target =  torch.LongTensor([next_sample])
            y, hh = net(x, hh)
            loss += criterion(y, target)
            
        loss.backward()
    
        if epoch % 10 == 0:
            print (loss.data.item())
            if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm))
        else: 
            if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm)
    
        optim.step()

In [10]:
train(rnn, lr=0.01, n_epochs=150)

72.42355346679688
Clip gradient :  6.600681159038764
65.40913391113281
Clip gradient :  4.805861910842513
54.41361618041992
Clip gradient :  5.328119321187901
41.039649963378906
Clip gradient :  4.976227012313821
30.203020095825195
Clip gradient :  2.8555914418112227
24.489974975585938
Clip gradient :  2.448440428206962
21.300498962402344
Clip gradient :  1.568093778378552
19.28201675415039
Clip gradient :  1.7587875163415618
17.760635375976562
Clip gradient :  1.289146025764035
17.5850772857666
Clip gradient :  38.295315676671294
15.97107219696045
Clip gradient :  12.313857602049628
16.357086181640625
Clip gradient :  45.789410182594644
15.648397445678711
Clip gradient :  28.645256490782295
14.555351257324219
Clip gradient :  9.488562717721312
13.638294219970703
Clip gradient :  7.97922947124924


# Тестирование

In [11]:
rnn.eval()
hh = torch.zeros(rnn.hidden.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = rnn(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololololololololololololol
Original:	 ololoasdasddqweqw123456789


AssertionError: 

# ДЗ
Реализовать LSTM и GRU модули, обучить их предсказывать тестовое слово

In [12]:
#тестовое слово
word = 'ololoasdasddqweqw123456789'

## Реализовать LSTM

In [13]:
class LSTM_cell(nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super(LSTM_cell, self).__init__()
    
        self.n_a = hidden_size
        self.n_x = in_size
        self.n_y = out_size
        
        self.forget_gate = nn.Sequential(nn.Linear(in_size + hidden_size, hidden_size), nn.Sigmoid())
        self.update_gate = nn.Sequential(nn.Linear(in_size + hidden_size, hidden_size), nn.Sigmoid())
        self.output_gate = nn.Sequential(nn.Linear(in_size + hidden_size, hidden_size), nn.Sigmoid())
        self.cand_cell  = nn.Sequential(nn.Linear(in_size + hidden_size, hidden_size), nn.Tanh())
        self.out_weight = nn.Linear(hidden_size, out_size)
        
        self.hidden_activation = nn.Tanh()
        
    def forward(self, x, prev_hidden):
        concat = torch.cat((prev_hidden.squeeze(0), x.squeeze(0)))
        
        ft = self.forget_gate(concat)
        it = self.update_gate(concat)
        ot = self.output_gate(concat)
        
        cct = self.cand_cell(concat)
        c_next = cct * it + prev_hidden * ft
        a_next = ot * self.hidden_activation(c_next)
        
        output = self.out_weight(a_next).unsqueeze(0)
        return output, a_next
    

In [14]:
ds = WordDataSet(word=word)
lstm_rnn = LSTM_cell(in_size=ds.vec_size, hidden_size=3, out_size=ds.vec_size)

In [18]:
train(lstm_rnn, lr=0.1, n_epochs=300)

3.4686832427978516
Clip gradient :  16.469504612403775
7.138798713684082
Clip gradient :  15.02678278323623
4.251806735992432
Clip gradient :  2.7824029009647813
3.4166812896728516
Clip gradient :  2.014304863799851
3.014051914215088
Clip gradient :  1.137543142335301
2.7166848182678223
Clip gradient :  0.5987286151986356
2.5447959899902344
Clip gradient :  0.5014941530755052
2.4247255325317383
Clip gradient :  0.20293119753805675
2.335513114929199
Clip gradient :  0.20171146491761935
2.2603588104248047
Clip gradient :  0.12482335612137305
2.2075977325439453
Clip gradient :  0.9097749663593979
2.2112159729003906
Clip gradient :  1.1356064391567149
2.1422786712646484
Clip gradient :  0.5871379969164257
2.079409599304199
Clip gradient :  0.35148417782362
2.032583236694336
Clip gradient :  0.2544282075593267
1.9950084686279297
Clip gradient :  0.1392638815119641
1.962590217590332
Clip gradient :  0.08937392892108649
1.9338808059692383
Clip gradient :  0.07048798747028015
1.907950401306152

In [19]:
def evaluate(net):
    net = net.eval()
    hh = torch.zeros(net.n_a)
    id = 0
    
    softmax  = nn.Softmax(dim=1)
    predword = ds.get_char_by_id(id)
    for c in enumerate(word[:-1]):
        x = ds.get_one_hot(id).unsqueeze(0)
        y, hh = net(x, hh)
        y = softmax(y)
        m, id = torch.max(y, 1)
        id = id.data[0]
        predword += ds.get_char_by_id(id)
    print ('Prediction:\t' , predword)
    print("Original:\t", word)
    assert(predword == word)

In [20]:
evaluate(lstm_rnn)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789


## Реализовать GRU

In [21]:
class GRU_cell(nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super(GRU_cell, self).__init__()
    
        self.n_a = hidden_size
        self.n_x = in_size
        self.n_y = out_size
        
        self.update_gate = nn.Sequential(nn.Linear(in_size + hidden_size, hidden_size), nn.Sigmoid())
        self.relevance_gate = nn.Sequential(nn.Linear(in_size + hidden_size, hidden_size), nn.Sigmoid())
        self.candidate_cell  = nn.Sequential(nn.Linear(in_size + hidden_size, hidden_size), nn.Tanh())
        self.out_weight = nn.Linear(hidden_size, out_size)
        
        self.hidden_activation = nn.Tanh()
        self.candiate_activation = nn.Tanh()
        
    def forward(self, x, prev_hidden):
        a_prev = prev_hidden.squeeze(0)
        concat = torch.cat((a_prev, x.squeeze(0)))
        
        update_gate = self.update_gate(concat)
        relevance_gate = self.relevance_gate(concat)
        rel_candidate = relevance_gate * a_prev
        
        concat = torch.cat((rel_candidate, x.squeeze(0)))
        candidate = self.candidate_cell(concat)
        
        c_next = candidate * (1 - update_gate) + update_gate * a_prev
        output = self.out_weight(c_next).unsqueeze(0)
        return output, c_next

In [22]:
ds = WordDataSet(word=word)
gru_rnn = GRU_cell(in_size=ds.vec_size, hidden_size=5, out_size=ds.vec_size)

In [23]:
train(gru_rnn, lr=0.1, n_epochs=300)

71.56558990478516
Clip gradient :  3.831838887589271
33.33504104614258
Clip gradient :  5.853957675093574
10.06079387664795
Clip gradient :  10.716120767357921
5.819554328918457
Clip gradient :  5.733841121431512
5.50185489654541
Clip gradient :  9.764282672122878
3.0489377975463867
Clip gradient :  2.280550932119522
1.0880203247070312
Clip gradient :  3.3801880246201286
0.3574981689453125
Clip gradient :  0.38930667379147355
0.44561290740966797
Clip gradient :  4.597023382639376
0.19213485717773438
Clip gradient :  0.5352082028258958
0.1529092788696289
Clip gradient :  0.5568063524997623
0.11562252044677734
Clip gradient :  0.13055075322003332
0.10169601440429688
Clip gradient :  0.05041791309738597
0.092987060546875
Clip gradient :  0.04977069143156657
0.08562374114990234
Clip gradient :  0.029451627886872573
0.07989692687988281
Clip gradient :  0.026834745507326313
0.07496356964111328
Clip gradient :  0.022148995739883572
0.0707244873046875
Clip gradient :  0.01960121608615734
0.067

In [24]:
evaluate(gru_rnn)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789
