In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" 

In [3]:
import torch
import sys; sys.path.insert(0, '..')


In [4]:
import torch
from torch.utils.data import TensorDataset, DataLoader

In [5]:
import argparse

In [6]:
import os
import numpy as np
import torch
import warnings
import argparse
import sys

In [8]:
from main import *

In [9]:
def parse_args():
    parser = argparse.ArgumentParser(description='.')
    parser.add_argument("--arch", type=str, help="Neural network architecture")
    parser.add_argument(
        "--class-cond",
        action="store_true",
        default=False,
        help="train class-conditioned diffusion model",
    )
    parser.add_argument(
        "--diffusion-steps",
        type=int,
        default=1000,
        help="Number of timesteps in diffusion process",
    )
    parser.add_argument(
        "--sampling-steps",
        type=int,
        default=250,
        help="Number of timesteps in diffusion process",
    )
    parser.add_argument(
        "--ddim",
        action="store_true",
        default=False,
        help="Sampling using DDIM update step",
    )
    # dataset
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--data-dir", type=str, default="./dataset/")
    # optimizer
    parser.add_argument(
        "--batch-size", type=int, default=128, help="batch-size per gpu"
    )
    parser.add_argument("--lr", type=float, default=0.0001)
    parser.add_argument("--epochs", type=int, default=500)
    parser.add_argument("--ema_w", type=float, default=0.9995)
    # sampling/finetuning
    parser.add_argument("--pretrained-ckpt", type=str, help="Pretrained model ckpt")
    parser.add_argument("--delete-keys", nargs="+", help="Pretrained model ckpt")
    parser.add_argument(
        "--sampling-only",
        action="store_true",
        default=False,
        help="No training, just sample images (will save them in --save-dir)",
    )
    parser.add_argument(
        "--num-sampled-images",
        type=int,
        default=50000,
        help="Number of images required to sample from the model",
    )

    # misc
    parser.add_argument("--save-dir", type=str, default="./trained_models/")
    parser.add_argument("--local_rank", default=0, type=int)
#     parser.add_argument("--seed", default=112233, type=int)

    args = parser.parse_args(args=[])
    warnings.filterwarnings("ignore")

    return args

In [10]:
args = parse_args()

In [11]:
args.arch = "UNet"
args.dataset = "celeba"
args.epoch = 30  
args.save_dir = './models/'
args.data_dir = '../data/'
args.ddim = True 
args.epochs = 100
# args.batch_size = 64

In [12]:
args.class_cond = True

In [13]:
args

Namespace(arch='UNet', batch_size=128, class_cond=True, data_dir='../data/', dataset='celeba', ddim=True, delete_keys=None, diffusion_steps=1000, ema_w=0.9995, epoch=30, epochs=100, local_rank=0, lr=0.0001, num_sampled_images=50000, pretrained_ckpt=None, sampling_only=False, sampling_steps=250, save_dir='./models/')

In [14]:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29522"## 여러개 돌리려면 이 숫자 바꿔줘야함

In [15]:
def cleanup():
    dist.destroy_process_group()
    
# setup(int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]))

In [16]:
metadata = get_metadata('celeba')
metadata.num_classes = 2
# metadata.train_images = len(n(test_sampled_label)

In [17]:
metadata

{'image_size': 64,
 'num_classes': 2,
 'train_images': 109036,
 'val_images': 12376,
 'num_channels': 3}

In [18]:
torch.backends.cudnn.benchmark = True
args.device = "cuda:{}".format(args.local_rank)
torch.cuda.set_device(args.device)
# torch.manual_seed(args.seed + args.local_rank)
# np.random.seed(args.seed + args.local_rank)
if args.local_rank == 0:
    print(args)

# Creat model and diffusion process
model = unets.__dict__[args.arch](
    image_size=metadata.image_size,
    in_channels=metadata.num_channels,
    out_channels=metadata.num_channels,
    num_classes=metadata.num_classes if args.class_cond else None,
).to(args.device)
if args.local_rank == 0:
    print(
        "We are assuming that model input/ouput pixel range is [-1, 1]. Please adhere to it."
    )
diffusion = GuassianDiffusion(args.diffusion_steps, args.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)



Namespace(arch='UNet', batch_size=128, class_cond=True, data_dir='../data/', dataset='celeba', ddim=True, delete_keys=None, device='cuda:0', diffusion_steps=1000, ema_w=0.9995, epoch=30, epochs=100, local_rank=0, lr=0.0001, num_sampled_images=50000, pretrained_ckpt=None, sampling_only=False, sampling_steps=250, save_dir='./models/')
We are assuming that model input/ouput pixel range is [-1, 1]. Please adhere to it.


In [19]:
model_name = 'ENTER_MODEL_NAME.pt'

In [20]:
def load_model(model_name, model_location = './models/'):
    model = unets.__dict__[args.arch](
        image_size=metadata.image_size,
        in_channels=metadata.num_channels,
        out_channels=metadata.num_channels,
        num_classes=metadata.num_classes if args.class_cond else None,
        ).to(device)
    pretrained_ckpt = model_location + model_name
    print(f"Loading pretrained model from {pretrained_ckpt}")
    d = fix_legacy_dict(torch.load(pretrained_ckpt, map_location=args.device))
    dm = model.state_dict()
    if args.delete_keys:
        for k in args.delete_keys:
            print(
                f"Deleting key {k} becuase its shape in ckpt ({d[k].shape}) doesn't match "
                + f"with shape in model ({dm[k].shape})"
            )
            del d[k]
    model.load_state_dict(d, strict=False)
    print(
        f"Mismatched keys in ckpt and model: ",
        set(d.keys()) ^ set(dm.keys()),
    )
    print(f"Loaded pretrained model from {pretrained_ckpt}")
    return model

