In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import torch.utils.data
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from model import Model

class PermuteStack(object):
    def __init__(self):
        pass

    def __call__(self, sample):
        d, h, w, c = sample.shape
        assert c == 3
        sample = sample.permute(3, 0, 1, 2)
        return sample.reshape(c*d, h, w)

class ToFloat32(object):
    def __init__(self):
        pass

    def __call__(self, sample):
        return sample.to(torch.float32) / 255.0

class AdaptiveResize(object):
    def __init__(self, output_size):
        self.pool = nn.AdaptiveAvgPool2d(output_size)

    def __call__(self, sample):
        return self.pool(sample)

class Unstack(object):
    def __init__(self):
        pass

    def __call__(self, sample):
        cd, h, w = sample.shape
        assert cd % 3 == 0
        return sample.view(3, cd//3, h, w)

def get_same_index(target, label):
    label_indices = []
    for i in range(len(target)):
        if target[i] == label:
            label_indices.append(i)
    return label_indices

frames_per_clip = 8
step_between_clips = 1
batch_size = 16
shuffle = True
device = torch.device("cuda:0")

model = Model(max_scale=4,
              steps_per_scale=int(25e3),
              lr=1e-3,
              frames_per_clip=frames_per_clip)

data_transform = transforms.Compose([PermuteStack(),
                                     ToFloat32(),
                                     AdaptiveResize((64, 64)),
                                     Unstack(),
                                     ])
hmdb51_data_0 = torchvision.datasets.HMDB51(root="hmdb51/data_0",
                                          annotation_path="hmdb51/annotation_0",
                                          frames_per_clip=frames_per_clip,
                                          step_between_clips=step_between_clips,
                                          transform=data_transform
                                          )
data_loader_0 = torch.utils.data.DataLoader(hmdb51_data_0,
                                          batch_size=batch_size,
                                          shuffle=shuffle,
                                          )
hmdb51_data_1 = torchvision.datasets.HMDB51(root="hmdb51/data_1",
                                          annotation_path="hmdb51/annotation_1",
                                          frames_per_clip=frames_per_clip,
                                          step_between_clips=step_between_clips,
                                          transform=data_transform
                                          )
data_loader_1 = torch.utils.data.DataLoader(hmdb51_data_1,
                                          batch_size=batch_size,
                                          shuffle=shuffle,
                                          )

In [None]:
video_0, _, _ = next(iter(data_loader_0))
label_0 = torch.zeros(len(video_0))
plt.imshow(video_0[0][:,0,:,:].permute(1, 2, 0))
print(f"Video batch shape (N, C, D, H, W) :{video_0.shape}, Labels:{label_0}")

In [None]:
video_1, _, _ = next(iter(data_loader_1))
label_1 = torch.ones(len(video_1))
plt.imshow(video_1[0][:,0,:,:].permute(1, 2, 0))
print(f"Video batch shape (N, C, D, H, W) :{video_1.shape}, Labels:{label_1}")

In [None]:
for step_i in range(int(400e3)):
    video_0, _, _ = next(iter(data_loader_0))
    label_0 = torch.zeros((len(video_0), 1))
    video_1, _, _ = next(iter(data_loader_1))
    label_1 = torch.ones((len(video_1), 1))
    model.train_step(video_0.to(device), label_0.to(device), video_1.to(device), label_1.to(device))

model.save()