### required packages

In [1]:
import math
from pathlib import Path
from functools import partial, wraps
from collections import namedtuple
from multiprocessing import cpu_count

import numpy as np
import pandas as pd
from sklearn.preprocessing import QuantileTransformer

import torch
from torch import nn, einsum
from torch.cuda.amp import autocast
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch.optim import Adam

from torch_geometric.datasets import MovieLens1M
from torch_geometric.data import HeteroData

from einops import rearrange, reduce, repeat

from PIL import Image
from tqdm.auto import tqdm
from ema_pytorch import EMA

from accelerate import Accelerator

  from .autonotebook import tqdm as notebook_tqdm


### helper functions

In [2]:
def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

print_once = once(print)

def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

def divisible_by(numer, denom):
    return (numer % denom) == 0

def identity(t, *args, **kwargs):
    return t

def cycle(dl):
    while True:
        for data in dl:
            yield data

def has_int_squareroot(num):
    return (math.sqrt(num) ** 2) == num

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# ATTENTION

### constants

In [3]:
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

### main class

In [4]:
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False,
        scale = None
    ):
        super().__init__()
        self.dropout = dropout
        self.scale = scale
        self.attn_dropout = nn.Dropout(dropout)

        self.flash = flash

        # determine efficient attention configs for cuda and cpu

        self.cpu_config = AttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(False, True, True)

    def flash_attn(self, q, k, v):
        _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        if exists(self.scale):
            default_scale = q.shape[-1]
            q = q * (self.scale / default_scale)

        q, k, v = map(lambda t: t.contiguous(), (q, k, v))

        # Check if there is a compatible device for flash attention

        config = self.cuda_config if is_cuda else self.cpu_config

        # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p = self.dropout if self.training else 0.
            )

        return out

    def forward(self, q, k, v):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        if self.flash:
            return self.flash_attn(q, k, v)

        scale = default(self.scale, q.shape[-1] ** -0.5)

        # similarity

        sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

        # attention

        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        # aggregate values

        out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

        return out

### row attention (per user or movie adjacency list)

