In [30]:
import copy
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split, \
TensorDataset

In [31]:
class Encoder(nn.Module):
    def __init__(self,n_features,hidden_dim):
        
        super().__init__()
        self.hidden_dim=hidden_dim
        self.n_features=n_features
        self.hidden=None
        self.basic_rnn=nn.GRU(self.n_features,
                             self.hidden_dim,
                             batch_first=True)
        
    def forward(self,x):
        rnn_out,self.hidden=self.basic_rnn(x)
        return rnn_out
    

In [32]:
full_seq = (torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]])
.float()
.view(1, 4, 2))
source_seq = full_seq[:, :2] # first two corners
target_seq = full_seq[:, 2:] # last two corners

In [33]:
torch.manual_seed(21)
encoder = Encoder(n_features=2, hidden_dim=2)
hidden_seq = encoder(source_seq) # output is N, L, F
hidden_final = hidden_seq[:, -1:]
# takes last hidden state
hidden_final,hidden_seq

(tensor([[[ 0.3105, -0.5263]]], grad_fn=<SliceBackward0>),
 tensor([[[ 0.0832, -0.0356],
          [ 0.3105, -0.5263]]], grad_fn=<TransposeBackward1>))

In [34]:
class Decoder(nn.Module):
    
    def __init__(self,n_features,hidden_dim):
        
        super().__init__()
        
        self.hidden_dim=hidden_dim
        
        self.n_features=n_features
        
        self.hidden=None
        
        self.basic_rnn=nn.GRU(self.n_features,
                             self.hidden_dim,
                             batch_first=True)
        self.regression=nn.Linear(self.hidden_dim,
                                 self.n_features)
    def init_hidden(self,hidden_seq):
        
        hidden_final=hidden_seq[:,-1:]
        
        self.hidden=hidden_final.permute(1,0,2)
        
    def forward(self,x):
        
        batch_first_output , self.hidden=self.basic_rnn(x,self.hidden)
        
        last_output=batch_first_output[:,-1:]
        
        out=self.regression(last_output)
        
        return out.view(-1,1,self.n_features)
    
        

In [35]:
torch.manual_seed(21)
decoder=Decoder(2,2)

decoder.init_hidden(hidden_seq)
inputs=source_seq[:,-1:]

print('Inputs :',inputs)
target_len=2
for i in range(target_len):
    print(f'Hidden : {decoder.hidden}')
    out=decoder(inputs)
    print(f'outputs :{out}')
    inputs=out

Inputs : tensor([[[-1.,  1.]]])
Hidden : tensor([[[ 0.3105, -0.5263]]], grad_fn=<PermuteBackward0>)
outputs :tensor([[[-0.2339,  0.4702]]], grad_fn=<ViewBackward0>)
Hidden : tensor([[[ 0.3913, -0.6853]]], grad_fn=<StackBackward0>)
outputs :tensor([[[-0.0226,  0.4628]]], grad_fn=<ViewBackward0>)


In [36]:
# Initial hidden state will be encoder's final hidden state
decoder.init_hidden(hidden_seq)
# Initial data point is the last element of source sequence
inputs = source_seq[:, -1:]

target_len=2
for i in range(target_len):
    print(f'Hidden : {decoder.hidden}')
    out=decoder(inputs)
    print(f'outputs :{out}')
    inputs=target_seq[:,i:i+1]
    print(f'i :{target_seq[:,i:i+1]}')

Hidden : tensor([[[ 0.3105, -0.5263]]], grad_fn=<PermuteBackward0>)
outputs :tensor([[[-0.2339,  0.4702]]], grad_fn=<ViewBackward0>)
i :tensor([[[1., 1.]]])
Hidden : tensor([[[ 0.3913, -0.6853]]], grad_fn=<StackBackward0>)
outputs :tensor([[[0.2265, 0.4529]]], grad_fn=<ViewBackward0>)
i :tensor([[[ 1., -1.]]])


In [38]:
decoder.init_hidden(hidden_seq)

inputs=source_seq[:,-1:]

teacher_forcing_prob=0.5
target_len=2
for i in range(target_len):
    print(f'Hidden state :{decoder.hidden}')
    out=decoder(inputs)
    print(f'Output : {out}')
    
    if torch.randn(1)<=teacher_forcing_prob:
        inputs=target_seq[:,i:i+1]
        
    else:
        inputs=out

