In [1]:
import torch
import sys; sys.path.insert(0, '..')

## Preparing data

In [2]:
#Enter data path: please download celebA dataset or FairFace data set and make pt file with corresponding attribute.
sampled_train = torch.load()
sampled_label = torch.load()

In [3]:
import torch
from torch.utils.data import TensorDataset, DataLoader

In [4]:
sampled_dataset  = TensorDataset(sampled_train, sampled_label)

## Training Diffusion

In [5]:
import argparse

In [6]:
import os
import numpy as np
import warnings
import argparse
import sys

# warnings.filterwarnings("ignore")
# if os.getcwd()[-3:]=='run':
#     os.chdir('..')

# os.environ["CUDA_VISIBLE_DEVICES"] = "1" # Possible GPUS

In [7]:
from main import *

In [8]:
def parse_args():
    parser = argparse.ArgumentParser(description='.')
    parser.add_argument("--arch", type=str, help="Neural network architecture")
    parser.add_argument(
        "--class-cond",
        action="store_true",
        default=False,
        help="train class-conditioned diffusion model",
    )
    parser.add_argument(
        "--diffusion-steps",
        type=int,
        default=1000,
        help="Number of timesteps in diffusion process",
    )
    parser.add_argument(
        "--sampling-steps",
        type=int,
        default=250,
        help="Number of timesteps in diffusion process",
    )
    parser.add_argument(
        "--ddim",
        action="store_true",
        default=False,
        help="Sampling using DDIM update step",
    )
    # dataset
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--data-dir", type=str, default="./dataset/")
    # optimizer
    parser.add_argument(
        "--batch-size", type=int, default=128, help="batch-size per gpu"
    )
    parser.add_argument("--lr", type=float, default=0.0001)
    parser.add_argument("--epochs", type=int, default=500)
    parser.add_argument("--ema_w", type=float, default=0.9995)
    # sampling/finetuning
    parser.add_argument("--pretrained-ckpt", type=str, help="Pretrained model ckpt")
    parser.add_argument("--delete-keys", nargs="+", help="Pretrained model ckpt")
    parser.add_argument(
        "--sampling-only",
        action="store_true",
        default=False,
        help="No training, just sample images (will save them in --save-dir)",
    )
    parser.add_argument(
        "--num-sampled-images",
        type=int,
        default=50000,
        help="Number of images required to sample from the model",
    )

    # misc
    parser.add_argument("--save-dir", type=str, default="./trained_models/")
    parser.add_argument("--local_rank", default=0, type=int)
    parser.add_argument("--seed", default=112233, type=int)

    args = parser.parse_args(args=[])
    warnings.filterwarnings("ignore")

    return args

In [9]:
args = parse_args()

In [10]:
args.arch = "UNet"
args.dataset = "celeba"
args.epoch = 30  
args.save_dir = './models/'
args.data_dir = '../data/'
args.ddim = True 
args.epochs = 100

In [11]:
args

Namespace(arch='UNet', batch_size=128, class_cond=False, data_dir='../data/', dataset='celeba', ddim=True, delete_keys=None, diffusion_steps=1000, ema_w=0.9995, epoch=30, epochs=100, local_rank=0, lr=0.0001, num_sampled_images=50000, pretrained_ckpt=None, sampling_only=False, sampling_steps=250, save_dir='./models/', seed=112233)

In [12]:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29506"

In [13]:
def cleanup():
    dist.destroy_process_group()
    
# setup(int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]))

In [14]:
metadata = get_metadata(args.dataset)
metadata.num_classes = 2
# metadata.train_images = len(sampled_label)
# metadata.val_images = 3641#len(test_sampled_label)

In [15]:
metadata

{'image_size': 64,
 'num_classes': 2,
 'train_images': 109036,
 'val_images': 12376,
 'num_channels': 3}

In [16]:
# metadata.image_size = 64
data_type = 'gender'
args.batch_size = 128
args.lr=0.00005