In [5]:
class RowAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32,
        num_mem_kv = 128,
        flash = False
    ):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads

        self.attend = Attend(flash = flash)

        self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
        self.to_qkv = nn.Conv2d(
            in_channels = dim, 
            out_channels = hidden_dim * 3, 
            kernel_size = 1, 
            bias = False)
        self.to_out = nn.Conv2d(
            in_channels = hidden_dim, 
            out_channels = dim, 
            kernel_size = 1)

    def forward(self, x):
        b, c, h, w = x.shape
        
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b x) h y d', h = self.heads), qkv)

        mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b=b*h), self.mem_kv)
        k, v = map(partial(torch.cat, dim=-2), ((mk, k), (mv, v)))

        out = self.attend(q, k, v)

        out = rearrange(out, '(b x) h y d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

# GUASSIAN DIFFUSION PIPELINE

### helpers and beta schedules

In [6]:
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def linear_beta_schedule(timesteps):
    """
    linear schedule, proposed in original ddpm paper
    """
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
    alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
    """
    sigmoid schedule
    proposed in https://arxiv.org/abs/2212.11972 - Figure 8
    better for images > 64x64, when used during training
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

### main gaussian diffusion class

In [7]:
class GaussianDiffusion(nn.Module):
    def __init__(
            self,
            model,
            *,
            image_size,
            timesteps=1000,
            sampling_timesteps=None,
            objective='pred_noise',
            beta_schedule='cosine',
            schedule_fn_kwargs=dict(),
            ddim_sampling_eta=0.,
            auto_normalize=True,
            offset_noise_strength=0.,  # https://www.crosslabs.org/blog/diffusion-with-offset-noise
            min_snr_loss_weight=False,  # https://arxiv.org/abs/2303.09556
            min_snr_gamma=5
    ):
        super().__init__()
        assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond

        self.model = model
        self.self_condition = None

        if isinstance(image_size, int):
            image_size = (image_size, image_size)
        assert isinstance(image_size, (tuple, list)) and len(
            image_size) == 2, 'image size must be a integer or a tuple/list of two integers'
        self.image_size = image_size

        self.objective = objective

        assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, \
            'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'

        if beta_schedule == 'linear':
            beta_schedule_fn = linear_beta_schedule
        elif beta_schedule == 'cosine':
            beta_schedule_fn = cosine_beta_schedule
        elif beta_schedule == 'sigmoid':
            beta_schedule_fn = sigmoid_beta_schedule
        else:
            raise ValueError(f'unknown beta schedule {beta_schedule}')

        betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)

        # sampling related parameters

        self.sampling_timesteps = default(sampling_timesteps,
                                          timesteps)  # default num sampling timesteps to number of timesteps at training

        assert self.sampling_timesteps <= timesteps
        self.is_ddim_sampling = self.sampling_timesteps < timesteps
        self.ddim_sampling_eta = ddim_sampling_eta

        # helper function to register buffer from float64 to float32

        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))

        register_buffer('betas', betas)
        register_buffer('alphas_cumprod', alphas_cumprod)
        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others

        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)

        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)

        register_buffer('posterior_variance', posterior_variance)

        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain

        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))
        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

        # offset noise strength - in blogpost, they claimed 0.1 was ideal

        self.offset_noise_strength = offset_noise_strength

        # derive loss weight
        # snr - signal noise ratio

        snr = alphas_cumprod / (1 - alphas_cumprod)

        # https://arxiv.org/abs/2303.09556

        maybe_clipped_snr = snr.clone()
        if min_snr_loss_weight:
            maybe_clipped_snr.clamp_(max=min_snr_gamma)

        if objective == 'pred_noise':
            register_buffer('loss_weight', maybe_clipped_snr / snr)
        elif objective == 'pred_x0':
            register_buffer('loss_weight', maybe_clipped_snr)
        elif objective == 'pred_v':
            register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))

        # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False

        self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
        self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity

    @property
    def device(self):
        return self.betas.device

    def predict_start_from_noise(self, x_t, t, noise):
        if isinstance(noise, torch.Tensor):
            noise[:, 1:, :, :] = 0
        return (
                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def predict_noise_from_start(self, x_t, t, x0):
        return (
                (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
        )

    def predict_v(self, x_start, t, noise):
        return (
                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
                extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
        )

    def predict_start_from_v(self, x_t, t, v):
        return (
                extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
                extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
        )

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
                extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
                extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def model_predictions(self, x, t, clip_x_start=False, rederive_pred_noise=False):
        model_output = self.model(x, t)
        zeros = torch.zeros_like(x, device=model_output.device)
        zeros[:, 0, :, :] = model_output
        model_output = zeros
        maybe_clip = partial(torch.clamp, min=-1., max=1.) if clip_x_start else identity

        if self.objective == 'pred_noise':
            pred_noise = model_output
            x_start = self.predict_start_from_noise(x, t, pred_noise)
            x_start = maybe_clip(x_start)

            if clip_x_start and rederive_pred_noise:
                pred_noise = self.predict_noise_from_start(x, t, x_start)

        elif self.objective == 'pred_x0':
            x_start = model_output
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x, t, x_start)

        elif self.objective == 'pred_v':
            v = model_output
            x_start = self.predict_start_from_v(x, t, v)
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x, t, x_start)

        return ModelPrediction(pred_noise, x_start)

    def p_mean_variance(self, x, t, x_self_cond=None, clip_denoised=True):
        preds = self.model_predictions(x, t, x_self_cond)
        x_start = preds.pred_x_start

        if clip_denoised:
            x_start.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance, x_start

    @torch.inference_mode()
    def p_sample(self, x, t: int, self_cond=None):
        b, *_, device = *x.shape, self.device
        batched_times = torch.full((b,), t, device=device, dtype=torch.long)
        model_mean, _, model_log_variance, x_start = self.p_mean_variance(x=x, t=batched_times, x_self_cond=self_cond,
                                                                          clip_denoised=True)
        noise = torch.randn_like(x) if t > 0 else 0.  # no noise if t == 0
        if isinstance(noise, torch.Tensor):
            noise[:, 1:, :, :] = 0
        pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
        return pred_img, x_start

    @torch.inference_mode()
    def p_sample_loop(self, x_start, return_all_timesteps=False):
        batch, device = x_start.shape[0], self.device

        img = x_start
        imgs = [img]

        x_start = None

        for t in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
            img, _ = self.p_sample(img, t, self_cond=None)
            imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim=1)

        ret = self.unnormalize(ret)
        return ret

    @torch.inference_mode()
    def ddim_sample(self, x_start, return_all_timesteps=False):
        batch, device, total_timesteps, sampling_timesteps, eta, objective = x_start.shape[
            0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective

        times = torch.linspace(-1, total_timesteps - 1,
                               steps=sampling_timesteps + 1)  # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
        times = list(reversed(times.int().tolist()))
        time_pairs = list(zip(times[:-1], times[1:]))  # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]

        img = x_start
        imgs = [img]

        for time, time_next in tqdm(time_pairs, desc='sampling loop time step'):
            time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
            pred_noise, _, *_ = self.model_predictions(img, time_cond, clip_x_start=True, rederive_pred_noise=True)

            if time_next < 0:
                img = x_start
                imgs.append(img)
                continue

            alpha = self.alphas_cumprod[time]
            alpha_next = self.alphas_cumprod[time_next]

            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = (1 - alpha_next - sigma ** 2).sqrt()

            noise = torch.randn_like(img)
            noise[:, 1:, :, :] = 0

            img = x_start * alpha_next.sqrt() + \
                  c * pred_noise + \
                  sigma * noise

            imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim=1)

        ret = self.unnormalize(ret)
        return ret

    @torch.inference_mode()
    def sample(self, x_start, return_all_timesteps=False):
        sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
        return sample_fn(x_start, return_all_timesteps=return_all_timesteps)

    @torch.inference_mode()
    def interpolate(self, x1, x2, t=None, lam=0.5):
        b, *_, device = *x1.shape, x1.device
        t = default(t, self.num_timesteps - 1)

        assert x1.shape == x2.shape

        t_batched = torch.full((b,), t, device=device)
        xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))

        img = (1 - lam) * xt1 + lam * xt2

        x_start = None

        for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):
            self_cond = x_start if self.self_condition else None
            img, x_start = self.p_sample(img, i, self_cond)

        return img

    @autocast(enabled=False)
    def q_sample(self, x_start, t, noise=None):
        b, c, h, w = x_start.shape
        noise = default(noise, lambda: torch.randn_like(x_start))
        noise[:, 1:, :, :] = 0

        return (
                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    def p_losses(self, x_start, t, noise=None, offset_noise_strength=None):
        b, c, h, w = x_start.shape

        noise = default(noise, lambda: torch.randn_like(x_start))

        # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise

        offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)

        if offset_noise_strength > 0.:
            offset_noise = torch.randn(x_start.shape[:2], device=self.device)
            noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')

        # noise sample
        mask = x_start[:,0,:,:] == -10
        x_start[:,0,:,:][mask] = 0
        x = self.q_sample(x_start=x_start, t=t, noise=noise)

        # if doing self-conditioning, 50% of the time, predict x_start from current set of times
        # and condition with unet with that
        # this technique will slow down training by 25%, but seems to lower FID significantly

        # predict and take gradient step
        # print(x.shape, t.shape)
        model_out = self.model(x, t)
        model_out[mask] = 0

        if self.objective == 'pred_noise':
            target = noise[:, 0, :, :]
        elif self.objective == 'pred_x0':
            target = x_start
        elif self.objective == 'pred_v':
            v = self.predict_v(x_start, t, noise)
            target = v
        else:
            raise ValueError(f'unknown objective {self.objective}')
        
        loss = F.mse_loss(model_out, target, reduction='none')
        loss = reduce(loss, 'b ... -> b', 'mean')

        loss = loss * extract(self.loss_weight, t, loss.shape)
        return loss.mean()

    def forward(self, img, *args, **kwargs):
        b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        assert h == img_size[0] and w == img_size[1], f'height and width of image must be {img_size}'
        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

        return self.p_losses(img, t, *args, **kwargs)

# TRAINER

In [8]:
class Trainer(object):
    def __init__(
            self,
            diffusion_model,
            folder,
            *,
            train_batch_size=16,
            gradient_accumulate_every=1,
            train_lr=1e-4,
            train_num_steps=100000,
            ema_update_every=10,
            ema_decay=0.995,
            adam_betas=(0.9, 0.99),
            save_and_sample_every=1000,
            num_samples=25,
            n_subsamples=1600,
            min_edges_per_subsample=10,
            results_folder='./results',
            amp=False,
            mixed_precision_type='fp16',
            split_batches=True,
            inception_block_idx=2048,
            max_grad_norm=1.,
            save_best_and_latest_only=False
    ):
        super().__init__()

        # accelerator

        self.accelerator = Accelerator(
            split_batches=split_batches,
            mixed_precision=mixed_precision_type if amp else 'no'
        )

        # model

        self.model = diffusion_model
        is_ddim_sampling = diffusion_model.is_ddim_sampling

        # sampling and training hyperparameters

        assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
        self.num_samples = num_samples
        self.save_and_sample_every = save_and_sample_every

        self.batch_size = train_batch_size
        self.gradient_accumulate_every = gradient_accumulate_every
        assert (train_batch_size * gradient_accumulate_every) >= 16, \
            f'your effective batch size (train_batch_size x gradient_accumulate_every) should be at least 16 or above'

        self.train_num_steps = train_num_steps
        self.image_size = diffusion_model.image_size

        self.max_grad_norm = max_grad_norm

        # dataset and dataloader
        self.ds = ProcessedMovieLens(folder, n_subsamples, n_unique_per_sample=min_edges_per_subsample,
                                     dataset_transform=rating_transform, download=True)

        dl = DataLoader(self.ds, batch_size=train_batch_size, shuffle=True, pin_memory=True, num_workers=cpu_count())

        dl = self.accelerator.prepare(dl)
        self.dl = cycle(dl)

        # optimizer

        self.opt = Adam(diffusion_model.parameters(), lr=train_lr, betas=adam_betas)

        # for logging results in a folder periodically

        if self.accelerator.is_main_process:
            self.ema = EMA(diffusion_model, beta=ema_decay, update_every=ema_update_every)
            self.ema.to(self.device)

        self.results_folder = Path(results_folder)
        self.results_folder.mkdir(exist_ok=True)

        # step counter state

        self.step = 0

        # prepare model, dataloader, optimizer with accelerator

        self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

        self.save_best_and_latest_only = save_best_and_latest_only

    @property
    def device(self):
        return self.accelerator.device

    def save(self, milestone):
        if not self.accelerator.is_local_main_process:
            return

        data = {
            'step': self.step,
            'model': self.accelerator.get_state_dict(self.model),
            'opt': self.opt.state_dict(),
            'ema': self.ema.state_dict(),
            'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
        }

        torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))

    def load(self, milestone):
        accelerator = self.accelerator
        device = accelerator.device

        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'), map_location=device)

        model = self.accelerator.unwrap_model(self.model)
        model.load_state_dict(data['model'])

        self.step = data['step']
        self.opt.load_state_dict(data['opt'])
        if self.accelerator.is_main_process:
            self.ema.load_state_dict(data["ema"])

        if 'version' in data:
            print(f"loading from version {data['version']}")

        if exists(self.accelerator.scaler) and exists(data['scaler']):
            self.accelerator.scaler.load_state_dict(data['scaler'])

    def train(self):
        accelerator = self.accelerator
        device = accelerator.device

        with tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process) as pbar:

            while self.step < self.train_num_steps:

                total_loss = 0.

                for _ in range(self.gradient_accumulate_every):
                    data = next(self.dl).to(device)

                    with self.accelerator.autocast():
                        loss = self.model(data)
                        loss = loss / self.gradient_accumulate_every
                        total_loss += loss.item()

                    self.accelerator.backward(loss)

                pbar.set_description(f'loss: {total_loss:.4f}')

                accelerator.wait_for_everyone()
                accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

                self.opt.step()
                self.opt.zero_grad()

                accelerator.wait_for_everyone()

                self.step += 1
                if accelerator.is_main_process:
                    self.ema.update()

                    if self.step != 0 and divisible_by(self.step, self.save_and_sample_every):
                        self.ema.ema_model.eval()

                        with torch.inference_mode():
                            eval_data = next(self.dl).to(device)
                            b, c, h, w = eval_data.shape
                            random_rating = torch.randn(b, h, w)
                            eval_data[:, 0, :, :] = random_rating

                            val_loss = self.model(eval_data)
                            val_sample = self.ema.ema_model.sample(eval_data)
                            print(f"Validation Loss: {val_loss.item()}")
                            np.save(
                                f"{self.results_folder}/sample-{self.step}.npy",
                                val_sample[:, 0, :, :].cpu().detach().numpy()
                            )
                            self.save(self.step)

                pbar.update(1)

        accelerator.print('training complete')

# MOVIELENS DATASET

In [9]:
MOVIE_HEADERS = ["movieId", "title", "genres"]
USER_HEADERS = ["userId", "gender", "age", "occupation", "zipCode"]
RATING_HEADERS = ['userId', 'movieId', 'rating', 'timestamp']

### rating transforms

In [10]:
def rating_transform(data):
    ratings = data[2, :]
    ratings[ratings == 1] = -5
    ratings[ratings == 2] = -2
    ratings[ratings == 3] = 0
    ratings[ratings == 4] = 2
    data[2, :] = ratings
    return data

class RatingQuantileTransform(object):
    def __init__(self):
        self.qt_transformer = QuantileTransformer(n_quantiles=5, output_distribution="normal")

    def __call__(self, data):
        ratings = data[2, :]
        data[2, :] = torch.Tensor(self.qt_transformer.fit_transform(ratings.reshape(-1, 1))).T
        return data

### self-processing dataset features

In [11]:
class RawMovieLens1M(MovieLens1M):
    def __init__(self, root, transform=None, pre_transform=None, force_reload=False):
        super(RawMovieLens1M, self).__init__(root, transform, pre_transform, force_reload)

    def _process_genres(self, df):
        l = df["genres"].str.get_dummies('|').values
        max_genres = l.sum(axis=1).max()
        idx_list = []
        for i in range(l.shape[0]):
            idxs = np.where(l[i, :] == 1)[0] + 1
            missing = max_genres - len(idxs)
            if missing > 0:
                idxs = np.array(list(idxs) + missing * [0])
            idx_list.append(idxs)
        return np.stack(idx_list)

    def process(self) -> None:
        import pandas as pd

        data = HeteroData()

        # Process movie data:
        df = pd.read_csv(
            self.raw_paths[0],
            sep='::',
            header=None,
            index_col='movieId',
            names=MOVIE_HEADERS,
            encoding='ISO-8859-1',
            engine='python',
        )
        movie_mapping = {idx: i for i, idx in enumerate(df.index)}

        genres = self._process_genres(df)
        genres = torch.from_numpy(genres).to(torch.float)

        data['movie'].x = genres

        # Process user data:
        df = pd.read_csv(
            self.raw_paths[1],
            sep='::',
            header=None,
            index_col='userId',
            names=USER_HEADERS,
            dtype='str',
            encoding='ISO-8859-1',
            engine='python',
        )
        user_mapping = {idx: i for i, idx in enumerate(df.index)}

        age = df['age'].str.get_dummies().values.argmax(axis=1)[:, None]
        age = torch.from_numpy(age).to(torch.float)

        gender = df['gender'].str.get_dummies().values[:, 0][:, None]
        gender = torch.from_numpy(gender).to(torch.float)

        occupation = df['occupation'].str.get_dummies().values.argmax(axis=1)[:, None]
        occupation = torch.from_numpy(occupation).to(torch.float)

        data['user'].x = torch.cat([age, gender, occupation], dim=-1)

        self.int_user_data = df

        # Process rating data:
        df = pd.read_csv(
            self.raw_paths[2],
            sep='::',
            header=None,
            names=RATING_HEADERS,
            encoding='ISO-8859-1',
            engine='python',
        )

        src = [user_mapping[idx] for idx in df['userId']]
        dst = [movie_mapping[idx] for idx in df['movieId']]
        edge_index = torch.tensor([src, dst])
        data['user', 'rates', 'movie'].edge_index = edge_index

        rating = torch.from_numpy(df['rating'].values).to(torch.long)
        data['user', 'rates', 'movie'].rating = rating

        time = torch.from_numpy(df['timestamp'].values)
        data['user', 'rates', 'movie'].time = time

        data['movie', 'rated_by', 'user'].edge_index = edge_index.flip([0])
        data['movie', 'rated_by', 'user'].rating = rating
        data['movie', 'rated_by', 'user'].time = time

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        self.save([data], self.processed_paths[0])

### custom dataset

In [12]:
class ProcessedMovieLens(Dataset):
    PROCESSED_ML_SUBPATH = "/processed/data.pt"

    def __init__(self, root, n_subsamples=100, n_unique_per_sample=10, dataset_transform=None, transform=None,
                 download=True):
        if download:
            self.ml_1m = RawMovieLens1M(root, force_reload=True)
            self.ml_1m.process()

        self.n_unique_per_sample = n_unique_per_sample
        self.n_subsamples = n_subsamples
        self.transform = transform
        self.dataset_transform = dataset_transform
        print(root + self.PROCESSED_ML_SUBPATH)
        self.processed_data = torch.load(root + self.PROCESSED_ML_SUBPATH)
        self.processed_ratings = self._preprocess_ratings(self.processed_data)

    def _preprocess_ratings(self, data):
        edges = data[0][('user', 'rates', 'movie')]
        edge_ratings = torch.concatenate([edges["edge_index"], edges["rating"].reshape((1, -1))])
        return self.dataset_transform(edge_ratings)

    def __getitem__(self, idx):
        n_unique = self.n_unique_per_sample
        edge_ratings = self.processed_ratings
        movie_feats, user_feats = self.processed_data[0]["movie"]["x"], self.processed_data[0]["user"]["x"]

        _, edge_size = edge_ratings.shape
        indices = torch.randint(0, edge_size, (n_unique,))
        sampled_edges = torch.ones((1, 1))
        while len(sampled_edges[0, :].unique()) < n_unique or len(sampled_edges[1, :].unique()) < n_unique:
            indices = torch.randint(0, edge_size, (n_unique,))
            sampled_edges = edge_ratings[:, indices]

        xs = edge_ratings[0, indices]
        ys = edge_ratings[1, indices]

        indices_xs = torch.where(torch.isin(edge_ratings[0, :], xs))[0]
        indices_ys = torch.where(torch.isin(edge_ratings[1, :], ys))[0]
        subsample_edges = edge_ratings[:, np.intersect1d(indices_xs, indices_ys)].T

        subsample_movie_feats = movie_feats[subsample_edges[:, 1], :]
        subsample_user_feats = user_feats[subsample_edges[:, 0], :]
        subsample_user_movie_feats = torch.cat([subsample_movie_feats, subsample_user_feats], dim=1)

        broadcasted_movie_feats, broadcasted_user_feats = (
            torch.broadcast_to(
                movie_feats[ys.unique().sort().values, :].T.reshape((-1, 1, n_unique)).swapaxes(1, 2),
                (-1, n_unique, n_unique)
            ).swapaxes(1, 2),
            torch.broadcast_to(
                user_feats[xs.unique().sort().values, :].T.reshape((-1, 1, n_unique)).swapaxes(1, 2),
                (-1, n_unique, n_unique)
            )
        )

        rating_matrix = torch.Tensor(pd.DataFrame(subsample_edges).pivot(columns=[1], index=[0]).fillna(-10).to_numpy())
        item = torch.cat([
            rating_matrix.reshape((1, n_unique, n_unique)),
            broadcasted_movie_feats,
            broadcasted_user_feats
        ], dim=0)

        return item

    def __len__(self):
        return self.n_subsamples

### feature embedding head

In [13]:
class MovieLensFeatureEmb(nn.Module):
    MAX_N_GENRES = 6
    
    def __init__(self,
                 age_dim = 4,
                 gender_dim = 3,
                 occupation_dim = 8,
                 genre_dim = 16,
        ):
        super().__init__()
        
        self.age_embedding = nn.Embedding(
            num_embeddings = 6,
            embedding_dim = age_dim
        )
        self.gender_embedding = nn.Embedding(
            num_embeddings = 2,
            embedding_dim = gender_dim
        )
        self.occupation_embedding = nn.Embedding(
            num_embeddings = 21,
            embedding_dim = occupation_dim
        )
        self.genre_embedding = nn.Embedding(
            num_embeddings = 19,
            embedding_dim = genre_dim
        )
        
        self.age_dim = age_dim
        self.gender_dim = gender_dim
        self.occupation_dim = occupation_dim
        self.genre_dim = genre_dim
        
    @property
    def embed_dim(self):
        return self.age_dim + self.gender_dim + self.occupation_dim + self.genre_dim + 1
    
    def forward(self, x):
        # dims = [ft, user, movie]
        # ft = [1 rating, 6 genres, --> these are all bogus rn --> 1 age, 1 gender, 1 occupation]
        # x.shape = (b, f, n, m)
        assert x.shape[1] == 4 + self.MAX_N_GENRES

        ratings = x[:, 0:1]
        genres = x[:, 1:7].long()
        age = x[:, 7].long()
        gender = x[:, 8].long()
        occupation = x[:, 9].long()
        
        genre_emb = self.genre_embedding(genres).swapdims(1, -1).sum(dim=-1)
        age_emb = self.age_embedding(age).permute(0, 3, 1, 2)
        gender_emb = self.gender_embedding(gender).permute(0, 3, 1, 2)
        occupation_emb = self.occupation_embedding(occupation).permute(0, 3, 1, 2)

        full_embeds = torch.cat([ratings, genre_emb, age_emb, gender_emb, occupation_emb], dim=1)
        assert full_embeds.shape[1] == self.embed_dim

        return full_embeds

# DENOISING MODEL (SUBGRAPH DIT)

### constants

In [14]:
ModelPrediction =  namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])

### small helper modules

In [15]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = RMSNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

### sinusoidal positional embeds

In [16]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim, theta = 10000):
        super().__init__()
        self.dim = dim
        self.theta = theta

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.theta) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class RandomOrLearnedSinusoidalPosEmb(nn.Module):
    """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

    def __init__(self, dim, is_random = False):
        super().__init__()
        assert divisible_by(dim, 2)
        half_dim = dim // 2
        self.dim = dim + 1
        self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        fouriered = torch.cat((x, fouriered), dim = -1)
        return fouriered

### sugraph DiT block

In [17]:
class SubgraphDiTBlock(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        ffn_hidden_dims = None,
        ffn_activation = nn.SiLU(),
        attn_heads = 4,
        attn_dim_head = 32,
        time_emb_dim = None,
        flash = False
	):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        # mlps for timestep conditioning using adaptive layer norm
        self.time_mlp_pre = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, in_dim*8)
        ) if exists(time_emb_dim) else None
        
        self.time_mlp_post = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, in_dim*2)
        ) if exists(time_emb_dim) else None
        
        # layer norm for input
        self.input_ln = RMSNorm(in_dim)
        
        # row and column attention
        self.row_attn = RowAttention(in_dim, attn_heads, attn_dim_head, flash)
        self.col_attn = RowAttention(in_dim, attn_heads, attn_dim_head, flash)
        
        # post residual layer norm
        self.concat_ln = RMSNorm(in_dim*2)
        
        # pointwise feedforward network for attention + residual
        if exists(ffn_hidden_dims):
            ffn_in_dims = [in_dim*2] + ffn_hidden_dims[:-1]
            ffn_out_dims = ffn_hidden_dims
            self.ffn = nn.Sequential(
                *[
                    nn.Sequential(
                        nn.Conv2d(dim_in, dim_out, 1),
                        ffn_activation
                    ) for dim_in, dim_out in zip(ffn_in_dims, ffn_out_dims)
                 ],
                nn.Conv2d(ffn_in_dims[-1], in_dim*2, 1)
            )
        else:
            self.ffn = nn.Conv2d(in_dim*2, in_dim*2, 1)
        
        # final downsample for ffn output + residual
        self.downsample = nn.Conv2d(in_dim*2, out_dim, 1)
        
    def forward(self, x, time_emb = None):
        input_residual = x   
        
        if exists(self.time_mlp_pre) and exists(time_emb):
            time_emb_pre = self.time_mlp_pre(time_emb)
            time_emb_pre = rearrange(time_emb_pre, 'b c -> b c 1 1')
            
            pre_attn_scale, pre_attn_shift, post_row_attn_scale, post_col_attn_scale, pre_ffn_scale1, pre_ffn_scale2, pre_ffn_shift1, pre_ffn_shift2 = time_emb_pre.chunk(8, dim = 1)
            
            pre_ffn_scale = torch.cat([pre_ffn_scale1, pre_ffn_scale2], dim = 1)
            pre_ffn_shift = torch.cat([pre_ffn_shift1, pre_ffn_shift2], dim = 1)
            
            post_ffn_scale = self.time_mlp_post(time_emb)
            post_ffn_scale = rearrange(post_ffn_scale, 'b c -> b c 1 1')
        
        x = self.input_ln(x)
        x = x*(1 + pre_attn_scale) + pre_attn_shift
        
        row_attn = self.row_attn(x)
        row_attn = row_attn*(1 + post_row_attn_scale)
        
        col_attn = self.col_attn(x.transpose(-2, -1)).transpose(-2, -1)
        col_attn = col_attn*(1 + post_col_attn_scale)
        
        row_attn = row_attn + input_residual*0.5
        col_attn = col_attn + input_residual*0.5
        concatenated = torch.cat([row_attn, col_attn], dim=1)
        
        x = self.concat_ln(concatenated)
        x = x*(1 + pre_ffn_scale) + pre_ffn_shift
        x = self.ffn(x)
        x = x*(1 + post_ffn_scale)
        
        x = x + concatenated
        out = self.downsample(x)
        return out

