In [1]:
import torch, torch.autograd as autograd
import torch.nn as nn, torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable as avar
    
from SimpleTask import SimpleGridTask
from TransportTask import TransportTask
from NavTask import NavigationTask
from SeqData import SeqData

import os, sys, pickle, numpy as np, numpy.random as npr, random as r


In [3]:
#Pytorch LSTM input: Sequence * Batch * Input

class LSTMForwardModel(nn.Module):
    
    def __init__(self, inputSize, stateSize, h_size=400, nlayers=1):
        super(LSTMForwardModel, self).__init__()
        self.hdim, self.stateSize, self.nlayers, self.inputSize, self.actionSize = h_size, stateSize, nlayers, inputSize, inputSize - stateSize
        self.lstm = nn.LSTM(input_size=inputSize, hidden_size=self.hdim, num_layers=nlayers)
        self.hiddenToState = nn.Linear(self.hdim, stateSize)
        self.reInitialize(1)
        
    def reInitialize(self, batch_size):
        # Size = (num_layers, minibatch_size, hidden_dim)
        self.hidden = (avar(torch.zeros(self.nlayers,batch_size,self.hdim)), avar(torch.zeros(self.nlayers,batch_size,self.hdim)))
        
    def forward(self, inital_state, actions, seqn):
        #initalState [1*1*state_size] actions[batch*noOfActions*Action_size] 
        #print(actions[0].shape)
        #print(seqn)
        int_states = []
        
        current_state = avar(torch.from_numpy(inital_state).float())
        #print(current_state.shape)
        #print(torch.cat((current_state, actions[0]),0))
        for i in range(seqn):
            concat_vec = torch.cat((current_state, actions[i]),0).view(1,1,-1)
            lstm_out, self.hidden = self.lstm(concat_vec, self.hidden)
            output_state = self.hiddenToState(lstm_out[0,0,:])
            int_states.append(output_state)
            current_state = output_state
            
        return current_state, int_states
    
    def train(self, trainSeq, validSeq, nEpochs=1500, epochLen=500, validateEvery=20, vbs=500, printEvery=5, noiseSigma=0.4):
        optimizer = optim.Adam(self.parameters(), lr = 0.003)
        state_size, action_size, tenv = self.stateSize, self.actionSize, trainSeq.env
        for epoch in range(nEpochs):
            if epoch % printEvery == 0: print('Epoch:',epoch, end='')
            loss = 0.0
            self.zero_grad() # Zero out gradients
            for i in range(epochLen):
                self.reInitialize(1) # Reset LSTM hidden state
                seq,label = trainSeq.randomTrainingPair() # Current value
                actions = [ s[64:74]  for s in seq ]
                actions = [ avar(torch.from_numpy(s).float()) for s in actions] 
                intial_state = seq[0][0:64]
                seqn = len(seq)
                prediction, _ = self.forward(intial_state,actions,seqn) #[-1,:]
                label = avar(torch.from_numpy(label).float())
                loss += self._lossFunction(prediction, label, env=tenv)
            loss.backward()
            optimizer.step()
            if epoch % printEvery == 0: print(" -> AvgLoss",str(loss.data[0] / epochLen))
            if epoch % validateEvery == 0:
                bdata,blabels,bseqlen = validSeq.next(vbs,nopad=True)
                acc1, _ = self._accuracyBatch(bdata,blabels,validSeq.env)
                bdata,blabels,bseqlen = trainSeq.next(vbs,nopad=True)
                acc2, _ = self._accuracyBatch(bdata,blabels,tenv)
                print('\tCurrent Training Acc (est) =', acc1)
                print('\tCurrent Validation Acc (est) =', acc2)
    
    def _lossFunction(self,outputs,targets,useMSE=True,env=None):
        if useMSE:
            loss = nn.MSELoss()
            return loss(outputs,targets)
        else: # Use Cross-entropy
            loss = nn.CrossEntropyLoss()
            cost = avar( torch.FloatTensor( [0] ) )
            predVec = env.deconcatenateOneHotStateVector(outputs)
            labelVec = env.deconcatenateOneHotStateVector(targets)
            for pv,lv in zip(predVec,labelVec):
                val,ind = lv.max(0)
                cost += loss(pv.view(1,len(pv)), ind)
            return cost / len(predVec)
        
    def _accuracyBatch(self,seqs,labels,env):
        n, acc = float(len(seqs)), 0.0
        #print(len(seq))
        for s,l in zip(seqs,labels): acc += self._accuracySingle(s,l,env)
        return acc / n, int(n)

    # Accuracy averaged over subvecs
    def _accuracySingle(self,seq,label,env):
        seq = [avar(torch.from_numpy(s).float()) for s in seq] 
        seq = torch.cat(seq).view(len(seq), 1, -1) # [seqlen x batchlen x hidden_size]
        self.reInitialize(1) # Reset LSTM hidden state
        #print(seq.shape)
        actions = [ s[0][64:74]  for s in seq ]
        #actions = [ avar(torch.from_numpy(s).float()) for s in actions] 
        intial_state = seq[0][0][0:64].data.numpy()
        seqn = len(seq)
        prediction, _ = self.forward(intial_state,actions,seqn) #[-1,:]
        #prediction = self.forward(seq) # Only retrieves final time state
        predVec = env.deconcatenateOneHotStateVector(prediction)
        labelVec = env.deconcatenateOneHotStateVector(label)
        locAcc = 0.0
        for pv, lv in zip(predVec, labelVec):
            _, ind_pred = pv.max(0)
            ind_label = np.argmax(lv)
            locAcc += 1.0 if ind_pred.data[0] == ind_label else 0.0
        return locAcc / len(predVec)

