In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import os
import copy

#DESIGN PARAMETERS FOR NEURAL NETWORK
NR_LSTM_UNITS = 2
IMAGE_INPUT_SIZE = 228
IMAGE_AFTER_CONV_SIZE = 5
#for 3x3 kernels, n=num_layers: len_in = 2^n*len_out + sum[i=1..n](2^i)
#CONV_LAYER_LENGTH = 5

LSTM_IO_SIZE = IMAGE_AFTER_CONV_SIZE*IMAGE_AFTER_CONV_SIZE

RGB_CHANNELS = 3
TIMESTEPS = 10
BATCH_SIZE = 1 #until now just batch_size = 1

#USE RANDOM IMAGES TO SET UP WORKING EXAMPLE
class TEST_CNN_LSTM(nn.Module):
    def __init__(self):
        super(TEST_CNN_LSTM, self).__init__()
        self.conv1 = nn.Conv2d(3,6,3) #input 388x388
        self.pool1 = nn.MaxPool2d(2,2) #input 48x48 output 24x24
        self.conv2 = nn.Conv2d(6,16,3)
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(16,6,3)
        self.pool3 = nn.MaxPool2d(2,2)
        self.conv4 = nn.Conv2d(6,3,3)
        self.pool4 = nn.MaxPool2d(2,2)
        self.conv5 = nn.Conv2d(3,1,3)
        self.pool5 = nn.MaxPool2d(2,2) #output 5x5
        self.lstm = nn.LSTM(LSTM_IO_SIZE,
                            LSTM_IO_SIZE,
                            NR_LSTM_UNITS)
        self.fc1 = nn.Linear(LSTM_IO_SIZE,120)
        self.fc2 = nn.Linear(120,20)
        self.fc3 = nn.Linear(20,3)
        
        #initialize hidden states of LSTM
        self._hidden = (torch.randn(NR_LSTM_UNITS, BATCH_SIZE, LSTM_IO_SIZE), 
                        torch.randn(NR_LSTM_UNITS, BATCH_SIZE, LSTM_IO_SIZE))
        #print("Hidden:", _hidden)
        
    def forward(self,x):
        #print("Input:", x.size())
        x = x.float() #necessary for some reason
        x_arr = torch.zeros(TIMESTEPS,IMAGE_AFTER_CONV_SIZE,IMAGE_AFTER_CONV_SIZE)
        #print("X arr size", x_arr.size())
        for i in range(TIMESTEPS):#parallel convolutions which are later concatenated for LSTM
            x_tmp_c1 = self.pool1(F.relu(self.conv1(x[i].unsqueeze(0))))
            x_tmp_c2 = self.pool2(F.relu(self.conv2(x_tmp_c1)))
            x_tmp_c3 = self.pool3(F.relu(self.conv3(x_tmp_c2)))
            x_tmp_c4 = self.pool4(F.relu(self.conv4(x_tmp_c3)))
            x_tmp_c5 = self.pool5(F.relu(self.conv5(x_tmp_c4)))
            x_arr[i] = torch.squeeze(x_tmp_c5)
        #
        x, _hidden = self.lstm(x_arr.view(TIMESTEPS,BATCH_SIZE,-1), self._hidden)
        x = x.view(-1,LSTM_IO_SIZE)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
print("Class defined")

#rand_arr = np.random.rand(TIMESTEPS,RGB_CHANNELS,IMAGE_INPUT_SIZE,IMAGE_INPUT_SIZE)
arr_1 = np.full((1,RGB_CHANNELS,IMAGE_INPUT_SIZE,IMAGE_INPUT_SIZE),0)
arr_2 = np.full((1,RGB_CHANNELS,IMAGE_INPUT_SIZE,IMAGE_INPUT_SIZE),1)
arr_3 = np.full((1,RGB_CHANNELS,IMAGE_INPUT_SIZE,IMAGE_INPUT_SIZE),2)
arr_full = np.concatenate((arr_1, arr_2, arr_3, arr_1, arr_2, arr_3, arr_1, arr_2, arr_3, arr_1))
print("Shape", np.shape(arr_full))
test_images = torch.from_numpy(arr_full)
test_labels = torch.tensor([0,1,2,0,1,2,0,1,2,0]) #DIFFICULT
#test_labels = torch.tensor([0,0,0,1,1,1,2,2,2,2])#EASY

