In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torch.autograd import Variable
import torch.nn.functional as F
import json 
from configs.task_generator import CopyDataset

In [2]:
# =====================                                                                   
# Original Source: https://github.com/vlgiitr/ntm-pytorch/tree/master/ntm/datasets        
# =====================    

task_params = json.load(open("configs/copy.json"))                                        
                                                                                          
dataset = CopyDataset(task_params)                                                        
                                   

In [3]:
data = next(iter(dataset))

x_,y_ = data["input"],data["target"]

print(x_.shape)
print(y_.shape)

torch.Size([3, 10])
torch.Size([1, 8])


In [5]:
class NTM(nn.Module):
    def __init__(self):
        super(NTM,self).__init__()
        
        D_in = 10
        D_out = 40
        
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        dim = 10
        
        self.controller = torch.nn.LSTMCell(D_in,D_out)
        self.control_linear = nn.Linear(40,100)
        
        self.soft = nn.Softmax(dim=0)
        self.out_net = nn.Linear(140,8)
        
        # prev reads
        
        self.prev_reads = []
        
        
        # create a memory 
        N = 100 # number of location 
        M = 100 # size of location 
        self.memory = Variable(torch.ones(N,M))
        
        # layers for the read and write heads
        self.w = nn.Linear(40,100) # weight 
        self.e = nn.Linear(100,100) # erase
        self.a = nn.Linear(100,100) # add
        

        self.c_state = torch.zeros([1,dim])
        self.h_state = torch.zeros([1,dim])
    
    def addressing(self,output):
        
      
        beta = 2
        
        similarity_scores = F.cosine_similarity(output.unsqueeze(1), self.memory, dim=1)
        
        content_weights = F.softmax(beta * similarity_scores, dim=1) 
        
        
        return content_weights
    
    def output(self,reads):
        
        state_output = torch.cat([self.h_state,reads],dim=1)
        
        output = torch.sigmoid(self.out_net(state_output))
        
        return output
               
        
    def forward(self,x):
        
        
        self.h_state,self.c_state = self.controller(x)
        
        control_output = self.control_linear(self.c_state)
                
        # modify the weights th be able to use in addressing 
        w = self.addressing(control_output)
                              
        # single head
        
        # read
        read = torch.matmul(w , self.memory)
        
        
        # erase 
        e_t = self.e(control_output) # Linear
        
        er = (1 - (w *  e_t)).resize(100,1)
        erase = self.memory * er
        
        
        # write 
        a_t = self.a(control_output) # Linear
                
        weight = w.detach()
        det_erase = erase.detach()
        
        self.memory = det_erase + weight.T * a_t
        #self.memory = torch.matmul(w.T , a_t)

                
        output = self.output(read)
        
        # put read write head hear        
                
        return output


In [6]:
model = NTM()

In [7]:
criterion = nn.BCELoss() 
optimizer = torch.optim.RMSprop(model.parameters())

In [8]:
# how to calculate the cost


In [12]:
# training

losses = []

for data in dataset:
    
    
    
    optimizer.zero_grad()
    
    
    with torch.autograd.set_detect_anomaly(True):

        X,y = data["input"],data["target"]
        out = torch.zeros(y.size()) 

        zero_inputs = torch.zeros(X.size()[1]).unsqueeze(0) # dummy


        for i in range(X.size()[0]):
            model(X[i].unsqueeze(0))

        for i in range(y.size()[0]):                                                     
            out[i] = model(zero_inputs) 


        loss = criterion(out, y) 

        print(loss.item())

        losses.append(loss.item())                                                            
        loss.backward(retain_graph=True) 
        #loss.backward(retain_graph=True)
        optimizer.step()   
    

9.287558555603027
28.022823333740234
44.14920425415039
41.2113151550293
22.725820541381836
28.060293197631836
41.10667419433594
34.75285339355469
40.89961624145508
40.06613540649414
16.86406135559082
43.476287841796875
44.49610137939453
58.766082763671875
49.03845977783203
62.5
53.90625
54.605262756347656
47.65625
45.0
54.6875
43.05555725097656
37.5
58.33333206176758
48.21428680419922
68.75
39.16666793823242
51.97368240356445
55.55555725097656
31.25
45.83333206176758
75.0
49.264705657958984
58.33333206176758
52.88461685180664
50.657894134521484
51.04166793823242
52.88461685180664
50.0
62.5
46.875
39.70588302612305
41.66666793823242
52.5
57.14285659790039
33.33333206176758
50.0
43.75
45.83333206176758
62.5
50.0
45.83333206176758
52.67856979370117
47.91666793823242
50.0
47.5
40.0
50.0
44.73684310913086
50.0
52.77777862548828
58.92856979370117
51.66666793823242
48.02631759643555
60.41666793823242
50.69444274902344
53.125
55.35714340209961
47.91666793823242
47.91666793823242
56.25
48.4375


KeyboardInterrupt: 

In [13]:
%matplotlib inline