In [1]:
import torch
import torch.nn as nn
from torch.optim import SGD 
import numpy as np

# Упражнение, для реализации "Ванильной" RNN
* Попробуем обучить сеть восстанавливать слово hello по первой букве. т.е. построим charecter-level модель

In [2]:
a = torch.ones((3,3))*3
b = torch.ones((3,3))*5

In [3]:
a @ b

tensor([[45., 45., 45.],
        [45., 45., 45.],
        [45., 45., 45.]])

In [4]:
a * b

tensor([[15., 15., 15.],
        [15., 15., 15.],
        [15., 15., 15.]])

In [5]:
word = 'ololoasdasddqweqw123456789'
# word = 'hello'

## Датасет. 
Позволяет:
* Закодировать символ при помощи one-hot
* Делать итератор по слову, которыей возвращает текущий символ и следующий как таргет

In [161]:
class WordDataSet:
    
    def __init__(self, word):
        self.chars2idx = {}
        self.indexs  = []
        for c in word: 
            if c not in self.chars2idx:
                self.chars2idx[c] = len(self.chars2idx)
                
            self.indexs.append(self.chars2idx[c])
            
        self.vec_size = len(self.chars2idx)
        self.seq_len  = len(word)
        
    def get_one_hot(self, idx):
        x = torch.zeros(self.vec_size)
        x[idx] = 1
        return x
    
    def __iter__(self):
        return zip(self.indexs[:-1], self.indexs[1:])
    
    def __len__(self):
        return self.seq_len
    
    def get_char_by_id(self, id):
        for c, i in self.chars2idx.items():
            if id == i: return c
        return None

## Реализация базовой RNN
<br/>
Скрытый элемент
$$ h_t= tanh⁡ (W_{ℎℎ} h_{t−1}+W_{xh} x_t) $$
Выход сети

$$ y_t = W_{hy} h_t $$

In [162]:
class VanillaRNN(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=3, out_size=5):
        super(VanillaRNN, self).__init__()        
        self.x2hidden    = nn.Linear(in_features=in_size, out_features=hidden_size)
        self.hidden      = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.activation  = nn.Tanh()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        hidden = self.activation(self.x2hidden(x) + self.hidden(prev_hidden))
#         Версия без активации - может происходить gradient exploding
#         hidden = self.x2hidden(x) + self.hidden(prev_hidden)
        output = self.outweight(hidden)
        return output, hidden

## Инициализация переменных 

In [178]:
ds = WordDataSet(word=word)
rnn1 = VanillaRNN(in_size=ds.vec_size, hidden_size=3, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 10000
optim     = SGD(rnn1.parameters(), lr = 0.1, momentum=0.9)

# Обучение

In [179]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.zeros(rnn1.hidden.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh = rnn1(x, hh)
        

        loss += criterion(y, target)
     

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn1.parameters(), max_norm=2))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn1.parameters(), max_norm=1)
            
#     print("Params : ")
#     num_params = 0
#     for item in rnn.parameters():
#         num_params += 1
#         print(item.grad)
#     print("NumParams :", num_params)
#     print("Optimize")
    
    optim.step()

71.0382308959961
Clip gradient :  5.881846659066268
55.74384307861328
Clip gradient :  4.674657666492074
35.60104751586914
Clip gradient :  5.047921879825889
23.772014617919922
Clip gradient :  3.616060308489141
18.032136917114258
Clip gradient :  4.13736659590644
15.71683120727539
Clip gradient :  15.003905337801223
13.512069702148438
Clip gradient :  5.801668751589889
12.460285186767578
Clip gradient :  15.415602514608247
12.053683280944824
Clip gradient :  17.355391193150446
10.893165588378906
Clip gradient :  7.401825309789792
10.354090690612793
Clip gradient :  6.8910387846369145
9.643281936645508
Clip gradient :  3.9440388172386562
9.580869674682617
Clip gradient :  9.240081924249633
9.25111198425293
Clip gradient :  10.435001446845906
9.776644706726074
Clip gradient :  10.43875186859597
9.175935745239258
Clip gradient :  7.282981559624126
8.38603401184082
Clip gradient :  4.229935079939474
9.120325088500977
Clip gradient :  12.562104771803552
9.217437744140625
Clip gradient :  2

