In [None]:
%load_ext autoreload
%autoreload 2
import sys, os

In [None]:
import glob, os
import mediapy as media
import torch
from torch.utils.data import DataLoader

from load_model_from_ckpt import load_model, get_readout_sampler, init_samples
from datasets import get_dataset, data_transform, inverse_data_transform
from runners.ncsn_runner import conditioning_fn

from os.path import expanduser
home = expanduser("~")

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Load model checkpoint

In [None]:
model_path = '/ccn2/u/thekej/ucf10132_big192_288_4c4_unetm_spade/logs/'
ckpt_path = glob.glob(os.path.join(model_path, "checkpoint_*.pt"))[0]
    
# load model
scorenet, config = load_model(ckpt_path, device)
# get sampler
sampler = get_readout_sampler(config)

In [None]:
print(device)
print(ckpt_path)
print(config)

# Load data

In [None]:
from datasets.physion import PhysionDataset

In [None]:
def get_dataset(config):
    
    frames_per_sample = 48
    dataset = PhysionDataset('/ccn2/u/thekej/phys_readouts_mcvd_all/shard_0001.hdf5', 
                             frames_per_sample=frames_per_sample, 
                             image_size=config.data.image_size, train=False, random_time=True,
                             random_horizontal_flip=False,
                             complete=True,
                             simulation=False) #change this

    return dataset

In [None]:
def inverse_transform(config, X):
    X = X.to('cpu')
    if hasattr(config, 'image_mean'):
        X = X + config.image_mean.to(X.device)[None, ...]

    if config.data.logit_transform:
        X = torch.sigmoid(X)
    elif config.data.rescaled:
        X = (X + 1.) / 2.
    return torch.clamp(X, 0.0, 1.0)

In [None]:
dataset = get_dataset(config)

In [None]:
test_loader = DataLoader(dataset, batch_size=1, shuffle=False,#config.training.batch_size, shuffle=False,
                         num_workers=config.data.num_workers, drop_last=True)
test_iter = iter(test_loader)
test_x, test_y = next(test_iter)
print(test_x.shape)

In [None]:
test_x = data_transform(config, test_x)

real, cond, cond_mask = conditioning_fn(config, test_x, num_frames_pred=config.data.num_frames,
                                        prob_mask_cond=getattr(config.data, 'prob_mask_cond', 0.0),
                                        prob_mask_future=getattr(config.data, 'prob_mask_future', 0.0))
print(real.shape, cond.shape)

# Load initial samples

In [None]:
init = init_samples(len(real), config)

In [None]:
i = 0
print(init.shape, init[i].shape)
media.show_images(init[i])

# Predict

In [None]:
preds = []
for i in range(10):
    init = init_samples(len(real), config)
    pred, gamma, beta, mid = sampler(init, scorenet, cond=cond, cond_mask=cond_mask, subsample=100, verbose=False)
    #show_video(pred[0], config)
    media.show_images(inverse_transform(config, pred[0][::3]))
    cond = pred#inverse_transform(config, pred)

In [None]:
def show_video(frames, config):
    pred = inverse_transform(config, frames)
    media.show_images(pred)

In [None]:
print(preds[0].shape)
show_video(preds[0], config)

In [None]:
show_video(cond[30], config)

In [None]:
p = torch.stack(preds)
print(preds[0].shape)
p.shape

In [None]:
import imageio
import numpy as np


for i in range(4):
    # Create a list of image frames from the array
    frames = inverse_transform(config, p[i])
    media.show_images(frames.permute(0, 2, 3, 1))
    frames = frames.numpy().transpose(0, 2, 3, 1)

    # Save the frames as a GIF
    imageio.mimsave('%d.gif'%i, frames)