### subgraph DiT

In [18]:
class SubgraphDiT(nn.Module):
    def __init__(
        self,
        hidden_dims = None,
        out_dim = 1,
        feature_embedding = MovieLensFeatureEmb(),
        time_embedding = RandomOrLearnedSinusoidalPosEmb,
        ffn_hidden_dims = None,
        ffn_activation = nn.SiLU(),
        attn_heads = 4,
        attn_dim_head = 32,
        final_ffn_hidden_dims = None,
        final_ffn_activation = nn.SiLU(),
        final_attn_heads = 4,
        final_attn_dim_head = 32,
        flash = False
    ):
        super().__init__()
        self.feature_dim = feature_embedding.embed_dim
        self.feature_embedder = feature_embedding
        
        self.time_embedding = time_embedding(self.feature_dim)
        self.time_embedder = nn.Sequential(
            self.time_embedding,
            nn.Linear(self.time_embedding.dim, 4*self.feature_dim),
            nn.SiLU(),
            nn.Linear(4*self.feature_dim, 4*self.feature_dim)
        )
        self.blocks = nn.ModuleList([])
        
        mid_dim = self.feature_dim
        for i, hidden_dim in enumerate(hidden_dims):
            self.blocks.append(SubgraphDiTBlock(
                in_dim = mid_dim,
                out_dim = hidden_dim,
                ffn_hidden_dims = ffn_hidden_dims[i] if exists(ffn_hidden_dims) else None,
                ffn_activation = ffn_activation[i] if isinstance(ffn_activation, list) else ffn_activation,
                attn_heads = attn_heads[i] if isinstance(attn_heads, list) else attn_heads,
                attn_dim_head = attn_dim_head[i] if isinstance(attn_dim_head, list) else attn_dim_head,
                time_emb_dim = 4*self.feature_dim,
                flash = flash
            ))
            mid_dim = hidden_dim
            
        self.blocks.append(SubgraphDiTBlock(
            in_dim = mid_dim,
            out_dim = out_dim,
            ffn_hidden_dims = final_ffn_hidden_dims,
            ffn_activation = final_ffn_activation,
            attn_heads = final_attn_heads,
            attn_dim_head = final_attn_dim_head,
            time_emb_dim = 4*self.feature_dim,
            flash = flash
        ))
            
        self.layer_norm = RMSNorm(out_dim)
        self.to_out = nn.Conv2d(out_dim, out_dim, 1)
        
    def forward(self, x, time):
        time_emb = self.time_embedder(time)
        x = self.feature_embedder(x)
        for block in self.blocks:
            x = block(x, time_emb)
            x = self.layer_norm(x)
        return self.to_out(x)
        

