In [3]:
import os
import torch
import torch.nn as nn
from torch import optim
import numpy as np
import scipy
from torchvision import datasets, models, transforms
from torchvision.utils import save_image
import torchvision
import torch.utils.data
from matplotlib import pyplot as plt
from tqdm import tqdm
from drive.MyDrive.Colab_Notebooks.Diffuse_553.ddpm_553.utils import *
from drive.MyDrive.Colab_Notebooks.Diffuse_553.ddpm_553.UNet import UNet
import logging
from torch.utils.tensorboard import SummaryWriter
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pytorch_fid

In [None]:
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_known_args()[0]
args.run_name = "DDPM_Uncondtional"
args.epochs = 5
args.batch_size = 128
args.image_size = 32
args.dataset_path = "./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/datasets"
args.subset_path = "./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/datasets/cifar10_subset_images"
args.device = "cuda"
args.lr = 3e-4
args.loss_type = 'mse'  # 'mse' or 'l1'

In [None]:
# get cifar_10 data
def get_data_cifar10(args):
    transforms = torchvision.transforms.Compose([
        #torchvision.transforms.Resize(40),  # args.image_size + 1/4 *args.image_size
        #torchvision.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    #dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms)
    dataset_un = torchvision.datasets.CIFAR10(root= args.dataset_path,train=True, transform=transforms, download=False)
    dataloader = DataLoader(dataset_un, batch_size=args.batch_size, shuffle=True)
    return dataloader, dataset_un

In [None]:
torchvision.datasets.Flowers102("./",download=True)

In [None]:
## -------------run only for the first time !!---------------

# Create the directory for CIFAR10 subset images
os.makedirs('./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/datasets/cifar10_subset_images', exist_ok=True)

# Create the directory for CIFAR10 generated images
os.makedirs('./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/datasets/cifar10_generate_images', exist_ok=True)

# generate a subset of cifar_10
dataloader, dataset_un = get_data_cifar10(args)
n = 10000  # create subset 
cifar10_subset, _ = torch.utils.data.random_split(dataset_un, [n,len(dataset_un)-n])
cifar10_subset_dataloader_un = DataLoader(cifar10_subset, batch_size=args.batch_size, shuffle=True)

#save images to path
for i, (image, _) in enumerate(cifar10_subset):
    save_image(image,f'./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/datasets/cifar10_subset_images/image_{i}.png')


In [None]:
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device
        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self, use_cosine=False, s=0.008):  
        if use_cosine == True:
            def f(t, noise_steps):
                return (np.cos((t / noise_steps + s) / (1 + s) * np.pi / 2)) ** 2
            alphas = []
            f0 = f(0, self.noise_steps)
            for t in range(self.noise_steps + 1):
                alphas.append(f(t, self.noise_steps) / f0)
            betas = []
            for t in range(1, self.noise_steps + 1):
                betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999))
            return torch.tensor(betas)
        else:
            return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - (beta / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

In [None]:
#print('# of samples for ulabeled, train, and test, {}'.format(len(dataset_un)))
#print('Classes in train: {}'.format(dataset_un.classes))

In [None]:
dataloader, dataset_un = get_data_cifar10(args)
n = 10000  # create subset 
cifar10_subset, _ = torch.utils.data.random_split(dataset_un, [n,len(dataset_un)-n])
cifar10_subset_dataloader_un = DataLoader(cifar10_subset, batch_size=args.batch_size, shuffle=True)

In [None]:
def train(args):
    setup_logging(args.run_name)
    device = args.device
    model = UNet().to(device)
    if os.path.exists("./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/models/DDPM_Uncondtional/uncondition_ckpt_large_cifar.pt"):
        ckpt = torch.load("./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/models/DDPM_Uncondtional/uncondition_ckpt_large_cifar.pt")
        model.load_state_dict(ckpt)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    if args.loss_type == 'mse':
        unet_loss = nn.MSELoss()
    elif args.loss_type == 'l1':
        unet_loss = nn.L1loss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(cifar10_subset_dataloader_un)

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(cifar10_subset_dataloader_un)
        for i, (images, _) in enumerate(pbar):
            images = images.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
            loss = unet_loss(noise, predicted_noise)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if args.loss_type == 'mse':
                pbar.set_postfix(MSE=loss.item())
                logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
            elif args.loss_type == 'l1': 
                pbar.set_postfix(L1=loss.item())
                logger.add_scalar("L1", loss.item(), global_step=epoch * l + i)

        sampled_images = diffusion.sample(model, n=images.shape[0])
        save_images(sampled_images, os.path.join("./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/results", args.run_name, f"{150+epoch}.jpg"))
        torch.save(model.state_dict(), os.path.join("./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/models", args.run_name, f"uncondition_ckpt_large_cifar.pt"))


In [None]:
if os.path.exists("./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/models/DDPM_Uncondtional/uncondition_ckpt_large_cifar.pt"):
  print("1")

In [None]:
if __name__ == '__main__':
    train(args)

In [None]:
!nvidia-smi

In [None]:
sample_number = 400
device = "cuda"
model = UNet().to(device)
ckpt = torch.load("./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/models/DDPM_Uncondtional/uncondition_ckpt_large_cifar.pt")
model.load_state_dict(ckpt)
diffusion = Diffusion(img_size=32, device=device)
sampled_images = diffusion.sample(model, sample_number)
print(sampled_images.shape)    
#plt.figure(figsize=(16, 16))
#plt.imshow(torch.cat([torch.cat([i for i in sampled_images.cpu()], dim=-1)], dim=-2).permute(1, 2, 0).cpu())
#plt.show()

In [None]:
# Define the number of rows and columns for the plot
num_rows = sample_number//8
num_cols = 8

# Split the generated images into rows
image_rows = [sampled_images[i:i+num_cols] for i in range(0, len(sampled_images), num_cols)]

# Concatenate the images within each row horizontally
concatenated_rows = [torch.cat(tuple(image_row), dim=-1) for image_row in image_rows]

# Concatenate the rows vertically
concatenated_image = torch.cat(tuple(concatenated_rows), dim=-2)

# Plot the concatenated image
plt.figure(figsize=(10, 8))
plt.imshow(concatenated_image.permute(1, 2, 0).cpu())
plt.axis('off')
plt.show()



In [None]:
from pytorch_fid import fid_score

cifar10_subset_path = './drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/datasets/cifar10_subset_images'
generated_images_path = './drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/datasets/cifar10_generate_images'

# Save the generated images to disk 
for i in range(sample_number):
    # Ensure that the pixel values are in the range [0, 1]
    sampled_image = sampled_images[i].float()/255.0
    save_image(sampled_image, f'./drive/MyDrive/Colab_Notebooks/Diffuse_553/ddpm_553/datasets/cifar10_generate_images/image_{i}.png')


In [None]:
fid = fid_score.calculate_fid_given_paths([cifar10_subset_path, generated_images_path], 128, 'cuda',2048)
print('FID score:', fid)