<a href="https://colab.research.google.com/github/sean-halpin/diffusion_models/blob/main/min_diffusion_dct_clip.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!rm -rf /content/images/*

In [3]:
!mkdir -p /content/images/

In [4]:
from typing import Dict, Tuple
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import save_image, make_grid

## CIFAR

### UNet

In [7]:
"""
Simple Unet Structure.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class Conv3(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, is_res: bool = False
    ) -> None:
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
        )
        self.conv = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
        )

        self.is_res = is_res

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.main(x)
        if self.is_res:
            x = x + self.conv(x)
            return x / 1.414
        else:
            return self.conv(x)


class UnetDown(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super(UnetDown, self).__init__()
        layers =  [
                    Conv3(in_channels, out_channels), nn.MaxPool2d(2)
                  ]
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        return self.model(x)


class UnetUp(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super(UnetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            Conv3(out_channels, out_channels),
            Conv3(out_channels, out_channels),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        x = torch.cat((x, skip), 1)
        x = self.model(x)

        return x


class TimeSiren(nn.Module):
    def __init__(self, emb_dim: int) -> None:
        super(TimeSiren, self).__init__()

        self.lin1 = nn.Linear(1, emb_dim, bias=False)
        self.lin2 = nn.Linear(emb_dim, emb_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(-1, 1)
        x = torch.sin(self.lin1(x))
        x = self.lin2(x)
        return x


class NaiveUnet(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, n_feat: int = 256) -> None:
        super(NaiveUnet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.n_feat = n_feat

        self.init_conv = nn.Sequential(
            Conv3(in_channels, n_feat, is_res=True)
        )
        # self.init_conv = Conv3(in_channels, n_feat, is_res=True)

        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)
        self.down3 = UnetDown(2 * n_feat, 2 * n_feat)

        self.to_vec = nn.Sequential(nn.AvgPool2d(4), nn.ReLU())

        self.timeembed = TimeSiren(2 * n_feat)

        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 4, 4),
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )

        self.up1 = UnetUp(4 * n_feat, 2 * n_feat)
        self.up2 = UnetUp(4 * n_feat, n_feat)
        self.up3 = UnetUp(2 * n_feat, n_feat)
        self.out = nn.Conv2d(2 * n_feat, self.out_channels, 3, 1, 1)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:

        x = self.init_conv(x)

        down1 = self.down1(x)
        down2 = self.down2(down1)
        down3 = self.down3(down2)

        thro = self.to_vec(down3)
        temb = self.timeembed(t).view(-1, self.n_feat * 2, 1, 1)

        thro = self.up0(thro + temb)

        up1 = self.up1(thro, down3) + temb
        up2 = self.up2(up1, down2)
        up3 = self.up3(up2, down1)

        out = self.out(torch.cat((up3, x), 1))

        return out

### Denoising Diffusion Probabilistic Model

In [8]:
from typing import Dict, Tuple


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class DDPM(nn.Module):
    def __init__(
        self,
        eps_model: nn.Module,
        betas: Tuple[float, float],
        n_T: int,
        criterion: nn.Module = nn.MSELoss(),
    ) -> None:
        super(DDPM, self).__init__()
        self.eps_model = eps_model

        # register_buffer allows us to freely access these tensors by name. It helps device placement.
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.criterion = criterion

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Makes forward diffusion x_t, and tries to guess epsilon value from x_t using eps_model.
        This implements Algorithm 1 in the paper.
        """

        _ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(x.device)
        # t ~ Uniform(0, n_T)
        eps = torch.randn_like(x)  # eps ~ N(0, 1)

        x_t = (
            self.sqrtab[_ts, None, None, None] * x
            + self.sqrtmab[_ts, None, None, None] * eps
        )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
        # We should predict the "error term" from this x_t. Loss is what we return.

        return self.criterion(eps, self.eps_model(x_t, _ts / self.n_T))

    def sample(self, n_sample: int, size, device) -> torch.Tensor:

        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1)
        # x_i = torch.randn(n_sample, *size).to(device) if imgg else imgg

        # This samples accordingly to Algorithm 2. It is exactly the same logic.
        for i in range(self.n_T, 0, -1):
            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
            eps = self.eps_model(
                x_i, torch.tensor(i / self.n_T).to(device).repeat(n_sample, 1)
            )
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )

        return x_i


