In [1]:
import os
import cv2
import copy
import math
import argparse
import numpy as np
from time import time
from tqdm import tqdm
from easydict import EasyDict

import torch
import torch.distributed as dist
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

import models

# Define the transformation to apply to the dataset
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# ])

transform_train = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

# Download and load the CIFAR10 training dataset

train_set = torchvision.datasets.CIFAR10(root='./data/', train=True, download=True, transform=transform_train)

# Download and load the CIFAR10 test dataset
# test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# testloader = torch.utils.data.DataLoader(test_set, batch_size=4, shuffle=False, num_workers=2)

# Define the classes for CIFAR10
# classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
######## CHECK METADATA #########
metadata = EasyDict(
    {
        "image_size": 32,
        "num_classes": 10,
        "train_images": 50000,
        "val_images": 10000,
        "num_channels": 3,
    }
)
############# PARAMETERS #############
DIFFUSION_STEPS = 100
BATCH_SIZE = 32 # OR 128
LEARNING_RATE = 0.0001
EMA_W = 0.75 # Exponential Moving Average Weight
EPOCHS = 100
PRETRAINED_MODEL = ""
SAVE_MODEL = True
SAVE_MODEL_PATH = "./saved_models/"
SAVE_IMAGES_PATH = "./images/"




if os.path.exists(SAVE_MODEL_PATH) == False:
    os.makedirs(SAVE_MODEL_PATH)


Files already downloaded and verified


In [10]:

class loss_logger:
    def __init__(self, max_steps):
        self.max_steps = max_steps
        self.loss = []
        self.start_time = time()
        self.ema_loss = None
        self.ema_w = EMA_W

    def log(self, v, display=False):
        self.loss.append(v)
        if self.ema_loss is None:
            self.ema_loss = v
        else:
            self.ema_loss = self.ema_w * self.ema_loss + (1 - self.ema_w) * v

        if display:
            print(
                f"Steps: {len(self.loss)}/{self.max_steps} \t loss (ema): {self.ema_loss:.3f} "
                + f"\t Time elapsed: {(time() - self.start_time)/3600:.3f} hr"
            )


def train_one_epoch(
    model,
    dataloader,
    diffusion,
    optimizer,
    logger,
    lrs,
    ema_dict,
    device,
):
    model.train()
    for step, (images, labels) in enumerate(dataloader):
        assert (images.max().item() <= 1) and (0 <= images.min().item())

        # must use [-1, 1] pixel range for images
        images, labels = (
            2 * images.to(device) - 1,
            None, # class_cond = False
        )
        t = torch.randint(diffusion.timesteps, (len(images),), dtype=torch.int64).to(
            device
        )
        xt, eps = diffusion.sample_from_forward_process(images, t)
        # print(xt.shape,t.shape,images.shape)
        pred_eps = model(xt, t)
        # print(pred_eps.shape, eps.shape)
        loss = ((pred_eps - eps) ** 2).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if lrs is not None:
            lrs.step()

        # update ema_dict
        # if args.local_rank == 0:
        new_dict = model.state_dict()
        for (k, v) in ema_dict.items():
            ema_dict[k] = (
                EMA_W * ema_dict[k] + (1 - EMA_W) * new_dict[k]
            )
        logger.log(loss.item(), display=not step % 100)

    return ema_dict

In [None]:
'''Auxiliary function to handle the model inference and generate images from the diffusion process.'''

def generate_N_images(
    N,
    model,
    diffusion,
    xT=None,
    sampling_steps=250,
    batch_size=32,
    num_channels=3,
    image_size=32,
#    num_classes=None,
):
    """use this function to generate any number of images from a given
        diffusion model and diffusion process.

    Args:
        N : Number of images
        model : Diffusion model
        diffusion : Diffusion process
        xT : Starting instantiation of noise vector.
        sampling_steps : Number of sampling steps.
        batch_size : Batch-size for sampling.
        num_channels : Number of channels in the image.
        image_size : Image size (assuming square images).
        num_classes : Number of classes in the dataset (needed for class-conditioned models)


    Returns: Numpy array with N images and corresponding labels.
    """

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    samples, labels, num_samples = [], [], 0
    num_processes, group = dist.get_world_size(), dist.group.WORLD
    with tqdm(total=math.ceil(N / (batch_size * num_processes))) as pbar:
        while num_samples < N:
            if xT is None:
                xT = (
                    torch.randn(batch_size, num_channels, image_size, image_size)
                    .float()
                    .to(device)
                )
            else:
                y = None
            gen_images = diffusion.sample_from_reverse_process(
                model, xT, sampling_steps, {"y": y},
            )
            samples_list = [torch.zeros_like(gen_images) for _ in range(num_processes)]

            dist.all_gather(samples_list, gen_images, group)
            samples.append(torch.cat(samples_list).detach().cpu().numpy())
            num_samples += len(xT) * num_processes
            pbar.update(1)
    samples = np.concatenate(samples).transpose(0, 2, 3, 1)[:N]
    samples = (127.5 * (samples + 1)).astype(np.uint8)
    return samples

def load_model(model, model_path):
    model.load_state_dict(torch.load(model_path))
    return model

