In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam

import os
import time
import math
import numpy as np
import skimage.io as io

from di_dataset2 import DepthImageDataset, collate_batch

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Hyperparams

In [None]:
latent_dim = 200
num_epochs = 100
batch_size = 32
learning_rate = 1e-3

save_model = True
load_model = False

save_model_file = "../../../vae_models/vae_resnet_full.pth"
load_model_file = "../../../vae_models/vae_resnet.pth"

# Load Dataset

In [None]:
tfrecord_folder='../../../rl_data/tfrecord'
tfrecord_test_folder='../../../rl_data/tfrecord/test'

In [None]:
train_dataset = DepthImageDataset(tfrecord_folder=tfrecord_folder, batch_size=batch_size) # 176 tfrecords
test_dataset = DepthImageDataset(tfrecord_folder=tfrecord_test_folder, batch_size=batch_size) # 20 tfrecords

In [None]:
# len_train_dataset = sum(1 for _ in train_dataset) # 11223 batches (of 32)
# len_test_dataset = sum(1 for _ in test_dataset) # 1305 batches (of 32)
len_train_dataset, len_test_dataset = (69, 66) #(11223, 1305), 1 tfrecord will have varying lengths

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=1, collate_fn=collate_batch)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, collate_fn=collate_batch)

In [None]:
dataiter = iter(train_loader)
image_batch, *_ = dataiter.next() # image, height, width, depth
image_batch.shape

In [None]:
def imshow(image):
    io.imshow(image.squeeze().numpy())
    io.show()

for image in image_batch[0:4]:
    imshow(image)
    
image.shape

In [None]:
image_batch.squeeze().mean(), image_batch.squeeze().var()

# Define Variational Autoencoder

Adapted from https://github.com/microsoft/AirSim-Drone-Racing-VAE-Imitation/blob/master/racing_models/cmvae.py

### Define Model

In [None]:
vae_model = VAE()

if load_model:
    vae_model.load_state_dict(torch.load(load_model_file))
    vae_model.eval()
else:
    vae_model.train()

optimiser = torch.optim.Adam(vae_model.parameters(), lr=learning_rate)

In [None]:
from torchsummary import summary
summary(vae_model, (1, 270, 480))

# Training

In [None]:
def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.BCELoss(reduction='sum')(x_hat, x)
    KLD      = 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss - KLD

In [None]:
n_training_samples = len_train_dataset * 32 # 32 samples per batch
n_iterations = len_train_dataset
n_training_samples, n_iterations

In [None]:
time_iteration = []
for epoch in range(num_epochs):
    overall_loss = 0
    overall_elbo = 0
    since = time.time()
    
    train_dataset = DepthImageDataset(tfrecord_folder=tfrecord_folder, batch_size=batch_size)
    train_loader = DataLoader(dataset=train_dataset, batch_size=1, collate_fn=collate_batch)
    
    for i, (image, *_) in enumerate(train_loader):
        since_iter = time.time()

        optimiser.zero_grad()

        # VAE forward pass
        mu, logvar, x_hat = vae_model(image)

        # Loss
        loss = loss_function(image, x_hat, mu, logvar)

        overall_loss += loss

        # Update weights
        
        loss.backward()
        optimiser.step()
        
        
        time_iteration.append(time.time() - since_iter)
        iter_time_mean = np.array(time_iteration).mean()
        
        if (i+1) % 1 == 0:
            time_elapsed = time.time() - since
            print(f"Epoch: {epoch+1}/{num_epochs}, Step: {i+1}/{n_iterations}, Avg loss: {overall_loss/((i+1)*batch_size):.3f}, time: {time_elapsed:.2f}, Avg. per iter {iter_time_mean:.2f}, Est. time left {iter_time_mean*(n_iterations - (i+1)):.2f}")
        
    if (i+1) % 10 == 0:
        imshow(image[0])
        imshow(x_hat[0].detach())
            

In [None]:
if save_model:
    torch.save(vae_model.state_dict(), save_model_file)

# Evaluation

In [None]:
vae_model.eval()

# Visualise sample of images
with torch.no_grad():
    
    images, *_ = next(iter(test_loader))

    _, _, x_hat = vae_model(images)

    print(images.mean(), x_hat.mean())
    print(images.var(), x_hat.var())
    print(images.shape)
    for idx in range(len(images)):
        # show images
        imshow(images[idx])
        imshow(x_hat[idx])

In [None]:
vae_model.eval()
# Visualise sample of images

with torch.no_grad():
        
    images, *_ = next(iter(train_loader))
    # VAE forward pass
    _, _, x_hat = vae_model(images)

    print(images.mean(), x_hat.mean())
    print(images.var(), x_hat.var())
    print(images.shape)
    for idx in range(len(images)):
        # show images
        imshow(images[idx])
        imshow(x_hat[idx])
