In [1]:
!mkdir -p contents_custom

In [2]:
from typing import Dict, Tuple

import random
import numpy as np
import torch
import torch.nn as nn
from accelerate import Accelerator
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch


def generate_order(T):
    half_point = T // 2 - 1
    backward = list(range(half_point, -1, -1))
    forward = list(range(half_point + 1, T))
    print(forward, backward)
    order = []
    for i in range(half_point + 1):
        order.append(forward[i])
        order.append(backward[i])

    if T % 2 == 1:
        order.append(forward[-1])

    return order

In [4]:
def row_schedules(order: int, T: int) -> Dict[str, torch.Tensor]:
    """Returns order and schedule for each row in the image."""

    assert order == 1 or order == 2, "Order must be either 1 or 2"

    half_point = T // 2 - 1
    backward = list(range(13, -1, -1))
    forward = list(range(13 + 1, 28))
    rows = []
    if order == 1:
        for i in range(half_point + 1):
            rows.append(forward[i])
            rows.append(backward[i])
        if T % 2 == 1:
            rows.append(forward[half_point + 1])
    if order == 2:
        for i in range(1, half_point + 2):
            rows.append(backward[-i])
            rows.append(forward[-i])
        if T % 2 == 1:
            rows.append(forward[-half_point - 2])
    else:
        # Unique random order
        rows = random.sample(range(28), 28)

    return {"rows_t": torch.tensor(rows)}

In [6]:
class CNNBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        *,
        expected_shape,
        act=nn.GELU,
        kernel_size=7,
    ):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.LayerNorm((out_channels, *expected_shape)),
            act(),
        )

    def forward(self, x):
        return self.net(x)

In [7]:
class CNN(nn.Module):
    def __init__(
        self,
        in_channels,
        expected_shape=(28, 28),
        n_hidden=(64, 128, 64),
        kernel_size=7,
        last_kernel_size=3,
        time_embeddings=16,
        act=nn.GELU,
    ) -> None:
        super().__init__()
        last = in_channels

        self.blocks = nn.ModuleList()
        for hidden in n_hidden:
            self.blocks.append(
                CNNBlock(
                    last,
                    hidden,
                    expected_shape=expected_shape,
                    kernel_size=kernel_size,
                    act=act,
                )
            )
            last = hidden

        # The final layer, we use a regular Conv2d to get the
        # correct scale and shape (and avoid applying the activation)
        self.blocks.append(
            nn.Conv2d(
                last,
                in_channels,
                last_kernel_size,
                padding=last_kernel_size // 2,
            )
        )

        ## This part is literally just to put the single scalar "t" into the CNN
        ## in a nice, high-dimensional way:
        self.time_embed = nn.Sequential(
            nn.Linear(time_embeddings * 2, 128),
            act(),
            nn.Linear(128, 128),
            act(),
            nn.Linear(128, 128),
            act(),
            nn.Linear(128, n_hidden[0]),
        )
        frequencies = torch.tensor(
            [0] + [2 * np.pi * 1.5**i for i in range(time_embeddings - 1)]
        )
        self.register_buffer("frequencies", frequencies)

    def time_encoding(self, t: int) -> torch.Tensor:
        phases = torch.concat(
            (
                torch.sin(t[:, None] * self.frequencies[None, :]),
                torch.cos(t[:, None] * self.frequencies[None, :]) - 1,
            ),
            dim=1,
        )

        return self.time_embed(phases)[:, :, None, None]

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # Shapes of input:
        #    x: (batch, chan, height, width)
        #    t: (batch,)

        embed = self.blocks[0](x)
        # ^ (batch, n_hidden[0], height, width)

        # Add information about time along the diffusion process
        #  (Providing this information by superimposing in latent space)
        embed += self.time_encoding(t)
        #         ^ (batch, n_hidden[0], 1, 1) - thus, broadcasting
        #           to the entire spatial domain

        for block in self.blocks[1:]:
            embed = block(embed)

        return embed

