<a href="https://colab.research.google.com/github/samuel23taku/NoteBooks/blob/main/Conditional_Diffusion_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.transforms import  ToTensor, Normalize

transform= transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5, ))
])

training = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transform
)

train_dataloader = DataLoader(
    training,batch_size=1000,shuffle=True
)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 4823887.46it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 127899.13it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1215392.82it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6094219.06it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






# Model

In [2]:
from diffusers import UNet2DModel
import  torch.nn as nn

class ClassConditionedUNet(nn.Module):
    def __init__(self, num_classes=10,class_emb_size=4):
        super().__init__()

        self.class_emb = nn.Embedding(num_classes, class_emb_size)

        self.model = UNet2DModel(
            sample_size=28,
            in_channels=1 + class_emb_size, #concat  the embedding to get new from the gray scale(1) channel
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(32,64,128),
            down_block_types = (
            "DownBlock2D",
            "AttnDownBlock2D",
            "AttnDownBlock2D",
        ),

            up_block_types=(
                "AttnUpBlock2D",
                "AttnUpBlock2D",
                "UpBlock2D"
            )
        )

    def forward(self, x,t,class_labels):
        batch_size, channels, height, width = x.size()

        class_condition = self.class_emb(class_labels)
        class_condition = class_condition.view(batch_size, class_condition.shape[1], 1,1).expand(batch_size,class_condition.shape[1], width,height)

        net_input = torch.cat((x,class_condition),dim=1)
        return self.model(net_input,t).sample


The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

# Training

In [18]:
from diffusers import DDPMScheduler
from torch.optim import AdamW
from diffusers.optimization import get_cosine_schedule_with_warmup
from accelerate import Accelerator
from tqdm.auto import tqdm
import os
noise_scheduler = DDPMScheduler(num_train_timesteps=1_000)

os.makedirs("./ModelWeights",exist_ok=True)

def training_loop(epochs,model,noise_scheduler,training_dataloader):
    loss = nn.MSELoss()
    optimizer = AdamW(model.parameters(), lr=1e-4)
    lr_scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=500,num_training_steps=(len(training_dataloader)*50))

    accelerator = Accelerator(
        mixed_precision="fp16",
        gradient_accumulation_steps=1
    )

    model ,optimizer, training_dataloader,lr_scheduler = accelerator.prepare(model,optimizer,training_dataloader,lr_scheduler)

    losses = []
    for epoch in range(epochs):
      for x, y in tqdm(training_dataloader):
        x = x.to(accelerator.device)
        y = y.to(accelerator.device)

        noise = torch.randn_like(x)
        timesteps = torch.randint(0,99, (x.shape[0],)).long().to(accelerator.device)
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps)

        with accelerator.accumulate(model):
          noise_pred = model(noisy_x,timesteps, y)
          loss_error = loss(noise_pred, noise)
          losses.append(loss_error)
          accelerator.backward(loss_error)
          logs = {"loss": loss_error.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "Epoch": epoch+1 }
          accelerator.log(logs,step=epoch +1)

          # Saving weights

        if accelerator.is_main_process:
          unwrapped_model = accelerator.unwrap_model(model)
          torch.save(unwrapped_model.state_dict(),f"./ModelWeights/weights_epoch_{epoch+1}")

          torch.save(optimizer.state_dict(), f"./ModelWeights/epoch_{epoch+1}_optimizer.pt")
          torch.save(lr_scheduler.state_dict(), f"./ModelWeights/scheduler.pt")
      average_loss = sum(losses[-100:]) / 100
      print(f"Finished epoch with average loss of :{average_loss}")
    accelerator.end_training()



            # Save optimizer state
torch.cuda.empty_cache()
model = ClassConditionedUNet()
training_loop(epochs=1,model=model,noise_scheduler=noise_scheduler,training_dataloader=train_dataloader)

  0%|          | 0/60 [00:00<?, ?it/s]

Finished epoch with average loss of :0.6712813973426819


# Sampling