In [1]:
import os
import glob
import cv2

import torch
import torchvision

from dataset import WeizmannHumanActionVideo
from image_autoencoder import ImageAutoEncoder

In [2]:
# use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device: ", device)

device:  cuda:0


In [3]:
"""
trans_data = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
"""

trans_data = torchvision.transforms.ToTensor()
trans_label = None

dataset = WeizmannHumanActionVideo(trans_data=None, trans_label=trans_label, train=True)

In [4]:
# train-test split
train_size = int(1.0 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
print("train: ", len(train_dataset))
print("test: ", len(test_dataset))

train:  93
test:  0


In [5]:
batch_size=1
n_epochs=10

In [6]:
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True, 
                                           num_workers=4)

"""
test_loader = torch.utils.data.DataLoader(test_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True, 
                                           num_workers=1)
"""

'\ntest_loader = torch.utils.data.DataLoader(test_dataset, \n                                           batch_size=batch_size, \n                                           shuffle=True, \n                                           num_workers=1)\n'

In [7]:
model_path = "trained_models/image_autoencoder.pth"
model = ImageAutoEncoder(n_channel=3, dim_zm=2, dim_zc=2).to(device)
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [8]:
model.eval()

ImageAutoEncoder(
  (encoder_zm): Sequential(
    (0): Conv2d(3, 256, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(256, 128, kernel_size=(4, 4), stride=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(128, 64, kernel_size=(4, 4), stride=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Conv2d(64, 64, kernel_size=(4, 4), stride=(1, 1))
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): LeakyReLU(negative_slope=0.2, inplace=True)
    (12): Conv2d(64, 1, kernel_size=(4, 4), stride=(1, 1))
    (13): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

In [9]:
from torchvision.utils import save_image

In [18]:
epoch=0

with torch.no_grad():
    for batch_idx, (batch_data, _) in enumerate(train_loader):
     
        # x: 4D-Tensor (video_len, channel, height, width)
        x = torch.squeeze(batch_data, dim=0).to(device)
        x_hat_z, x_hat_zc, zc = model(x)
        
        if batch_idx == 0:
            n = x.size(0)
            print(n)
            print(x.shape, x_hat_z.shape, x_hat_zc.shape)
            # comparison = torch.cat([x[:n], x_hat_z.view(x.shape[0], 3, 96, 96)[:n], x_hat_zc.view(x.shape[0], 3, 96, 96)[:n]])
            # comparison = torch.cat([x[:n], x_hat_z[:n], x_hat_zc[:n]])
            comparison = torch.cat([x[:n], x_hat_z[:n]])
            save_image(comparison.cpu(), 'results/reconstruction_' + str(epoch) + '.png', nrow=n)

67
torch.Size([67, 3, 96, 96]) torch.Size([67, 3, 96, 96]) torch.Size([67, 2])