def ddpm_schedules(beta1: float, beta2: float, T: int) -> Dict[str, torch.Tensor]:
    """
    Returns pre-computed schedules for DDPM sampling, training process.
    """
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrtab = torch.sqrt(alphabar_t)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)

    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

    return {
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }

### Train DDPM on CIFAR

In [32]:
import matplotlib.pyplot as plt
from typing import Dict, Optional, Tuple
from sympy import Ci
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.utils import save_image, make_grid

from pathlib import Path


def create_model(load_pth = False, device: str = "cuda:0"):
    ddpm = DDPM(eps_model=NaiveUnet(3, 3, n_feat=128), betas=(1e-4, 0.02), n_T=2000)
    if load_pth:
      my_file = Path("ddpm_cifar.pth")
      if my_file.is_file(): 
        print("Loading Model from File")
        ddpm.load_state_dict(torch.load("ddpm_cifar.pth"))
    ddpm.to(device)
    return ddpm

def train_cifar10(ddpm: Optional[any] = None, n_epoch: int = 1, device: str = "cuda:0") -> None:

    if ddpm == None:
      print("Creating Model")
      ddpm = create_model()

    tf = transforms.Compose(
      [
        transforms.ToTensor(), 
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
      ]
    )

    dataset = CIFAR10(
        "./data",
        train=True,
        download=True,
        transform=tf,
    )

    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=1)
    optim = torch.optim.Adam(ddpm.parameters(), lr=1e-5)

    for i in range(n_epoch):
        print(f"Epoch {i} : ")
        ddpm.train()

        pbar = tqdm(dataloader)
        loss_ema = None
        for x, _ in pbar:
            optim.zero_grad()
            x = x.to(device)
            loss = ddpm(x)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()

        ddpm.eval()
        with torch.no_grad():
            xh = ddpm.sample(8, (3, 32, 32), device)
            xset = torch.cat([xh, x[:8]], dim=0)
            grid = make_grid(xset, normalize=True, value_range=(-1, 1), nrow=4)
            # plt.imshow(transforms.functional.rotate(grid, 90).T)
            # plt.show()
            save_image(grid, f"/content/images/ddpm_sample_cifar{i}.png")
            # save model
            torch.save(ddpm.state_dict(), f"./ddpm_cifar.pth")

### Begin Training

In [33]:
import torch
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
# torch.cuda.memory_summary(device=None, abbreviated=False)

In [34]:
ddpm = create_model(load_pth=True)

In [35]:
train_cifar10(ddpm, n_epoch=150)

Files already downloaded and verified
Epoch 0 : 


loss: 0.1391: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 1 : 


loss: 0.0895: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 2 : 


loss: 0.0745: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 3 : 


loss: 0.0636: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 4 : 


loss: 0.0568: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 5 : 


loss: 0.0523: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 6 : 


loss: 0.0523: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 7 : 


loss: 0.0481: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 8 : 


loss: 0.0437: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 9 : 


loss: 0.0423: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 10 : 


loss: 0.0422: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 11 : 


loss: 0.0394: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 12 : 


loss: 0.0421: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 13 : 


loss: 0.0373: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 14 : 


loss: 0.0358: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 15 : 


loss: 0.0354: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 16 : 


loss: 0.0326: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 17 : 


loss: 0.0368: 100%|██████████| 391/391 [01:17<00:00,  5.01it/s]


Epoch 18 : 


loss: 0.0341: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 19 : 


loss: 0.0339: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 20 : 


loss: 0.0342: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 21 : 


loss: 0.0392: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 22 : 


