In [5]:
from google.colab import drive
# Mount your google drive to the /content/drive directory.
# This ensures that files changes remain saved in your drive.
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [6]:
%cd /content/drive/MyDrive/cpen455-project/

/content/drive/MyDrive/cpen455-project


In [None]:
!pip install -r requirements.txt

Collecting bidict (from -r requirements.txt (line 1))
  Downloading bidict-0.23.1-py3-none-any.whl.metadata (8.7 kB)
Collecting pytorch-fid (from -r requirements.txt (line 7))
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Collecting dotenv (from -r requirements.txt (line 11))
  Downloading dotenv-0.9.9-py2.py3-none-any.whl.metadata (279 bytes)
Collecting python-dotenv (from dotenv->-r requirements.txt (line 11))
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r requirements.txt (line 7))
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r requirements.txt (line 7))
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r r

In [7]:
import time
import os
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import wandb
from tqdm import tqdm
from pprint import pprint
import argparse
from pytorch_fid.fid_score import calculate_fid_given_paths
import logging

import dotenv
dotenv.load_dotenv()

True

In [4]:
import importlib
import utils, model, dataset, generation_evaluation

importlib.reload(utils)  # Reload the script after changes
importlib.reload(model)  # Reload the script after changes
importlib.reload(dataset)  # Reload the script after changes
importlib.reload(generation_evaluation)  # Reload the script after changes

from utils import *
from model import *
from dataset import *
import generation_evaluation

In [None]:
import logging
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.WARN)

In [None]:
from dataclasses import dataclass

@dataclass
class Config:
    en_wandb: bool = False
    tag: str = "default"
    sampling_interval: int = 5
    data_dir: str = "data"
    save_dir: str = "models"
    sample_dir: str = "samples"
    dataset: str = "cpen455"
    save_interval: int = 10
    load_params: str = None
    obs: tuple = (3, 32, 32)
    nr_resnet: int = 1
    nr_filters: int = 40
    nr_logistic_mix: int = 5
    lr: float = 0.0002
    lr_decay: float = 0.999995
    batch_size: int = 64
    sample_batch_size: int = 32
    max_epochs: int = 5000
    seed: int = 1

In [None]:
args = Config(
    batch_size=16,
    sample_batch_size=16,
    sampling_interval=5,
    save_interval=10,
    max_epochs=500,
    en_wandb=True,
    tag="Middle_1",
    sample_dir="samples_middle"
)

pprint(args.__dict__)
check_dir_and_create(args.save_dir)

# reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)

model_name = 'pcnn_' + args.dataset + "_"
model_path = args.save_dir + '/'
if args.load_params is not None:
    model_name = model_name + 'load_model'
    model_path = model_path + model_name + '/'
else:
    model_name = model_name + 'from_scratch'
    model_path = model_path + model_name + '/'

job_name = "PCNN_Training_" + "dataset:" + args.dataset + "_" + args.tag

if args.en_wandb:
    # start a new wandb run to track this script
    wandb.init(
        # set entity to specify your username or team name
        # entity="qihangz-work",
        # set the wandb project where this run will be logged
        project="CPEN455HW",
        # group=Group Name
        name=job_name,
    )
    wandb.config.current_time = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
    wandb.config.update(args)

#set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Reminder: if you have patience to read code line by line, you should notice this comment. here is the reason why we set num_workers to 0:
#In order to avoid pickling errors with the dataset on different machines, we set num_workers to 0.
#If you are using ubuntu/linux/colab, and find that loading data is too slow, you can set num_workers to 1 or even bigger.
kwargs = {'num_workers':2, 'pin_memory':True, 'drop_last':True}

{'batch_size': 16,
 'data_dir': 'data',
 'dataset': 'cpen455',
 'en_wandb': True,
 'load_params': None,
 'lr': 0.0002,
 'lr_decay': 0.999995,
 'max_epochs': 500,
 'nr_filters': 40,
 'nr_logistic_mix': 5,
 'nr_resnet': 1,
 'obs': (3, 32, 32),
 'sample_batch_size': 16,
 'sample_dir': 'samples_middle',
 'sampling_interval': 5,
 'save_dir': 'models',
 'save_interval': 10,
 'seed': 1,
 'tag': 'Middle_1'}


In [None]:
# set data
if "mnist" in args.dataset:
    ds_transforms = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), rescaling, replicate_color_channel])
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, download=True,
                        train=True, transform=ds_transforms), batch_size=args.batch_size,
                            shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