Hidden state :tensor([[[ 0.3105, -0.5263]]], grad_fn=<PermuteBackward0>)
Output : tensor([[[-0.2339,  0.4702]]], grad_fn=<ViewBackward0>)
Hidden state :tensor([[[ 0.3913, -0.6853]]], grad_fn=<StackBackward0>)
Output : tensor([[[-0.0226,  0.4628]]], grad_fn=<ViewBackward0>)


In [85]:
class EncoderDecoder(nn.Module):
    
    def __init__(self,
                encoder,
                decoder,
                input_len,
                target_len,
                teacher_forcing_prob=0.5):
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder
        self.input_len=input_len
        self.target_len=target_len
        self.teacher_forcing_prob=teacher_forcing_prob
        self.outputs=None
        
    def init_outputs(self,batch_size):
        print('\n\nInitialization of outputs ')
        device = next(self.parameters()).device
        # N---->batch_size
        #L----->target_len
        #F----->ecnoder_features
        print(f'\n\nBatch_size ->(N) : {batch_size}\nInput_length ->(L) : {self.target_len} \nEncoder_feature-->(F) : {self.encoder.n_features}')
        self.outputs=torch.zeros(batch_size,self.target_len,
                                 self.encoder.n_features).to(device)
        
    def store_outputs(self,i,out):
        print('\n\nStore_outputs')
        self.outputs[:,i:i+1,:]=out
        print(f'Full Output : ',out)
        print('Iteration', i)
        print(f'Storing output from {i} : {i+1}')
        print(f'Stored outputs :{self.outputs[:,i:i+1]}')
        
    def forward(self,x):
        print('\n\nforward')
        #splits the data in source and target sequence
        #the target seq will be empty in testing mode
        # N,L,F
        
        source_seq=x[:,:self.input_len,:]
        print(f'Source Seq : {source_seq}')
        target_seq=x[:,self.input_len:,:]
        print(f'target Seq : {target_seq}')
        self.init_outputs(x.shape[0])
        
        hidden_seq=self.encoder(source_seq)
        print(f'Hidden_seq Of encoder : {hidden_seq}')
        self.decoder.init_hidden(hidden_seq)
        
        dec_inputs=source_seq[:,-1:,:]
        print('\n\n***************Decoder Iteration******************\n')
        for i in range(self.target_len):
            
            out=self.decoder(dec_inputs)
            
            self.store_outputs(i,out)
            
            prob=self.teacher_forcing_prob
            
            if not self.training:
                prob=0
                
            if torch.rand(1) <= prob:
                
                dec_inputs=target_seq[:,i:i+1,:]
                print(f'Decoder Ouputs : {target_seq[:,i:i+1,:]}')
            else:
                dec_inputs=out
        
        return self.outputs
        

In [86]:
encdec=EncoderDecoder(encoder,
                     decoder,
                     input_len=2,
                     target_len=2,
                     teacher_forcing_prob=0.5)


In [94]:

encdec.train()
print('Output : ',encdec(full_seq))



forward
Source Seq : tensor([[[-1., -1.],
         [-1.,  1.]]])
target Seq : tensor([[[ 1.,  1.],
         [ 1., -1.]]])


Initialization of outputs 


Batch_size ->(N) : 1
Input_length ->(L) : 2 
Encoder_feature-->(F) : 2
Hidden_seq Of encoder : tensor([[[ 0.0832, -0.0356],
         [ 0.3105, -0.5263]]], grad_fn=<TransposeBackward1>)


***************Decoder Iteration******************



Store_outputs
Full Output :  tensor([[[-0.2339,  0.4702]]], grad_fn=<ViewBackward0>)
Iteration 0
Storing output from 0 : 1
Stored outputs :tensor([[[-0.2339,  0.4702]]], grad_fn=<SliceBackward0>)


Store_outputs
Full Output :  tensor([[[-0.0226,  0.4628]]], grad_fn=<ViewBackward0>)
Iteration 1
Storing output from 1 : 2
Stored outputs :tensor([[[-0.0226,  0.4628]]], grad_fn=<SliceBackward0>)
Output :  tensor([[[-0.2339,  0.4702],
         [-0.0226,  0.4628]]], grad_fn=<CopySlices>)
