In [1]:
import os
from os.path import expanduser

import glob
import yaml
import wandb
import numpy as np
from tqdm.auto import tqdm
from imageio import mimwrite
from functools import partial
import matplotlib.pyplot as plt
from collections import OrderedDict
from skimage.metrics import structural_similarity

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid, save_image
try:
    from torchvision.transforms.functional import resize, InterpolationMode
    interp = InterpolationMode.NEAREST
except:
    from torchvision.transforms.functional import resize
    interp = 0

from main import dict2namespace
from models.ema import EMAHelper
import models.eval_models as eval_models
from models import (
    get_sigmas,
    anneal_Langevin_dynamics,
    anneal_Langevin_dynamics_consistent,
    ddpm_sampler,
    ddim_sampler,
    FPNDM_sampler
)
from runners.ncsn_runner import get_model, conditioning_fn
from load_model_from_ckpt import load_model, get_sampler, init_samples
from datasets import get_dataset, data_transform, inverse_data_transform

In [2]:
home = expanduser("~")
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
def scale(arr):
    m, M = arr.min(), arr.max()
    return (arr - m) / (M - m)

# Set directories to download model, data

In [4]:
# SET THESE!!!
GDRIVE_URL = "https://drive.google.com/drive/folders/1bM6wqU_kymoljz5uYQRCYNup_8adBfLH" # smmnist_big_5c5_unetm_b2
EXP_PATH = os.path.join(home, "scratch/MCVD_SMMNIST_pred")
DATA_PATH = os.path.join(home, "scratch/Datasets/MNIST")

# Download experiment (model checkpoint, config, etc.)

In [5]:
# # GDRIVE_URL = GDRIVE_URL.removesuffix("?usp=sharing")
# !gdown --fuzzy {GDRIVE_URL} -O {EXP_PATH}/ --folder

# Load model checkpoint

In [6]:
wandb.init(project="masked-conditional-video-diffusion", entity="wandb", job_type="inference")

artifact = wandb.use_artifact(
    'wandb/masked-conditional-video-diffusion/checkpoint-revived-sun-29-1f792ve5:v328', type='model'
)
artifact_dir = artifact.download()