elif "cifar" in args.dataset:
    ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling])
    if args.dataset == "cifar10":
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True,
            download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader  = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
    elif args.dataset == "cifar100":
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_dir, train=True,
            download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader  = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
    else:
        raise Exception('{} dataset not in {cifar10, cifar100}'.format(args.dataset))

elif "cpen455" in args.dataset:
    ds_transforms = transforms.Compose([transforms.Resize((32, 32)), rescaling])
    train_loader = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'train',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
    test_loader  = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'test',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
    val_loader  = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'validation',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
else:
    raise Exception('{} dataset not in {mnist, cifar, cpen455}'.format(args.dataset))

In [None]:
def train_or_test(model, data_loader, optimizer, loss_op, device, args, epoch, mode = 'training'):
    logging.debug('mode: {}'.format(mode))
    if mode == 'training':
        model.train()
    else:
        model.eval()

    deno =  args.batch_size * np.prod(args.obs) * np.log(2.)
    loss_tracker = mean_tracker()

    for batch_idx, item in enumerate(tqdm(data_loader)):
        logging.debug('batch_idx: {}'.format(batch_idx))
        model_input, label = item
        model_input = model_input.to(device)
        model_output = model(model_input, label)

        loss = loss_op(model_input, model_output)
        loss_tracker.update(loss.item()/deno)
        if mode == 'training':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    if args.en_wandb:
        wandb.login()
        wandb.log({mode + "-Average-BPD" : loss_tracker.get_mean()})
        wandb.log({mode + "-epoch": epoch})


In [None]:
args.obs = (3, 32, 32)
input_channels = args.obs[0]

loss_op   = lambda real, fake : discretized_mix_logistic_loss(real, fake)
sample_op = lambda x : sample_from_discretized_mix_logistic(x, args.nr_logistic_mix)

model = PixelCNN(nr_resnet=args.nr_resnet, nr_filters=args.nr_filters,
            input_channels=input_channels, nr_logistic_mix=args.nr_logistic_mix)
model = model.to(device)

if args.load_params:
    model.load_state_dict(torch.load(args.load_params))
    print('model parameters loaded')

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)