args.name ='bs%s_lr%s_%s'%(args.batch_size , args.lr, data_type)
args.class_cond = True

In [17]:
args.name

'bs128_lr5e-05_gender'

In [18]:
torch.backends.cudnn.benchmark = True
args.device = "cuda:{}".format(args.local_rank)
torch.cuda.set_device(args.device)
torch.manual_seed(args.seed + args.local_rank)
np.random.seed(args.seed + args.local_rank)
if args.local_rank == 0:
    print(args)

# Creat model and diffusion process
model = unets.__dict__[args.arch](
    image_size=metadata.image_size,
    in_channels=metadata.num_channels,
    out_channels=metadata.num_channels,
    num_classes=metadata.num_classes if args.class_cond else None,
).to(args.device)
if args.local_rank == 0:
    print(
        "We are assuming that model input/ouput pixel range is [-1, 1]. Please adhere to it."
    )
diffusion = GuassianDiffusion(args.diffusion_steps, args.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)



Namespace(arch='UNet', batch_size=128, class_cond=True, data_dir='../data/', dataset='celeba', ddim=True, delete_keys=None, device='cuda:0', diffusion_steps=1000, ema_w=0.9995, epoch=30, epochs=100, local_rank=0, lr=5e-05, name='bs128_lr5e-05_gender', num_sampled_images=50000, pretrained_ckpt=None, sampling_only=False, sampling_steps=250, save_dir='./models/', seed=112233)
We are assuming that model input/ouput pixel range is [-1, 1]. Please adhere to it.


In [20]:
# load pre-trained model
if args.pretrained_ckpt:
    print(f"Loading pretrained model from {args.pretrained_ckpt}")
    d = fix_legacy_dict(torch.load(args.pretrained_ckpt, map_location=args.device))
    dm = model.state_dict()
    if args.delete_keys:
        for k in args.delete_keys:
            print(
                f"Deleting key {k} becuase its shape in ckpt ({d[k].shape}) doesn't match "
                + f"with shape in model ({dm[k].shape})"
            )
            del d[k]
    model.load_state_dict(d, strict=False)
    print(
        f"Mismatched keys in ckpt and model: ",
        set(d.keys()) ^ set(dm.keys()),
    )
    print(f"Loaded pretrained model from {args.pretrained_ckpt}")
# distributed training
ngpus = torch.cuda.device_count()
print("ngpus",ngpus)
if ngpus > 1:
    if args.local_rank == 0:
        print(f"Using distributed training on {ngpus} gpus.")
    args.batch_size = args.batch_size // ngpus
    
    
torch.distributed.init_process_group(backend="nccl", init_method="env://") 
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)

# sampling
if args.sampling_only:
    sampled_images, labels = sample_N_images(
        args.num_sampled_images,
        model,
        diffusion,
        None,
        args.sampling_steps,
        args.batch_size,
        metadata.num_channels,
        metadata.image_size,
        metadata.num_classes,
        args,
    )
    np.savez(
        os.path.join(
            args.save_dir,
            f"{args.arch}_{args.dataset}-{args.sampling_steps}-sampling_steps-{len(sampled_images)}_images-class_condn_{args.class_cond}_{args.name}.npz",
        ),
        sampled_images,
        labels,
    )
    sys.exit()



ngpus 4
Using distributed training on 4 gpus.


In [21]:
# ngpus = torch.cuda.device_count()

# Load dataset
# train_set = get_dataset(args.dataset, args.data_dir, metadata)
train_set = sampled_dataset
sampler = DistributedSampler(train_set) if ngpus > 1 else None
train_loader = DataLoader(
    train_set,
    batch_size=args.batch_size,
    shuffle=sampler is None,#####
    sampler=sampler,
    num_workers=4,
    pin_memory=True,
)
if args.local_rank == 0:
    print(
        f"Training dataset loaded: Number of batches: {len(train_loader)}, Number of images: {len(train_set)}"
    )
logger = loss_logger(len(train_loader) * args.epochs)

# ema model
args.ema_dict = copy.deepcopy(model.state_dict())
    