1.3462629318237305
Clip gradient :  3.4112419743302027
1.5480079650878906
Clip gradient :  7.979000563772641
1.702042579650879
Clip gradient :  10.846669791006098
1.4909677505493164
Clip gradient :  10.255905526085519
1.9308748245239258
Clip gradient :  15.407110055715627
1.671834945678711
Clip gradient :  15.081982130558055
1.5122089385986328
Clip gradient :  10.913488555126678
2.0426931381225586
Clip gradient :  12.906326386972703
1.7485952377319336
Clip gradient :  6.844710561419286
3.749911308288574
Clip gradient :  27.9603647036558
1.949082374572754
Clip gradient :  17.799823635379013
1.8200321197509766
Clip gradient :  15.878573983232501
1.206864356994629
Clip gradient :  5.236921097820488
4.495697975158691
Clip gradient :  60.71844309602484
1.3233451843261719
Clip gradient :  5.844804318075425
3.437544822692871
Clip gradient :  12.790124962011747
2.0343093872070312
Clip gradient :  14.958048661281055
1.4447870254516602
Clip gradient :  6.223204054780664
2.8775196075439453
Clip g

0.5665302276611328
Clip gradient :  1.8695802508723174
1.4600334167480469
Clip gradient :  29.55045608032087
0.6324882507324219
Clip gradient :  3.2719736468102933
0.8814888000488281
Clip gradient :  8.238295983987484
8.584476470947266
Clip gradient :  178.01911547163493
5.274883270263672
Clip gradient :  78.45390630183901
0.7118244171142578
Clip gradient :  2.626897362996927
0.7304983139038086
Clip gradient :  4.600870718039973
0.9129800796508789
Clip gradient :  7.041396332908566
0.7214794158935547
Clip gradient :  5.515673541738266
0.7102127075195312
Clip gradient :  4.91566605608599
1.7126922607421875
Clip gradient :  19.702844263835956
0.991154670715332
Clip gradient :  17.893333058469924
0.5363130569458008
Clip gradient :  1.6905116648136103
0.7588319778442383
Clip gradient :  6.393897807699777
1.5529117584228516
Clip gradient :  27.357093876200437
1.5378484725952148
Clip gradient :  25.768062106347823
0.8391990661621094
Clip gradient :  8.276847526469531
1.342233657836914
Clip g

0.44631004333496094
Clip gradient :  3.0904846081602253
0.39492225646972656
Clip gradient :  2.0778361416604234
1.3287067413330078
Clip gradient :  24.175500608610932
0.5795211791992188
Clip gradient :  4.228632788826174
1.187469482421875
Clip gradient :  12.773323332897998
3.2228965759277344
Clip gradient :  41.4480730885317
4.8127593994140625
Clip gradient :  57.68481024085199
0.7752227783203125
Clip gradient :  16.099816202933063
0.6434459686279297
Clip gradient :  6.132388459473015
0.4649085998535156
Clip gradient :  3.7466121680217293
0.4749603271484375
Clip gradient :  5.4149002293077
0.9210872650146484
Clip gradient :  9.114594220047728
2.2774734497070312
Clip gradient :  29.076122949902413
1.829742431640625
Clip gradient :  55.383769507650136
3.5915632247924805
Clip gradient :  22.63767716471367
1.1868820190429688
Clip gradient :  7.487788853379672
1.772775650024414
Clip gradient :  59.307496502193295
0.5889253616333008
Clip gradient :  5.0334076030888255
1.7239055633544922
Cli

