In [70]:
import torch, torch.autograd as autograd
import torch.nn as nn, torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable as avar
    
from SimpleTask import SimpleGridTask
from TransportTask import TransportTask
from NavTask import NavigationTask
from SeqData import SeqData

import os, sys, pickle, numpy as np, numpy.random as npr, random as r


In [245]:
#Pytorch LSTM input: Sequence * Batch * Input

class LSTMForwardModel(nn.Module):
    
    def __init__(self, inputSize, stateSize, h_size=100, nlayers=1):
        super(LSTMForwardModel, self).__init__()
        self.hdim, self.stateSize, self.nlayers, self.inputSize, self.actionSize = h_size, stateSize, nlayers, inputSize, inputSize - stateSize
        self.lstm = nn.LSTM(input_size=inputSize, hidden_size=self.hdim, num_layers=nlayers)
        self.hiddenToState = nn.Linear(self.hdim, stateSize)
        self.reInitialize(1)
        
    def reInitialize(self, batch_size):
        # Size = (num_layers, minibatch_size, hidden_dim)
        self.hidden = (avar(torch.zeros(self.nlayers,batch_size,self.hdim)), avar(torch.zeros(self.nlayers,batch_size,self.hdim)))
        
    def forward(self, inital_state, actions, seqn):
        #initalState [1*1*state_size] actions[batch*noOfActions*Action_size] 
        #print(actions[0].shape)
        #print(seqn)
        int_states = []
        
        current_state = avar(torch.from_numpy(inital_state).float())
        #print(current_state.shape)
        #print(torch.cat((current_state, actions[0]),0))
        for i in range(seqn):
            concat_vec = torch.cat((current_state, actions[i]),0).view(1,1,-1)
            lstm_out, self.hidden = self.lstm(concat_vec, self.hidden)
            output_state = self.hiddenToState(lstm_out[0,0,:])
            int_states.append(output_state)
            current_state = output_state
            
        return current_state, int_states
    
    def train(self, trainSeq, validSeq, nEpochs=200, epochLen=250, validateEvery=5, vbs=500, printEvery=5, noiseSigma=0.4):
        optimizer = optim.Adam(self.parameters(), lr = 0.0003)
        state_size, action_size, tenv = self.stateSize, self.actionSize, trainSeq.env
        for epoch in range(nEpochs):
            if epoch % printEvery == 0: print('Epoch:',epoch, end='')
            loss = 0.0
            self.zero_grad() # Zero out gradients
            for i in range(epochLen):
                self.reInitialize(1) # Reset LSTM hidden state
                seq,label = trainSeq.randomTrainingPair() # Current value
                actions = [ s[64:74]  for s in seq ]
                actions = [ avar(torch.from_numpy(s).float()) for s in actions] 
                intial_state = seq[0][0:64]
                seqn = len(seq)
                prediction, _ = self.forward(intial_state,actions,seqn) #[-1,:]
                label = avar(torch.from_numpy(label).float())
                loss += self._lossFunction(prediction, label, env=tenv)
            loss.backward()
            optimizer.step()
            if epoch % printEvery == 0: print(" -> AvgLoss",str(loss.data[0] / epochLen))
            if epoch % validateEvery == 0:
                bdata,blabels,bseqlen = validSeq.next(vbs,nopad=True)
                acc1, _ = self._accuracyBatch(bdata,blabels,validSeq.env)
                bdata,blabels,bseqlen = trainSeq.next(vbs,nopad=True)
                acc2, _ = self._accuracyBatch(bdata,blabels,tenv)
                print('\tCurrent Training Acc (est) =', acc1)
                print('\tCurrent Validation Acc (est) =', acc2)
    
    def _lossFunction(self,outputs,targets,useMSE=True,env=None):
        if useMSE:
            loss = nn.MSELoss()
            return loss(outputs,targets)
        else: # Use Cross-entropy
            loss = nn.CrossEntropyLoss()
            cost = avar( torch.FloatTensor( [0] ) )
            predVec = env.deconcatenateOneHotStateVector(outputs)
            labelVec = env.deconcatenateOneHotStateVector(targets)
            for pv,lv in zip(predVec,labelVec):
                val,ind = lv.max(0)
                cost += loss(pv.view(1,len(pv)), ind)
            return cost / len(predVec)
        
    def _accuracyBatch(self,seqs,labels,env):
        n, acc = float(len(seqs)), 0.0
        #print(len(seq))
        for s,l in zip(seqs,labels): acc += self._accuracySingle(s,l,env)
        return acc / n, int(n)

    # Accuracy averaged over subvecs
    def _accuracySingle(self,seq,label,env):
        seq = [avar(torch.from_numpy(s).float()) for s in seq] 
        seq = torch.cat(seq).view(len(seq), 1, -1) # [seqlen x batchlen x hidden_size]
        self.reInitialize(1) # Reset LSTM hidden state
        #print(seq.shape)
        actions = [ s[0][64:74]  for s in seq ]
        #actions = [ avar(torch.from_numpy(s).float()) for s in actions] 
        intial_state = seq[0][0][0:64].data.numpy()
        seqn = len(seq)
        prediction, _ = self.forward(intial_state,actions,seqn) #[-1,:]
        #prediction = self.forward(seq) # Only retrieves final time state
        predVec = env.deconcatenateOneHotStateVector(prediction)
        labelVec = env.deconcatenateOneHotStateVector(label)
        locAcc = 0.0
        for pv, lv in zip(predVec, labelVec):
            _, ind_pred = pv.max(0)
            ind_label = np.argmax(lv)
            locAcc += 1.0 if ind_pred.data[0] == ind_label else 0.0
        return locAcc / len(predVec)