#TRAINING
test_net = TEST_CNN_LSTM()
criterion = nn.CrossEntropyLoss()
#criterion = nn.BCELoss()
optimizer = optim.SGD(test_net.parameters(), lr=0.1, momentum=0.9)


print('Start training...')
for epoch in range(100): 
    print("Epoch:", epoch)
    running_loss = 0.0
    #for i in range(TIMESTEP):
    inputs = test_images
    labels = test_labels

    optimizer.zero_grad() 
    outputs = test_net(inputs)
    print("Out:", len(outputs),outputs.size(),  outputs)
    print("Labels:", len(labels),labels.size() , labels)
    loss = criterion(outputs, labels)
    loss.backward() 
    
    optimizer.step()

    running_loss += loss.item()
    print("Loss:", running_loss)
print('...Training finished')


Class defined
Shape (10, 3, 228, 228)
Start training...
Epoch: 0
Out: 10 torch.Size([10, 3]) tensor([[-0.1845, -0.0397, -0.1013],
        [-0.1814, -0.0252, -0.1032],
        [-0.1815, -0.0232, -0.1024],
        [-0.1815, -0.0238, -0.1059],
        [-0.1817, -0.0261, -0.1071],
        [-0.1822, -0.0285, -0.1072],
        [-0.1824, -0.0301, -0.1073],
        [-0.1824, -0.0308, -0.1075],
        [-0.1824, -0.0312, -0.1076],
        [-0.1823, -0.0315, -0.1078]], grad_fn=<AddmmBackward>)
Labels: 10 torch.Size([10]) tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0])
Loss: 1.107919454574585
Epoch: 1
Out: 10 torch.Size([10, 3]) tensor([[-0.1722, -0.0481, -0.1087],
        [-0.1697, -0.0338, -0.1094],
        [-0.1698, -0.0319, -0.1076],
        [-0.1698, -0.0325, -0.1110],
        [-0.1698, -0.0346, -0.1124],
        [-0.1704, -0.0370, -0.1126],
        [-0.1705, -0.0386, -0.1126],
        [-0.1706, -0.0393, -0.1128],
        [-0.1706, -0.0397, -0.1129],
        [-0.1705, -0.0400, -0.1130]], grad_fn=<Add

Loss: 1.0883498191833496
Epoch: 16
Out: 10 torch.Size([10, 3]) tensor([[ 0.1804, -0.3134, -0.3138],
        [ 0.1635, -0.2770, -0.2880],
        [ 0.1650, -0.2790, -0.2726],
        [ 0.1664, -0.2838, -0.2667],
        [ 0.1657, -0.2903, -0.2617],
        [ 0.1653, -0.2948, -0.2586],
        [ 0.1652, -0.2976, -0.2569],
        [ 0.1653, -0.2991, -0.2561],
        [ 0.1653, -0.2999, -0.2558],
        [ 0.1653, -0.3003, -0.2557]], grad_fn=<AddmmBackward>)
Labels: 10 torch.Size([10]) tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0])
Loss: 1.087776780128479
Epoch: 17
Out: 10 torch.Size([10, 3]) tensor([[ 0.1722, -0.3064, -0.3155],
        [ 0.1542, -0.2682, -0.2880],
        [ 0.1559, -0.2707, -0.2717],
        [ 0.1575, -0.2758, -0.2652],
        [ 0.1568, -0.2827, -0.2598],
        [ 0.1564, -0.2874, -0.2563],
        [ 0.1564, -0.2904, -0.2546],
        [ 0.1565, -0.2921, -0.2537],
        [ 0.1566, -0.2930, -0.2534],
        [ 0.1566, -0.2935, -0.2532]], grad_fn=<AddmmBackward>)
