RNN architectures exploration

In [1]:
import torch
from utils import count_params
from rnns import RNN, GRU, LSTM, BLSTM

In [2]:
rnn = RNN(28, 256, 2)
count_params(rnn)

207370

In [3]:
rnn(torch.rand(4, 1, 28, 28).squeeze(1)).shape

torch.Size([4, 10])

In [4]:
for name, param in rnn.named_parameters():
    print(f'{param.size()} : {name}')

torch.Size([256, 28]) : rnn.weight_ih_l0
torch.Size([256, 256]) : rnn.weight_hh_l0
torch.Size([256]) : rnn.bias_ih_l0
torch.Size([256]) : rnn.bias_hh_l0
torch.Size([256, 256]) : rnn.weight_ih_l1
torch.Size([256, 256]) : rnn.weight_hh_l1
torch.Size([256]) : rnn.bias_ih_l1
torch.Size([256]) : rnn.bias_hh_l1
torch.Size([10, 256]) : fc.weight
torch.Size([10]) : fc.bias


In [5]:
gru = GRU(28, 256, 2)
count_params(gru)

616970

In [6]:
gru(torch.rand(4, 1, 28, 28).squeeze(1)).shape

torch.Size([4, 10])

In [7]:
for name, param in gru.named_parameters():
    print(f'{param.size()} : {name}')

torch.Size([768, 28]) : gru.weight_ih_l0
torch.Size([768, 256]) : gru.weight_hh_l0
torch.Size([768]) : gru.bias_ih_l0
torch.Size([768]) : gru.bias_hh_l0
torch.Size([768, 256]) : gru.weight_ih_l1
torch.Size([768, 256]) : gru.weight_hh_l1
torch.Size([768]) : gru.bias_ih_l1
torch.Size([768]) : gru.bias_hh_l1
torch.Size([10, 256]) : fc.weight
torch.Size([10]) : fc.bias


In [8]:
lstm = LSTM(28, 256, 2)
count_params(lstm)

821770

In [9]:
lstm(torch.rand(4, 1, 28, 28).squeeze(1)).shape

torch.Size([4, 10])

In [10]:
for name, param in lstm.named_parameters():
    print(f'{param.size()} : {name}')

torch.Size([1024, 28]) : lstm.weight_ih_l0
torch.Size([1024, 256]) : lstm.weight_hh_l0
torch.Size([1024]) : lstm.bias_ih_l0
torch.Size([1024]) : lstm.bias_hh_l0
torch.Size([1024, 256]) : lstm.weight_ih_l1
torch.Size([1024, 256]) : lstm.weight_hh_l1
torch.Size([1024]) : lstm.bias_ih_l1
torch.Size([1024]) : lstm.bias_hh_l1
torch.Size([10, 256]) : fc.weight
torch.Size([10]) : fc.bias


In [11]:
blstm = BLSTM(28, 256, 2)
count_params(blstm)

2167818

In [12]:
blstm(torch.rand(4, 1, 28, 28).squeeze(1)).shape

torch.Size([4, 10])

In [13]:
for name, param in blstm.named_parameters():
    print(f'{param.size()} : {name}')

torch.Size([1024, 28]) : lstm.weight_ih_l0
torch.Size([1024, 256]) : lstm.weight_hh_l0
torch.Size([1024]) : lstm.bias_ih_l0
torch.Size([1024]) : lstm.bias_hh_l0
torch.Size([1024, 28]) : lstm.weight_ih_l0_reverse
torch.Size([1024, 256]) : lstm.weight_hh_l0_reverse
torch.Size([1024]) : lstm.bias_ih_l0_reverse
torch.Size([1024]) : lstm.bias_hh_l0_reverse
torch.Size([1024, 512]) : lstm.weight_ih_l1
torch.Size([1024, 256]) : lstm.weight_hh_l1
torch.Size([1024]) : lstm.bias_ih_l1
torch.Size([1024]) : lstm.bias_hh_l1
torch.Size([1024, 512]) : lstm.weight_ih_l1_reverse
torch.Size([1024, 256]) : lstm.weight_hh_l1_reverse
torch.Size([1024]) : lstm.bias_ih_l1_reverse
torch.Size([1024]) : lstm.bias_hh_l1_reverse
torch.Size([10, 512]) : fc.weight
torch.Size([10]) : fc.bias