In [4]:
f_model_name = 'forward-lstm-stochastic.pt'    
s = 'navigation' # 'transport'
trainf, validf = s + "-data-train-small.pickle", s + "-data-test-small.pickle"
print('Reading Data')
train, valid = SeqData(trainf), SeqData(validf)

Reading Data
Reading navigation-data-train-small.pickle
	Built
Reading navigation-data-test-small.pickle
	Built


In [5]:
print(train.lenOfInput,train.lenOfState)
fm = LSTMForwardModel(train.lenOfInput,train.lenOfState)

74 64


In [6]:
fm.train(train, valid)

Epoch: 0 -> AvgLoss 0.07861821746826173
	Current Training Acc (est) = 0.1312000000000004
	Current Validation Acc (est) = 0.12640000000000046
Epoch: 5 -> AvgLoss 0.06970014953613281
Epoch: 10 -> AvgLoss 0.06739921569824218
Epoch: 15 -> AvgLoss 0.06436611175537109
Epoch: 20 -> AvgLoss 0.06016217041015625
	Current Training Acc (est) = 0.5628000000000002
	Current Validation Acc (est) = 0.5700000000000006
Epoch: 25 -> AvgLoss 0.05613838195800781
Epoch: 30 -> AvgLoss 0.052529346466064455
Epoch: 35 -> AvgLoss 0.04765742111206055
Epoch: 40 -> AvgLoss 0.044381328582763675
	Current Training Acc (est) = 0.7344000000000045
	Current Validation Acc (est) = 0.7216000000000043
Epoch: 45 -> AvgLoss 0.04087410354614258
Epoch: 50 -> AvgLoss 0.03824283599853515
Epoch: 55 -> AvgLoss 0.03731814575195312
Epoch: 60 -> AvgLoss 0.035043338775634765
	Current Training Acc (est) = 0.7604000000000053
	Current Validation Acc (est) = 0.7552000000000049
Epoch: 65 -> AvgLoss 0.033585033416748046
Epoch: 70 -> AvgLoss 0.

In [7]:
torch.save(fm.state_dict(), "LSTM_FM_1_98")

In [None]:
fm.train(train, valid)

Epoch: 0 -> AvgLoss 0.013008377075195313
	Current Training Acc (est) = 0.8744000000000043
	Current Validation Acc (est) = 0.8700000000000041
