In [4]:
import os
import os.path
import argparse
import math
import torch
import torchvision.utils as vutils
from datetime import datetime
from einops import rearrange
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from balls import Balls
from dalle_video import DALLE
from utils import *

parser = argparse.ArgumentParser()

parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--image_size', type=int, default=64)

parser.add_argument('--checkpoint_path', default='')
# parser.add_argument('--pretrained_dvae_path', default='/home/t-pbansal/video-dalle/checkpoints/dvae_model.pt')
parser.add_argument('--pretrained_dvae_path', default='/ecstorage/video-dalle/checkpoints/best_model.pt')
parser.add_argument('--data_path', default='/home/t-pbansal/video-dalle/data/bouncing_3balls')

parser.add_argument('--num_dec_blocks', type=int, default=4)
parser.add_argument('--vocab_size', type=int, default=256)
parser.add_argument('--d_model', type=int, default=192)
parser.add_argument('--num_heads', type=int, default=4)
parser.add_argument('--dropout', type=float, default=0.1)

# parser.add_argument('--num_iterations', type=int, default=7)
parser.add_argument('--num_iterations', type=int, default=5)
parser.add_argument('--num_slots', type=int, default=4)
# parser.add_argument('--num_slots', type=int, default=8)
parser.add_argument('--cnn_hidden_size', type=int, default=32)
parser.add_argument('--slot_size', type=int, default=64)
# parser.add_argument('--slot_size', type=int, default=32)
parser.add_argument('--mlp_hidden_size', type=int, default=128)
parser.add_argument('--img_channels', type=int, default=3)
parser.add_argument('--pos_channels', type=int, default=4)
parser.add_argument('--sample_length', type=int, default=20)
parser.add_argument('--num_gt_steps', type=int, default=10)
parser.add_argument('--num_generation_steps', type=int, default=10)
parser.add_argument('--num_vis', type=int, default=8)

parser.add_argument('--tau_start', type=float, default=1.0)
parser.add_argument('--tau_final', type=float, default=0.1)
parser.add_argument('--tau_epochs', type=int, default=3)

parser.add_argument('--hard', action='store_true')
parser.add_argument('--sample_batch', action='store_true')

parser.add_argument('--memory_type', type=str, choices=['slots', 'conv', 'vector'], default='slots')

parser.add_argument('--fast_run', action='store_true', default=False)

args = parser.parse_args()

torch.manual_seed(args.seed)

val_dataset = Balls(root=args.data_path, mode='val')

loader_kwargs = {
    'batch_size': args.batch_size,
    'shuffle': False,
    'num_workers': args.num_workers,
    'pin_memory': True,
    'drop_last': True,
}

val_loader = DataLoader(val_dataset,  **loader_kwargs)

model = DALLE(args)

print(f'Loading from checkpoint {args.checkpoint_path}')
checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
start_epoch = checkpoint['epoch']
best_val_loss = checkpoint['best_val_loss']
best_epoch = checkpoint['best_epoch']
model.load_state_dict(checkpoint['model'])

model = model.cuda()

def visualize(video, recon, recon_transformer=None, N=8):
    B, T, C, H, W = video.shape
    video = video[:N].unsqueeze(1)
    recon = recon[:N].unsqueeze(1)
    recon_diff = video - recon
    vis_list = [video, recon, recon_diff]
    if recon_transformer is not None:
        recon_transformer = recon_transformer[:N].unsqueeze(1)
        recon_transformer_diff = video - recon_transformer
        vis_list.extend([recon_transformer, recon_transformer_diff])
    vis = torch.cat(vis_list, dim=1).view(-1, C, H, W)
    return vis

def visualize_attn(video, attn, N=2):
    # [[predicted frame, attended frame 1, slot1, slot2, etc...],
    #  [predicted frame, attended frame 2, slot1, slot2, etc...]
    #  ...]
    B, T, C, H, W = video.shape
    predicted = video[:N,:][:, :, None, ...]
    attn = attn[:N]
    #attn = attn * attended
    # attn = attn * predicted + 1 - attn
    # attn = 1-attn(1-predicted)
    vis = torch.cat([predicted, predicted, attn], dim=-4)
    vis = vis.view(-1, C, H, W)
    return vis

def visualize_gen(video, recon, N=8):
    _, _, H, W = image.shape
    image = image[:N].expand(-1, 3, H, W).unsqueeze(dim=1)
    recon = recon[:N].expand(-1, 3, H, W).unsqueeze(dim=1)

    return torch.cat((image, recon), dim=1).view(-1, 3, H, W)


for val_batch, full_video in enumerate(val_loader):
    full_video = full_video.cuda(non_blocking=True)
    video = full_video[:, :args.sample_length]
    recon_transformer = model.generate(video)
    gen_vis = visualize(video[:,-recon_transformer.shape[1]:], recon_transformer, N=args.num_vis)
    gen_vis_grid = vutils.make_grid(gen_vis, nrow=recon_transformer.shape[1], pad_value=1)
    plt.imshow(gen_vis_grid)
    break()

(10000, 20, 64, 64, 3)