In [None]:
f_model_name = 'forward-lstm-stochastic.pt'    
s = 'navigation' # 'transport'
trainf, validf = s + "-data-train-small.pickle", s + "-data-test-small.pickle"
print('Reading Data')
train, valid = SeqData(trainf), SeqData(validf)

In [190]:
print(train.lenOfInput,train.lenOfState)
fm = LSTMForwardModel(train.lenOfInput,train.lenOfState)

74 64


In [191]:
fm.train(train, valid)

Epoch: 0 -> AvgLoss 0.0791427001953125
	Current Training Acc (est) = 0.12440000000000034
	Current Validation Acc (est) = 0.12200000000000041
Epoch: 5 -> AvgLoss 0.07056641387939454
	Current Training Acc (est) = 0.19320000000000098
	Current Validation Acc (est) = 0.196400000000001
Epoch: 10 -> AvgLoss 0.06914867401123047
	Current Training Acc (est) = 0.2600000000000014
	Current Validation Acc (est) = 0.27440000000000125
Epoch: 15 -> AvgLoss 0.06811223602294922
	Current Training Acc (est) = 0.3064000000000011
	Current Validation Acc (est) = 0.30840000000000056
Epoch: 20 -> AvgLoss 0.06629911041259766
	Current Training Acc (est) = 0.3416000000000005
	Current Validation Acc (est) = 0.35400000000000026
Epoch: 25 -> AvgLoss 0.06410797882080078
	Current Training Acc (est) = 0.3628
	Current Validation Acc (est) = 0.3776
Epoch: 30 -> AvgLoss 0.06187154388427735
	Current Training Acc (est) = 0.44839999999999974
	Current Validation Acc (est) = 0.4491999999999996
Epoch: 35 -> AvgLoss 0.06004048919

Epoch: 290 -> AvgLoss 0.02411238479614258
	Current Training Acc (est) = 0.8172000000000049
	Current Validation Acc (est) = 0.8260000000000051
Epoch: 295 -> AvgLoss 0.024488079071044923
	Current Training Acc (est) = 0.817200000000005
	Current Validation Acc (est) = 0.8140000000000046
Epoch: 300 -> AvgLoss 0.02357991600036621
	Current Training Acc (est) = 0.8312000000000049
	Current Validation Acc (est) = 0.8284000000000047
Epoch: 305 -> AvgLoss 0.02456898307800293
	Current Training Acc (est) = 0.8148000000000053
	Current Validation Acc (est) = 0.8296000000000054
Epoch: 310 -> AvgLoss 0.024433824539184572
	Current Training Acc (est) = 0.819600000000005
	Current Validation Acc (est) = 0.8352000000000049
Epoch: 315 -> AvgLoss 0.022875852584838867
	Current Training Acc (est) = 0.8180000000000045
	Current Validation Acc (est) = 0.8280000000000047
Epoch: 320 -> AvgLoss 0.023552112579345703
	Current Training Acc (est) = 0.8324000000000042
	Current Validation Acc (est) = 0.8240000000000046
Epoc

Epoch: 580 -> AvgLoss 0.01901835250854492
	Current Training Acc (est) = 0.8532000000000044
	Current Validation Acc (est) = 0.8540000000000042
Epoch: 585 -> AvgLoss 0.019462574005126954
	Current Training Acc (est) = 0.8536000000000044
	Current Validation Acc (est) = 0.8632000000000042
Epoch: 590 -> AvgLoss 0.019160945892333983
	Current Training Acc (est) = 0.847200000000005
	Current Validation Acc (est) = 0.8608000000000043
Epoch: 595 -> AvgLoss 0.01950810432434082
	Current Training Acc (est) = 0.8456000000000046
	Current Validation Acc (est) = 0.8548000000000041