Labels: 10 torch

Loss: 1.0779263973236084
Epoch: 32
Out: 10 torch.Size([10, 3]) tensor([[ 0.0933, -0.2624, -0.3470],
        [ 0.0279, -0.1728, -0.2682],
        [ 0.0206, -0.1752, -0.2269],
        [ 0.0225, -0.1857, -0.2092],
        [ 0.0217, -0.1981, -0.1962],
        [ 0.0224, -0.2073, -0.1892],
        [ 0.0235, -0.2133, -0.1858],
        [ 0.0243, -0.2170, -0.1840],
        [ 0.0249, -0.2192, -0.1833],
        [ 0.0254, -0.2205, -0.1829]], grad_fn=<AddmmBackward>)
Labels: 10 torch.Size([10]) tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0])
Loss: 1.076852560043335
Epoch: 33
Out: 10 torch.Size([10, 3]) tensor([[ 0.1098, -0.2757, -0.3604],
        [ 0.0365, -0.1795, -0.2746],
        [ 0.0274, -0.1814, -0.2304],
        [ 0.0290, -0.1922, -0.2115],
        [ 0.0281, -0.2051, -0.1978],
        [ 0.0290, -0.2146, -0.1906],
        [ 0.0301, -0.2209, -0.1869],
        [ 0.0309, -0.2247, -0.1851],
        [ 0.0316, -0.2270, -0.1843],
        [ 0.0321, -0.2283, -0.1839]], grad_fn=<AddmmBackward>)
Labels: 10 torch

Loss: 1.0475783348083496
Epoch: 48
Out: 10 torch.Size([10, 3]) tensor([[ 0.4372, -0.4518, -0.7280],
        [ 0.0482, -0.1442, -0.3712],
        [-0.0176, -0.1338, -0.2436],
        [-0.0234, -0.1540, -0.1985],
        [-0.0221, -0.1794, -0.1729],
        [-0.0166, -0.1988, -0.1609],
        [-0.0104, -0.2124, -0.1564],
        [-0.0047, -0.2210, -0.1566],
        [-0.0009, -0.2265, -0.1569],
        [ 0.0015, -0.2298, -0.1571]], grad_fn=<AddmmBackward>)
Labels: 10 torch.Size([10]) tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0])
Loss: 1.0441036224365234
Epoch: 49
Out: 10 torch.Size([10, 3]) tensor([[ 0.4794, -0.4744, -0.7695],
        [ 0.0403, -0.1358, -0.3720],
        [-0.0324, -0.1242, -0.2348],
        [-0.0387, -0.1460, -0.1872],
        [-0.0365, -0.1733, -0.1608],
        [-0.0296, -0.1940, -0.1497],
        [-0.0207, -0.2081, -0.1497],
        [-0.0142, -0.2174, -0.1502],
        [-0.0099, -0.2233, -0.1507],
        [-0.0074, -0.2269, -0.1510]], grad_fn=<AddmmBackward>)
Labels: 10 torc

Loss: 0.9704068899154663
Epoch: 64
Out: 10 torch.Size([10, 3]) tensor([[ 1.7985, -1.0978, -1.9984],
        [-0.3141,  0.1594, -0.2785],
        [-0.3589,  0.0745, -0.0683],
        [-0.2531, -0.0487, -0.0532],
        [-0.1380, -0.1642, -0.0768],
        [-0.0395, -0.2553, -0.1108],
        [ 0.0293, -0.3155, -0.1385],
        [ 0.0738, -0.3531, -0.1577],
        [ 0.1023, -0.3762, -0.1715],
        [ 0.1198, -0.3901, -0.1802]], grad_fn=<AddmmBackward>)
