In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
from piqa import SSIM
from tqdm import tqdm
import wandb
import matplotlib.pyplot as plt
import os
import csv

from datasets.image_vae_dataset import ImageVAEDataset, ImageVAEDatasetFull
from models.visual_autoencoder import VisualAutoencoder
import config as CFG


  from .autonotebook import tqdm as notebook_tqdm


## VisualAutoencoder

dataloder

In [2]:
caption_path_training = '{}/lang_annotations/auto_lang_ann.npy'.format(CFG.datapath_training)
train_dataset = ImageVAEDataset(dataset_path=CFG.datapath_training, caption_path=caption_path_training)
train_loader = DataLoader(train_dataset,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)

caption_path_val = '{}/lang_annotations/auto_lang_ann.npy'.format(CFG.datapath_val)
val_dataset = ImageVAEDataset(dataset_path=CFG.datapath_val, caption_path=caption_path_val)
val_loader = DataLoader(val_dataset,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)


train_dataset_full = ImageVAEDatasetFull(dataset_path=CFG.datapath_training)
train_loader_full = DataLoader(train_dataset_full,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)

val_dataset_full = ImageVAEDatasetFull(dataset_path=CFG.datapath_val)
val_loader_full = DataLoader(val_dataset_full,
                          #num_workers=16,
                          #prefetch_factor=6, 
                          batch_size=CFG.batch_size, 
                          shuffle=True)

test dataloader

In [None]:
for idx, batch in enumerate(train_loader):
    print("img_static: ", batch.img_static.shape)
    print("img_gripper: ", batch.img_gripper.shape)
    break

for idx, batch in enumerate(train_loader_full):
    print("img_static: ", batch.img_static.shape)
    print("img_gripper: ", batch.img_gripper.shape)
    break


Train VAE model

In [None]:
encoding_size = 512
learning_rate = 0.0001

class SSIMLoss(SSIM):
    def forward(self, x, y):
        return 1. - super().forward(x, y)

def visualize_reconstruction(original, reconstructed):
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(original.permute(1, 2, 0))
    axs[0].set_title('Original')
    axs[1].imshow(reconstructed.detach().permute(1, 2, 0))
    axs[1].set_title('Reconstructed')
    plt.close()
    # return figure for wandb
    return fig


def train():
    # Train the model
    step = 0
    checkpoint_step = 0;
    #use tqdm for progress bar
    for epoch in range(CFG.epochs):

        running_loss = 0.0
        print(f"Epoch {epoch+1}/{CFG.epochs}")
        
        #for i, data in enumerate(tqdm(train_loader)):
        for i, data in enumerate(tqdm(train_loader_full)):
            # Get the inputs and move them to the GPU if available
            inputs = data['img_static']
            inputs = inputs.to(CFG.device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)

            # Compute the loss
            reconstruction_loss = criterion(outputs, inputs)
            #ssim_loss = SSIM_criterion(outputs, inputs) * 0.001

            loss = reconstruction_loss #+ ssim_loss
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Log the loss
            wandb.log({"Total Loss": loss.item()}, step=step)
            wandb.log({"Reconstruction Loss": reconstruction_loss.item()}, step=step)
            #wandb.log({"SSIM Loss": ssim_loss.item()}, step=step)

            #if step % 321 == 0:
            if step % 2000 == 0:
                #visualize_reconstruction(inputs[0].detach().cpu(), outputs[0].detach().cpu())
                wandb.log({"Training Images": wandb.Image(visualize_reconstruction(inputs[0].detach().cpu(), outputs[0].detach().cpu()))}, step=step)
                #save model weights
                torch.save(model.state_dict(), f"checkpoints/image_vae/image_vae_static_full_long/image_vae_static_full_high_res-{checkpoint_step:03d}.pt")
                checkpoint_step += 1
            
            step += 1
        """
        wandb.log({"Training Images": wandb.Image(visualize_reconstruction(inputs[0].detach().cpu(), outputs[0].detach().cpu()))}, step=step)
        torch.save(model.state_dict(), f"checkpoints/image_vae/image_vae_static_long/image_vae_static_high_res-{epoch:03d}.pt")
        """

wandb.init(project="ImageVAE")
model = VisualAutoencoder(encoding_size)
model.to(CFG.device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
SSIM_criterion = SSIMLoss().cuda()

train()

validate image vae

In [None]:
def visualize_reconstruction(original, reconstructed):
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(original.permute(1, 2, 0))
    axs[0].set_title('Original')
    axs[1].imshow(reconstructed.detach().permute(1, 2, 0))
    axs[1].set_title('Reconstructed')
    plt.show()
    plt.close()
    # return figure for wandb
    return fig

def reconstruct(model):
    for i, data in enumerate(val_loader):
        inputs = data['img_static']
        inputs = inputs.to(CFG.device)
        outputs = model(inputs)
        for i in range(0, 15):
            visualize_reconstruction(inputs[i].detach().cpu(), outputs[i].detach().cpu())
        break
    
encoding_size = 512
model = VisualAutoencoder(encoding_size).to(CFG.device)
model.load_state_dict(torch.load('checkpoints/image_vae/image_vae_static/image_vae_static_high_res-039.pt', map_location=CFG.device))
model = model.eval()
reconstruct(model)