Epoch: 600 -> AvgLoss 0.01852149200439453
	Current Training Acc (est) = 0.8524000000000038
	Current Validation Acc (est) = 0.8508000000000042
Epoch: 605 -> AvgLoss 0.019586666107177735
	Current Training Acc (est) = 0.8488000000000047
	Current Validation Acc (est) = 0.8512000000000046
Epoch: 610 -> AvgLoss 0.01902497863769531
	Current Training Acc (est) = 0.8480000000000044
	Current Validation Acc (est) = 0.8612000000000045
Epoc

Epoch: 870 -> AvgLoss 0.01645734214782715
	Current Training Acc (est) = 0.8576000000000046
	Current Validation Acc (est) = 0.8556000000000046
Epoch: 875 -> AvgLoss 0.016608142852783205
	Current Training Acc (est) = 0.8696000000000041
	Current Validation Acc (est) = 0.8600000000000042
Epoch: 880 -> AvgLoss 0.015363986968994141
	Current Training Acc (est) = 0.8552000000000043
	Current Validation Acc (est) = 0.8696000000000044
Epoch: 885 -> AvgLoss 0.016459337234497072
	Current Training Acc (est) = 0.8520000000000042
	Current Validation Acc (est) = 0.8684000000000042
Epoch: 890 -> AvgLoss 0.015526869773864745
	Current Training Acc (est) = 0.8628000000000039
	Current Validation Acc (est) = 0.8588000000000043
Epoch: 895 -> AvgLoss 0.015305726051330567
	Current Training Acc (est) = 0.8684000000000042
	Current Validation Acc (est) = 0.8624000000000045
Epoch: 900 -> AvgLoss 0.016951076507568358
	Current Training Acc (est) = 0.8624000000000042
	Current Validation Acc (est) = 0.8520000000000042


Epoch: 1160 -> AvgLoss 0.0154070987701416
	Current Training Acc (est) = 0.8712000000000042
	Current Validation Acc (est) = 0.8732000000000036
Epoch: 1165 -> AvgLoss 0.015112344741821288
	Current Training Acc (est) = 0.8600000000000044
	Current Validation Acc (est) = 0.8692000000000043
Epoch: 1170 -> AvgLoss 0.015098500251770019
	Current Training Acc (est) = 0.8664000000000041
	Current Validation Acc (est) = 0.8588000000000048
Epoch: 1175 -> AvgLoss 0.015208738327026367
	Current Training Acc (est) = 0.8640000000000045
	Current Validation Acc (est) = 0.8736000000000038
Epoch: 1180 -> AvgLoss 0.013622153282165527
	Current Training Acc (est) = 0.8580000000000049
	Current Validation Acc (est) = 0.8692000000000043
Epoch: 1185 -> AvgLoss 0.014071931838989258
	Current Training Acc (est) = 0.8684000000000043
	Current Validation Acc (est) = 0.8728000000000036
Epoch: 1190 -> AvgLoss 0.014772335052490234
	Current Training Acc (est) = 0.8580000000000046
	Current Validation Acc (est) = 0.87280000000

Epoch: 1445 -> AvgLoss 0.013739877700805665
	Current Training Acc (est) = 0.874000000000004
	Current Validation Acc (est) = 0.8892000000000033
Epoch: 1450 -> AvgLoss 0.012380976676940918
	Current Training Acc (est) = 0.869600000000004
	Current Validation Acc (est) = 0.8704000000000038
Epoch: 1455 -> AvgLoss 0.013888038635253907
	Current Training Acc (est) = 0.8596000000000042
	Current Validation Acc (est) = 0.8768000000000044
Epoch: 1460 -> AvgLoss 0.013712221145629883
	Current Training Acc (est) = 0.8692000000000042
	Current Validation Acc (est) = 0.8756000000000044
Epoch: 1465 -> AvgLoss 0.013280937194824219
	Current Training Acc (est) = 0.8788000000000036
	Current Validation Acc (est) = 0.878000000000004
Epoch: 1470 -> AvgLoss 0.013889076232910157
	Current Training Acc (est) = 0.8680000000000042
	Current Validation Acc (est) = 0.8744000000000045
Epoch: 1475 -> AvgLoss 0.012672592163085937
	Current Training Acc (est) = 0.8716000000000041
	Current Validation Acc (est) = 0.872800000000

In [193]:
torch.save(fm.state_dict(), "LSTM_FM_1")

In [None]:
fm.train(train, valid)

Epoch: 0 -> AvgLoss 0.013008377075195313
	Current Training Acc (est) = 0.8744000000000043
	Current Validation Acc (est) = 0.8700000000000041
Epoch: 5 -> AvgLoss 0.015376659393310546
	Current Training Acc (est) = 0.8656000000000043
	Current Validation Acc (est) = 0.8692000000000047