for epoch in tqdm(range(args.max_epochs)):
    train_or_test(model = model,
                  data_loader = train_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'training')

    # decrease learning rate
    scheduler.step()
    train_or_test(model = model,
                  data_loader = test_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'test')

    train_or_test(model = model,
                  data_loader = val_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'val')

    if epoch % args.sampling_interval == 0:
        print('......sampling......')
        sampled_images = generation_evaluation.my_sample(model, args.sample_dir, args.sample_batch_size, args.obs, sample_op)
        sample_result = {label: wandb.Image(img, caption="epoch {}".format(epoch)) for label, img in sampled_images.items()}

        gen_data_dir = args.sample_dir
        ref_data_dir = args.data_dir +'/test'
        paths = [gen_data_dir, ref_data_dir]
        try:
            fid_score = calculate_fid_given_paths(paths, 32, device, dims=192)
            print("Dimension {:d} works! fid score: {}".format(192, fid_score))
        except:
            print("Dimension {:d} fails!".format(192))

        if args.en_wandb:
          for img in sample_result:
            wandb.log({"samples": img,
                        "FID": fid_score})

    if (epoch + 1) % args.save_interval == 0:
        if not os.path.exists("models"):
            os.makedirs("models")
        torch.save(model.state_dict(), 'models/{}_{}.pth'.format(model_name, epoch))


  WeightNorm.apply(module, name, dim)
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:16<1:11:00, 16.51s/it][A
  1%|          | 3/259 [00:17<20:19,  4.76s/it]  [A
  2%|▏         | 4/259 [00:18<14:19,  3.37s/it][A
  2%|▏         | 5/259 [00:20<12:49,  3.03s/it][A
  2%|▏         | 6/259 [00:21<09:25,  2.24s/it][A
  3%|▎         | 7/259 [00:23<09:43,  2.32s/it][A
  3%|▎         | 8/259 [00:24<07:15,  1.73s/it][A
  3%|▎         | 9/259 [00:27<08:29,  2.04s/it][A
  4%|▍         | 10/259 [00:27<06:08,  1.48s/it][A
  4%|▍         | 11/259 [00:29<07:37,  1.84s/it][A
  5%|▍         | 12/259 [00:30<05:29,  1.34s/it][A
  5%|▌         | 13/259 [00:32<07:09,  1.75s/it][A
  5%|▌         | 14/259 [00:33<05:19,  1.30s/it][A
  6%|▌         | 15/259 [00:35<06:56,  1.71s/it][A
  6%|▌         | 16/259 [00:36<05:27,  1.35s/it][A
  7%|▋         | 17/259 [00:38<06:56,  1.72s/it][A
  7%|▋         | 18/259 [00:39<05:22,  1.34s/it][A

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3


Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth

  0%|          | 0.00/91.2M [00:00<?, ?B/s][A
 24%|██▍       | 21.9M/91.2M [00:00<00:00, 229MB/s][A
 65%|██████▍   | 59.0M/91.2M [00:00<00:00, 323MB/s][A
100%|██████████| 91.2M/91.2M [00:00<00:00, 282MB/s]

  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  3.07it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:27,  1.74s/it][A
 12%|█▏        | 2/17 [00:01<00:12,  1.21it/s][A
 18%|█▊        | 3/17 [00:02<00:07,  2.00it/s][A
 29%|██▉       | 5/17 [00:02<00:03,  3.65it/s][A
 35%|███▌      | 6/17 [00:02<00:02,  4.46it/s][A
 41%|████      | 7/17 [00:02<00:01,  5.18it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.76it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  7.84it/s][A
 76%|███████▋  | 13/17 [00:03<00:00,  8.65it/s][A
 88%|█████

Dimension 192 works! fid score: 26.357485329779625



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:09,  3.70it/s][A
  1%|          | 2/259 [00:00<00:48,  5.29it/s][A
  2%|▏         | 4/259 [00:00<00:32,  7.91it/s][A
  2%|▏         | 5/259 [00:00<00:30,  8.35it/s][A
  2%|▏         | 6/259 [00:00<00:29,  8.62it/s][A
  3%|▎         | 8/259 [00:00<00:25,  9.78it/s][A
  4%|▍         | 10/259 [00:01<00:23, 10.47it/s][A
  5%|▍         | 12/259 [00:01<00:22, 10.88it/s][A
  5%|▌         | 14/259 [00:01<00:21, 11.26it/s][A
  6%|▌         | 16/259 [00:01<00:21, 11.41it/s][A
  7%|▋         | 18/259 [00:01<00:19, 12.20it/s][A
  8%|▊         | 20/259 [00:01<00:20, 11.49it/s][A
  8%|▊         | 22/259 [00:02<00:22, 10.70it/s][A
  9%|▉         | 24/259 [00:02<00:21, 10.87it/s][A
 10%|█         | 26/259 [00:02<00:20, 11.50it/s][A
 11%|█         | 28/259 [00:02<00:19, 12.05it/s][A
 12%|█▏        | 30/259 [00:02<00:21, 10.66it/s][A
 12%|█▏        | 32/259 [00:03<00:20, 10.83it/s][A
 13%|█▎        | 34/259 [0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  5.56it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.05it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:25,  1.60s/it][A
 18%|█▊        | 3/17 [00:01<00:06,  2.06it/s][A
 29%|██▉       | 5/17 [00:01<00:03,  3.51it/s][A
 41%|████      | 7/17 [00:02<00:02,  4.88it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.04it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  7.04it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  7.94it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  8.25it/s][A
 88%|████████▊ | 15/17 [00:02<00:00,  8.45it/s][A
100%|██████████| 17/17 [00:03<00:00,  5.45it/s]
  1%|          | 6/500 [14:24<13:16:40, 96.76s/it]

Dimension 192 works! fid score: 46.28623452935197



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:22,  3.13it/s][A
  1%|          | 2/259 [00:00<00:50,  5.10it/s][A
  2%|▏         | 4/259 [00:00<00:32,  7.78it/s][A
  2%|▏         | 6/259 [00:00<00:28,  8.81it/s][A
  3%|▎         | 7/259 [00:00<00:31,  8.10it/s][A
  3%|▎         | 8/259 [00:01<00:31,  7.85it/s][A
  3%|▎         | 9/259 [00:01<00:33,  7.46it/s][A
  4%|▍         | 10/259 [00:01<00:33,  7.36it/s][A
  4%|▍         | 11/259 [00:01<00:33,  7.33it/s][A
  5%|▍         | 12/259 [00:01<00:34,  7.24it/s][A
  5%|▌         | 13/259 [00:01<00:33,  7.24it/s][A
  5%|▌         | 14/259 [00:01<00:33,  7.30it/s][A
  6%|▌         | 15/259 [00:02<00:33,  7.36it/s][A
  6%|▌         | 16/259 [00:02<00:33,  7.31it/s][A
  7%|▋         | 17/259 [00:02<00:32,  7.54it/s][A
  7%|▋         | 18/259 [00:02<00:34,  6.98it/s][A
  7%|▋         | 19/259 [00:02<00:34,  6.86it/s][A
  8%|▊         | 20/259 [00:02<00:34,  6.90it/s][A
  8%|▊         | 21/259 [00

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.53it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:25,  1.62s/it][A
 18%|█▊        | 3/17 [00:01<00:06,  2.02it/s][A
 29%|██▉       | 5/17 [00:02<00:03,  3.49it/s][A
 35%|███▌      | 6/17 [00:02<00:02,  4.19it/s][A
 41%|████      | 7/17 [00:02<00:02,  4.89it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.37it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  7.47it/s][A
 71%|███████   | 12/17 [00:02<00:00,  7.89it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  8.16it/s][A
 88%|████████▊ | 15/17 [00:03<00:00,  8.95it/s][A
100%|██████████| 17/17 [00:03<00:00,  5.36it/s]
  2%|▏         | 11/500 [18:38<9:31:30, 70.12s/it]

Dimension 192 works! fid score: 44.22560922782329



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:22,  3.12it/s][A
  1%|          | 2/259 [00:00<01:00,  4.24it/s][A
  1%|          | 3/259 [00:00<00:49,  5.12it/s][A
  2%|▏         | 4/259 [00:00<00:43,  5.83it/s][A
  2%|▏         | 5/259 [00:00<00:42,  5.94it/s][A
  2%|▏         | 6/259 [00:01<00:39,  6.47it/s][A
  3%|▎         | 8/259 [00:01<00:33,  7.56it/s][A
  3%|▎         | 9/259 [00:01<00:34,  7.19it/s][A
  4%|▍         | 10/259 [00:01<00:34,  7.30it/s][A
  4%|▍         | 11/259 [00:01<00:34,  7.24it/s][A
  5%|▍         | 12/259 [00:01<00:34,  7.08it/s][A
  5%|▌         | 13/259 [00:02<00:35,  6.95it/s][A
  5%|▌         | 14/259 [00:02<00:34,  7.12it/s][A
  6%|▌         | 15/259 [00:02<00:34,  7.04it/s][A
  6%|▌         | 16/259 [00:02<00:33,  7.22it/s][A
  7%|▋         | 17/259 [00:02<00:36,  6.59it/s][A
  7%|▋         | 18/259 [00:02<00:33,  7.22it/s][A
  7%|▋         | 19/259 [00:02<00:31,  7.52it/s][A
  8%|▊         | 20/259 [00:

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  6.01it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.27it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:24,  1.52s/it][A
 12%|█▏        | 2/17 [00:01<00:10,  1.44it/s][A
 24%|██▎       | 4/17 [00:01<00:04,  3.14it/s][A
 29%|██▉       | 5/17 [00:01<00:03,  3.95it/s][A
 35%|███▌      | 6/17 [00:02<00:02,  4.70it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  6.20it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.79it/s][A
 59%|█████▉    | 10/17 [00:02<00:00,  7.25it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  7.77it/s][A
 71%|███████   | 12/17 [00:02<00:00,  8.06it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  8.85it/s][A
 88%|████████▊ | 15/17 [00:03<00:00,  8.87it/s][A
100%|██████████| 17/17 [00:03<00:00,  5.38it/s]
  3%|▎         | 16/500 [22:50<8:52:40, 66.03s/it]

Dimension 192 works! fid score: 30.886889934830172



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:21,  3.17it/s][A
  1%|          | 2/259 [00:00<00:51,  5.04it/s][A
  1%|          | 3/259 [00:00<00:40,  6.31it/s][A
  2%|▏         | 5/259 [00:00<00:31,  8.05it/s][A
  3%|▎         | 7/259 [00:00<00:27,  9.31it/s][A
  3%|▎         | 8/259 [00:01<00:26,  9.37it/s][A
  4%|▍         | 10/259 [00:01<00:24, 10.07it/s][A
  5%|▍         | 12/259 [00:01<00:23, 10.44it/s][A
  5%|▌         | 14/259 [00:01<00:22, 10.76it/s][A
  6%|▌         | 16/259 [00:01<00:22, 10.65it/s][A
  7%|▋         | 18/259 [00:01<00:22, 10.93it/s][A
  8%|▊         | 20/259 [00:02<00:21, 10.96it/s][A
  8%|▊         | 22/259 [00:02<00:21, 10.93it/s][A
  9%|▉         | 24/259 [00:02<00:21, 11.17it/s][A
 10%|█         | 26/259 [00:02<00:21, 10.94it/s][A
 11%|█         | 28/259 [00:02<00:21, 10.68it/s][A
 12%|█▏        | 30/259 [00:03<00:22, 10.02it/s][A
 12%|█▏        | 32/259 [00:03<00:24,  9.32it/s][A
 13%|█▎        | 33/259 [0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.42it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:00<00:04,  3.88it/s][A
 12%|█▏        | 2/17 [00:01<00:12,  1.24it/s][A
 24%|██▎       | 4/17 [00:01<00:04,  2.74it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  4.18it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  5.51it/s][A
 59%|█████▉    | 10/17 [00:02<00:01,  6.67it/s][A
 71%|███████   | 12/17 [00:02<00:00,  7.41it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  7.43it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  7.48it/s][A
 88%|████████▊ | 15/17 [00:02<00:00,  7.77it/s][A
100%|██████████| 17/17 [00:02<00:00,  5.71it/s]
  4%|▍         | 21/500 [27:07<8:46:24, 65.94s/it]

Dimension 192 works! fid score: 24.402886934897413



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:20,  3.20it/s][A
  1%|          | 2/259 [00:00<00:52,  4.91it/s][A
  1%|          | 3/259 [00:00<00:41,  6.20it/s][A
  2%|▏         | 5/259 [00:00<00:30,  8.29it/s][A
  3%|▎         | 7/259 [00:00<00:26,  9.51it/s][A
  3%|▎         | 8/259 [00:01<00:26,  9.48it/s][A
  4%|▍         | 10/259 [00:01<00:24, 10.29it/s][A
  5%|▍         | 12/259 [00:01<00:23, 10.42it/s][A
  5%|▌         | 14/259 [00:01<00:22, 10.69it/s][A
  6%|▌         | 16/259 [00:01<00:23, 10.56it/s][A
  7%|▋         | 18/259 [00:01<00:22, 10.83it/s][A
  8%|▊         | 20/259 [00:02<00:22, 10.86it/s][A
  8%|▊         | 22/259 [00:02<00:23, 10.02it/s][A
  9%|▉         | 24/259 [00:02<00:25,  9.17it/s][A
 10%|▉         | 25/259 [00:02<00:25,  9.03it/s][A
 10%|█         | 27/259 [00:02<00:23,  9.70it/s][A
 11%|█         | 28/259 [00:03<00:24,  9.43it/s][A
 11%|█         | 29/259 [00:03<00:26,  8.58it/s][A
 12%|█▏        | 30/259 [0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.34it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:26,  1.67s/it][A
 12%|█▏        | 2/17 [00:01<00:11,  1.33it/s][A
 18%|█▊        | 3/17 [00:01<00:06,  2.20it/s][A
 24%|██▎       | 4/17 [00:01<00:04,  3.12it/s][A
 35%|███▌      | 6/17 [00:02<00:02,  4.93it/s][A
 41%|████      | 7/17 [00:02<00:01,  5.70it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  6.17it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.76it/s][A
 59%|█████▉    | 10/17 [00:02<00:00,  7.26it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  6.69it/s][A
 71%|███████   | 12/17 [00:02<00:00,  7.07it/s][A
 76%|███████▋  | 13/17 [00:03<00:00,  7.17it/s][A
 82%|████████▏ | 14/17 [00:03<00:00,  7.11it/s][A
 88%|████████▊ | 15/17 [00:03<00:00,  6.89it/s][A
100%|██████████| 17/17 [00:03<00:00,  4.71it/s]
  5%|▌         | 26/500 [31:26<8:42:32, 66.15s/it]

Dimension 192 works! fid score: 50.45411611047611



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:27,  1.75it/s][A
  1%|          | 2/259 [00:00<01:21,  3.14it/s][A
  1%|          | 3/259 [00:00<01:01,  4.14it/s][A
  2%|▏         | 4/259 [00:01<00:53,  4.77it/s][A
  2%|▏         | 5/259 [00:01<00:47,  5.32it/s][A
  2%|▏         | 6/259 [00:01<00:44,  5.67it/s][A
  3%|▎         | 7/259 [00:01<00:41,  6.02it/s][A
  3%|▎         | 8/259 [00:01<00:40,  6.15it/s][A
  3%|▎         | 9/259 [00:01<00:40,  6.13it/s][A
  4%|▍         | 10/259 [00:01<00:39,  6.35it/s][A
  5%|▍         | 12/259 [00:02<00:30,  8.00it/s][A
  5%|▌         | 13/259 [00:02<00:30,  8.19it/s][A
  5%|▌         | 14/259 [00:02<00:30,  7.96it/s][A
  6%|▌         | 15/259 [00:02<00:36,  6.68it/s][A
  6%|▌         | 16/259 [00:02<00:33,  7.34it/s][A
  7%|▋         | 17/259 [00:02<00:34,  7.08it/s][A
  7%|▋         | 18/259 [00:02<00:34,  6.92it/s][A
  7%|▋         | 19/259 [00:03<00:34,  7.05it/s][A
  8%|▊         | 21/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  4.87it/s][A
100%|██████████| 2/2 [00:00<00:00,  4.59it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:27,  1.70s/it][A
 12%|█▏        | 2/17 [00:01<00:11,  1.27it/s][A
 18%|█▊        | 3/17 [00:02<00:06,  2.01it/s][A
 24%|██▎       | 4/17 [00:02<00:04,  2.83it/s][A
 29%|██▉       | 5/17 [00:02<00:03,  3.49it/s][A
 35%|███▌      | 6/17 [00:02<00:02,  4.46it/s][A
 41%|████      | 7/17 [00:02<00:01,  5.30it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.76it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  7.85it/s][A
 76%|███████▋  | 13/17 [00:03<00:00,  8.57it/s][A
 88%|████████▊ | 15/17 [00:03<00:00,  8.92it/s][A
100%|██████████| 17/17 [00:03<00:00,  4.87it/s]
  6%|▌         | 31/500 [35:49<8:45:19, 67.21s/it]

Dimension 192 works! fid score: 44.02896468741757



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:10,  3.68it/s][A
  1%|          | 2/259 [00:00<00:48,  5.34it/s][A
  1%|          | 3/259 [00:00<00:38,  6.63it/s][A
  2%|▏         | 4/259 [00:00<00:37,  6.87it/s][A
  2%|▏         | 5/259 [00:00<00:32,  7.74it/s][A
  2%|▏         | 6/259 [00:00<00:30,  8.35it/s][A
  3%|▎         | 7/259 [00:00<00:28,  8.73it/s][A
  3%|▎         | 8/259 [00:01<00:27,  9.04it/s][A
  3%|▎         | 9/259 [00:01<00:27,  9.24it/s][A
  4%|▍         | 10/259 [00:01<00:27,  8.97it/s][A
  5%|▍         | 12/259 [00:01<00:24,  9.91it/s][A
  5%|▌         | 13/259 [00:01<00:25,  9.57it/s][A
  5%|▌         | 14/259 [00:01<00:25,  9.62it/s][A
  6%|▌         | 16/259 [00:01<00:24, 10.10it/s][A
  7%|▋         | 18/259 [00:02<00:22, 10.50it/s][A
  8%|▊         | 20/259 [00:02<00:22, 10.62it/s][A
  8%|▊         | 22/259 [00:02<00:22, 10.48it/s][A
  9%|▉         | 24/259 [00:02<00:23, 10.18it/s][A
 10%|█         | 26/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  5.69it/s][A
100%|██████████| 2/2 [00:00<00:00,  5.86it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:25,  1.57s/it][A
 18%|█▊        | 3/17 [00:01<00:06,  2.08it/s][A
 24%|██▎       | 4/17 [00:01<00:04,  2.84it/s][A
 29%|██▉       | 5/17 [00:01<00:03,  3.67it/s][A
 35%|███▌      | 6/17 [00:02<00:02,  4.56it/s][A
 41%|████      | 7/17 [00:02<00:01,  5.35it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  6.19it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.75it/s][A
 59%|█████▉    | 10/17 [00:02<00:00,  7.09it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  7.69it/s][A
 71%|███████   | 12/17 [00:02<00:00,  7.96it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  8.34it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  8.58it/s][A
 88%|████████▊ | 15/17 [00:03<00:00,  8.94it/s][A
100%|██████████| 17/17 [00:03<00:00,  5.22it/s]
  7%|▋         | 36/500 [40:12<8:43:09, 67.65s/it]

Dimension 192 works! fid score: 28.82040247318594



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:31,  2.82it/s][A
  1%|          | 2/259 [00:00<00:56,  4.53it/s][A
  1%|          | 3/259 [00:00<00:52,  4.89it/s][A
  2%|▏         | 4/259 [00:00<00:42,  6.06it/s][A
  2%|▏         | 5/259 [00:00<00:36,  6.88it/s][A
  2%|▏         | 6/259 [00:00<00:33,  7.62it/s][A
  3%|▎         | 7/259 [00:01<00:30,  8.20it/s][A
  3%|▎         | 8/259 [00:01<00:29,  8.53it/s][A
  3%|▎         | 9/259 [00:01<00:29,  8.51it/s][A
  4%|▍         | 10/259 [00:01<00:29,  8.36it/s][A
  4%|▍         | 11/259 [00:01<00:29,  8.49it/s][A
  5%|▍         | 12/259 [00:01<00:28,  8.60it/s][A
  5%|▌         | 14/259 [00:01<00:25,  9.69it/s][A
  6%|▌         | 15/259 [00:01<00:25,  9.44it/s][A
  6%|▌         | 16/259 [00:02<00:25,  9.56it/s][A
  7%|▋         | 18/259 [00:02<00:24,  9.79it/s][A
  7%|▋         | 19/259 [00:02<00:24,  9.75it/s][A
  8%|▊         | 20/259 [00:02<00:27,  8.77it/s][A
  8%|▊         | 21/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  4.51it/s][A
100%|██████████| 2/2 [00:00<00:00,  4.35it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:29,  1.84s/it][A
 12%|█▏        | 2/17 [00:01<00:12,  1.21it/s][A
 18%|█▊        | 3/17 [00:02<00:06,  2.03it/s][A
 24%|██▎       | 4/17 [00:02<00:04,  2.90it/s][A
 29%|██▉       | 5/17 [00:02<00:03,  3.79it/s][A
 41%|████      | 7/17 [00:02<00:01,  5.58it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  6.30it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.77it/s][A
 59%|█████▉    | 10/17 [00:02<00:00,  7.39it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  7.43it/s][A
 71%|███████   | 12/17 [00:03<00:00,  7.99it/s][A
 76%|███████▋  | 13/17 [00:03<00:00,  8.24it/s][A
 88%|████████▊ | 15/17 [00:03<00:00,  9.09it/s][A
100%|██████████| 17/17 [00:03<00:00,  4.83it/s]
  8%|▊         | 41/500 [44:41<8:43:35, 68.44s/it]

Dimension 192 works! fid score: 42.63737607712062



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:24,  3.07it/s][A
  1%|          | 2/259 [00:00<00:50,  5.12it/s][A
  1%|          | 3/259 [00:00<00:39,  6.49it/s][A
  2%|▏         | 4/259 [00:00<00:34,  7.43it/s][A
  2%|▏         | 5/259 [00:00<00:32,  7.71it/s][A
  3%|▎         | 7/259 [00:00<00:27,  9.17it/s][A
  3%|▎         | 8/259 [00:01<00:27,  9.28it/s][A
  3%|▎         | 9/259 [00:01<00:27,  9.08it/s][A
  4%|▍         | 10/259 [00:01<00:27,  9.19it/s][A
  4%|▍         | 11/259 [00:01<00:26,  9.19it/s][A
  5%|▍         | 12/259 [00:01<00:29,  8.42it/s][A
  5%|▌         | 13/259 [00:01<00:27,  8.79it/s][A
  5%|▌         | 14/259 [00:01<00:27,  8.90it/s][A
  6%|▌         | 15/259 [00:01<00:33,  7.33it/s][A
  6%|▌         | 16/259 [00:02<00:32,  7.38it/s][A
  7%|▋         | 17/259 [00:02<00:32,  7.46it/s][A
  7%|▋         | 18/259 [00:02<00:30,  7.97it/s][A
  8%|▊         | 20/259 [00:02<00:26,  9.11it/s][A
  8%|▊         | 21/259 [00:

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  4.95it/s][A
100%|██████████| 2/2 [00:00<00:00,  4.66it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:26,  1.64s/it][A
 12%|█▏        | 2/17 [00:01<00:11,  1.35it/s][A
 18%|█▊        | 3/17 [00:01<00:06,  2.11it/s][A
 24%|██▎       | 4/17 [00:02<00:04,  2.91it/s][A
 29%|██▉       | 5/17 [00:02<00:03,  3.48it/s][A
 35%|███▌      | 6/17 [00:02<00:02,  4.14it/s][A
 41%|████      | 7/17 [00:02<00:02,  4.68it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  5.38it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  5.64it/s][A
 59%|█████▉    | 10/17 [00:02<00:01,  6.29it/s][A
 65%|██████▍   | 11/17 [00:03<00:00,  6.49it/s][A
 71%|███████   | 12/17 [00:03<00:00,  6.92it/s][A
 76%|███████▋  | 13/17 [00:03<00:00,  7.41it/s][A
 82%|████████▏ | 14/17 [00:03<00:00,  7.74it/s][A
 88%|████████▊ | 15/17 [00:03<00:00,  8.23it/s][A
100%|██████████| 17/17 [00:03<00:00,  4.58it/s]
  9%|▉         |

Dimension 192 works! fid score: 24.34350053270988



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:27,  2.96it/s][A
  1%|          | 2/259 [00:00<00:52,  4.91it/s][A
  1%|          | 3/259 [00:00<00:41,  6.12it/s][A
  2%|▏         | 5/259 [00:00<00:31,  8.13it/s][A
  3%|▎         | 7/259 [00:00<00:27,  9.32it/s][A
  3%|▎         | 8/259 [00:01<00:26,  9.43it/s][A
  3%|▎         | 9/259 [00:01<00:26,  9.27it/s][A
  4%|▍         | 11/259 [00:01<00:24, 10.11it/s][A
  5%|▌         | 13/259 [00:01<00:24, 10.05it/s][A
  6%|▌         | 15/259 [00:01<00:23, 10.46it/s][A
  7%|▋         | 17/259 [00:01<00:22, 10.74it/s][A
  7%|▋         | 19/259 [00:02<00:22, 10.89it/s][A
  8%|▊         | 21/259 [00:02<00:22, 10.65it/s][A
  9%|▉         | 23/259 [00:02<00:21, 10.75it/s][A
 10%|▉         | 25/259 [00:02<00:22, 10.59it/s][A
 10%|█         | 27/259 [00:02<00:22, 10.51it/s][A
 11%|█         | 29/259 [00:02<00:21, 10.74it/s][A
 12%|█▏        | 31/259 [00:03<00:21, 10.63it/s][A
 13%|█▎        | 33/259 [00

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  5.69it/s][A
100%|██████████| 2/2 [00:00<00:00,  5.88it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:23,  1.45s/it][A
 12%|█▏        | 2/17 [00:01<00:09,  1.53it/s][A
 18%|█▊        | 3/17 [00:01<00:05,  2.45it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.42it/s][A
 29%|██▉       | 5/17 [00:01<00:02,  4.47it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  5.34it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  7.00it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  7.57it/s][A
 59%|█████▉    | 10/17 [00:02<00:00,  7.78it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  8.23it/s][A
 71%|███████   | 12/17 [00:02<00:00,  8.42it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  8.75it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  8.78it/s][A
 88%|████████▊ | 15/17 [00:02<00:00,  8.59it/s][A
100%|██████████| 17/17 [00:03<00:00,  5.41it/s]
 10%|█         | 51/500 [53:47<8:40:18, 69.53s/it]

Dimension 192 works! fid score: 24.223940360573888



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:06,  2.05it/s][A
  1%|          | 2/259 [00:00<01:19,  3.23it/s][A
  1%|          | 3/259 [00:00<01:01,  4.15it/s][A
  2%|▏         | 4/259 [00:00<00:52,  4.82it/s][A
  2%|▏         | 5/259 [00:01<00:49,  5.14it/s][A
  2%|▏         | 6/259 [00:01<00:44,  5.67it/s][A
  3%|▎         | 7/259 [00:01<00:41,  6.03it/s][A
  3%|▎         | 8/259 [00:01<00:40,  6.15it/s][A
  3%|▎         | 9/259 [00:01<00:41,  6.02it/s][A
  4%|▍         | 10/259 [00:01<00:40,  6.18it/s][A
  4%|▍         | 11/259 [00:02<00:40,  6.06it/s][A
  5%|▍         | 12/259 [00:02<00:38,  6.46it/s][A
  5%|▌         | 13/259 [00:02<00:39,  6.17it/s][A
  5%|▌         | 14/259 [00:02<00:46,  5.31it/s][A
  6%|▌         | 15/259 [00:02<00:45,  5.38it/s][A
  6%|▌         | 16/259 [00:03<00:45,  5.34it/s][A
  7%|▋         | 17/259 [00:03<00:49,  4.88it/s][A
  7%|▋         | 18/259 [00:03<00:51,  4.66it/s][A
  7%|▋         | 19/259 [00:0

KeyboardInterrupt: 

In [None]:
sample

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [11]:
! python3 generation_evaluation.py

  WeightNorm.apply(module, name, dim)
Label: Class0
Label: Class1
Label: Class2
Label: Class3
#generated images: 100, #reference images: 519
100% 1/1 [00:00<00:00,  1.71it/s]
100% 5/5 [00:03<00:00,  1.45it/s]
Dimension 192 works! fid score: 25.48530164102886
Average fid score: 25.48530164102886