[34m[1mwandb[0m: Currently logged in as: [33mgeekyrakshit[0m ([33mwandb[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact checkpoint-revived-sun-29-1f792ve5:v328, 426.94MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.1


In [7]:
def load_model(ckpt_path, config_path, device=device):
    # Parse config file
    with open(config_path, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    # Load config file
    config = dict2namespace(config)
    config.device = device
    # Load model
    scorenet = get_model(config)
    if config.device != torch.device('cpu'):
        scorenet = torch.nn.DataParallel(scorenet)
        states = torch.load(ckpt_path, map_location=config.device)
    else:
        states = torch.load(ckpt_path, map_location='cpu')
        states[0] = OrderedDict([(k.replace('module.', ''), v) for k, v in states[0].items()])
    scorenet.load_state_dict(states[0], strict=False)
    if config.model.ema:
        ema_helper = EMAHelper(mu=config.model.ema_rate)
        ema_helper.register(scorenet)
        ema_helper.load_state_dict(states[-1])
        ema_helper.ema(scorenet)
    scorenet.eval()
    return scorenet, config

In [8]:
# ckpt_path = glob.glob(os.path.join(EXP_PATH, "checkpoint_*.pt"))[0]
ckpt_path = os.path.join(artifact_dir, "checkpoint.pt")
config_path = "./smmnist_cat/logs/config.yml"
scorenet, config = load_model(ckpt_path, config_path, device)
sampler = get_sampler(config)

# Load data

In [9]:
dataset, test_dataset = get_dataset(
    DATA_PATH, config, video_frames_pred=config.data.num_frames
)

Dataset length: 60000
Dataset length: 256


In [10]:
test_loader = DataLoader(
    test_dataset,
    batch_size=config.training.batch_size,
    shuffle=False,
    num_workers=config.data.num_workers,
    drop_last=True
)
test_iter = iter(test_loader)
test_x, test_y = next(test_iter)

print(test_x.shape)

torch.Size([64, 10, 1, 64, 64])


In [11]:
test_x = data_transform(config, test_x)
real, cond, cond_mask = conditioning_fn(
    config, test_x,
    num_frames_pred=config.data.num_frames,
    prob_mask_cond=getattr(config.data, 'prob_mask_cond', 0.0),
    prob_mask_future=getattr(config.data, 'prob_mask_future', 0.0)
)

print(real.shape, cond.shape)

torch.Size([64, 5, 64, 64]) torch.Size([64, 5, 64, 64])


# Load initial samples

In [13]:
init = init_samples(len(real), config)

# Predict

In [14]:
%%time
pred = sampler(
    init, scorenet, cond=cond, cond_mask=cond_mask, subsample=100, verbose=True
)

DDPM: 1/100, grad_norm: 143.1238555908203, image_norm: 129.24099731445312, grad_mean_norm: 315.03326416015625
DDPM: 10/100, grad_norm: 143.14053344726562, image_norm: 130.4148406982422, grad_mean_norm: 323.0744934082031
DDPM: 20/100, grad_norm: 143.20840454101562, image_norm: 131.68173217773438, grad_mean_norm: 319.9212951660156
DDPM: 30/100, grad_norm: 143.70172119140625, image_norm: 133.24813842773438, grad_mean_norm: 323.2720031738281
DDPM: 40/100, grad_norm: 144.8486328125, image_norm: 134.3368377685547, grad_mean_norm: 321.5518493652344
DDPM: 50/100, grad_norm: 148.37176513671875, image_norm: 135.39781188964844, grad_mean_norm: 316.765625
DDPM: 60/100, grad_norm: 157.77597045898438, image_norm: 136.05177307128906, grad_mean_norm: 322.7695617675781
DDPM: 70/100, grad_norm: 180.20025634765625, image_norm: 136.34881591796875, grad_mean_norm: 319.5494689941406
DDPM: 80/100, grad_norm: 235.12551879882812, image_norm: 136.62008666992188, grad_mean_norm: 321.6923522949219
DDPM: 90/100, g

In [15]:
table = wandb.Table(
    columns=[
        "Initial-Frames",
        "Predicted-Frames",
        "Real-Frames",
        "LPIPS",
        "Structural-Similarity",
        "Peak-Signal-To-Noise-Ratio"
    ]
)
model_transforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5))
])
model_lpips = eval_models.PerceptualLoss(
    model='net-lin', net='alex', device=device
)

for idx in tqdm(range(len(real))):
    init_images = [
        wandb.Image(frame)
        for frame in np.expand_dims(init.numpy()[idx], -1)
    ]
    predicted_images = [
        wandb.Image(frame)
        for frame in np.expand_dims(pred.numpy()[idx], -1)
    ]
    real_images = [
        wandb.Image(frame)
        for frame in np.expand_dims(real.numpy()[idx], -1)
    ]
    
    psnr_value = 20 * torch.log10(
        1.0 / torch.sqrt(F.mse_loss(
            scale(torch.from_numpy(real.numpy()[idx])),
            scale(torch.from_numpy(pred.numpy()[idx]))
        ))
    )
    
    lpis_value, ssim_value = 0, 0
    for j in range(config.data.num_frames):
        pred_frame = transforms.ToPILImage()(
            torch.from_numpy(np.expand_dims(pred.numpy()[idx][j], 0))
        ).convert("RGB")
        real_frame = transforms.ToPILImage()(
            torch.from_numpy(np.expand_dims(real.numpy()[idx][j], 0))
        ).convert("RGB")
        pred_lpips = model_transforms(pred_frame).unsqueeze(0).to(device)
        real_lpips = model_transforms(real_frame).unsqueeze(0).to(device)
        lpis_value += model_lpips.forward(real_lpips, pred_lpips)
        
        pred_frame_gray = np.asarray(pred_frame.convert('L'))
        real_frame_gray = np.asarray(real_frame.convert('L'))
        if config.data.dataset.upper() in ["STOCHASTICMOVINGMNIST", "MOVINGMNIST"]:
            pred_frame_gray = np.asarray(
                transforms.ToPILImage()(
                    torch.round(
                        torch.from_numpy(np.expand_dims(pred.numpy()[idx][j], 0))
                    )).convert("RGB").convert('L')
            )
            real_frame_gray = np.asarray(
                transforms.ToPILImage()(torch.round(
                    torch.from_numpy(np.expand_dims(real.numpy()[idx][j], 0))
                )).convert("RGB").convert('L')
            )
        
        ssim_value += structural_similarity(
            pred_frame_gray,
            real_frame_gray,
            data_range=255,
            gaussian_weights=True,
            use_sample_covariance=False
        )
    
    table.add_data(
        init_images, predicted_images, real_images,
        lpis_value.item() / float(config.data.num_frames),
        ssim_value.item() / float(config.data.num_frames),
        psnr_value.item()
    )



initial_videos = [
    wandb.Video(video)
    for video in np.expand_dims(
        (scale(init.numpy()) * 255.0).astype("uint8"), 2
    )
]
    
predicted_videos = [
    wandb.Video(video)
    for video in np.expand_dims(
        (scale(pred.numpy()) * 255.0).astype("uint8"), 2
    )
]

real_videos = [
    wandb.Video(video)
    for video in np.expand_dims(
        (scale(real.numpy()) * 255.0).astype("uint8"), 2
    )
]

wandb.log({
    # "Real-Videos": real_videos,
    # "Initial-Videos": initial_videos,
    # "Predicted-Videos": predicted_videos,
    "Predictions": table
})

Setting up Perceptual loss...


  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "


Loading model from: /home/jupyter/mcvd-pytorch/models/weights/v0.1/alex.pth
...[net-lin [alex]] initialized
...Done


  0%|          | 0/64 [00:00<?, ?it/s]

In [16]:
wandb.finish()