In [1]:
import os
from os.path import expanduser

import glob
import yaml
import wandb
import numpy as np
import ml_collections
from pathlib import Path
from tqdm.auto import tqdm
from imageio import mimwrite
from functools import partial
import matplotlib.pyplot as plt
from collections import OrderedDict
from skimage.metrics import structural_similarity

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms as T
from torchvision.utils import make_grid, save_image
try:
    from torchvision.transforms.functional import resize, InterpolationMode
    interp = InterpolationMode.NEAREST
except:
    from torchvision.transforms.functional import resize
    interp = 0

from main import dict2namespace
from models.ema import EMAHelper
import models.eval_models as eval_models
from models import (
    get_sigmas,
    anneal_Langevin_dynamics,
    anneal_Langevin_dynamics_consistent,
    ddpm_sampler,
    ddim_sampler,
    FPNDM_sampler
)
from runners.ncsn_runner import get_model, conditioning_fn
from load_model_from_ckpt import load_model, get_sampler, init_samples
from datasets import get_dataset, data_transform, inverse_data_transform

In [2]:
def get_data_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.channels = 1
    config.dataset = 'StochasticMovingMNIST'
    config.gaussian_dequantization = False
    config.image_size = 64
    config.logit_transform = False
    config.num_digits = 2
    config.num_frames = 5
    config.num_frames_cond = 5
    config.num_frames_future = 0
    config.num_workers = 0
    config.prob_mask_cond = 0.0
    config.prob_mask_future = 0.0
    config.prob_mask_sync = False
    config.random_flip = True
    config.rescaled = True
    config.step_length = 0.1
    config.uniform_dequantization = False

    return config


def get_fast_fid_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.batch_size = 1000
    config.begin_ckpt = 5000
    config.end_ckpt = 300000
    config.ensemble = False
    config.freq = 5000
    config.n_steps_each = 0
    config.num_samples = 1000
    config.pr_nn_k = 3
    config.step_lr = 0.0
    config.verbose = False

    return config


def get_model_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.arch = 'unetmore'
    config.attn_resolutions = [8, 16, 32]
    config.ch_mult = [1, 2, 3, 4]
    config.cond_emb = False
    config.conditional = True
    config.depth = 'deep'
    config.dropout = 0.1
    config.ema = True
    config.ema_rate = 0.999
    config.gamma = False
    config.n_head_channels = 64
    config.ngf = 64
    config.noise_in_cond = False
    config.nonlinearity = 'swish'
    config.normalization = 'InstanceNorm++'
    config.num_classes = 1000
    config.num_res_blocks = 2
    config.output_all_frames = False
    config.sigma_begin = 0.02
    config.sigma_dist = 'linear'
    config.sigma_end = 0.0001
    config.spade = False
    config.spade_dim = 128
    config.spec_norm = False
    config.time_conditional = True
    config.type = 'v1'
    config.scheduler = 'DDPM'

    return config


def get_optim_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.amsgrad = False
    config.beta1 = 0.9
    config.eps = 1e-08
    config.grad_clip = 1.0
    config.lr = 0.0002
    config.optimizer = 'Adam'
    config.warmup = 1000
    config.weight_decay = 0.0

    return config


def get_sampling_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.batch_size = 100
    config.ckpt_id = 0
    config.clip_before = True
    config.consistent = True
    config.data_init = False
    config.denoise = True
    config.fid = False
    config.final_only = True
    config.fvd = True
    config.init_prev_t = -1.0
    config.inpainting = False
    config.interpolation = False
    config.max_data_iter = 100000
    config.n_interpolations = 15
    config.n_steps_each = 0
    config.num_frames_pred = 20
    config.num_samples4fid = 10000
    config.num_samples4fvd = 10000
    config.one_frame_at_a_time = False
    config.preds_per_test = 1
    config.ssim = True
    config.step_lr = 0.0
    config.subsample = 1000
    config.train = False

    return config


def get_test_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.batch_size = 100
    config.begin_ckpt = 5000
    config.end_ckpt = 300000

    return config


def get_training_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.L1 = False
    config.batch_size = 64
    config.checkpoint_freq = 100
    config.log_all_sigmas = False
    config.log_freq = 50
    config.n_epochs = 500
    config.n_iters = 3000001
    config.sample_freq = 50000
    config.snapshot_freq = 1000
    config.snapshot_sampling = True
    config.val_freq = 100
    config.checkpoint_dir = "smmnist_cat"
    config.checkpoint_freq = 50

    return config


