In [128]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import os
import copy

#DESIGN PARAMETERS FOR NEURAL NETWORK
NR_LSTM_UNITS = 2
IMAGE_INPUT_SIZE = 50
IMAGE_AFTER_CONV_SIZE = 24

LSTM_INPUT_SIZE = IMAGE_AFTER_CONV_SIZE*IMAGE_AFTER_CONV_SIZE

RGB_CHANNELS = 3
TIMESTEPS = 10
BATCH_SIZE = 1 #until now just batch_size = 1

#USE RANDOM IMAGES TO SET UP WORKING EXAMPLE
class TEST_CNN_LSTM(nn.Module):
    def __init__(self):
        super(TEST_CNN_LSTM, self).__init__()
        self.conv = nn.Conv2d(3,1,3) #input 50x50x3
        self.pool = nn.MaxPool2d(2,2) #input 48x48 output 24x24
        self.lstm = nn.LSTM(LSTM_INPUT_SIZE,
                            LSTM_INPUT_SIZE,
                            NR_LSTM_UNITS)
        self.fc1 = nn.Linear(LSTM_INPUT_SIZE,
                             120)
        self.fc2 = nn.Linear(120,3)
        
        #initialize hidden states of LSTM
        self._hidden = (torch.randn(NR_LSTM_UNITS, BATCH_SIZE, LSTM_INPUT_SIZE), 
                        torch.randn(NR_LSTM_UNITS, BATCH_SIZE, LSTM_INPUT_SIZE))
        #print("Hidden:", _hidden)
        
    def forward(self,x):
        #print("Input:", x.size())
        x = x.float() #necessary for some reason
        x_arr = torch.zeros(TIMESTEPS,IMAGE_AFTER_CONV_SIZE,IMAGE_AFTER_CONV_SIZE)
        #print("X arr size", x_arr.size())
        for i in range(TIMESTEPS):#parallel convolutions which are later concatenated for LSTM
            x_tmp2 = self.pool(F.relu(self.conv(x[i].unsqueeze(0))))
            x_arr[i] = torch.squeeze(x_tmp2)
        
        x, _hidden = self.lstm(x_arr.view(TIMESTEPS,BATCH_SIZE,-1), self._hidden)
        x = x.view(-1,LSTM_INPUT_SIZE)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
print("Class defined")

rand_arr = np.random.rand(TIMESTEPS,RGB_CHANNELS,IMAGE_INPUT_SIZE,IMAGE_INPUT_SIZE)
test_images = torch.from_numpy(rand_arr)
#test_labels = torch.tensor([0,1,2,0,1,2,0,1,2,0]) #DIFFICULT
test_labels = torch.tensor([0,0,0,1,1,1,2,2,2,2])#EASY

#TRAINING
test_net = TEST_CNN_LSTM()
criterion = nn.CrossEntropyLoss()
#criterion = nn.BCELoss()
optimizer = optim.SGD(test_net.parameters(), lr=0.01, momentum=0.9)


print('Start training...')
for epoch in range(100): 
    print("Epoch:", epoch)
    running_loss = 0.0
    #for i in range(TIMESTEP):
    inputs = test_images
    labels = test_labels

    optimizer.zero_grad() 
    outputs = test_net(inputs)
    #print("Out:", len(outputs), outputs)
    #print("Labels:", len(labels), labels)
    loss = criterion(outputs, labels)
    loss.backward() 
    
    optimizer.step()

    running_loss += loss.item()
    print("Loss:", running_loss)
print('...Training finished')


Class defined
Start training...
('Epoch:', 0)
('Loss:', 1.0909373760223389)
('Epoch:', 1)
('Loss:', 1.089163064956665)
('Epoch:', 2)
('Loss:', 1.0857974290847778)
('Epoch:', 3)
('Loss:', 1.0811439752578735)
('Epoch:', 4)
('Loss:', 1.0754539966583252)
('Epoch:', 5)
('Loss:', 1.0689092874526978)
('Epoch:', 6)
('Loss:', 1.061870813369751)
('Epoch:', 7)
('Loss:', 1.0541836023330688)
('Epoch:', 8)
('Loss:', 1.045945167541504)
('Epoch:', 9)
('Loss:', 1.0371663570404053)
('Epoch:', 10)
('Loss:', 1.0278629064559937)
('Epoch:', 11)
('Loss:', 1.01815927028656)
('Epoch:', 12)
('Loss:', 1.0078754425048828)
('Epoch:', 13)
('Loss:', 0.9970199465751648)
('Epoch:', 14)
('Loss:', 0.9857150316238403)
('Epoch:', 15)
('Loss:', 0.9738418459892273)
('Epoch:', 16)
('Loss:', 0.9615076780319214)
('Epoch:', 17)
('Loss:', 0.9487428665161133)
('Epoch:', 18)
('Loss:', 0.93562251329422)
('Epoch:', 19)
('Loss:', 0.9222139120101929)
('Epoch:', 20)
('Loss:', 0.9085963368415833)
('Epoch:', 21)
('Loss:', 0.8947837948799

In [None]:
#FOR TESTING PURPOSES
test = torch.zeros(10,24,24)
test[1] = torch.randn(24,24)
test = test.view(10,-1)
#concat = torch.cat([x for x in test],0)
print(test.size())
#print(test)

In [None]:
#
#print(test_images[0])
#concatenation = torch.cat((test_images, test_images))

#for x in range(test_images.size(0)):
#    concatenation = torch.cat((test_images[x]))
#print(concatenation.size())