0.23334884643554688
Clip gradient :  1.5546895397114076
0.2674274444580078
Clip gradient :  1.6610532856601425
0.20411300659179688
Clip gradient :  0.7019361068063912
0.19288253784179688
Clip gradient :  1.1023377295245314
0.20381927490234375
Clip gradient :  1.314430453312232
0.2187175750732422
Clip gradient :  3.065193198056277
0.39957237243652344
Clip gradient :  7.349297164704774
1.6682815551757812
Clip gradient :  45.22556992916989
2.1410770416259766
Clip gradient :  37.234830281787744
0.3763999938964844
Clip gradient :  4.10544430167686
1.3273582458496094
Clip gradient :  30.46708224307653
0.31284523010253906
Clip gradient :  2.504459123538107
0.9675941467285156
Clip gradient :  38.46243450902833
0.6703510284423828
Clip gradient :  14.273904027453247
0.6259765625
Clip gradient :  5.458771020844024
0.3272266387939453
Clip gradient :  3.807450240987609
0.22232437133789062
Clip gradient :  1.419572010264633
0.24737930297851562
Clip gradient :  1.9051601624657473
0.21570205688476562


0.5144100189208984
Clip gradient :  8.849927177882856
2.0158443450927734
Clip gradient :  59.7826783572345
2.1300888061523438
Clip gradient :  79.62064242774426
0.5359249114990234
Clip gradient :  10.537354167700478
0.8674182891845703
Clip gradient :  17.8795816972666
2.0936851501464844
Clip gradient :  42.42296897981315
0.2076091766357422
Clip gradient :  1.9079403293845203
0.5388984680175781
Clip gradient :  7.602341919065913
0.22070693969726562
Clip gradient :  4.821067457184352
0.24562835693359375
Clip gradient :  3.9045536386579833
0.3003978729248047
Clip gradient :  4.924820591660902
0.2108325958251953
Clip gradient :  2.2766844827407335
0.2040424346923828
Clip gradient :  1.4715899862185584
0.22865867614746094
Clip gradient :  2.683981114182937
0.5101871490478516
Clip gradient :  14.381229610942452
0.91217041015625
Clip gradient :  11.787749700624898
0.6115550994873047
Clip gradient :  12.899265546499317
0.3489494323730469
Clip gradient :  10.706224218089739
0.1703643798828125
C

0.8463535308837891
Clip gradient :  28.963633970761574
0.18253135681152344
Clip gradient :  2.3574996583785524
0.21013450622558594
Clip gradient :  5.55542248844894
0.09014892578125
Clip gradient :  0.8262502519135411
0.27696800231933594
Clip gradient :  5.469494221686908
0.9196910858154297
Clip gradient :  21.470355652169566
0.19969558715820312
Clip gradient :  2.4231846208946832
5.104576110839844
Clip gradient :  100.04403789951023
2.4540939331054688
Clip gradient :  77.57506558899388
3.4731884002685547
Clip gradient :  216.46568132129252
0.09865379333496094
Clip gradient :  0.3678403440940792
0.13039779663085938
Clip gradient :  2.6318551854685968
0.44711875915527344
Clip gradient :  17.801542907432445
0.1287670135498047
Clip gradient :  1.6773797608292713
0.12378501892089844
Clip gradient :  1.8017415237440808
0.09549331665039062
Clip gradient :  1.0712680058193327
0.08515739440917969
Clip gradient :  0.6711600693185019
0.06825828552246094
Clip gradient :  0.22309998486294827
0.067

# Тестирование

