In [None]:
import os
import glob
import sys
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch import nn
from torch.backends import cudnn
import skimage.io
from comet_ml import Experiment

In [None]:
# experiment = Experiment(api_key="E3oWJUSFulpXpCUQfc5oGz0zY", project_name="obama_video")

In [None]:
cudnn.benchmark = True

In [None]:
device = torch.device("cuda:0")

In [None]:
img_size = 256
channels = 3
seq_length = 20
latent_dim = 128

In [None]:
class VideoDataset(Dataset):
    def __init__(self, root_dir, seq_length, transform=None):
        self.root_dir = root_dir
        self.seq_length = seq_length
        self.transform = transform
        self.filenames = sorted(glob.glob(os.path.join(root_dir, "*.png")))
    def __len__(self):
        return len(self.filenames) - (seq_length - 1)
    def __getitem__(self, idx):
        images = [skimage.io.imread(self.filenames[idx+i]) for i in range(seq_length)]
        if self.transform:
            images = list(map(self.transform, images))
        else:
            images = list(map(transforms.ToTensor(), images))
        return torch.stack(images)

In [None]:
dataset = VideoDataset("/home/santiago/Downloads/obama/images/", 20, transform=transforms.Compose([
    transforms.ToPILImage(),
    transforms.CenterCrop(1080),
    transforms.Resize(img_size),
    transforms.ToTensor()
]))

In [None]:
batch_size = 4
workers = 4
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)

In [None]:
class VideoAutoencoder(nn.Module):
    def __init__(self, img_size, latent_dim):
        self.img_size = img_size
        self.ds_size = self.img_size // 2**5
        self.latent_dim = latent_dim
        super(VideoAutoencoder, self).__init__()
        
        self.enc_conv = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.Conv2d(16, 16, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.1),
            nn.BatchNorm2d(16, 0.8),
            
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.Conv2d(32, 32, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.1),
            nn.BatchNorm2d(32, 0.8),
            
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.1),
            nn.BatchNorm2d(64, 0.8),
            
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.Conv2d(128, 128, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.1),
            nn.BatchNorm2d(128, 0.8),
            
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.Conv2d(256, 256, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.1),
            nn.BatchNorm2d(256, 0.8)
        )
        self.enc_proj = nn.Sequential(
            nn.Linear(256*self.ds_size**2, latent_dim),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.1),
#             nn.BatchNorm2d(latent_dim, 0.8)
        )
        self.enc_lstm = nn.LSTM(latent_dim, latent_dim, 1, dropout=0.1)
        
        self.dec_lstm = nn.LSTMCell(latent_dim, latent_dim)
        self.dec_proj = nn.Sequential(
#             nn.BatchNorm2d(128, 0.8),
            nn.Linear(latent_dim, 256*self.ds_size**2),
            nn.ReLU(inplace=True)
        )
        self.dec_conv = nn.Sequential(
            nn.BatchNorm2d(256, 0.8),
            nn.ConvTranspose2d(256, 256, 3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(256, 128, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.BatchNorm2d(128, 0.8),
            nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(128, 64, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.BatchNorm2d(64, 0.8),
            nn.ConvTranspose2d(64, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(64, 32, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.BatchNorm2d(32, 0.8),
            nn.ConvTranspose2d(32, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(32, 16, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.BatchNorm2d(16, 0.8),
            nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(16, 3, 3, stride=1, padding=1),
            nn.Sigmoid()
        )
    
    def encode(self, x):
        x = x.permute(1, 0, 2, 3, 4)
        enc_reps = []
        for step in x.split(1):
            step = step[0]
            conv = self.enc_conv(step)
            conv = conv.view(conv.shape[0], 256*self.ds_size**2)
            proj = self.enc_proj(conv)
            enc_reps.append(proj)
        out, states = self.enc_lstm(torch.stack(enc_reps))
        return states[0][0]  # hidden state
    
    def decode(self, h, steps):
        init = torch.zeros_like(h)
        c = torch.zeros_like(h)
        decoded = []
        for i in range(steps):
            h, c = self.dec_lstm(init, (h, c))
            proj = self.dec_proj(step)
            proj = proj.view(proj.shape[0], 256, self.ds_size, self.ds_size)
            dec = self.dec_conv(proj)
            decoded.append(dec)
        return torch.stack(reversed(decoded)).permute(1, 0, 2, 3, 4)
    
    def forward(self, x):
        h = self.encode(x)
        y = self.decode(h, x.shape[1])
        loss = F.l1_loss(x, y) ** 1.1
        return loss

In [None]:
model = VideoAutoencoder(img_size, latent_dim).cuda()

In [None]:
print(model)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
epochs = 1

In [None]:
for epoch in range(epochs):
    for i, data in enumerate(dataloader):
        inputs = Variable(data.cuda())
        optimizer.zero_grad()
        loss = model.forward(inputs)
        print(loss)