### model definition

In [22]:
model = SubgraphDiT(
    hidden_dims=[16, 8, 4],
    ffn_hidden_dims=[[32, 32], [16, 16], [8, 8]],
    ffn_activation=[nn.SiLU(), nn.SiLU(), nn.SiLU()],
    attn_heads=[4, 4, 4],
    attn_dim_head=[32, 16, 8],
    final_ffn_hidden_dims=[4, 4],
    final_ffn_activation=nn.SiLU(),
    final_attn_heads=4,
    final_attn_dim_head=8,
    flash=False
)

# EXECUTION

In [23]:
diffusion = GaussianDiffusion(model, image_size=10)
trainer = Trainer(diffusion, "./movie_lens", train_num_steps=int(1e4))

dataloader_config = DataLoaderConfiguration(split_batches=True)
Processing...
Done!


./movie_lens/processed/data.pt


In [24]:
trainer.train()

  0%|          | 0/10000 [00:01<?, ?it/s]

Input Shape torch.Size([16, 10, 10, 10])
Time Shape torch.Size([16])
Time Emb Shape torch.Size([16, 128])
Feature Embed Shape torch.Size([16, 32, 10, 10])
Feature map shape:  torch.Size([16, 32, 10, 10])
Feature map shape after input layer norm:  torch.Size([16, 32, 10, 10])
Shape after pre_attn modulate:  torch.Size([16, 32, 10, 10])
Shape after row attention:  torch.Size([16, 32, 10, 10])
Shape after post row attention:  torch.Size([16, 32, 10, 10])
Shape after col attention:  torch.Size([16, 32, 10, 10])
Shape after post col attention:  torch.Size([16, 32, 10, 10])
Shape after row attn residual:  torch.Size([16, 32, 10, 10])
Shape after col attn residual:  torch.Size([16, 32, 10, 10])
Shape after concatenation:  torch.Size([16, 64, 10, 10])
Shape after concat layer norm:  torch.Size([16, 64, 10, 10])
Shape after pre_ffn modulate:  torch.Size([16, 64, 10, 10])
Shape after ffn:  torch.Size([16, 64, 10, 10])
Shape after post ffn modulate:  torch.Size([16, 64, 10, 10])
Shape after resid




IndexError: The shape of the mask [16, 10, 10] at index 1 does not match the shape of the indexed tensor [16, 1, 10, 10] at index 1