In [180]:
rnn.eval()
hh = torch.zeros(rnn.hidden.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = rnn(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789


# ДЗ
Реализовать LSTM и GRU модули, обучить их предсказывать тестовое слово

In [181]:
#тестовое слово
word = 'ololoasdasddqweqw123456789'

## Реализовать LSTM

In [182]:
#Написать реализацию LSTM и обучить предсказывать слово
class LSTM(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=6, state_size=6, out_size=5):
        super(LSTM, self).__init__()        
        self.xc = nn.Linear(in_features=in_size, out_features=state_size, bias=False)
        self.hc = nn.Linear(in_features=hidden_size, out_features=state_size, bias=True)
        self.hc.bias.data.fill_(6)
        self.tanh = nn.Tanh()
        
        self.xi = nn.Linear(in_features=in_size, out_features=state_size, bias=False)
        self.hi = nn.Linear(in_features=hidden_size, out_features=state_size, bias=True)
        self.hi.bias.data.fill_(6)
        
        self.xf = nn.Linear(in_features=in_size, out_features=state_size, bias=False)
        self.hf = nn.Linear(in_features=hidden_size, out_features=state_size, bias=True)
        self.hf.bias.data.fill_(6)
        
        self.xo = nn.Linear(in_features=in_size, out_features=state_size, bias=False)
        self.ho = nn.Linear(in_features=hidden_size, out_features=state_size, bias=True)
        self.ho.bias.data.fill_(6)
        self.sigmoid = nn.Sigmoid()
        self.outweight   = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden, prev_state):
        candidate_cell_state = self.tanh(self.xc(x) + self.hc(prev_hidden))
        input_gate = self.sigmoid(self.xi(x) + self.hi(prev_hidden))
        forget_gate = self.sigmoid(self.xf(x) + self.hf(prev_hidden))
        output_gate = self.sigmoid(self.xo(x) + self.ho(prev_hidden))
        cell_state = forget_gate*prev_state + input_gate*candidate_cell_state
        hidden = output_gate*self.tanh(cell_state)
        output = self.outweight(hidden)
        
        return output, hidden, cell_state

## Инициализация LSTM

In [252]:
ds = WordDataSet(word=word)
rnn2 = LSTM(in_size=ds.vec_size, hidden_size=6, state_size=6, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 200
optim     = SGD(rnn2.parameters(), lr = 1, momentum=0.9)

# Обучение LSTM

In [253]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.ones(rnn2.hc.in_features)
    state = torch.zeros(rnn2.xc.out_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh, state = rnn2(x, hh, state)

        loss += criterion(y, target)
    

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn2.parameters(), max_norm=2))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn2.parameters(), max_norm=1)
    
    optim.step()

75.84354400634766
Clip gradient :  11.522329270796165
74.91972351074219
Clip gradient :  9.591151626771234
73.9619140625
Clip gradient :  10.551394261642445
70.35196685791016
Clip gradient :  7.7366944425976865
68.41836547851562
Clip gradient :  19.936452892031834
51.46903610229492
Clip gradient :  6.705567648265524
33.15876007080078
Clip gradient :  8.433708581021468
19.583295822143555
Clip gradient :  6.551302343994269
9.421567916870117
Clip gradient :  2.1063759829360085
6.251310348510742
Clip gradient :  0.983257901055144
3.9333314895629883
Clip gradient :  1.9499699676476923
3.0828914642333984
Clip gradient :  0.3568301999199802
2.9148683547973633
Clip gradient :  0.20736570849706104
2.8450708389282227
Clip gradient :  0.04335171517285729
2.8078689575195312
Clip gradient :  0.06065720238225795
2.7147903442382812
Clip gradient :  0.14359491609534997
1.9104042053222656
Clip gradient :  0.772576448452729
0.4684133529663086
Clip gradient :  0.42435564511424195
0.2863607406616211
Clip 

# Тест

