In [None]:
import copy
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split, \
TensorDataset

In [18]:
class Encoder(nn.Module):
    def __init__(self,n_features,hidden_dim):
        
        super().__init__()
        self.hidden_dim=hidden_dim
        self.n_features=n_features
        self.hidden=None
        self.basic_rnn=nn.GRU(self.n_features,
                             self.hidden_dim,
                             batch_first=True)
        
    def forward(self,x):
        rnn_out,self.hidden=self.basic_rnn(x)
        return rnn_out
    

In [19]:
full_seq = (torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]])
.float()
.view(1, 4, 2))
source_seq = full_seq[:, :2] # first two corners
target_seq = full_seq[:, 2:] # last two corners

In [20]:
torch.manual_seed(21)
encoder = Encoder(n_features=2, hidden_dim=2)
hidden_seq = encoder(source_seq) # output is N, L, F
hidden_final = hidden_seq[:, -1:]
# takes last hidden state
hidden_final,hidden_seq

(tensor([[[ 0.3105, -0.5263]]], grad_fn=<SliceBackward0>),
 tensor([[[ 0.0832, -0.0356],
          [ 0.3105, -0.5263]]], grad_fn=<TransposeBackward1>))

In [23]:
class Decoder(nn.Module):
    
    def __init__(self,n_features,hidden_dim):
        
        super().__init__()
        
        self.hidden_dim=hidden_dim
        
        self.n_features=n_features
        
        self.hidden=None
        
        self.basic_rnn=nn.GRU(self.n_features,
                             self.hidden_dim,
                             batch_first=True)
        self.regression=nn.Linear(self.hidden_dim,
                                 self.n_features)
    def init_hidden(self,hidden_seq):
        
        hidden_final=hidden_seq[:,-1:]
        
        self.hidden=hidden_final.permute(1,0,2)
        
    def forward(self,x):
        
        batch_first_output , self.hidden=self.basic_rnn(x,self.hidden)
        
        last_output=batch_first_output[:,-1:]
        
        out=self.regression(last_output)
        
        return out.view(-1,1,self.n_features)
    
        

In [26]:
torch.manual_seed(21)
decoder=Decoder(2,2)

decoder.init_hidden(hidden_seq)
inputs=source_seq[:,-1:]

print('Inputs :',inputs)
target_len=2
for i in range(target_len):
    print(f'Hidden : {decoder.hidden}')
    out=decoder(inputs)
    print(f'outputs :{out}')
    inputs=out

Inputs tensor([[[-1.,  1.]]])
Hidden : tensor([[[ 0.3105, -0.5263]]], grad_fn=<PermuteBackward0>)
outputs :tensor([[[-0.2339,  0.4702]]], grad_fn=<ViewBackward0>)
Hidden : tensor([[[ 0.3913, -0.6853]]], grad_fn=<StackBackward0>)
outputs :tensor([[[-0.0226,  0.4628]]], grad_fn=<ViewBackward0>)


In [None]:
# Initial hidden state will be encoder's final hidden state
decoder.init_hidden(hidden_seq)
# Initial data point is the last element of source sequence
inputs = source_seq[:, -1:]

target_len=2
for i in range(target_len):
    print(f'Hidden : {decoder.hidden}')
    out=decoder(inputs)
    print(f'outputs :{out}')
    inputs=target_seq[:,i:i+1]
    print(f'i :{target_seq[:,i:i+1]}')

In [None]:
from tqdm import tqdm
decoder.init_hidden(hidden_seq)

inputs=source_seq[:,-1:]

teeacher_forcing_prob=0.5
target_len=2

for i in tqdm(range(target_len)):
    print(f'Hidden :{decoder.hidden}')
    
    out=decoder(inputs)
    
    print(f'output :{out}')
    
    if torch.rand(1)<=teacher_forcing_prob:
        input=target_seq[:,i:i+1]
    else:
        inputs=out