In [1]:
import argparse
import torch
import os
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch import autograd
#from torchvision.datasets import MNIST
#from torchvision.transforms import transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR
from torch.nn import functional as F

In [54]:
class Pyramidal_GRU(nn.Module):
    def __init__(self,input_size, hidden_size, seq_len=8, stack_size=3):
        super(Pyramidal_GRU, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.seq_len=seq_len
        self.stack_size = 3
        self.gru0 = nn.GRU(self.input_size, self.hidden_size,batch_first=True) #initial projection
        
        
        self.pyramid = nn.ModuleList(
            [nn.GRU(2 * hidden_size, hidden_size, batch_first=True) for _ in range(stack_size)])
        
            
    def forward(self,input):
        x,hidden = self.gru0(input)
        
        print('x.size',x.size())
        
        seq_len = x.size(1)
        print('seq_len', seq_len)
        for i in range(self.stack_size):
            x = x.contiguous().view(-1,int(seq_len/2),int(2*self.hidden_size)) #need 'contiguous' or it errors out
            print('reshaped',x.size())
            seq_len /=2
            
            x,_ = self.pyramid[i](x)
            print('processed', x.size())
            
        print('final',x.size())

In [55]:
pyramid = Pyramidal_GRU(20,10,8,3)

In [56]:
print(pyramid)

Pyramidal_GRU(
  (gru0): GRU(20, 10, batch_first=True)
  (pyramid): ModuleList(
    (0): GRU(20, 10, batch_first=True)
    (1): GRU(20, 10, batch_first=True)
    (2): GRU(20, 10, batch_first=True)
  )
)


In [57]:
x = torch.randn(100,8,20)

In [58]:
pyramid(x)

x.size torch.Size([100, 8, 10])
seq_len 8
reshaped torch.Size([100, 4, 20])
processed torch.Size([100, 4, 10])
reshaped torch.Size([100, 2, 20])
processed torch.Size([100, 2, 10])
reshaped torch.Size([100, 1, 20])
processed torch.Size([100, 1, 10])
final torch.Size([100, 1, 10])


In [84]:
class Pyramidal_BiGRU(nn.Module):
    def __init__(self,input_size,hidden_size,seq_len=8,stack_size=3):
        super(Pyramidal_BiGRU,self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.seq_len=seq_len
        self.stack_size = 3
        self.gru0 = nn.GRU(self.input_size, self.hidden_size,batch_first=True,bidirectional=True) #initial projection
        
        self.pyramid = nn.ModuleList(
            [nn.GRU(4 * hidden_size, hidden_size, batch_first=True,bidirectional=True) for _ in range(stack_size)])
        
            
    def forward(self,input):
        x,hidden = self.gru0(input)
        
        print('x.size',x.size())
        seq_len = x.size(1)
        seq_len = self.seq_len
        for i in range(self.stack_size):
            x = x.contiguous().view(-1,int(seq_len/2),4*self.hidden_size) #need 'contiguous' or it errors out
            print('reshaped',x.size())
            seq_len /=2
            
            x,_ = self.pyramid[i](x)
            
            print('processed', x.size())
            
        print('final',x.size())        
        

In [85]:
pyramid = Pyramidal_BiGRU(20,10,8,3)

In [86]:
print(pyramid)

Pyramidal_BiGRU(
  (gru0): GRU(20, 10, batch_first=True, bidirectional=True)
  (pyramid): ModuleList(
    (0): GRU(40, 10, batch_first=True, bidirectional=True)
    (1): GRU(40, 10, batch_first=True, bidirectional=True)
    (2): GRU(40, 10, batch_first=True, bidirectional=True)
  )
)


In [87]:
x = torch.randn(10,8,20)

In [88]:
pyramid(x)

x.size torch.Size([10, 8, 20])
reshaped torch.Size([10, 4, 40])
processed torch.Size([10, 4, 20])
reshaped torch.Size([10, 2, 40])
processed torch.Size([10, 2, 20])
reshaped torch.Size([10, 1, 40])
processed torch.Size([10, 1, 20])
final torch.Size([10, 1, 20])