In [254]:
rnn2.eval()
hh = torch.ones(rnn2.hc.in_features)
state = torch.zeros(rnn2.xc.out_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh, state = rnn2(x, hh, state)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789


## Реализовать GRU

In [265]:
#Написать реализацию GRU и обучить предсказывать слово
class GRU(nn.Module):
    
    def __init__(self, in_size=5, hidden_size=6, out_size=5):
        super(GRU, self).__init__()        
        self.xu = nn.Linear(in_features=in_size, out_features=hidden_size, bias=False)
        self.hu = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
        self.hu.bias.data.fill_(0)
        
        self.xr = nn.Linear(in_features=in_size, out_features=hidden_size, bias=False)
        self.hr = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
        self.hr.bias.data.fill_(0)
        
        self.xh_ = nn.Linear(in_features=in_size, out_features=hidden_size, bias=False)
        self.hh_ = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=False)
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.outweight = nn.Linear(in_features=hidden_size, out_features=out_size)
    
    def forward(self, x, prev_hidden):
        update_gate = self.sigmoid(self.xu(x) + self.hu(prev_hidden))
        reset_gate = self.sigmoid(self.xr(x) + self.hr(prev_hidden))
        hidden_candidate = self.tanh(self.xh_(x) + self.hh_(reset_gate*prev_hidden))
        hidden = (1-update_gate)*hidden_candidate + update_gate*prev_hidden
        output = self.outweight(hidden)
                
        return output, hidden

# Инициализация GRU

In [341]:
ds = WordDataSet(word=word)
rnn3 = GRU(in_size=ds.vec_size, hidden_size=6, out_size=ds.vec_size)
criterion = nn.CrossEntropyLoss()
e_cnt     = 500
optim     = SGD(rnn3.parameters(), lr = 0.1, momentum=0.9)

# Обучение GRU

In [342]:
CLIP_GRAD = True

for epoch in range(e_cnt):
    hh = torch.zeros(rnn3.hu.in_features)
    loss = 0
    optim.zero_grad()
    for sample, next_sample in ds:
        x = ds.get_one_hot(sample).unsqueeze(0)
        target =  torch.LongTensor([next_sample])

        y, hh = rnn3(x, hh)

        loss += criterion(y, target)
    

    loss.backward()
    
    if epoch % 10 == 0:
        print (loss.data.item())
        if CLIP_GRAD: print("Clip gradient : ", torch.nn.utils.clip_grad_norm_(rnn3.parameters(), max_norm=2))
    else: 
        if CLIP_GRAD: torch.nn.utils.clip_grad_norm_(rnn3.parameters(), max_norm=1)
    
    optim.step()

71.53659057617188
Clip gradient :  3.404611410126129
61.22755813598633
Clip gradient :  4.05440308340974
32.53065872192383
Clip gradient :  4.382091490432036
11.969423294067383
Clip gradient :  2.7911241228538834
3.985781192779541
Clip gradient :  1.2169979116329834
2.0936622619628906
Clip gradient :  0.9504320235760928
1.1519384384155273
Clip gradient :  0.8172125945930072
0.33123779296875
Clip gradient :  0.3463197656576472
0.18771839141845703
Clip gradient :  0.16007212122341608
0.1376781463623047
Clip gradient :  0.0976659379027335
0.10856151580810547
Clip gradient :  0.05982024161931021
0.09231281280517578
Clip gradient :  0.051455649438854535
0.08148765563964844
Clip gradient :  0.0357676475847463
0.07346343994140625
Clip gradient :  0.025504556565421596
0.06722450256347656
Clip gradient :  0.023604549234468516
0.06206512451171875
Clip gradient :  0.020809149408325626
0.05772113800048828
Clip gradient :  0.019300138262100735
0.05398082733154297
Clip gradient :  0.0180376920274712

# Тесты

In [343]:
rnn3.eval()
hh = torch.ones(rnn2.hc.in_features)
id = 0
softmax  = nn.Softmax(dim=1)
predword = ds.get_char_by_id(id)
for c in enumerate(word[:-1]):
    x = ds.get_one_hot(id).unsqueeze(0)
    y, hh = rnn3(x, hh)
    y = softmax(y)
    m, id = torch.max(y, 1)
    id = id.data[0]
    predword += ds.get_char_by_id(id)
print ('Prediction:\t' , predword)
print("Original:\t", word)
assert(predword == word)

Prediction:	 ololoasdasddqweqw123456789
Original:	 ololoasdasddqweqw123456789