loss: 0.0296: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 23 : 


loss: 0.0301: 100%|██████████| 391/391 [01:17<00:00,  5.01it/s]


Epoch 24 : 


loss: 0.0302: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 25 : 


loss: 0.0342: 100%|██████████| 391/391 [01:17<00:00,  5.01it/s]


Epoch 26 : 


loss: 0.0308: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 27 : 


loss: 0.0322: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 28 : 


loss: 0.0298: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 29 : 


loss: 0.0311: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 30 : 


loss: 0.0304: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 31 : 


loss: 0.0327: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 32 : 


loss: 0.0304: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 33 : 


loss: 0.0309: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 34 : 


loss: 0.0306: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 35 : 


loss: 0.0306: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 36 : 


loss: 0.0275: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 37 : 


loss: 0.0325: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 38 : 


loss: 0.0289: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 39 : 


loss: 0.0283: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 40 : 


loss: 0.0290: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 41 : 


loss: 0.0315: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 42 : 


loss: 0.0285: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 43 : 


loss: 0.0275: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 44 : 


loss: 0.0283: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 45 : 


loss: 0.0282: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 46 : 


loss: 0.0278: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 47 : 


loss: 0.0273: 100%|██████████| 391/391 [01:17<00:00,  5.01it/s]


Epoch 48 : 


loss: 0.0282: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 49 : 


loss: 0.0299: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 50 : 


loss: 0.0303: 100%|██████████| 391/391 [01:17<00:00,  5.03it/s]


Epoch 51 : 


loss: 0.0281: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 52 : 


loss: 0.0265: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 53 : 


loss: 0.0304: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 54 : 


loss: 0.0300: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 55 : 


loss: 0.0273: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 56 : 


loss: 0.0276: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 57 : 


loss: 0.0288: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 58 : 


loss: 0.0274: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 59 : 


loss: 0.0260: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 60 : 


loss: 0.0284: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 61 : 


loss: 0.0269: 100%|██████████| 391/391 [01:17<00:00,  5.01it/s]


Epoch 62 : 


loss: 0.0278: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 63 : 


loss: 0.0254: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 64 : 


loss: 0.0284: 100%|██████████| 391/391 [01:17<00:00,  5.01it/s]


Epoch 65 : 


loss: 0.0272: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 66 : 


loss: 0.0274: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 67 : 


loss: 0.0251: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 68 : 


loss: 0.0299: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 69 : 


loss: 0.0296: 100%|██████████| 391/391 [01:17<00:00,  5.01it/s]


Epoch 70 : 


loss: 0.0283: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 71 : 


loss: 0.0275: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 72 : 


loss: 0.0290: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 73 : 


loss: 0.0266: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 74 : 


loss: 0.0268: 100%|██████████| 391/391 [01:17<00:00,  5.01it/s]


Epoch 75 : 


loss: 0.0295: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 76 : 


loss: 0.0256: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 77 : 


loss: 0.0261: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 78 : 


loss: 0.0280: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 79 : 


loss: 0.0248: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch 80 : 


loss: 0.0258: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 81 : 


loss: 0.0275: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 82 : 


loss: 0.0270: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 83 : 


loss: 0.0246: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 84 : 


loss: 0.0271: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 85 : 


loss: 0.0262: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 86 : 


loss: 0.0250: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 87 : 


loss: 0.0265: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 88 : 


loss: 0.0248: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 89 : 


loss: 0.0253: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 90 : 


loss: 0.0253: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 91 : 


loss: 0.0256: 100%|██████████| 391/391 [01:18<00:00,  4.97it/s]


Epoch 92 : 


loss: 0.0278: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 93 : 


loss: 0.0268: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 94 : 


loss: 0.0273: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 95 : 


loss: 0.0258: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 96 : 


loss: 0.0246: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 97 : 


loss: 0.0264: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 98 : 


loss: 0.0246: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 99 : 


loss: 0.0256: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 100 : 