In [11]:
def main():
    
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    # torch.cuda.set_device(device)
    torch.manual_seed(123)
    np.random.seed(123)


    attention_ds = []
    attention_resolutions = "32,16,8"

    for res in attention_resolutions.split(","):
        attention_ds.append(metadata.image_size // int(res))
    
    model = models.UNetModel(image_size=metadata.image_size, 
                                in_channels=metadata.num_channels, 
                                out_channels=metadata.num_channels,
                                model_channels = 64,
                                channel_mult = (1, 2, 2, 2),
                                num_res_blocks = 3,
                                dropout = 0.1,
                                num_classes=None,  # We're not using class labels
                                use_checkpoint=False,
                                use_fp16=False,
                                num_heads=4,
                                attention_resolutions=tuple(attention_ds),
                                num_head_channels=64,
                                num_heads_upsample=-1,
                                use_scale_shift_norm=True,
                                resblock_updown=True,
                                use_new_attention_order=True
                                ).to(device)
    
    diffusion_model = models.GaussianDiffusion(DIFFUSION_STEPS, device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    if PRETRAINED_MODEL == "":
        print("Training from scratch")
    else:
        print(f"Loading pretrained model from {PRETRAINED_MODEL}")
        model.load_state_dict(torch.load(PRETRAINED_MODEL))

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    print(f"Training dataset loaded: Number of batches: {len(train_loader)}, Number of images: {len(train_set)}")
    logger = loss_logger(len(train_loader) * EPOCHS)

    ema_dict = copy.deepcopy(model.state_dict())

    for epoch in range(EPOCHS):
        ema_dict = train_one_epoch(model, train_loader, diffusion_model, optimizer, logger, None, ema_dict, device)
        if epoch % 10 == 0:
            print(f"Epoch {epoch} completed")
            torch.save(ema_dict, f"{SAVE_MODEL_PATH}ema_dict_{epoch}.pt")
            torch.save(model.state_dict(), f"{SAVE_MODEL_PATH}model_{epoch}.pt")

    torch.save(ema_dict, f"{SAVE_MODEL_PATH}ema_dict_{epoch}.pt")
    torch.save(model.state_dict(), f"{SAVE_MODEL_PATH}model_{epoch}.pt")
    
    return 0

if __name__=="__main__":
    main()

cuda:0
Training from scratch
Training dataset loaded: Number of batches: 1563, Number of images: 50000
Steps: 1/156300 	 loss (ema): 1.000 	 Time elapsed: 0.000 hr
Steps: 101/156300 	 loss (ema): 0.284 	 Time elapsed: 0.007 hr
Steps: 201/156300 	 loss (ema): 0.102 	 Time elapsed: 0.013 hr
Steps: 301/156300 	 loss (ema): 0.089 	 Time elapsed: 0.020 hr
Steps: 401/156300 	 loss (ema): 0.077 	 Time elapsed: 0.027 hr
Steps: 501/156300 	 loss (ema): 0.068 	 Time elapsed: 0.033 hr
Steps: 601/156300 	 loss (ema): 0.059 	 Time elapsed: 0.040 hr
Steps: 701/156300 	 loss (ema): 0.067 	 Time elapsed: 0.047 hr
Steps: 801/156300 	 loss (ema): 0.063 	 Time elapsed: 0.053 hr
Steps: 901/156300 	 loss (ema): 0.066 	 Time elapsed: 0.060 hr
Steps: 1001/156300 	 loss (ema): 0.064 	 Time elapsed: 0.067 hr
Steps: 1101/156300 	 loss (ema): 0.055 	 Time elapsed: 0.073 hr
Steps: 1201/156300 	 loss (ema): 0.055 	 Time elapsed: 0.080 hr
Steps: 1301/156300 	 loss (ema): 0.064 	 Time elapsed: 0.087 hr
Steps: 1401/1

In [None]:
"""Inference mode to generate images from the trained model"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
attention_ds = []
attention_resolutions = "32,16,8"

for res in attention_resolutions.split(","):
    attention_ds.append(metadata.image_size // int(res))

model = models.UNetModel(image_size=metadata.image_size, 
                            in_channels=metadata.num_channels, 
                            out_channels=metadata.num_channels,
                            model_channels = 64,
                            channel_mult = (1, 2, 2, 2),
                            num_res_blocks = 3,
                            dropout = 0.1,
                            num_classes=None,  # We're not using class labels
                            use_checkpoint=False,
                            use_fp16=False,
                            num_heads=4,
                            attention_resolutions=tuple(attention_ds),
                            num_head_channels=64,
                            num_heads_upsample=-1,
                            use_scale_shift_norm=True,
                            resblock_updown=True,
                            use_new_attention_order=True
                            ).to(device)

load_model(model, "./saved_models/model_100.pt")

diffusion_model = models.GaussianDiffusion(DIFFUSION_STEPS, device)

# Generate images
samples = generate_N_images(64, model, diffusion_model, xT=None, sampling_steps=250, batch_size=32, num_channels=3, image_size=32)

cv2.imwrite(SAVE_IMAGES_PATH+train_set.__class__.__name__+str(DIFFUSION_STEPS)+".jpeg", np.concatenate(samples, axis=1)[:, :, ::-1],)