def get_config() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()
    
    config.data = get_data_configs()
    config.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    config.fast_fid = get_fast_fid_configs()
    config.model = get_model_configs()
    config.optim = get_optim_configs()
    config.sampling = get_sampling_configs()
    config.test = get_test_configs()
    config.training = get_training_configs()
    config.start_at = 0
    
    return config

In [3]:
def scale(arr):
    m, M = arr.min(), arr.max()
    return (arr - m) / (M - m)


def ls(path): 
    return sorted(list(path.iterdir()))


def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

In [4]:
config = get_config()
config_dict = config.to_dict()
config_dict.pop("device", None)

wandb.init(
    project="masked-conditional-video-diffusion",
    entity="wandb", job_type="inference", config=config_dict
)

artifact = wandb.use_artifact(
    'wandb/masked-conditional-video-diffusion/checkpoint-glistening-snake-87-yfq7vyx1:v9', type='model'
)
model_artifact_dir = artifact.download()

[34m[1mwandb[0m: Currently logged in as: [33mgeekyrakshit[0m ([33mwandb[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact checkpoint-glistening-snake-87-yfq7vyx1:v9, 426.95MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.1


In [5]:
def load_model(ckpt_path):
    scorenet = get_model(config)
    if config.device != torch.device('cpu'):
        scorenet = torch.nn.DataParallel(scorenet)
        states = torch.load(ckpt_path, map_location=config.device)
    else:
        states = torch.load(ckpt_path, map_location='cpu')
        states[0] = OrderedDict([(k.replace('module.', ''), v) for k, v in states[0].items()])
    scorenet.load_state_dict(states[0], strict=False)
    if config.model.ema:
        ema_helper = EMAHelper(mu=config.model.ema_rate)
        ema_helper.register(scorenet)
        ema_helper.load_state_dict(states[-1])
        ema_helper.ema(scorenet)
    scorenet.eval()
    return scorenet

In [6]:
ckpt_path = os.path.join(model_artifact_dir, "checkpoint.pt")
scorenet = load_model(ckpt_path)

In [7]:
class CloudDataset:
    
    def __init__(self, files, num_frames=4, scale=True, size=64):
        self.num_frames = num_frames
        self.size = size
        self.tfms = T.Compose([
            T.Resize((size, int(size * 1.7))),
            T.CenterCrop(size)
        ])
        data = []
        for file in tqdm(files):
            one_day = np.load(file)
            one_day = 0.5 - self._scale(one_day) if scale else one_day
            wds = np.lib.stride_tricks.sliding_window_view(
                one_day.squeeze(), 
                num_frames, 
                axis=0
            ).transpose((0, 3, 1, 2))
            data.append(wds)
        self.data = np.concatenate(data, axis=0)
            
    @staticmethod
    def _scale(arr):
        "Scales values of array in [0,1]"
        m, M = arr.min(), arr.max()
        return (arr - m) / (M - m)
    
    def __getitem__(self, idx):
        data = self.tfms(torch.from_numpy(self.data[idx]))
        data = torch.unsqueeze(data, dim=-3)
        return data, data
    
    def __len__(self):
        return len(self.data)

    def save(self, fname="cloud_frames.npy"):
        np.save(fname, self.data)

In [8]:
artifact = wandb.use_artifact('capecape/gtc/np_dataset:v0', type='dataset')
data_artifact_dir = artifact.download()

dataset = CloudDataset(
    ls(Path(data_artifact_dir)),
    num_frames=config.data.num_frames + config.data.num_frames_cond,
    size=config.data.image_size
)

data_loader = DataLoader(
    dataset,
    batch_size=config.training.batch_size,
    shuffle=False,
    num_workers=config.data.num_workers,
    drop_last=True
)
test_x, test_y = next(iter(data_loader))

print(test_x.shape)

[34m[1mwandb[0m: Downloading large artifact np_dataset:v0, 3816.62MB. 30 files... 
[34m[1mwandb[0m:   30 of 30 files downloaded.  
Done. 0:0:0.1


  0%|          | 0/30 [00:00<?, ?it/s]

torch.Size([64, 10, 1, 64, 64])


In [9]:
test_x = data_transform(config, test_x)
real, cond, cond_mask = conditioning_fn(
    config, test_x,
    num_frames_pred=config.data.num_frames,
    prob_mask_cond=getattr(config.data, 'prob_mask_cond', 0.0),
    prob_mask_future=getattr(config.data, 'prob_mask_future', 0.0)
)

print(real.shape, cond.shape)

torch.Size([64, 5, 64, 64]) torch.Size([64, 5, 64, 64])


In [10]:
init = init_samples(len(real), config)
sampler = get_sampler(config)

In [11]:
%%time
pred = sampler(
    init, scorenet, cond=cond, cond_mask=cond_mask, subsample=100, verbose=True
)

DDPM: 1/100, grad_norm: 143.343017578125, image_norm: 129.55447387695312, grad_mean_norm: 327.7379455566406
DDPM: 10/100, grad_norm: 143.14578247070312, image_norm: 130.61659240722656, grad_mean_norm: 321.5141296386719
DDPM: 20/100, grad_norm: 143.03793334960938, image_norm: 132.10137939453125, grad_mean_norm: 317.95855712890625
DDPM: 30/100, grad_norm: 142.97877502441406, image_norm: 133.49734497070312, grad_mean_norm: 320.0403747558594
DDPM: 40/100, grad_norm: 143.80072021484375, image_norm: 134.91636657714844, grad_mean_norm: 318.9917297363281
DDPM: 50/100, grad_norm: 146.82933044433594, image_norm: 135.5860137939453, grad_mean_norm: 324.7160949707031
DDPM: 60/100, grad_norm: 155.2665557861328, image_norm: 133.97149658203125, grad_mean_norm: 330.50665283203125
DDPM: 70/100, grad_norm: 176.037109375, image_norm: 128.89035034179688, grad_mean_norm: 347.5442199707031
DDPM: 80/100, grad_norm: 227.30291748046875, image_norm: 120.12443542480469, grad_mean_norm: 364.6986389160156
DDPM: 90/

In [12]:
table = wandb.Table(
    columns=[
        "Initial-Frames",
        "Predicted-Frames",
        "Real-Frames",
        "LPIPS",
        "Structural-Similarity",
        "Peak-Signal-To-Noise-Ratio"
    ]
)
model_transforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5))
])
model_lpips = eval_models.PerceptualLoss(
    model='net-lin', net='alex', device=config.device
)

for idx in tqdm(range(len(real))):
    init_images = [
        wandb.Image(frame)
        for frame in np.expand_dims(init.numpy()[idx], -1)
    ]
    predicted_images = [
        wandb.Image((scale(frame) * 255).astype(np.uint8))
        for frame in np.expand_dims(pred.numpy()[idx], -1)
    ]
    real_images = [
        wandb.Image(frame)
        for frame in np.expand_dims(real.numpy()[idx], -1)
    ]
    
    psnr_value = 20 * torch.log10(
        1.0 / torch.sqrt(F.mse_loss(
            scale(torch.from_numpy(real.numpy()[idx])),
            scale(torch.from_numpy(pred.numpy()[idx]))
        ))
    )
    
    lpis_value, ssim_value = 0, 0
    for j in range(config.data.num_frames):
        pred_frame = transforms.ToPILImage()(
            torch.from_numpy(np.expand_dims(pred.numpy()[idx][j], 0))
        ).convert("RGB")
        real_frame = transforms.ToPILImage()(
            torch.from_numpy(np.expand_dims(real.numpy()[idx][j], 0))
        ).convert("RGB")
        pred_lpips = model_transforms(pred_frame).unsqueeze(0).to(config.device)
        real_lpips = model_transforms(real_frame).unsqueeze(0).to(config.device)
        lpis_value += model_lpips.forward(real_lpips, pred_lpips)
        
        pred_frame_gray = np.asarray(pred_frame.convert('L'))
        real_frame_gray = np.asarray(real_frame.convert('L'))
        if config.data.dataset.upper() in ["STOCHASTICMOVINGMNIST", "MOVINGMNIST"]:
            pred_frame_gray = np.asarray(
                transforms.ToPILImage()(
                    torch.round(
                        torch.from_numpy(np.expand_dims(pred.numpy()[idx][j], 0))
                    )).convert("RGB").convert('L')
            )
            real_frame_gray = np.asarray(
                transforms.ToPILImage()(torch.round(
                    torch.from_numpy(np.expand_dims(real.numpy()[idx][j], 0))
                )).convert("RGB").convert('L')
            )
        
        ssim_value += structural_similarity(
            pred_frame_gray,
            real_frame_gray,
            data_range=255,
            gaussian_weights=True,
            use_sample_covariance=False
        )
    
    table.add_data(
        init_images, predicted_images, real_images,
        lpis_value.item() / float(config.data.num_frames),
        ssim_value.item() / float(config.data.num_frames),
        psnr_value.item()
    )


wandb.log({"Predictions": table})

Setting up Perceptual loss...


  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Loading model from: /home/jupyter/mcvd-pytorch/models/weights/v0.1/alex.pth
...[net-lin [alex]] initialized
...Done


  0%|          | 0/64 [00:00<?, ?it/s]

In [13]:
wandb.finish()