In [1]:
import os
import glob
import cv2

from time import sleep

import torch
import torchvision

import matplotlib.pyplot as plt
%matplotlib inline

from dataset import WeizmannHumanActionVideo
from image_autoencoder import ImageAutoEncoder
from lstm_autoencoder import VideoAutoEncoder

In [2]:
#  use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Dataset 

In [3]:
"""
trans_data = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
"""

trans_data = torchvision.transforms.ToTensor()
trans_label = None

dataset = WeizmannHumanActionVideo(trans_data=None, trans_label=trans_label, train=True)

## Train-test split

In [4]:
train_size = int(1.0 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [5]:
print("train: ", len(train_dataset))
print("test: ", len(test_dataset))

train:  93
test:  0


## Training

**Dataloader**

In [6]:
batch_size=1

In [7]:
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True, 
                                           num_workers=4)

"""
test_loader = torch.utils.data.DataLoader(test_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True, 
                                           num_workers=1)
"""

'\ntest_loader = torch.utils.data.DataLoader(test_dataset, \n                                           batch_size=batch_size, \n                                           shuffle=True, \n                                           num_workers=1)\n'

In [8]:
# type(data)

**Iterative algorithm (SGD)**

In [9]:
n_epochs=10

In [10]:
model = VideoAutoEncoder(input_size=16, hidden_size=16).to(device)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

criterion = torch.nn.MSELoss()

In [12]:
for epoch in range(n_epochs):
    model.train()
    train_loss = 0
    
    for batch_id, (batch_data, _) in enumerate(train_loader):
        # print(torch.cuda.memory_allocated(device))
        
        # batch_data: 5D-Tensor (batch_size=1, video_len, channel, height, width)
        x = batch_data.to(device) 
               
        optimizer.zero_grad()
    
        # x_hat: (video_len, channel, height, width)
        # zc: (video_len, dim_zc)        
        x_hat = model(x) 
        
        print(x_hat)

        loss = criterion(x_hat, x)
        loss.backward() # compute accumulated gradients
        
        train_loss += loss.item()

        optimizer.step()
                
        print("epoch : {}/{}, batch : {}/{}, loss = {:.6f}".format(
            epoch + 1, n_epochs, batch_id, int(len(train_dataset)/batch_size), loss.item()))   
        del x_hat
    print("epoch : {}/{}, loss = {:.4f}".format(epoch + 1, n_epochs, train_loss / len(train_loader)))

