In [5]:
import os
import cv2
import copy
import math
import argparse
import numpy as np
from time import time
from tqdm import tqdm
from easydict import EasyDict

import torch
import torch.distributed as dist
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

import unet_model

# Define the transformation to apply to the dataset
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# ])

transform_train = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)
# train_set = torchvision.datasets.CIFAR10(
    
#     root='./data',
#     train=True,
#     download=True,
#     transform=transform_train,
# )

# Download and load the CIFAR10 training dataset

train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
# trainloader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2)

# Download and load the CIFAR10 test dataset
# test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# testloader = torch.utils.data.DataLoader(test_set, batch_size=4, shuffle=False, num_workers=2)

# Define the classes for CIFAR10
# classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
######## CHECK METADATA #########
metadata = EasyDict(
    {
        "image_size": 32,
        "num_classes": 10,
        "train_images": 50000,
        "val_images": 10000,
        "num_channels": 3,
    }
)
############# PARAMETERS #############
DIFFUSION_STEPS = 100
BATCH_SIZE = 32 # OR 128
LEARNING_RATE = 0.0001
EMA_W = 0.75 # Exponential Moving Average Weight
EPOCHS = 100
PRETRAINED_MODEL = ""
SAVE_MODEL = True
SAVE_MODEL_PATH = "./saved_models/"





Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:41<00:00, 4150923.97it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [None]:

class loss_logger:
    def __init__(self, max_steps):
        self.max_steps = max_steps
        self.loss = []
        self.start_time = time()
        self.ema_loss = None
        self.ema_w = 0.9

    def log(self, v, display=False):
        self.loss.append(v)
        if self.ema_loss is None:
            self.ema_loss = v
        else:
            self.ema_loss = self.ema_w * self.ema_loss + (1 - self.ema_w) * v

        if display:
            print(
                f"Steps: {len(self.loss)}/{self.max_steps} \t loss (ema): {self.ema_loss:.3f} "
                + f"\t Time elapsed: {(time() - self.start_time)/3600:.3f} hr"
            )


def train_one_epoch(
    model,
    dataloader,
    diffusion,
    optimizer,
    logger,
    lrs,
    ema_dict,
    device,
):
    model.train()
    for step, (images, labels) in enumerate(dataloader):
        assert (images.max().item() <= 1) and (0 <= images.min().item())

        # must use [-1, 1] pixel range for images
        images, labels = (
            2 * images.to(device) - 1,
            None, # class_cond = False
        )
        t = torch.randint(diffusion.timesteps, (len(images),), dtype=torch.int64).to(
            device
        )
        xt, eps = diffusion.sample_from_forward_process(images, t)
        pred_eps = model(xt, t, y=labels)

        loss = ((pred_eps - eps) ** 2).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if lrs is not None:
            lrs.step()

        # update ema_dict
        # if args.local_rank == 0:
        new_dict = model.state_dict()
        for (k, v) in ema_dict.items():
            ema_dict[k] = (
                EMA_W * ema_dict[k] + (1 - EMA_W) * new_dict[k]
            )
        logger.log(loss.item(), display=not step % 100)

    return ema_dict

In [None]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# # tensor = torch.tensor([1.0, 2.0, 3.0], device=device)

In [None]:
def main():
    
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    torch.cuda.set_device(device)
    torch.manual_seed(123)
    np.random.seed(123)


    attention_ds = []
    attention_resolutions = "32,16,8"

    for res in attention_resolutions.split(","):
        attention_ds.append(metadata.image_size // int(res))
    
    model = unet_model.UNetModel(image_size=metadata.image_size, 
                                in_channels=metadata.num_channels, 
                                out_channels=metadata.num_classes,
                                model_channels = 64,
                                channel_mult = (1, 2, 2, 2),
                                num_res_blocks = 3,
                                dropout = 0.1,
                                num_classes=None,  # We're not using class labels
                                use_checkpoint=False,
                                use_fp16=False,
                                num_heads=4,
                                attention_resolutions=tuple(attention_ds),
                                num_head_channels=64,
                                num_heads_upsample=-1,
                                use_scale_shift_norm=True,
                                resblock_updown=True,
                                use_new_attention_order=True
                                ).to(device)
    
    diffusion_model = unet_model.GaussianDiffusion(DIFFUSION_STEPS, device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    if PRETRAINED_MODEL == "":
        print("Training from scratch")
    else:
        print(f"Loading pretrained model from {PRETRAINED_MODEL}")
        model.load_state_dict(torch.load(PRETRAINED_MODEL))

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    print(f"Training dataset loaded: Number of batches: {len(train_loader)}, Number of images: {len(train_set)}")
    logger = loss_logger(len(train_loader) * EPOCHS)

    ema_dict = copy.deepcopy(model.state_dict())

    for epoch in range(EPOCHS):
        ema_dict = train_one_epoch(model, train_loader, diffusion_model, optimizer, logger, None, ema_dict, device)
        if epoch % 10 == 0:
            print(f"Epoch {epoch} completed")
            torch.save(ema_dict, f"{SAVE_MODEL_PATH}ema_dict_{epoch}.pt")
            torch.save(model.state_dict(), f"{SAVE_MODEL_PATH}model_{epoch}.pt")
            
    return 0

if __name__=="__main__":
    main()