In [1]:
import wandb
import numpy as np
import ml_collections
from tqdm.auto import tqdm
from functools import partial

import torch
from torch.utils.data import DataLoader

from datasets import (
    get_dataset, data_transform, inverse_data_transform
)

from models import eval_models
from models.ema import EMAHelper
from models.unet import UNet_SMLD, UNet_DDPM
from models.fvd.fvd import (
    get_fvd_feats, frechet_distance, load_i3d_pretrained
)
from models import (
    ddpm_sampler,
    ddim_sampler,
    FPNDM_sampler,
    anneal_Langevin_dynamics,
    anneal_Langevin_dynamics_consistent,
    anneal_Langevin_dynamics_inpainting,
    anneal_Langevin_dynamics_interpolation
)
from models.better.ncsnpp_more import UNetMore_DDPM

from losses import get_optimizer, warmup_lr
from losses.dsm import anneal_dsm_score_estimation

from load_model_from_ckpt import init_samples as initialize_samples
from runners.ncsn_runner import conditioning_fn

In [2]:
def get_data_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.channels = 1
    config.dataset = 'StochasticMovingMNIST'
    config.gaussian_dequantization = False
    config.image_size = 64
    config.logit_transform = False
    config.num_digits = 2
    config.num_frames = 5
    config.num_frames_cond = 5
    config.num_frames_future = 0
    config.num_workers = 0
    config.prob_mask_cond = 0.0
    config.prob_mask_future = 0.0
    config.prob_mask_sync = False
    config.random_flip = True
    config.rescaled = True
    config.step_length = 0.1
    config.uniform_dequantization = False

    return config


def get_fast_fid_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.batch_size = 1000
    config.begin_ckpt = 5000
    config.end_ckpt = 300000
    config.ensemble = False
    config.freq = 5000
    config.n_steps_each = 0
    config.num_samples = 1000
    config.pr_nn_k = 3
    config.step_lr = 0.0
    config.verbose = False

    return config


def get_model_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.arch = 'unetmore'
    config.attn_resolutions = [8, 16, 32]
    config.ch_mult = [1, 2, 3, 4]
    config.cond_emb = False
    config.conditional = True
    config.depth = 'deep'
    config.dropout = 0.1
    config.ema = True
    config.ema_rate = 0.999
    config.gamma = False
    config.n_head_channels = 64
    config.ngf = 64
    config.noise_in_cond = False
    config.nonlinearity = 'swish'
    config.normalization = 'InstanceNorm++'
    config.num_classes = 1000
    config.num_res_blocks = 2
    config.output_all_frames = False
    config.sigma_begin = 0.02
    config.sigma_dist = 'linear'
    config.sigma_end = 0.0001
    config.spade = False
    config.spade_dim = 128
    config.spec_norm = False
    config.time_conditional = True
    config.type = 'v1'
    config.scheduler = 'DDPM'

    return config


def get_optim_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.amsgrad = False
    config.beta1 = 0.9
    config.eps = 1e-08
    config.grad_clip = 1.0
    config.lr = 0.0002
    config.optimizer = 'Adam'
    config.warmup = 1000
    config.weight_decay = 0.0

    return config


def get_sampling_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.batch_size = 100
    config.ckpt_id = 0
    config.clip_before = True
    config.consistent = True
    config.data_init = False
    config.denoise = True
    config.fid = False
    config.final_only = True
    config.fvd = True
    config.init_prev_t = -1.0
    config.inpainting = False
    config.interpolation = False
    config.max_data_iter = 100000
    config.n_interpolations = 15
    config.n_steps_each = 0
    config.num_frames_pred = 20
    config.num_samples4fid = 10000
    config.num_samples4fvd = 10000
    config.one_frame_at_a_time = False
    config.preds_per_test = 1
    config.ssim = True
    config.step_lr = 0.0
    config.subsample = 1000
    config.train = False

    return config


def get_test_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.batch_size = 100
    config.begin_ckpt = 5000
    config.end_ckpt = 300000

    return config