In [8]:
class Row_Averaging(nn.Module):
    def __init__(
        self,
        gt,
        row_order: int,
        n_T: int,
        criterion: nn.Module = nn.MSELoss(),
    ) -> None:
        super().__init__()

        self.gt = gt

        row_schedule = row_schedules(row_order, n_T)

        # `register_buffer` will track these tensors for device placement, but
        # not store them as model parameters. This is useful for constants.
        self.register_buffer("rows_t", row_schedule["rows_t"])
        self.rows_t  # Exists! Set by register_buffer

        self.n_T = n_T
        self.criterion = criterion

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Row averaging diffusion"""

        t = torch.randint(1, self.n_T, (x.shape[0],), device=x.device)
        z_t = x
        for idx in t:
            rows_t = self.rows_t[:idx]
            z_t[:, :, rows_t, :] = torch.mean(
                z_t[:, :, rows_t, :], dim=2, keepdim=True
            ).expand_as(z_t[:, :, rows_t, :])

        # We should predict the "error term" from this z_t. Loss is what we return.

        return self.criterion(x, self.gt(z_t, t / self.n_T))

    def degrade(self, x: torch.Tensor, t: int) -> torch.Tensor:
        """Row averaging diffusion for a set time step"""

        rows_t = self.rows_t[: int(t[0].item())]

        z_t = x

        z_t[:, :, rows_t, :] = torch.mean(
            x[:, :, rows_t, :], dim=2, keepdim=True
        ).expand_as(x[:, :, rows_t, :])

        return z_t

    def sample(self, n_sample: int, dataset, size, device) -> torch.Tensor:
        """Algorithm 2 in Bansal et al. (2022)"""

        # num_images = len(dataset)
        # idx = random.sample(range(num_images), n_sample)
        # z_t = torch.stack([dataset[i][0].clone() for i in idx])
        # z_t = z_t.to(device)

        # for i in range(n_sample):
        #    rows_t = self.rows_t[: self.n_T]
        #    z_t[:,:,rows_t,:] = torch.mean(z_t[:,:,rows_t,:]).expand_as(z_t[:,:,rows_t,:])

        # z_t = torch.empty((n_sample,1,28,28))

        tensor_values = torch.FloatTensor(n_sample, 1, 28, 28).uniform_(-0.5, -0.2)
        z_t = torch.mean(tensor_values, dim=2, keepdim=True).expand_as(tensor_values)
        z_t = z_t.to(device)

        for i in range(self.n_T, 0, -1):
            i = torch.Tensor([i])
            t = i.expand_as(torch.empty(n_sample))
            t = t.to(device)
            x_hat = self.gt(z_t, t / self.n_T)
            z_t -= self.degrade(x_hat, t) - self.degrade(x_hat, t - 1)

        return z_t

In [10]:
tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0))])
dataset = MNIST("./data", train=True, download=True, transform=tf)
dataloader = DataLoader(
    dataset, batch_size=128, shuffle=True, num_workers=4, drop_last=True
)

In [11]:
gt = CNN(in_channels=1, expected_shape=(28, 28), n_hidden=(16, 32, 32, 16), act=nn.GELU)
# For testing: (16, 32, 32, 16)
# For more capacity (for example): (64, 128, 256, 128, 64)
row_avg = Row_Averaging(gt=gt, row_order=1, n_T=28)
optim = torch.optim.Adam(row_avg.parameters(), lr=2e-5)

In [12]:
accelerator = Accelerator()

# We wrap our model, optimizer, and dataloaders with `accelerator.prepare`,
# which lets HuggingFace's Accelerate handle the device placement and gradient accumulation.
row_avg, optim, dataloader = accelerator.prepare(row_avg, optim, dataloader)

In [13]:
for x, _ in dataloader:
    break

with torch.no_grad():
    row_avg(x)

In [14]:
n_epoch = 50
losses = []
avg_losses = []

for i in range(n_epoch):
    row_avg.train()

    pbar = tqdm(dataloader)  # Wrap our loop with a visual progress bar
    for x, _ in pbar:
        optim.zero_grad()

        loss = row_avg(x)

        loss.backward()
        # ^Technically should be `accelerator.backward(loss)` but not necessary for local training

        losses.append(loss.item())
        avg_loss = np.average(losses[max(len(losses) - 100, 0) :])

        pbar.set_description(
            f"loss: {avg_loss:.3g}"
        )  # Show running average of loss in progress bar

        optim.step()

    row_avg.eval()
    with torch.no_grad():
        xh = row_avg.sample(
            16, dataset, (1, 28, 28), accelerator.device
        )  # Can get device explicitly with `accelerator.device`
        grid = make_grid(xh, nrow=4)
        avg_losses.append(avg_loss)
        # Save samples to `./contents_custom` directory
        save_image(grid, f"./contents_custom/ddpm_sample_{i:04d}.png")

        # save model
        torch.save(row_avg.state_dict(), f"./ddpm_mnist.pth")

loss: 0.0673: 100%|██████████| 468/468 [00:40<00:00, 11.42it/s]

torch.Size([16, 1, 28, 28])





IndexError: index 1 is out of bounds for dimension 1 with size 1