Training dataset loaded: Number of batches: 899, Number of images: 28759


In [22]:
def sample_N_images(
    N,
    model,
    diffusion,
    xT=None,
    sampling_steps=250,
    batch_size=64,
    num_channels=3,
    image_size=32,
    num_classes=None,
    args=None,
):
    """use this function to sample any number of images from a given
        diffusion model and diffusion process.

    Args:
        N : Number of images
        model : Diffusion model
        diffusion : Diffusion process
        xT : Starting instantiation of noise vector.
        sampling_steps : Number of sampling steps.
        batch_size : Batch-size for sampling.
        num_channels : Number of channels in the image.
        image_size : Image size (assuming square images).
        num_classes : Number of classes in the dataset (needed for class-conditioned models)
        args : All args from the argparser.

    Returns: Numpy array with N images and corresponding labels.
    """
    samples, labels, num_samples = [], [], 0
    num_processes, group = dist.get_world_size(), dist.group.WORLD
    with tqdm(total=math.ceil(N / (args.batch_size * num_processes))) as pbar:
        while num_samples < N:
            if xT is None:
                xT = (
                    torch.randn(batch_size, num_channels, image_size, image_size)
                    .float()
                    .to(args.device)
                )
            if args.class_cond:
                y = torch.randint(num_classes, (len(xT),), dtype=torch.int64).to(
                    args.device
                )
            else:
                y = None
            gen_images = diffusion.sample_from_reverse_process(
                model, xT, sampling_steps, {"y": y}, args.ddim
            )
            samples_list = [torch.zeros_like(gen_images) for _ in range(num_processes)]
            if args.class_cond:
                labels_list = [torch.zeros_like(y) for _ in range(num_processes)]
                dist.all_gather(labels_list, y, group)
                labels.append(torch.cat(labels_list).detach().cpu().numpy())

            dist.all_gather(samples_list, gen_images, group)
            samples.append(torch.cat(samples_list).detach().cpu().numpy())
            num_samples += len(xT) * num_processes
            pbar.update(1)
    samples = np.concatenate(samples).transpose(0, 2, 3, 1)[:N]
    samples = (127.5 * (samples + 1)).astype(np.uint8)
    return (samples, np.concatenate(labels) if args.class_cond else None)



In [23]:
ngpus = torch.cuda.device_count()

In [None]:
for epoch in range(args.epochs):
    if sampler is not None:
        sampler.set_epoch(epoch)
    #####
    train_one_epoch(model, train_loader, diffusion, optimizer, logger, None, args)
    
    if not epoch % 1:
        sampled_images, sampled_label = sample_N_images(
            64,
            model,
            diffusion,
            None,
            args.sampling_steps,
            args.batch_size,
            metadata.num_channels,
            metadata.image_size,
            metadata.num_classes,
            args,
        )
        if args.local_rank == 0:
            cv2.imwrite(
                os.path.join(
                    args.save_dir,
                    f"{args.arch}_{args.dataset}-{args.diffusion_steps}_steps-{args.sampling_steps}-sampling_steps-class_condn_{args.class_cond}_{args.name}.png",
                ),
                np.concatenate(sampled_images, axis=1)[:, :, ::-1],
            )
            print(sampled_label)
    if args.local_rank == 0:
        torch.save(
            model.state_dict(),
            os.path.join(
                args.save_dir,
                f"{args.arch}_{args.dataset}-epoch_{args.epochs}-timesteps_{args.diffusion_steps}-class_condn_{args.class_cond}_{args.name}.pt",
            ),
        )
        torch.save(
            args.ema_dict,
            os.path.join(
                args.save_dir,
                f"{args.arch}_{args.dataset}-epoch_{args.epochs}-timesteps_{args.diffusion_steps}-class_condn_{args.class_cond}_ema_{args.ema_w}_{args.name}.pt",
            ),
        )

        
# cleanup()


Steps: 1/89900 	 loss (ema): 0.999 	 Time elapsed: 0.001 hr