loss: 0.0277: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 101 : 


loss: 0.0253: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 102 : 


loss: 0.0262: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 103 : 


loss: 0.0260: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 104 : 


loss: 0.0251: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 105 : 


loss: 0.0255: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 106 : 


loss: 0.0263: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 107 : 


loss: 0.0261: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 108 : 


loss: 0.0268: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 109 : 


loss: 0.0255: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 110 : 


loss: 0.0261: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 111 : 


loss: 0.0258: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 112 : 


loss: 0.0256: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 113 : 


loss: 0.0231: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 114 : 


loss: 0.0248: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 115 : 


loss: 0.0276: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 116 : 


loss: 0.0262: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 117 : 


loss: 0.0259: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 118 : 


loss: 0.0241: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 119 : 


loss: 0.0264: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 120 : 


loss: 0.0264: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch 121 : 


loss: 0.0252: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 122 : 


loss: 0.0241: 100%|██████████| 391/391 [01:18<00:00,  4.96it/s]


Epoch 123 : 


loss: 0.0253: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 124 : 


loss: 0.0251: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 125 : 


loss: 0.0258: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 126 : 


loss: 0.0262: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 127 : 


loss: 0.0256: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 128 : 


loss: 0.0243: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 129 : 


loss: 0.0258: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 130 : 


loss: 0.0264: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 131 : 


loss: 0.0263: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch 132 : 


loss: 0.0237: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 133 : 


loss: 0.0252: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 134 : 


loss: 0.0237: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 135 : 


loss: 0.0249: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 136 : 


loss: 0.0246: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 137 : 


loss: 0.0234: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 138 : 


loss: 0.0240: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 139 : 


loss: 0.0257: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 140 : 


loss: 0.0255: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 141 : 


loss: 0.0252: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch 142 : 


loss: 0.0226: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch 143 : 


loss: 0.0238: 100%|██████████| 391/391 [01:18<00:00,  4.97it/s]


Epoch 144 : 


loss: 0.0243: 100%|██████████| 391/391 [01:18<00:00,  4.96it/s]


Epoch 145 : 


loss: 0.0220: 100%|██████████| 391/391 [01:18<00:00,  4.97it/s]


Epoch 146 : 


loss: 0.0255: 100%|██████████| 391/391 [01:19<00:00,  4.94it/s]


Epoch 147 : 


loss: 0.0254: 100%|██████████| 391/391 [01:18<00:00,  4.95it/s]


Epoch 148 : 


loss: 0.0263: 100%|██████████| 391/391 [01:18<00:00,  4.97it/s]


Epoch 149 : 


loss: 0.0235: 100%|██████████| 391/391 [01:18<00:00,  4.97it/s]


### Testing


In [None]:
!pip install transformers

In [None]:
from PIL import Image
import requests

from transformers import CLIPProcessor, CLIPModel

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)

outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities

In [46]:
probs

tensor([[0.9949, 0.0051]], grad_fn=<SoftmaxBackward0>)

In [None]:
with torch.no_grad():
  x = ddpm.sample(8, (3, 32, 32), "cuda:0")
  grid = make_grid(x.cpu(), normalize=True, value_range=(-1, 1), nrow=4)
  for i in x.cpu():
    plt.imshow(i.T)
    plt.show()

In [None]:
del ddpm

In [None]:
!pip install GPUtil

from GPUtil import showUtilization as gpu_usage
# gpu_usage()   

In [None]:
from numba import cuda
cuda.select_device(0)
cuda.close()
cuda.select_device(0)

In [None]:
!pip install GPUtil

import torch
from GPUtil import showUtilization as gpu_usage
from numba import cuda

def free_gpu_cache():
    print("Initial GPU Usage")
    gpu_usage()                             

    torch.cuda.empty_cache()

    cuda.select_device(0)
    cuda.close()
    cuda.select_device(0)

    print("GPU Usage after emptying the cache")
    gpu_usage()

free_gpu_cache()  