In [213]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random

In [309]:
class TestGRU(nn.Module):
    def __init__(self, voc_len, sent_len, hidden_size, num_layers, filters, filter_len, num_convs, bidirectional):
        super(TestGRU, self).__init__()
        self.voc_len = voc_len
        self.sent_len = sent_len
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_filters = filters
        self.filter_len = filter_len
        self.num_convs = num_convs
        self.bidirectional = bidirectional
        multiplier = 1
        if self.bidirectional == True:
            multiplier = 2
        
        self.convs = []
        for n in range(self.num_convs):
            c = nn.Conv1d(sent_len,
                          self.num_filters,
                          n+1,
                          padding=0)
            self.convs.append(c)
 
        self.grus = []
        for n in range(self.num_convs):
            g = nn.GRU(self.num_filters * (self.voc_len - (n)),
                       self.hidden_size,
                       self.num_layers,
                       bidirectional=bidirectional)
            self.grus.append(g)

        self.fcs = []
        for n in range(self.num_convs):
            f =  nn.Linear(multiplier * self.hidden_size,
                           hidden_size)
            self.fcs.append(f)

        self.fc3 = nn.Linear(num_convs * self.hidden_size,
                             self.voc_len)
    
    def forward(self, x):
        x1 = []
        for n in range(self.num_convs):
            l = self.flatten(F.max_pool1d(F.relu(self.convs[n](x)), 3, 1, padding=1))
            x1.append(l)
        g1 = []
        h1 = []
        for n in range(self.num_convs):
            g, h = self.grus[n](x1[n])
            g = g.reshape(g.size()[1:])
            g1.append(g)
            h1.append(h)
        fc = []
        for n in range(self.num_convs):
            f = F.relu(self.fcs[n](g1[n]))
            fc.append(f)
        flat = torch.cat((fc), dim=1)
        out = self.fc3(flat)
        return out
    
    def flatten(self, v):
        f = self.flat_size(v)
        s = v.size()[0]
        return v.reshape(1, s, f)
    
    def flat_size(self, t):
        ns = t.size()[1:]
        f = 1
        for s in ns:
            f *= s
        return f

voc_len = 34
sent_len = 6
rnn_size = 20
num_layers = 1
filters = 20
filter_len = 1
num_convs = 8
bidirectional = False
t = TestGRU(voc_len, sent_len, rnn_size, num_layers, filters, filter_len, num_convs, bidirectional)
print(t)

seq_len = random.randint(2, 20)
i = torch.rand(seq_len, sent_len, voc_len)
print(i.size())
x = t(i)
print(x.size())
print("params")
for p in t.parameters():
    print(p.size())

TestGRU(
  (fc3): Linear(in_features=160, out_features=34, bias=True)
)
torch.Size([10, 6, 34])
torch.Size([10, 34])
params
torch.Size([34, 160])
torch.Size([34])
