In [4]:
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from tqdm import tqdm

# Custom libraries
import utils
from diffusion import DiffusionModel

In [13]:
#! Settings
#DEVICE = "cuda:0"
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
LR = 0.001
NO_EPOCHS = 10000
BATCH_SIZE = 5
PRINT_FREQUENCY = 100
VERBOSE = True
MARGIN = 0.004
IMAGE_SHAPE = (64, 64)

SAVE_MODEL_DIR = "/Volumes/ML/Simple-Diffusion-Model-From-Scratch/"

In [7]:
print(torch.backends.mps.is_available())  # Should return True if MPS is supported


True


In [10]:
#! Prepare data
img0 = Image.open("data/players/beckham.jpg").convert('RGB')
img1 = Image.open("data/players/messi.jpg").convert('RGB')
img2 = Image.open("data/players/ronaldo.png").convert('RGB')
img3 = Image.open("data/players/haaland.jpg").convert('RGB')
img4 = Image.open("data/players/silva.jpg").convert('RGB')


img0_tensor = utils.transform(img0)
img1_tensor = utils.transform(img1)
img2_tensor = utils.transform(img2)
img3_tensor = utils.transform(img3)
img4_tensor = utils.transform(img4)


batch = torch.stack([img0_tensor, img1_tensor, img2_tensor, img3_tensor, img4_tensor]).to(DEVICE)
label = torch.tensor([0,1,2,3,4]).reshape(-1,1).float().to(DEVICE)

print("Input Batch Images:", batch.shape, batch.dtype, batch.device)
print("Input Batch Labels:", label.shape, label.dtype, label.device)

Input Batch Images: torch.Size([5, 3, 64, 64]) torch.float32 mps:0
Input Batch Labels: torch.Size([5, 1]) torch.float32 mps:0


In [11]:
class UNet(nn.Module):
    def __init__(self):
      super().__init__()
      # Complete code here
      channels = (64, 128, 256, 512, 1024)
      down_channels = list(reversed(channels))
      time_emb_dim = 32
      label = False

      self.conv1 = nn.Conv2d(3, channels[0], 3, padding=1)
      self.relu = nn.ReLU()

      #DownScale
      self.downscales = nn.ModuleList([utils.Block(channels[i], channels[i+1], 100, label) for i in range(len(channels)-1)])

      #UpScale
      self.upscales = nn.ModuleList([utils.Block(down_channels[i], down_channels[i+1], 100, label, downsample=False) for i in range(len(channels)-1)])

      #output
      self.output = nn.Conv2d(64, 3, 1)


    def forward(self, x, t, **kwargs):
        # Complete code here
        x = self.relu(self.conv1(x))
        # Unet
        residual_inputs = []
        for downscale in self.downscales:
            x = downscale(x, t)
            residual_inputs.append(x)
        for upscale in self.upscales:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = upscale(x, t)
        return self.output(x)

model = UNet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
model

Num params:  62633667


UNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (downscales): ModuleList(
    (0): Block(
      (time_embedding): SinusoidalPositionEmbeddings()
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (final): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (bnorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (time_mlp): Linear(in_features=100, out_features=128, bias=True)
      (relu): ReLU()
    )
    (1): Block(
      (time_embedding): SinusoidalPositionEmbeddings()
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (final): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (bnorm1): BatchNorm

In [12]:
#! Define Diffusion Model
diffusion_model = DiffusionModel()

#! Define Unet
unet = UNet().to(DEVICE)

#! Define Optimizer
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)

In [14]:
#! Training
MODEL_DIR = f"{SAVE_MODEL_DIR}/trained-model.pt"
for epoch in tqdm(range(NO_EPOCHS)):
    # 1. Get B random integers between (0, timestep)
    # Hint use: torch.randint()
    t = torch.randint(0, 300, (BATCH_SIZE,)).to(DEVICE)

    # 2. Forward diffuse each image with different timestep values
    # Hint use: diffusion_model.forward() function
    img_plus_noise, noise = diffusion_model.forward(batch, t, DEVICE)
    # 3. Predict the noise in this batch using Unet
    predicted_noise = unet(img_plus_noise, t, labels=label)

    # 4. Calculate the MSE loss between actual noise and predicted noise
    # Hint use: torch.nn.functional.mse_loss
    mse_loss = nn.functional.mse_loss(noise, predicted_noise)
    # 5. Make all gradients zero
    optimizer.zero_grad()

    # 6. Calculate Gradients
    mse_loss.backward()

    # 7. Update the Model
    optimizer.step()

    # 8. Print the Loss on the intervel of given PRINT_FREQUENCY
    if epoch % PRINT_FREQUENCY == 0:
      print(f"Current epoch: {epoch} ------ loss: {mse_loss} ----------")


    # 9. if loss < MARGIN, save the model and exit the training process
    # Hint use torch.save() and model.state_dict()
    if mse_loss < MARGIN:
      print(f"loss less than margin---------- {mse_loss} ------ {MARGIN}")
      torch.save(unet.state_dict(), MODEL_DIR)
      break

# 10. save the model once more at the end
torch.save(unet, MODEL_DIR)

  0%|                                                                               | 1/10000 [00:07<19:52:02,  7.15s/it]

Current epoch: 0 ------ loss: 1.0241875648498535 ----------


  1%|▊                                                                               | 102/10000 [00:19<19:38,  8.40it/s]

Current epoch: 100 ------ loss: 0.1168253943324089 ----------


  2%|█▌                                                                              | 202/10000 [00:31<19:26,  8.40it/s]

Current epoch: 200 ------ loss: 0.21618153154850006 ----------


  3%|██▍                                                                             | 302/10000 [00:43<19:14,  8.40it/s]

Current epoch: 300 ------ loss: 0.047934677451848984 ----------


  4%|███▏                                                                            | 402/10000 [00:54<19:05,  8.38it/s]

Current epoch: 400 ------ loss: 0.05663153529167175 ----------


  5%|████                                                                            | 502/10000 [01:06<18:54,  8.37it/s]

Current epoch: 500 ------ loss: 0.05388564243912697 ----------


  6%|████▊                                                                           | 602/10000 [01:18<18:42,  8.37it/s]

Current epoch: 600 ------ loss: 0.07494785636663437 ----------


  7%|█████▌                                                                          | 702/10000 [01:30<18:31,  8.37it/s]

Current epoch: 700 ------ loss: 0.0516185462474823 ----------


  8%|██████▍                                                                         | 802/10000 [01:42<18:19,  8.37it/s]

Current epoch: 800 ------ loss: 0.03696610778570175 ----------


  9%|███████▏                                                                        | 902/10000 [01:54<18:09,  8.35it/s]

Current epoch: 900 ------ loss: 0.08149661123752594 ----------


 10%|███████▉                                                                       | 1002/10000 [02:06<17:58,  8.34it/s]

Current epoch: 1000 ------ loss: 0.10120085626840591 ----------


 11%|████████▋                                                                      | 1102/10000 [02:18<17:46,  8.34it/s]

Current epoch: 1100 ------ loss: 0.022844165563583374 ----------


 12%|█████████▍                                                                     | 1202/10000 [02:30<17:32,  8.36it/s]

Current epoch: 1200 ------ loss: 0.027476130053400993 ----------


 13%|██████████▎                                                                    | 1302/10000 [02:42<17:22,  8.34it/s]

Current epoch: 1300 ------ loss: 0.1263360232114792 ----------


 14%|███████████                                                                    | 1402/10000 [02:54<17:09,  8.35it/s]

Current epoch: 1400 ------ loss: 0.018123619258403778 ----------


 15%|███████████▊                                                                   | 1502/10000 [03:06<16:56,  8.36it/s]

Current epoch: 1500 ------ loss: 0.03420677408576012 ----------


 16%|████████████▋                                                                  | 1602/10000 [03:18<16:44,  8.36it/s]

Current epoch: 1600 ------ loss: 0.1729242503643036 ----------


 17%|█████████████▍                                                                 | 1702/10000 [03:30<16:32,  8.36it/s]

Current epoch: 1700 ------ loss: 0.03872454911470413 ----------


 18%|██████████████▏                                                                | 1802/10000 [03:42<16:20,  8.36it/s]

Current epoch: 1800 ------ loss: 0.024403605610132217 ----------


 19%|███████████████                                                                | 1902/10000 [03:54<16:10,  8.35it/s]

Current epoch: 1900 ------ loss: 0.012686564587056637 ----------


 20%|███████████████▊                                                               | 2002/10000 [04:06<16:02,  8.31it/s]

Current epoch: 2000 ------ loss: 0.016804512590169907 ----------


 21%|████████████████▌                                                              | 2102/10000 [04:18<15:57,  8.25it/s]

Current epoch: 2100 ------ loss: 0.014153103344142437 ----------


 22%|█████████████████▍                                                             | 2202/10000 [04:30<15:35,  8.34it/s]

Current epoch: 2200 ------ loss: 0.10618115216493607 ----------


 23%|██████████████████▏                                                            | 2302/10000 [04:42<15:24,  8.33it/s]

Current epoch: 2300 ------ loss: 0.08412453532218933 ----------


 24%|██████████████████▉                                                            | 2402/10000 [04:54<15:20,  8.25it/s]

Current epoch: 2400 ------ loss: 0.1007036417722702 ----------


 25%|███████████████████▊                                                           | 2502/10000 [05:06<14:56,  8.36it/s]

Current epoch: 2500 ------ loss: 0.023087112233042717 ----------


 26%|████████████████████▌                                                          | 2602/10000 [05:18<14:52,  8.29it/s]

Current epoch: 2600 ------ loss: 0.030260130763053894 ----------


 27%|█████████████████████▎                                                         | 2702/10000 [05:30<14:58,  8.12it/s]

Current epoch: 2700 ------ loss: 0.025073489174246788 ----------


 28%|██████████████████████▏                                                        | 2802/10000 [05:43<14:38,  8.19it/s]

Current epoch: 2800 ------ loss: 0.01083206757903099 ----------


 29%|██████████████████████▉                                                        | 2902/10000 [05:55<14:25,  8.20it/s]

Current epoch: 2900 ------ loss: 0.01791832223534584 ----------


 30%|███████████████████████▋                                                       | 3002/10000 [06:07<14:00,  8.32it/s]

Current epoch: 3000 ------ loss: 0.04295645281672478 ----------


 31%|████████████████████████▌                                                      | 3102/10000 [06:19<13:45,  8.36it/s]

Current epoch: 3100 ------ loss: 0.04536391422152519 ----------


 32%|█████████████████████████▎                                                     | 3202/10000 [06:31<13:36,  8.33it/s]

Current epoch: 3200 ------ loss: 0.0292953047901392 ----------


 33%|██████████████████████████                                                     | 3302/10000 [06:43<13:26,  8.30it/s]

Current epoch: 3300 ------ loss: 0.010251770727336407 ----------


 34%|██████████████████████████▉                                                    | 3402/10000 [06:55<13:09,  8.35it/s]

Current epoch: 3400 ------ loss: 0.030828949064016342 ----------


 35%|███████████████████████████▋                                                   | 3502/10000 [07:07<12:59,  8.33it/s]

Current epoch: 3500 ------ loss: 0.019933601841330528 ----------


 36%|████████████████████████████▍                                                  | 3602/10000 [07:19<12:45,  8.36it/s]

Current epoch: 3600 ------ loss: 0.012167390435934067 ----------


 37%|█████████████████████████████▏                                                 | 3702/10000 [07:31<12:41,  8.27it/s]

Current epoch: 3700 ------ loss: 0.010965677909553051 ----------


 38%|██████████████████████████████                                                 | 3802/10000 [07:44<12:38,  8.17it/s]

Current epoch: 3800 ------ loss: 0.009088519029319286 ----------


 39%|██████████████████████████████▊                                                | 3902/10000 [07:56<12:09,  8.36it/s]

Current epoch: 3900 ------ loss: 0.01367809809744358 ----------


 40%|███████████████████████████████▌                                               | 4002/10000 [08:08<12:02,  8.30it/s]

Current epoch: 4000 ------ loss: 0.009212230332195759 ----------


 41%|████████████████████████████████▍                                              | 4102/10000 [08:20<12:09,  8.08it/s]

Current epoch: 4100 ------ loss: 0.015087682753801346 ----------


 42%|█████████████████████████████████▏                                             | 4202/10000 [08:32<11:36,  8.32it/s]

Current epoch: 4200 ------ loss: 0.06755384057760239 ----------


 43%|█████████████████████████████████▉                                             | 4302/10000 [08:44<11:26,  8.30it/s]

Current epoch: 4300 ------ loss: 0.023335803300142288 ----------


 44%|██████████████████████████████████▊                                            | 4402/10000 [08:56<11:27,  8.14it/s]

Current epoch: 4400 ------ loss: 0.0227817315608263 ----------


 45%|███████████████████████████████████▌                                           | 4502/10000 [09:08<10:57,  8.36it/s]

Current epoch: 4500 ------ loss: 0.01676989533007145 ----------


 46%|████████████████████████████████████▎                                          | 4602/10000 [09:20<10:46,  8.35it/s]

Current epoch: 4600 ------ loss: 0.04832926020026207 ----------


 47%|█████████████████████████████████████▏                                         | 4702/10000 [09:32<10:36,  8.33it/s]

Current epoch: 4700 ------ loss: 0.006865852512419224 ----------


 48%|█████████████████████████████████████▉                                         | 4802/10000 [09:44<10:33,  8.20it/s]

Current epoch: 4800 ------ loss: 0.01840043254196644 ----------


 49%|██████████████████████████████████████▋                                        | 4902/10000 [09:57<10:16,  8.27it/s]

Current epoch: 4900 ------ loss: 0.023862699046730995 ----------


 50%|███████████████████████████████████████▌                                       | 5002/10000 [10:09<10:00,  8.32it/s]

Current epoch: 5000 ------ loss: 0.021072885021567345 ----------


 51%|████████████████████████████████████████▎                                      | 5102/10000 [10:21<09:47,  8.34it/s]

Current epoch: 5100 ------ loss: 0.012067663483321667 ----------


 52%|█████████████████████████████████████████                                      | 5202/10000 [10:33<09:40,  8.27it/s]

Current epoch: 5200 ------ loss: 0.009158196859061718 ----------


 53%|█████████████████████████████████████████▉                                     | 5302/10000 [10:45<09:36,  8.15it/s]

Current epoch: 5300 ------ loss: 0.005934223998337984 ----------


 54%|██████████████████████████████████████████▋                                    | 5402/10000 [10:57<09:23,  8.16it/s]

Current epoch: 5400 ------ loss: 0.026496678590774536 ----------


 55%|███████████████████████████████████████████▍                                   | 5502/10000 [11:09<08:59,  8.33it/s]

Current epoch: 5500 ------ loss: 0.015065928921103477 ----------


 56%|████████████████████████████████████████████▎                                  | 5602/10000 [11:22<08:54,  8.23it/s]

Current epoch: 5600 ------ loss: 0.009372154250741005 ----------


 57%|█████████████████████████████████████████████                                  | 5702/10000 [11:34<08:37,  8.30it/s]

Current epoch: 5700 ------ loss: 0.007059585303068161 ----------


 58%|█████████████████████████████████████████████▊                                 | 5802/10000 [11:46<08:35,  8.14it/s]

Current epoch: 5800 ------ loss: 0.008480897173285484 ----------


 59%|██████████████████████████████████████████████▋                                | 5902/10000 [11:58<08:20,  8.19it/s]

Current epoch: 5900 ------ loss: 0.01043330505490303 ----------


 60%|███████████████████████████████████████████████▍                               | 6002/10000 [12:10<08:01,  8.30it/s]

Current epoch: 6000 ------ loss: 0.007892036810517311 ----------


 61%|████████████████████████████████████████████████▏                              | 6102/10000 [12:22<07:53,  8.23it/s]

Current epoch: 6100 ------ loss: 0.007636270485818386 ----------


 62%|████████████████████████████████████████████████▉                              | 6202/10000 [12:34<07:35,  8.35it/s]

Current epoch: 6200 ------ loss: 0.005974422208964825 ----------


 63%|█████████████████████████████████████████████████▊                             | 6302/10000 [12:46<07:25,  8.30it/s]

Current epoch: 6300 ------ loss: 0.014197250828146935 ----------


 64%|██████████████████████████████████████████████████▌                            | 6402/10000 [12:58<07:18,  8.20it/s]

Current epoch: 6400 ------ loss: 0.009560744278132915 ----------


 65%|███████████████████████████████████████████████████▎                           | 6502/10000 [13:10<07:01,  8.29it/s]

Current epoch: 6500 ------ loss: 0.019313201308250427 ----------


 66%|████████████████████████████████████████████████████▏                          | 6602/10000 [13:22<06:51,  8.25it/s]

Current epoch: 6600 ------ loss: 0.005205343943089247 ----------


 67%|████████████████████████████████████████████████████▉                          | 6702/10000 [13:34<06:46,  8.11it/s]

Current epoch: 6700 ------ loss: 0.014810634776949883 ----------


 68%|█████████████████████████████████████████████████████▋                         | 6802/10000 [13:46<06:36,  8.06it/s]

Current epoch: 6800 ------ loss: 0.008771334774792194 ----------


 69%|██████████████████████████████████████████████████████▌                        | 6902/10000 [13:59<06:13,  8.30it/s]

Current epoch: 6900 ------ loss: 0.007470885291695595 ----------


 70%|███████████████████████████████████████████████████████▎                       | 7002/10000 [14:11<06:00,  8.32it/s]

Current epoch: 7000 ------ loss: 0.01909034140408039 ----------


 71%|████████████████████████████████████████████████████████                       | 7102/10000 [14:23<05:47,  8.34it/s]

Current epoch: 7100 ------ loss: 0.007710203994065523 ----------


 72%|████████████████████████████████████████████████████████▉                      | 7202/10000 [14:34<05:35,  8.34it/s]

Current epoch: 7200 ------ loss: 0.006972841918468475 ----------


 73%|█████████████████████████████████████████████████████████▋                     | 7302/10000 [14:46<05:22,  8.35it/s]

Current epoch: 7300 ------ loss: 0.04235837981104851 ----------


 74%|██████████████████████████████████████████████████████████▍                    | 7402/10000 [14:58<05:11,  8.35it/s]

Current epoch: 7400 ------ loss: 0.013154572807252407 ----------


 75%|██████████████████████████████████████████████████████████▉                    | 7466/10000 [15:06<05:02,  8.37it/s]

loss less than margin---------- 0.0038284542970359325 ------ 0.004


 75%|██████████████████████████████████████████████████████████▉                    | 7466/10000 [15:07<05:07,  8.23it/s]


In [15]:
#! Reverse DIffusion
LABEL = 1

In [17]:
with torch.no_grad():
    # 1. Generate random noise
    img = torch.randn((1, 3) + IMAGE_SHAPE).to(DEVICE)
    for i in tqdm(reversed(range(diffusion_model.timesteps))):

        t = torch.full((1,), i, dtype=torch.long, device=DEVICE)
        l = torch.tensor(LABEL, dtype=torch.float, device=DEVICE).reshape(-1,1)

        # Reverse diffusion
        img = diffusion_model.backward(img, t, unet.eval(), labels = l)

        img_np = utils.reverse_transform(img[0])
        img_np = img_np.resize((256, 256))
        img_np.save(f"{SAVE_MODEL_DIR}/output-img/L{LABEL}-{i}.png")

300it [00:05, 51.31it/s]


In [18]:
diffusion_model.timesteps

300