In [21]:
device = args.device

In [None]:
model = load_model(model_name)

In [23]:
from main import *

In [24]:
diffusion = GuassianDiffusion(args.diffusion_steps, device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

In [25]:
ngpus = torch.cuda.device_count()

In [26]:
torch.distributed.init_process_group(backend="nccl", init_method="env://")
# model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)

In [27]:
args.device = device

In [28]:
args.ddim_from = True
args.ddim_to = True

In [29]:
def sample_N_images_with_attribute_switching(
    N,
    model,
    diffusion,
    xT=None,
    sampling_steps=250,
    batch_size=64,
    num_channels=3,
    image_size=32,
    num_classes=None,
    args=None,
    time = 750,
    image_from = None
):
    """use this function to sample any number of images from a given
        diffusion model and diffusion process.

    Args:
        N : Number of images
        model : Diffusion model
        diffusion : Diffusion process
        xT : Starting instantiation of noise vector.
        sampling_steps : Number of sampling steps.
        batch_size : Batch-size for sampling.
        num_channels : Number of channels in the image.
        image_size : Image size (assuming square images).
        num_classes : Number of classes in the dataset (needed for class-conditioned models)
        args : All args from the argparser.

    Returns: Numpy array with N images and corresponding labels.
    """
    samples_orig, samples_other, labels, num_samples = [], [], [], 0
    num_processes, group = dist.get_world_size(), dist.group.WORLD
    with tqdm(total=math.ceil(N / (args.batch_size * num_processes))) as pbar:
        while num_samples < N:
            if xT is None:
                xT = (
                    torch.randn(batch_size, num_channels, image_size, image_size)
                    .float()
                    .to(args.device)
                )
            if args.class_cond:
                if image_from == None:
                    y = torch.randint(num_classes, (len(xT),), dtype=torch.int64).to(
                        args.device
                    )
                else:
                    y = torch.ones((len(xT),), dtype=torch.int64).to(
                        args.device
                    ) * image_from
            else:
                y = None
                
            gen_images_t = diffusion.sample_from_reverse_process(
                model, xT, sampling_steps, {"y": y}, args.ddim_from, time = time
            )
            gen_images_orig = diffusion.sample_remain_from_reverse_process(model, gen_images_t, sampling_steps, {"y": y}, args.ddim_to, time = time, time_prev = None)
            y_new = torch.tensor(y == 0, dtype = int) 
            gen_images_other = diffusion.sample_remain_from_reverse_process(model, gen_images_t, sampling_steps, {"y": y_new}, args.ddim_to, time = time, time_prev = None)
            
            samples_list_orig = [torch.zeros_like(gen_images_orig) for _ in range(num_processes)]
            samples_list_other = [torch.zeros_like(gen_images_other) for _ in range(num_processes)]
            
            if args.class_cond:
                labels_list = [torch.zeros_like(y) for _ in range(num_processes)]
                dist.all_gather(labels_list, y, group)
                labels.append(torch.cat(labels_list).detach().cpu().numpy())

            dist.all_gather(samples_list_orig, gen_images_orig, group)
            dist.all_gather(samples_list_other, gen_images_other, group)
            
            samples_orig.append(torch.cat(samples_list_orig).detach().cpu().numpy())
            samples_other.append(torch.cat(samples_list_other).detach().cpu().numpy())
            
            num_samples += len(xT) * num_processes
            pbar.update(1)
    samples_orig = np.concatenate(samples_orig).transpose(0, 2, 3, 1)[:N]
    samples_other = np.concatenate(samples_other).transpose(0, 2, 3, 1)[:N]
    samples_orig = (127.5 * (samples_orig + 1)).astype(np.uint8)
    samples_other = (127.5 * (samples_other + 1)).astype(np.uint8)
    return (samples_orig, samples_other, np.concatenate(labels) if args.class_cond else None)



In [30]:
import PIL.Image

In [None]:
model = load_model(model_name)
metadata = get_metadata('celeba')
metadata.num_classes = 2

In [33]:
args.batch_size = 30
seed = 0
gridw=6
gridh=5
for time in [640]:
    for image_from in [0, 1]:
        torch.manual_seed(seed)
        name = 'tmp'
        samples_orig, samples_other, labels = sample_N_images_with_attribute_switching(
                args.batch_size,
                model,
                diffusion,
                None,
                args.sampling_steps,
                args.batch_size,
                metadata.num_channels,
                metadata.image_size,
                metadata.num_classes,
                args,
                time,
                image_from = image_from

            )
        image_list = [samples_orig, samples_other]
        name_list = ['_orig_%s'%image_from, '_switch_%s'%((image_from+1)%2)]
        for i in range(2):
            dest_path = name + name_list[i] + '.pdf'
            image = image_list[i]
            image = torch.tensor(image.reshape(gridh, gridw, *image.shape[1:]))
            image = image.permute(0, 2, 1, 3, 4)
            image = image.numpy()
            image = image.reshape(gridh * metadata.image_size, gridw * metadata.image_size, metadata.num_channels)
            PIL.Image.fromarray(image, 'RGB').save(dest_path)
        print('Done.')

100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:31<00:00, 31.42s/it]


Done.


100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:28<00:00, 28.70s/it]

Done.