Labels: 10 torch.Size([10]) tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0])
Loss: 0.966265857219696
Epoch: 65
Out: 10 torch.Size([10, 3]) tensor([[ 1.9195, -1.1531, -2.1063],
        [-0.3612,  0.1924, -0.2628],
        [-0.3802,  0.0838, -0.0602],
        [-0.2491, -0.0575, -0.0576],
        [-0.1180, -0.1839, -0.0893],
        [-0.0081, -0.2830, -0.1290],
        [ 0.0676, -0.3478, -0.1604],
        [ 0.1159, -0.3878, -0.1819],
        [ 0.1465, -0.4122, -0.1968],
        [ 0.1653, -0.4269, -0.2063]], grad_fn=<AddmmBackward>)
Labels: 10 torch

Loss: 0.9302506446838379
Epoch: 79
Out: 10 torch.Size([10, 3]) tensor([[ 3.0119, -1.5769, -3.2778],
        [-1.1968,  0.7903, -0.0519],
        [-0.7864,  0.2835,  0.0739],
        [-0.4126, -0.0235, -0.0424],
        [-0.1664, -0.2266, -0.1227],
        [-0.0124, -0.3604, -0.1726],
        [ 0.0892, -0.4460, -0.2078],
        [ 0.1491, -0.4963, -0.2288],
        [ 0.1850, -0.5264, -0.2413],
        [ 0.2068, -0.5446, -0.2489]], grad_fn=<AddmmBackward>)
Labels: 10 torch.Size([10]) tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0])
Loss: 0.9276407957077026
Epoch: 80
Out: 10 torch.Size([10, 3]) tensor([[ 3.0782, -1.5862, -3.3708],
        [-1.2189,  0.8312, -0.0848],
        [-0.7848,  0.2871,  0.0581],
        [-0.4083, -0.0322, -0.0455],
        [-0.1658, -0.2339, -0.1224],
        [-0.0157, -0.3655, -0.1694],
        [ 0.0823, -0.4489, -0.2021],
        [ 0.1400, -0.4978, -0.2216],
        [ 0.1744, -0.5267, -0.2333],
        [ 0.1952, -0.5441, -0.2404]], grad_fn=<AddmmBackward>)
Labels: 10 torc

Loss: 0.9000028371810913
Epoch: 94
Out: 10 torch.Size([10, 3]) tensor([[ 3.5319e+00, -1.4763e+00, -4.1927e+00],
        [-1.5759e+00,  1.2788e+00, -2.4092e-01],
        [-6.5550e-01,  1.2313e-01,  2.1875e-02],
        [-2.4937e-01, -2.3763e-01, -6.5999e-02],
        [-8.0332e-02, -3.9728e-01, -9.5794e-02],
        [-1.0829e-04, -4.7567e-01, -1.1082e-01],
        [ 4.3181e-02, -5.1700e-01, -1.2029e-01],
        [ 6.6459e-02, -5.3880e-01, -1.2588e-01],
        [ 7.8802e-02, -5.5013e-01, -1.2895e-01],
        [ 8.5686e-02, -5.5630e-01, -1.3073e-01]], grad_fn=<AddmmBackward>)
Labels: 10 torch.Size([10]) tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0])
Loss: 0.8987377285957336
Epoch: 95
Out: 10 torch.Size([10, 3]) tensor([[ 3.5473e+00, -1.4475e+00, -4.2493e+00],
        [-1.6139e+00,  1.3243e+00, -2.5173e-01],
        [-6.5163e-01,  1.1746e-01,  2.0825e-02],
        [-2.4352e-01, -2.4596e-01, -6.5757e-02],
        [-7.9729e-02, -4.0126e-01, -9.3660e-02],
        [-3.1510e-03, -4.7647e-01, -1.0747e-01

In [2]:
#FOR TESTING PURPOSES
test = torch.zeros(10,24,24)
test[1] = torch.randn(24,24)
test = test.view(10,-1)
#concat = torch.cat([x for x in test],0)
print(test.size())
#print(test)

torch.Size([10, 576])


In [None]:
#
#print(test_images[0])
#concatenation = torch.cat((test_images, test_images))

#for x in range(test_images.size(0)):
#    concatenation = torch.cat((test_images[x]))
#print(concatenation.size())