def get_training_configs() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()

    config.L1 = False
    config.batch_size = 64
    config.checkpoint_freq = 100
    config.log_all_sigmas = False
    config.log_freq = 50
    config.n_epochs = 10
    config.n_iters = 3000001
    config.sample_freq = 50000
    config.snapshot_freq = 1000
    config.snapshot_sampling = True
    config.val_freq = 100
    config.checkpoint_dir = "smmnist_cat"

    return config


def get_config() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()
    
    config.data = get_data_configs()
    config.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    config.fast_fid = get_fast_fid_configs()
    config.model = get_model_configs()
    config.optim = get_optim_configs()
    config.sampling = get_sampling_configs()
    config.test = get_test_configs()
    config.training = get_training_configs()
    config.start_at = 0
    
    return config

In [3]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

In [4]:
config = get_config()
config_dict = config.to_dict()
config_dict.pop("device", None)

wandb.init(
    project="masked-conditional-video-diffusion",
    entity="wandb",
    job_type="test",
    config=config_dict
)

wandb_config = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mgeekyrakshit[0m ([33mwandb[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
dataset, test_dataset = get_dataset(
    'mnist_dataset/',
    config,
    video_frames_pred=config.data.num_frames,
    start_at=config.start_at
)

train_loader = DataLoader(
    dataset,
    batch_size=config.training.batch_size,
    shuffle=True,
    num_workers=config.data.num_workers
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config.training.batch_size,
    shuffle=True,
    num_workers=config.data.num_workers,
    drop_last=True
)

wandb_config.input_dim = config.input_dim = config.data.image_size ** 2 * config.data.channels

Dataset length: 60000
Dataset length: 256


In [6]:
scorenet = UNetMore_DDPM(config).to(config.device)
scorenet = torch.nn.DataParallel(scorenet)
optimizer = get_optimizer(config, scorenet.parameters())

wandb.log({
    "Parameters": count_parameters(scorenet),
    "Trainable Parameters": count_trainable_parameters(scorenet)
}, commit=False)

In [7]:
if torch.cuda.is_available():
    num_devices = torch.cuda.device_count()
    print(f"Number of GPUs : {num_devices}")
    for i in range(num_devices):
        print(torch.cuda.get_device_properties(i))
else:
    print(f"Running on CPU!")

Number of GPUs : 1
_CudaDeviceProperties(name='NVIDIA A100-SXM4-40GB', major=8, minor=0, total_memory=40354MB, multi_processor_count=108)


In [8]:
if config.model.ema:
    ema_helper = EMAHelper(mu=config.model.ema_rate)
    ema_helper.register(scorenet)

net = scorenet.module if hasattr(scorenet, 'module') else scorenet

In [9]:
conditional = config.data.num_frames_cond > 0
cond, test_cond = None, None
future = getattr(config.data, "num_frames_future", 0)
n_init_samples = min(36, config.training.batch_size)
init_samples_shape = (
    n_init_samples,
    config.data.channels * config.data.num_frames,
    config.data.image_size,
    config.data.image_size
)

In [10]:
if config.model.scheduler == "SMLD":
    init_samples = data_transform(
        config,
        torch.rand(init_samples_shape, device=config.device)
    )
elif config.model.scheduler in ["DDPM", "DDIM", "FPNDM"]:
    if getattr(config.model, 'gamma', False):
        used_k, used_theta = net.k_cum[0], net.theta_t[0]
        z = torch.distributions.gamma(
            torch.full(init_samples_shape, used_k),
            torch.full(init_samples_shape, 1 / used_theta)
        ).sample().to(config.device)
        init_samples = z - used_k * used_theta
    else:
        init_samples = torch.randn(init_samples_shape, device=config.device)

In [11]:
if config.model.scheduler == "SMLD":
    consistent = getattr(config.sampling, 'consistent', False)
    sampler = anneal_Langevin_dynamics_consistent if consistent else anneal_Langevin_dynamics
elif config.model.scheduler == "DDPM":
    sampler = partial(ddpm_sampler, config=config)
elif config.model.scheduler == "DDIM":
    sampler = partial(ddim_sampler, config=config)
elif config.model.scheduler == "FPNDM":
    sampler = partial(FPNDM_sampler, config=config)

In [12]:
def train_step(x, y, step):
    optimizer.zero_grad()
    lr = warmup_lr(
        optimizer, step,
        getattr(config.optim, 'warmup', 0),
        config.optim.lr
    )
    scorenet.train()
    
    x = x.to(config.device)
    x = data_transform(config, x)
    x, cond, cond_mask = conditioning_fn(
        config, x, num_frames_pred=config.data.num_frames,
        prob_mask_cond=getattr(config.data, 'prob_mask_cond', 0.0),
        prob_mask_future=getattr(config.data, 'prob_mask_future', 0.0),
        conditional=conditional
    )
    
    loss = anneal_dsm_score_estimation(
        scorenet, x, labels=None, cond=cond, cond_mask=cond_mask,
        loss_type=getattr(config.training, 'loss_type', 'a'),
        gamma=getattr(config.model, 'gamma', False),
        L1=getattr(config.training, 'L1', False), hook=None,
        all_frames=getattr(config.model, 'output_all_frames', False)
    )
    loss.backward()
    
    grad_norm = torch.nn.utils.clip_grad_norm_(
        scorenet.parameters(), getattr(config.optim, 'grad_clip', np.inf)
    )
    optimizer.step()
    
    if config.model.ema:
        ema_helper.update(scorenet)
    
    return loss.item(), grad_norm.item(), lr

In [13]:
def validation_step(epoch):
    test_scorenet = ema_helper.ema_copy(scorenet) if config.model.ema else scorenet
    test_scorenet.eval()
    x, y = next(iter(test_loader))
    x = x.to(config.device)
    x = data_transform(config, x)
    x, test_cond, test_cond_mask = conditioning_fn(
        config, x, num_frames_pred=config.data.num_frames,
        prob_mask_cond=getattr(config.data, 'prob_mask_cond', 0.0),
        prob_mask_future=getattr(config.data, 'prob_mask_future', 0.0),
        conditional=conditional
    )
    with torch.no_grad():
        test_dsm_loss = anneal_dsm_score_estimation(
            test_scorenet, x, labels=None,
            cond=test_cond, cond_mask=test_cond_mask,
            loss_type=getattr(config.training, 'loss_type', 'a'),
            gamma=getattr(config.model, 'gamma', False),
            L1=getattr(config.training, 'L1', False), hook=None,
            all_frames=getattr(config.model, 'output_all_frames', False)
        )
        if wandb.run is not None:
            wandb.log({
                "validation/epoch": epoch,
                "validation/loss": test_dsm_loss.item(),
            }, step=epoch)

In [14]:
def save_model(scorenet, epoch):
    states = [scorenet.state_dict(), optimizer.state_dict(), epoch, step]
    if self.config.model.ema:
        states.append(ema_helper.state_dict())
    checkpoint_path = os.path.join(config.training.checkpoint_dir, 'checkpoint.pt')
    torch.save(states, checkpoint_path)
    if wandb.run is not None:
        artifact = wandb.Artifact(
            f'checkpoint-{wandb.run.name}-{wandb.run.id}', type='model'
        )
        artifact.add_file(checkpoint_path)
        wandb.log_artifact(artifact, aliases=["latest", f"epoch-{epoch}"])

In [None]:
step = 0

for epoch in range(1, config.training.n_epochs + 1):
    train_pbar = tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc=f"Training Epoch {epoch}"
    )
    for batch, (x, y) in train_pbar:
        loss, grad_norm, lr = train_step(x, y, step)
        if wandb.run is not None:
            wandb.log({
                "train/step": step,
                "lr": lr,
                "grad_norm": grad_norm,
                "train/loss": loss,
            }, step=step)
            step += 1
    save_model(scorenet, epoch)

Training Epoch 1:   0%|          | 0/938 [00:00<?, ?it/s]

In [None]:
wandb.finish()