# Library import

In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import argparse
import os
import random
from torch.autograd import Variable
from torch.utils.data import DataLoader
import numpy as np
import time
from skimage.measure import compare_ssim
from tqdm import trange
import easydict

# Local Module
import pssim.pytorch_ssim as pytorch_ssim
import utils.data_utils as data_utils
import utils.utils as utils
import utils.utils_3d as utils_3d

# Option parameter

In [2]:
opt = easydict.EasyDict({})
opt.lr = 0.0005
opt.beta1 = 0.9
opt.batch_size = 7
opt.log_dir = 'logs'
opt.model_dir = ''
opt.name = ''
opt.data_root = 'data'
opt.optimizer = optim.Adam
opt.data_type = 'sequence'
opt.niter = 60
opt.epoch_size =5000
opt.image_width = 64
opt.channels = 1
opt.dataset = 'smmnist'
opt.n_past = 8
opt.n_future = 10
opt.n_eval = 18
opt.rnn_size = 32
opt.predictor_rnn_layers = 8
opt.beta = 0.0001
opt.model = 'crevnet'
opt.data_threads = 0
opt.num_digits = 2
opt.max_step = opt.n_past + opt.n_future + 2

# Random seed setting
opt.seed = 1
random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed_all(opt.seed)
dtype = torch.cuda.FloatTensor

# Folder naming

In [3]:
if opt.model_dir != '':
    saved_model = torch.load('%s/model.pth' % opt.model_dir)
    optimizer = opt.optimizer
    model_dir = opt.model_dir
    opt = saved_model['opt']
    opt.optimizer = optimizer
    opt.model_dir = model_dir
    opt.log_dir = '%s/continued' % opt.log_dir
else:
    name = 'model_mnist=layers_%s=seq_len_%s=batch_size_%s' % (opt.predictor_rnn_layers,opt.n_eval,opt.batch_size)
    if opt.dataset == 'smmnist':
        opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name)
    else:
        opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name)

os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True)
os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True)

# Model setting

In [4]:
from AutoEncoder.AutoEncoder import AutoEncoder
from RPM.ReversiblePredictor import ReversiblePredictor

frame_predictor = ReversiblePredictor(input_size=opt.rnn_size,
                                      hidden_size=opt.rnn_size, 
                                      output_size=opt.rnn_size, 
                                      n_layers=opt.predictor_rnn_layers, 
                                      batch_size=opt.batch_size)

encoder = AutoEncoder(nBlocks=[4,5,3], 
                      nStrides=[1, 2, 2],
                      nChannels=None,
                      init_ds=2,
                      dropout_rate=0., 
                      affineBN=True, 
                      in_shape=[opt.channels, opt.image_width, opt.image_width],
                      mult=2)

frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

scheduler1 = torch.optim.lr_scheduler.StepLR(frame_predictor_optimizer, step_size=50, gamma=0.2)
scheduler2 = torch.optim.lr_scheduler.StepLR(encoder_optimizer, step_size=50, gamma=0.2)

mse_criterion = nn.MSELoss()

# Transfer to GPU

In [5]:
frame_predictor.cuda()
encoder.cuda()
mse_criterion.cuda()

MSELoss()

# Dataset loading

In [6]:
train_data, test_data = data_utils.load_dataset(opt)

train_loader = DataLoader(train_data,
                          num_workers=opt.data_threads,
                          batch_size=opt.batch_size,
                          shuffle=True,
                          drop_last=True,
                          pin_memory=False)
test_loader = DataLoader(test_data,
                         num_workers=opt.data_threads,
                         batch_size=opt.batch_size,
                         shuffle=True,
                         drop_last=True,
                         pin_memory=False)

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data\MNIST\raw\train-images-idx3-ubyte.gz


9920512it [00:02, 3386718.75it/s]                                                                                      


Extracting data\MNIST\raw\train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data\MNIST\raw\train-labels-idx1-ubyte.gz


32768it [00:00, 62667.29it/s]                                                                                          
0it [00:00, ?it/s]

Extracting data\MNIST\raw\train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data\MNIST\raw\t10k-images-idx3-ubyte.gz


1654784it [00:01, 1180975.45it/s]                                                                                      
0it [00:00, ?it/s]

Extracting data\MNIST\raw\t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data\MNIST\raw\t10k-labels-idx1-ubyte.gz


8192it [00:00, 18761.29it/s]                                                                                           

Extracting data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Processing...
Done!





# Batch generating

In [7]:
def get_training_batch():
    while True:
        for sequence in train_loader:
            batch = data_utils.normalize_data(opt, dtype, sequence)
            yield batch
            
def get_testing_batch():
    while True:
        for sequence in test_loader:
            batch = data_utils.normalize_data(opt, dtype, sequence)
            yield batch
            
training_batch_generator = get_training_batch()
testing_batch_generator = get_testing_batch()