Epoch: 10 -> AvgLoss 0.014667679786682129
	Current Training Acc (est) = 0.8656000000000045
	Current Validation Acc (est) = 0.8852000000000039
Epoch: 15 -> AvgLoss 0.013238248825073242
	Current Training Acc (est) = 0.8612000000000047
	Current Validation Acc (est) = 0.8828000000000039
Epoch: 20 -> AvgLoss 0.013871088981628418
	Current Training Acc (est) = 0.8688000000000043
	Current Validation Acc (est) = 0.8716000000000038
Epoch: 25 -> AvgLoss 0.012729496002197265
	Current Training Acc (est) = 0.8676000000000045
	Current Validation Acc (est) = 0.8776000000000037
Epoch: 30 -> AvgLoss 0.013888617515563965
	Current Training Acc (est) = 0.8692000000000042
	Current Validation Acc (est) = 0.8768000000000038
Epoch: 3

Epoch: 290 -> AvgLoss 0.011848825454711914
	Current Training Acc (est) = 0.882400000000004
	Current Validation Acc (est) = 0.8892000000000037
Epoch: 295 -> AvgLoss 0.011792118072509766
	Current Training Acc (est) = 0.8816000000000037
	Current Validation Acc (est) = 0.8828000000000039
Epoch: 300 -> AvgLoss 0.011854523658752442
	Current Training Acc (est) = 0.883600000000004
	Current Validation Acc (est) = 0.8760000000000039
Epoch: 305 -> AvgLoss 0.011519085884094239
	Current Training Acc (est) = 0.8788000000000039
	Current Validation Acc (est) = 0.885600000000004
Epoch: 310 -> AvgLoss 0.012330561637878419
	Current Training Acc (est) = 0.8740000000000039
	Current Validation Acc (est) = 0.8888000000000039
Epoch: 315 -> AvgLoss 0.012545440673828125
	Current Training Acc (est) = 0.8852000000000035
	Current Validation Acc (est) = 0.8928000000000035
Epoch: 320 -> AvgLoss 0.011830157279968261
	Current Training Acc (est) = 0.8780000000000039
	Current Validation Acc (est) = 0.8860000000000036
Ep

In [None]:
torch.save(fm.state_dict(), "LSTM_FM_2")

In [243]:
actions = []
for j in range(10):
    action = np.zeros( 10 )
    action[ npr.randint(0,10) ] = 1
    actions.append(action)
actions = np.asarray(actions)
actions = avar(torch.from_numpy(actions).float())
exampleEnv = NavigationTask()
inital_state = exampleEnv.getStateRep()
print(inital_state)
final_State , intermediate_state = fm.forward(inital_state, actions, 5)
final_State = final_State.data.numpy()

[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]


In [244]:
print(final_State)
print(final_State[0:15].argmax())
print(final_State[15:30].argmax())
print(final_State[30:34].argmax())
print(final_State[34:49].argmax())
print(final_State[49:64].argmax())
exampleEnv.performAction(np.argmax(actions.data.numpy()[0]))
exampleEnv.performAction(np.argmax(actions.data.numpy()[1]))
exampleEnv.performAction(np.argmax(actions.data.numpy()[2]))
exampleEnv.performAction(np.argmax(actions.data.numpy()[3]))
exampleEnv.performAction(np.argmax(actions.data.numpy()[4]))
# exampleEnv.performAction(np.argmax(actions.data.numpy()[3]))
print(exampleEnv.getStateRep(oneHotOutput=False))

[ 0.8225643   0.09766459  0.02171439  0.06105454  0.06160823  0.08762877
  0.01336757 -0.03495867 -0.02305578 -0.11920367 -0.01638512 -0.04701911
  0.02726124  0.03395923  0.0196971  -0.01911449 -0.0934657  -0.02328094
 -0.09144662  0.01276394  0.01553509 -0.00667164 -0.03118145  0.09538662
  0.06371482  0.02437986  0.01932992  0.05407467  0.20694506  0.71719193
  1.0477041  -0.07524498 -0.01323281  0.01471516 -0.0084057  -0.0689404
  0.17682275 -0.08818772  0.29019403 -0.15192288  0.09047423 -0.11848712
 -0.1672299   0.03674814  0.08108625  0.11288385  0.03727404 -0.10900491
  0.8611873   0.01559952 -0.02245157 -0.1319321   0.04237264  0.22253045
 -0.04425486 -0.04452154  0.00188921 -0.05718898  0.01302726  0.04246456
  0.22461896 -0.02833173 -0.10475935  0.85510135]
0
14
0
14
14
[ 0.  4.  1.  0.  0.  0. 14. 14.]