Epoch: 5 -> AvgLoss 0.015376659393310546
	Current Training Acc (est) = 0.8656000000000043
	Current Validation Acc (est) = 0.8692000000000047
Epoch: 10 -> AvgLoss 0.014667679786682129
	Current Training Acc (est) = 0.8656000000000045
	Current Validation Acc (est) = 0.8852000000000039
Epoch: 15 -> AvgLoss 0.013238248825073242
	Current Training Acc (est) = 0.8612000000000047
	Current Validation Acc (est) = 0.8828000000000039
Epoch: 20 -> AvgLoss 0.013871088981628418
	Current Training Acc (est) = 0.8688000000000043
	Current Validation Acc (est) = 0.8716000000000038
Epoch: 25 -> AvgLoss 0.012729496002197265
	Current Training Acc (est) = 0.8676000000000045
	Current Validation Acc (est) = 0.8776000000000037
Epoch: 30 -> AvgLoss 0.013888617515563965
	Current Training Acc (est) = 0.8692000000000042
	Current Validation Acc (est) = 0.8768000000000038
Epoch: 3

In [None]:
torch.save(fm.state_dict(), "LSTM_FM_2")

In [32]:
actions = []
for j in range(10):
    action = np.zeros( 10 )
    action[ npr.randint(0,10) ] = 1
    actions.append(action)
actions = np.asarray(actions)
actions = avar(torch.from_numpy(actions).float())
exampleEnv = NavigationTask()
inital_state = exampleEnv.getStateRep()
print(inital_state)
final_State , intermediate_state = fm.forward(inital_state, actions, 5)
final_State = final_State.data.numpy()

[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]


In [33]:
print(final_State)
print(final_State[0:15].argmax())
print(final_State[15:30].argmax())
print(final_State[30:34].argmax())
print(final_State[34:49].argmax())
print(final_State[49:64].argmax())
exampleEnv.performAction(np.argmax(actions.data.numpy()[0]))
exampleEnv.performAction(np.argmax(actions.data.numpy()[1]))
exampleEnv.performAction(np.argmax(actions.data.numpy()[2]))
exampleEnv.performAction(np.argmax(actions.data.numpy()[3]))
exampleEnv.performAction(np.argmax(actions.data.numpy()[4]))
# exampleEnv.performAction(np.argmax(actions.data.numpy()[3]))
print(exampleEnv.getStateRep(oneHotOutput=False))

[-1.1875875e-02  2.4601117e-02 -6.5202452e-03  7.3336400e-02
  1.2798220e-01  2.9160172e-01  3.1932461e-01  9.8608591e-02
  8.2296304e-02 -1.8412858e-02  1.0076821e-02  1.8662903e-03
  2.0893149e-02  5.4644514e-02 -2.2633627e-02  2.6738644e-02
 -5.6180120e-02 -2.0300578e-01  4.0981669e-02 -8.6036421e-02
  3.3219814e-01  6.2109228e-02  6.7701274e-01  6.2240537e-02
  1.1215083e-01 -1.8102080e-01  1.3881046e-01 -5.7481341e-02
  1.7736137e-01 -5.5281729e-02  1.6142691e-02 -1.1346847e-02
  1.0079575e+00  2.3069939e-02  3.3039648e-02  6.9703013e-02
  2.2586552e-02 -9.4303362e-02 -6.1772116e-02 -6.6349972e-03
  7.4274100e-02  6.8789423e-02  7.8592002e-03  7.3187649e-03
 -8.0595128e-03  6.5140426e-03  3.3716675e-02 -4.4507261e-02
  9.1718858e-01  3.7553817e-02  1.0577283e-01 -1.3109421e-02
  1.1834033e-02 -5.2856430e-03  2.7422644e-03 -4.4638518e-02
  5.2091077e-02 -9.4498619e-02 -7.2856620e-03 -3.7650973e-02
 -2.8304383e-04  8.1735373e-02 -7.2692037e-02  1.0059174e+00]
6
7
2
14
14
[ 0.  0.  0