### Importing required packages and functions

In [None]:
import sys
sys.path.insert(0, '..')

import os

from functools import reduce
from config.paths import Path
from config.constants import Constants
from dataset.frame_dataset import FrameDataset
from logger.train_logger import TrainLogger

In [None]:
def get_factors(n):
    return set(reduce(list.__add__, 
                ([i, n//i] for i in range(1, int(n**0.5) + 1) if n % i == 0)))

### Configurable Parameters for Experiments

In [None]:
SECONDS_PER_CLIP = Constants.SecondsPerClip.THREE_SEC # can be 1, 3, 5, 10
CLOSEST_BATCH_SIZE = 15
HOME_PATH = Path.DATA_HOME
MODEL_PATH = Path.AUTOENCODER_MODEL_PATH
WRITE_ENCODER_PATH = MODEL_PATH.format(sec=SECONDS_PER_CLIP, module='encoder')
WRITE_DECODER_PATH = MODEL_PATH.format(sec=SECONDS_PER_CLIP, module='decoder')

In [None]:
frame_dataset = FrameDataset(SECONDS_PER_CLIP)

In [None]:
factors = get_factors(frame_dataset.__len__())
BATCH_SIZE = min(factors, key=lambda x:abs(x-CLOSEST_BATCH_SIZE))
print(factors, '\n', BATCH_SIZE)

### Define Autoencoder model

In [None]:
import torch
import torch.nn as nn

from torch.optim import Adagrad
from torch.autograd import Variable
import torch.utils.data as data

use_cuda = torch.cuda.is_available()

In [None]:
class VideoEncoder(nn.Module):
    
    def __init__(self):
        super(VideoEncoder, self).__init__()
        self.lstm = nn.LSTM(512, 2048, 2)
        
    def forward(self, x):
        output, hidden = self.lstm(x)
        return output, hidden

In [None]:
class VideoDecoder(nn.Module):
    
    def __init__(self):
        super(VideoDecoder, self).__init__()
        self.lstm = nn.LSTM(2048, 2048, 2)
        self.linear = nn.Linear(2048, 512)
        
    def forward(self, x, hidden):
        output, hidden = self.lstm(x, hidden)
        output = self.linear(output)
        return output

In [None]:
def save_model(model_state_dict, optimizer_state_dict, epoch, path):
    checkpoint = {
        'epoch': epoch,
        'state_dict': model_state_dict,
        'optimizer': optimizer_state_dict
    }
    torch.save(checkpoint, path)

In [None]:
def load_model(path):
    checkpoint = torch.load(path)
    return checkpoint['state_dict'], checkpoint['optimizer']

In [None]:
encoder = VideoEncoder()
decoder = VideoDecoder()

loss_function = nn.MSELoss()

if use_cuda:
    encoder = encoder.cuda()
    decoder = decoder.cuda()
    loss_function = loss_function.cuda()
    
encoder_optimizer = Adagrad(encoder.parameters())
decoder_optimizer = Adagrad(encoder.parameters())

if not os.path.exists(os.path.dirname(HOME_PATH+WRITE_ENCODER_PATH)):
    os.makedirs(os.path.dirname(HOME_PATH+WRITE_ENCODER_PATH))

if not os.path.exists(os.path.dirname(HOME_PATH+WRITE_DECODER_PATH)):
    os.makedirs(os.path.dirname(HOME_PATH+WRITE_DECODER_PATH))
    
if (os.path.isfile(HOME_PATH+WRITE_ENCODER_PATH)):
    encoder_state, encoder_optimizer_state = load_model(HOME_PATH+WRITE_ENCODER_PATH)
    encoder.load_state_dict(encoder_state)
    encoder_optimizer.load_state_dict(encoder_optimizer_state)
    print('Encoder model found, loading saved state...')
    
if (os.path.isfile(HOME_PATH+WRITE_DECODER_PATH)):
    decoder_state, decoder_optimizer_state = load_model(HOME_PATH+WRITE_DECODER_PATH)
    decoder.load_state_dict(decoder_state)
    decoder_optimizer.load_state_dict(decoder_optimizer_state)
    print('Decoder model found, loading saved state...')

In [None]:
epochs = 50
print_every = 10

prev_epoch_loss = float('inf')
frame_dataloader = data.DataLoader(frame_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                                   num_workers=1)
train_logger = TrainLogger(BATCH_SIZE, print_every, frame_dataloader.__len__())

for e in range(epochs):
    for i, frame in enumerate(frame_dataloader):
        encoder.zero_grad()
        decoder.zero_grad()
        
        encoder_input = Variable(frame, requires_grad=True)
        target = Variable(frame)
        if use_cuda:
            encoder_input = encoder_input.cuda()
            target = target.cuda()
        encoder_input = encoder_input.view(encoder_input.shape[1], BATCH_SIZE, encoder_input.shape[2])
        target = target.view(target.shape[1], BATCH_SIZE, target.shape[2])
        encoder_output, encoder_hidden = encoder(encoder_input)

        decoder_output = decoder(encoder_output, encoder_hidden)
        loss = loss_function(decoder_output, target)
        epoch_loss = train_logger.update(e, i, decoder_output, target, loss)
        loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()
    
    if epoch_loss < prev_epoch_loss:
        save_model(encoder.state_dict(), encoder_optimizer.state_dict(), e, HOME_PATH+WRITE_ENCODER_PATH)
        save_model(decoder.state_dict(), decoder_optimizer.state_dict(), e, HOME_PATH+WRITE_DECODER_PATH)
        print('\n', (prev_epoch_loss/(frame_dataset.__len__()/BATCH_SIZE)), 
              (epoch_loss/(frame_dataset.__len__()/BATCH_SIZE)))
        prev_epoch_loss = epoch_loss
        train_logger.flush()
    else:
        break