input_vec.shape:  torch.Size([1, 68, 16])
zm:  torch.Size([68, 2]) zc:  torch.Size([68, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 0/93, loss = 0.061691
input_vec.shape:  torch.Size([1, 65, 16])
zm:  torch.Size([65, 2]) zc:  torch.Size([65, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 1/93, loss = 0.051892
input_vec.shape:  torch.Size([1, 57, 16])
zm:  torch.Size([57, 2]) zc:  torch.Size([57, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 2/93, loss = 0.052335
input_vec.shape:  torch.Size([1, 49, 16])
zm:  torch.Size([49, 2]) zc:  torch.Size([49, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 3/93, loss = 0.037794
input_vec.shape:  torch.Size([1, 53, 16])
zm:  torch.Size([53, 2]) zc:  torch.Size([53, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 4/93, loss = 0.031063
input_vec.shape:  torch.Size([1, 54, 16])
zm:  torch.Size([54, 2]) zc:  torch.Size([54, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 5/93, loss = 0.039673
input_vec.shape:  torch.Size([1, 61, 16])
zm:  torch.Size([61, 2]) zc:  torch.Size([61, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 6/93, loss = 0.030629
input_vec.shape:  torch.Size([1, 78, 16])
zm:  torch.Size([78, 2]) zc:  torch.Size([78, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 7/93, loss = 0.024891
input_vec.shape:  torch.Size([1, 67, 16])
zm:  torch.Size([67, 2]) zc:  torch.Size([67, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 8/93, loss = 0.022346
input_vec.shape:  torch.Size([1, 61, 16])
zm:  torch.Size([61, 2]) zc:  torch.Size([61, 2])
epoch : 1/10, batch : 9/93, loss = 0.023670
input_vec.shape:  torch.Size([1, 56, 16])
zm:  torch.Size([56, 2]) zc:  torch.Size([56, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 10/93, loss = 0.023707
input_vec.shape:  torch.Size([1, 73, 16])
zm:  torch.Size([73, 2]) zc:  torch.Size([73, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 11/93, loss = 0.022754
input_vec.shape:  torch.Size([1, 40, 16])
zm:  torch.Size([40, 2]) zc:  torch.Size([40, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 12/93, loss = 0.019051
input_vec.shape:  torch.Size([1, 36, 16])
zm:  torch.Size([36, 2]) zc:  torch.Size([36, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 13/93, loss = 0.018483
input_vec.shape:  torch.Size([1, 48, 16])
zm:  torch.Size([48, 2]) zc:  torch.Size([48, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 14/93, loss = 0.018535
input_vec.shape:  torch.Size([1, 48, 16])
zm:  torch.Size([48, 2]) zc:  torch.Size([48, 2])
epoch : 1/10, batch : 15/93, loss = 0.018168
input_vec.shape:  torch.Size([1, 52, 16])
zm:  torch.Size([52, 2]) zc:  torch.Size([52, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 16/93, loss = 0.019286
input_vec.shape:  torch.Size([1, 51, 16])
zm:  torch.Size([51, 2]) zc:  torch.Size([51, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 17/93, loss = 0.013351
input_vec.shape:  torch.Size([1, 72, 16])
zm:  torch.Size([72, 2]) zc:  torch.Size([72, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 18/93, loss = 0.016024
input_vec.shape:  torch.Size([1, 62, 16])
zm:  torch.Size([62, 2]) zc:  torch.Size([62, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 19/93, loss = 0.016205
input_vec.shape:  torch.Size([1, 42, 16])
zm:  torch.Size([42, 2]) zc:  torch.Size([42, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 20/93, loss = 0.012707
input_vec.shape:  torch.Size([1, 36, 16])
zm:  torch.Size([36, 2]) zc:  torch.Size([36, 2])
epoch : 1/10, batch : 21/93, loss = 0.013432
input_vec.shape:  torch.Size([1, 72, 16])
zm:  torch.Size([72, 2]) zc:  torch.Size([72, 2])
epoch : 1/10, batch : 22/93, loss = 0.010991
input_vec.shape:  torch.Size([1, 52, 16])
zm:  torch.Size([52, 2]) zc:  torch.Size([52, 2])
epoch : 1/10, batch : 23/93, loss = 0.010503
input_vec.shape:  torch.Size([1, 43, 16])
zm:  torch.Size([43, 2]) zc:  torch.Size([43, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 24/93, loss = 0.011257
input_vec.shape:  torch.Size([1, 56, 16])
zm:  torch.Size([56, 2]) zc:  torch.Size([56, 2])
epoch : 1/10, batch : 25/93, loss = 0.009426
input_vec.shape:  torch.Size([1, 79, 16])
zm:  torch.Size([79, 2]) zc:  torch.Size([79, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 26/93, loss = 0.011463
input_vec.shape:  torch.Size([1, 54, 16])
zm:  torch.Size([54, 2]) zc:  torch.Size([54, 2])
epoch : 1/10, batch : 27/93, loss = 0.009734
input_vec.shape:  torch.Size([1, 56, 16])
zm:  torch.Size([56, 2]) zc:  torch.Size([56, 2])
epoch : 1/10, batch : 28/93, loss = 0.009656
input_vec.shape:  torch.Size([1, 28, 16])
zm:  torch.Size([28, 2]) zc:  torch.Size([28, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 29/93, loss = 0.016158
input_vec.shape:  torch.Size([1, 60, 16])
zm:  torch.Size([60, 2]) zc:  torch.Size([60, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 30/93, loss = 0.010955
input_vec.shape:  torch.Size([1, 39, 16])
zm:  torch.Size([39, 2]) zc:  torch.Size([39, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 31/93, loss = 0.010394
input_vec.shape:  torch.Size([1, 42, 16])
zm:  torch.Size([42, 2]) zc:  torch.Size([42, 2])
epoch : 1/10, batch : 32/93, loss = 0.009566
input_vec.shape:  torch.Size([1, 51, 16])
zm:  torch.Size([51, 2]) zc:  torch.Size([51, 2])
epoch : 1/10, batch : 33/93, loss = 0.007513
input_vec.shape:  torch.Size([1, 42, 16])
zm:  torch.Size([42, 2]) zc:  torch.Size([42, 2])
epoch : 1/10, batch : 34/93, loss = 0.011007
input_vec.shape:  torch.Size([1, 51, 16])
zm:  torch.Size([51, 2]) zc:  torch.Size([51, 2])
epoch : 1/10, batch : 35/93, loss = 0.014325
input_vec.shape:  torch.Size([1, 48, 16])
zm:  torch.Size([48, 2]) zc:  torch.Size([48, 2])
epoch : 1/10, batch : 36/93, loss = 0.007760
input_vec.shape:  torch.Size([1, 44, 16])
zm:  torch.Size([44, 2]) zc:  torch.Size([44, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 37/93, loss = 0.008817
input_vec.shape:  torch.Size([1, 41, 16])
zm:  torch.Size([41, 2]) zc:  torch.Size([41, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 38/93, loss = 0.011207
input_vec.shape:  torch.Size([1, 45, 16])
zm:  torch.Size([45, 2]) zc:  torch.Size([45, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 39/93, loss = 0.012596
input_vec.shape:  torch.Size([1, 43, 16])
zm:  torch.Size([43, 2]) zc:  torch.Size([43, 2])
epoch : 1/10, batch : 40/93, loss = 0.010677
input_vec.shape:  torch.Size([1, 45, 16])
zm:  torch.Size([45, 2]) zc:  torch.Size([45, 2])
epoch : 1/10, batch : 41/93, loss = 0.008469
input_vec.shape:  torch.Size([1, 43, 16])
zm:  torch.Size([43, 2]) zc:  torch.Size([43, 2])
epoch : 1/10, batch : 42/93, loss = 0.009594
input_vec.shape:  torch.Size([1, 67, 16])
zm:  torch.Size([67, 2]) zc:  torch.Size([67, 2])
epoch : 1/10, batch : 43/93, loss = 0.007644
input_vec.shape:  torch.Size([1, 59, 16])
zm:  torch.Size([59, 2]) zc:  torch.Size([59, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 44/93, loss = 0.009128
input_vec.shape:  torch.Size([1, 63, 16])
zm:  torch.Size([63, 2]) zc:  torch.Size([63, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 45/93, loss = 0.007187
input_vec.shape:  torch.Size([1, 62, 16])
zm:  torch.Size([62, 2]) zc:  torch.Size([62, 2])
epoch : 1/10, batch : 46/93, loss = 0.006887
input_vec.shape:  torch.Size([1, 57, 16])
zm:  torch.Size([57, 2]) zc:  torch.Size([57, 2])
epoch : 1/10, batch : 47/93, loss = 0.006566
input_vec.shape:  torch.Size([1, 49, 16])
zm:  torch.Size([49, 2]) zc:  torch.Size([49, 2])
epoch : 1/10, batch : 48/93, loss = 0.007076
input_vec.shape:  torch.Size([1, 45, 16])
zm:  torch.Size([45, 2]) zc:  torch.Size([45, 2])
epoch : 1/10, batch : 49/93, loss = 0.009836
input_vec.shape:  torch.Size([1, 54, 16])
zm:  torch.Size([54, 2]) zc:  torch.Size([54, 2])
epoch : 1/10, batch : 50/93, loss = 0.011414
input_vec.shape:  torch.Size([1, 55, 16])
zm:  torch.Size([55, 2]) zc:  torch.Size([55, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 51/93, loss = 0.006343
input_vec.shape:  torch.Size([1, 31, 16])
zm:  torch.Size([31, 2]) zc:  torch.Size([31, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 52/93, loss = 0.016058
input_vec.shape:  torch.Size([1, 39, 16])
zm:  torch.Size([39, 2]) zc:  torch.Size([39, 2])
epoch : 1/10, batch : 53/93, loss = 0.007996
input_vec.shape:  torch.Size([1, 53, 16])
zm:  torch.Size([53, 2]) zc:  torch.Size([53, 2])
epoch : 1/10, batch : 54/93, loss = 0.006356
input_vec.shape:  torch.Size([1, 60, 16])
zm:  torch.Size([60, 2]) zc:  torch.Size([60, 2])
epoch : 1/10, batch : 55/93, loss = 0.004851
input_vec.shape:  torch.Size([1, 54, 16])
zm:  torch.Size([54, 2]) zc:  torch.Size([54, 2])
epoch : 1/10, batch : 56/93, loss = 0.006646
input_vec.shape:  torch.Size([1, 46, 16])
zm:  torch.Size([46, 2]) zc:  torch.Size([46, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 57/93, loss = 0.009529
input_vec.shape:  torch.Size([1, 46, 16])
zm:  torch.Size([46, 2]) zc:  torch.Size([46, 2])
epoch : 1/10, batch : 58/93, loss = 0.005207
input_vec.shape:  torch.Size([1, 52, 16])
zm:  torch.Size([52, 2]) zc:  torch.Size([52, 2])
epoch : 1/10, batch : 59/93, loss = 0.014235
input_vec.shape:  torch.Size([1, 38, 16])
zm:  torch.Size([38, 2]) zc:  torch.Size([38, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 60/93, loss = 0.008501
input_vec.shape:  torch.Size([1, 67, 16])
zm:  torch.Size([67, 2]) zc:  torch.Size([67, 2])
epoch : 1/10, batch : 61/93, loss = 0.006371
input_vec.shape:  torch.Size([1, 47, 16])
zm:  torch.Size([47, 2]) zc:  torch.Size([47, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 62/93, loss = 0.007809
input_vec.shape:  torch.Size([1, 64, 16])
zm:  torch.Size([64, 2]) zc:  torch.Size([64, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 63/93, loss = 0.005959
input_vec.shape:  torch.Size([1, 70, 16])
zm:  torch.Size([70, 2]) zc:  torch.Size([70, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 64/93, loss = 0.005222
input_vec.shape:  torch.Size([1, 39, 16])
zm:  torch.Size([39, 2]) zc:  torch.Size([39, 2])
epoch : 1/10, batch : 65/93, loss = 0.008535
input_vec.shape:  torch.Size([1, 62, 16])
zm:  torch.Size([62, 2]) zc:  torch.Size([62, 2])
epoch : 1/10, batch : 66/93, loss = 0.007346
input_vec.shape:  torch.Size([1, 45, 16])
zm:  torch.Size([45, 2]) zc:  torch.Size([45, 2])
epoch : 1/10, batch : 67/93, loss = 0.011401
input_vec.shape:  torch.Size([1, 39, 16])
zm:  torch.Size([39, 2]) zc:  torch.Size([39, 2])
epoch : 1/10, batch : 68/93, loss = 0.005404
input_vec.shape:  torch.Size([1, 61, 16])
zm:  torch.Size([61, 2]) zc:  torch.Size([61, 2])
epoch : 1/10, batch : 69/93, loss = 0.005775
input_vec.shape:  torch.Size([1, 57, 16])
zm:  torch.Size([57, 2]) zc:  torch.Size([57, 2])
epoch : 1/10, batch : 70/93, loss = 0.004638
input_vec.shape:  torch.Size([1, 63, 16])
zm:  torch.Size([63, 2]) zc:  torch.Size([63, 2])
epoch : 1/10, batch : 71/93, loss = 0.004

  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 72/93, loss = 0.006381
input_vec.shape:  torch.Size([1, 73, 16])
zm:  torch.Size([73, 2]) zc:  torch.Size([73, 2])
epoch : 1/10, batch : 73/93, loss = 0.015451
input_vec.shape:  torch.Size([1, 55, 16])
zm:  torch.Size([55, 2]) zc:  torch.Size([55, 2])
epoch : 1/10, batch : 74/93, loss = 0.005373
input_vec.shape:  torch.Size([1, 37, 16])
zm:  torch.Size([37, 2]) zc:  torch.Size([37, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


epoch : 1/10, batch : 75/93, loss = 0.005963
input_vec.shape:  torch.Size([1, 51, 16])
zm:  torch.Size([51, 2]) zc:  torch.Size([51, 2])
epoch : 1/10, batch : 76/93, loss = 0.012251
input_vec.shape:  torch.Size([1, 41, 16])
zm:  torch.Size([41, 2]) zc:  torch.Size([41, 2])
epoch : 1/10, batch : 77/93, loss = 0.006480
input_vec.shape:  torch.Size([1, 55, 16])
zm:  torch.Size([55, 2]) zc:  torch.Size([55, 2])
epoch : 1/10, batch : 78/93, loss = 0.004360
input_vec.shape:  torch.Size([1, 60, 16])
zm:  torch.Size([60, 2]) zc:  torch.Size([60, 2])
epoch : 1/10, batch : 79/93, loss = 0.008661
input_vec.shape:  torch.Size([1, 40, 16])
zm:  torch.Size([40, 2]) zc:  torch.Size([40, 2])
epoch : 1/10, batch : 80/93, loss = 0.005735
input_vec.shape:  torch.Size([1, 56, 16])
zm:  torch.Size([56, 2]) zc:  torch.Size([56, 2])
epoch : 1/10, batch : 81/93, loss = 0.012781
input_vec.shape:  torch.Size([1, 63, 16])
zm:  torch.Size([63, 2]) zc:  torch.Size([63, 2])
epoch : 1/10, batch : 82/93, loss = 0.008

epoch : 2/10, batch : 42/93, loss = 0.009010
input_vec.shape:  torch.Size([1, 61, 16])
zm:  torch.Size([61, 2]) zc:  torch.Size([61, 2])
epoch : 2/10, batch : 43/93, loss = 0.006455
input_vec.shape:  torch.Size([1, 50, 16])
zm:  torch.Size([50, 2]) zc:  torch.Size([50, 2])
epoch : 2/10, batch : 44/93, loss = 0.005792
input_vec.shape:  torch.Size([1, 31, 16])
zm:  torch.Size([31, 2]) zc:  torch.Size([31, 2])
epoch : 2/10, batch : 45/93, loss = 0.011661
input_vec.shape:  torch.Size([1, 70, 16])
zm:  torch.Size([70, 2]) zc:  torch.Size([70, 2])
epoch : 2/10, batch : 46/93, loss = 0.003552
input_vec.shape:  torch.Size([1, 40, 16])
zm:  torch.Size([40, 2]) zc:  torch.Size([40, 2])
epoch : 2/10, batch : 47/93, loss = 0.005535
input_vec.shape:  torch.Size([1, 78, 16])
zm:  torch.Size([78, 2]) zc:  torch.Size([78, 2])
epoch : 2/10, batch : 48/93, loss = 0.008981
input_vec.shape:  torch.Size([1, 54, 16])
zm:  torch.Size([54, 2]) zc:  torch.Size([54, 2])
epoch : 2/10, batch : 49/93, loss = 0.005

epoch : 3/10, batch : 9/93, loss = 0.012222
input_vec.shape:  torch.Size([1, 57, 16])
zm:  torch.Size([57, 2]) zc:  torch.Size([57, 2])
epoch : 3/10, batch : 10/93, loss = 0.003771
input_vec.shape:  torch.Size([1, 53, 16])
zm:  torch.Size([53, 2]) zc:  torch.Size([53, 2])
epoch : 3/10, batch : 11/93, loss = 0.004710
input_vec.shape:  torch.Size([1, 61, 16])
zm:  torch.Size([61, 2]) zc:  torch.Size([61, 2])
epoch : 3/10, batch : 12/93, loss = 0.005590
input_vec.shape:  torch.Size([1, 45, 16])
zm:  torch.Size([45, 2]) zc:  torch.Size([45, 2])
epoch : 3/10, batch : 13/93, loss = 0.009970
input_vec.shape:  torch.Size([1, 40, 16])
zm:  torch.Size([40, 2]) zc:  torch.Size([40, 2])
epoch : 3/10, batch : 14/93, loss = 0.004371
input_vec.shape:  torch.Size([1, 67, 16])
zm:  torch.Size([67, 2]) zc:  torch.Size([67, 2])
epoch : 3/10, batch : 15/93, loss = 0.004734
input_vec.shape:  torch.Size([1, 48, 16])
zm:  torch.Size([48, 2]) zc:  torch.Size([48, 2])
epoch : 3/10, batch : 16/93, loss = 0.0037

KeyboardInterrupt: 