In [1]:
from google.colab import drive
# Mount your google drive to the /content/drive directory.
# This ensures that files changes remain saved in your drive.
drive.mount('/content/drive/')

Mounted at /content/drive/


In [2]:
%cd /content/drive/MyDrive/cpen455-project/

/content/drive/MyDrive/cpen455-project


In [3]:
!pip install -r requirements.txt

Collecting bidict (from -r requirements.txt (line 1))
  Downloading bidict-0.23.1-py3-none-any.whl.metadata (8.7 kB)
Collecting pytorch-fid (from -r requirements.txt (line 7))
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Collecting dotenv (from -r requirements.txt (line 11))
  Downloading dotenv-0.9.9-py2.py3-none-any.whl.metadata (279 bytes)
Collecting python-dotenv (from dotenv->-r requirements.txt (line 11))
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r requirements.txt (line 7))
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r requirements.txt (line 7))
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r r

In [None]:
import time
import os
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import wandb
from tqdm import tqdm
from pprint import pprint
import argparse
from pytorch_fid.fid_score import calculate_fid_given_paths
import logging

import dotenv
dotenv.load_dotenv()

True

In [None]:
import importlib
import utils, model, dataset, generation_evaluation

importlib.reload(utils)  # Reload the script after changes
importlib.reload(model)  # Reload the script after changes
importlib.reload(dataset)  # Reload the script after changes
importlib.reload(generation_evaluation)  # Reload the script after changes

from utils import *
from model import *
from dataset import *
import generation_evaluation

In [None]:
import logging
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.WARN)

In [None]:
from dataclasses import dataclass

@dataclass
class Config:
    en_wandb: bool = False
    tag: str = "default"
    sampling_interval: int = 5
    data_dir: str = "data"
    save_dir: str = "models"
    sample_dir: str = "samples"
    dataset: str = "cpen455"
    save_interval: int = 10
    load_params: str = None
    obs: tuple = (3, 32, 32)
    nr_resnet: int = 1
    nr_filters: int = 40
    nr_logistic_mix: int = 5
    lr: float = 0.0002
    lr_decay: float = 0.999995
    batch_size: int = 64
    sample_batch_size: int = 32
    max_epochs: int = 5000
    seed: int = 1

In [None]:
args = Config(
    batch_size=16,
    sample_batch_size=16,
    sampling_interval=25,
    save_interval=500,
    max_epochs=500,
    en_wandb=True,
    tag="EarlyMiddleEnd_1"
)

pprint(args.__dict__)
check_dir_and_create(args.save_dir)

# reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)

model_name = 'pcnn_' + args.dataset + "_"
model_path = args.save_dir + '/'
if args.load_params is not None:
    model_name = model_name + 'load_model'
    model_path = model_path + model_name + '/'
else:
    model_name = model_name + 'from_scratch'
    model_path = model_path + model_name + '/'

job_name = "PCNN_Training_" + "dataset:" + args.dataset + "_" + args.tag

if args.en_wandb:
    # start a new wandb run to track this script
    wandb.init(
        # set entity to specify your username or team name
        # entity="qihangz-work",
        # set the wandb project where this run will be logged
        project="CPEN455HW",
        # group=Group Name
        name=job_name,
    )
    wandb.config.current_time = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
    wandb.config.update(args)

#set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Reminder: if you have patience to read code line by line, you should notice this comment. here is the reason why we set num_workers to 0:
#In order to avoid pickling errors with the dataset on different machines, we set num_workers to 0.
#If you are using ubuntu/linux/colab, and find that loading data is too slow, you can set num_workers to 1 or even bigger.
kwargs = {'num_workers':2, 'pin_memory':True, 'drop_last':True}

{'batch_size': 16,
 'data_dir': 'data',
 'dataset': 'cpen455',
 'en_wandb': True,
 'load_params': None,
 'lr': 0.0002,
 'lr_decay': 0.999995,
 'max_epochs': 500,
 'nr_filters': 40,
 'nr_logistic_mix': 5,
 'nr_resnet': 1,
 'obs': (3, 32, 32),
 'sample_batch_size': 16,
 'sample_dir': 'samples',
 'sampling_interval': 25,
 'save_dir': 'models',
 'save_interval': 500,
 'seed': 1,
 'tag': 'EarlyMiddleEnd_1'}


In [None]:
# set data
if "mnist" in args.dataset:
    ds_transforms = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), rescaling, replicate_color_channel])
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, download=True,
                        train=True, transform=ds_transforms), batch_size=args.batch_size,
                            shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

elif "cifar" in args.dataset:
    ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling])
    if args.dataset == "cifar10":
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True,
            download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader  = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
    elif args.dataset == "cifar100":
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_dir, train=True,
            download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader  = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
    else:
        raise Exception('{} dataset not in {cifar10, cifar100}'.format(args.dataset))

elif "cpen455" in args.dataset:
    ds_transforms = transforms.Compose([transforms.Resize((32, 32)), rescaling])
    train_loader = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'train',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
    test_loader  = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'test',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
    val_loader  = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'validation',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
else:
    raise Exception('{} dataset not in {mnist, cifar, cpen455}'.format(args.dataset))

In [None]:
def train_or_test(model, data_loader, optimizer, loss_op, device, args, epoch, mode = 'training'):
    logging.debug('mode: {}'.format(mode))
    if mode == 'training':
        model.train()
    else:
        model.eval()

    deno =  args.batch_size * np.prod(args.obs) * np.log(2.)
    loss_tracker = mean_tracker()

    for batch_idx, item in enumerate(tqdm(data_loader)):
        logging.debug('batch_idx: {}'.format(batch_idx))
        model_input, label = item
        model_input = model_input.to(device)
        model_output = model(model_input, label)

        loss = loss_op(model_input, model_output)
        loss_tracker.update(loss.item()/deno)
        if mode == 'training':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    if args.en_wandb:
        wandb.login()
        wandb.log({mode + "-Average-BPD" : loss_tracker.get_mean()})
        wandb.log({mode + "-epoch": epoch})


In [None]:
args.obs = (3, 32, 32)
input_channels = args.obs[0]

loss_op   = lambda real, fake : discretized_mix_logistic_loss(real, fake)
sample_op = lambda x : sample_from_discretized_mix_logistic(x, args.nr_logistic_mix)

model = PixelCNN(nr_resnet=args.nr_resnet, nr_filters=args.nr_filters,
            input_channels=input_channels, nr_logistic_mix=args.nr_logistic_mix)
model = model.to(device)

if args.load_params:
    model.load_state_dict(torch.load(args.load_params))
    print('model parameters loaded')

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)

for epoch in tqdm(range(args.max_epochs)):
    train_or_test(model = model,
                  data_loader = train_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'training')

    # decrease learning rate
    scheduler.step()
    train_or_test(model = model,
                  data_loader = test_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'test')

    train_or_test(model = model,
                  data_loader = val_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'val')

    if epoch % args.sampling_interval == 0:
        print('......sampling......')
        sampled_images = generation_evaluation.my_sample(model, args.sample_dir, args.sample_batch_size, args.obs, sample_op)
        sample_result = [wandb.Image(img, caption="epoch {}".format(epoch)) for img in sampled_images]

        gen_data_dir = args.sample_dir
        ref_data_dir = args.data_dir +'/test'
        paths = [gen_data_dir, ref_data_dir]
        try:
            fid_score = calculate_fid_given_paths(paths, 32, device, dims=192)
            print("Dimension {:d} works! fid score: {}".format(192, fid_score))
        except:
            print("Dimension {:d} fails!".format(192))

        if args.en_wandb:
          for img in sample_result:
            wandb.log({"samples": img,
                        "FID": fid_score})

    if (epoch + 1) % args.save_interval == 0:
        if not os.path.exists("models"):
            os.makedirs("models")
        torch.save(model.state_dict(), 'models/{}_{}.pth'.format(model_name, epoch))


  WeightNorm.apply(module, name, dim)
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:35,  2.71it/s][A
  1%|          | 3/259 [00:00<00:41,  6.10it/s][A
  2%|▏         | 5/259 [00:00<00:31,  8.01it/s][A
  3%|▎         | 7/259 [00:00<00:27,  9.04it/s][A
  3%|▎         | 9/259 [00:01<00:25,  9.91it/s][A
  4%|▍         | 11/259 [00:01<00:23, 10.62it/s][A
  5%|▌         | 13/259 [00:01<00:22, 10.93it/s][A
  6%|▌         | 15/259 [00:01<00:21, 11.29it/s][A
  7%|▋         | 17/259 [00:01<00:20, 11.60it/s][A
  7%|▋         | 19/259 [00:01<00:20, 11.49it/s][A
  8%|▊         | 21/259 [00:02<00:20, 11.64it/s][A
  9%|▉         | 23/259 [00:02<00:19, 11.83it/s][A
 10%|▉         | 25/259 [00:02<00:19, 11.89it/s][A
 10%|█         | 27/259 [00:02<00:19, 11.91it/s][A
 11%|█         | 29/259 [00:02<00:19, 12.10it/s][A
 12%|█▏        | 31/259 [00:02<00:19, 11.74it/s][A
 13%|█▎        | 33/259 [00:03<00:19, 11.66it/s][A


......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:00<00:01,  2.53it/s][A
 60%|██████    | 3/5 [00:00<00:00,  6.52it/s][A
100%|██████████| 5/5 [00:00<00:00,  7.32it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:16,  1.05s/it][A
 12%|█▏        | 2/17 [00:01<00:07,  2.03it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  4.14it/s][A
 35%|███▌      | 6/17 [00:01<00:01,  5.81it/s][A
 47%|████▋     | 8/17 [00:01<00:01,  7.24it/s][A
 59%|█████▉    | 10/17 [00:01<00:00,  8.26it/s][A
 71%|███████   | 12/17 [00:02<00:00,  8.96it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  9.56it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.84it/s]
  0%|          | 1/500 [02:01<16:54:20, 121.97s/it]

Dimension 192 works! fid score: 91.19650200523657


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  6%|▌         | 16/259 [00:02<00:35,  6.89it/s][A
  7%|▋         | 17/259 [00:02<00:34,  6.93it/s][A
  7%|▋         | 18/259 [00:02<00:34,  7.01it/s][A
  7%|▋         | 19/259 [00:02<00:33,  7.25it/s][A
  8%|▊         | 20/259 [00:02<00:36,  6.46it/s][A
  8%|▊         | 21/259 [00:03<00:36,  6.46it/s][A
  8%|▊         | 22/259 [00:03<00:36,  6.53it/s][A
  9%|▉         | 23/259 [00:03<00:34,  6.82it/s][A
  9%|▉         | 24/259 [00:03<00:33,  7.03it/s][A
 10%|▉         | 25/259 [00:03<00:33,  7.07it/s][A
 10%|█         | 26/259 [00:03<00:32,  7.11it/s][A
 10%|█         | 27/259 [00:03<00:33,  7.01it/s][A
 11%|█         | 28/259 [00:04<00:32,  7.22it/s][A
 11%|█         | 29/259 [00:04<00:31,  7.36it/s][A
 12%|█▏        | 30/259 [00:04<00:32,  7.10it/s][A
 12%|█▏        | 31/259 [00:04<00:31,  7.32it/s][A
 12%|█▏        | 32/259 [00:04<00:31,  7.22it/s][A
 13%|█▎        | 33/259 [00:04<00:31,  7.22it/s][A

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:00<00:01,  2.35it/s][A
 40%|████      | 2/5 [00:00<00:00,  4.07it/s][A
100%|██████████| 5/5 [00:00<00:00,  6.38it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:24,  1.51s/it][A
 12%|█▏        | 2/17 [00:01<00:10,  1.46it/s][A
 18%|█▊        | 3/17 [00:01<00:05,  2.36it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.33it/s][A
 35%|███▌      | 6/17 [00:02<00:02,  5.28it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  6.79it/s][A
 59%|█████▉    | 10/17 [00:02<00:00,  7.86it/s][A
 71%|███████   | 12/17 [00:02<00:00,  8.74it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  9.28it/s][A
100%|██████████| 17/17 [00:02<00:00,  5.67it/s]
  5%|▌         | 26/500 [17:51<8:27:19, 64.22s/it]

Dimension 192 works! fid score: 78.7496675571691


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 67%|██████▋   | 174/259 [00:19<00:07, 10.88it/s][A
 68%|██████▊   | 176/259 [00:19<00:07, 11.15it/s][A
 69%|██████▊   | 178/259 [00:19<00:07, 11.18it/s][A
 69%|██████▉   | 180/259 [00:19<00:07, 11.23it/s][A
 70%|███████   | 182/259 [00:19<00:07, 10.79it/s][A
 71%|███████   | 184/259 [00:19<00:06, 10.83it/s][A
 72%|███████▏  | 186/259 [00:20<00:06, 10.83it/s][A
 73%|███████▎  | 188/259 [00:20<00:06, 10.87it/s][A
 73%|███████▎  | 190/259 [00:20<00:06, 10.99it/s][A
 74%|███████▍  | 192/259 [00:20<00:06, 10.95it/s][A
 75%|███████▍  | 194/259 [00:20<00:06, 10.68it/s][A
 76%|███████▌  | 196/259 [00:21<00:05, 10.81it/s][A
 76%|███████▋  | 198/259 [00:21<00:05, 10.95it/s][A
 77%|███████▋  | 200/259 [00:21<00:05, 10.90it/s][A
 78%|███████▊  | 202/259 [00:21<00:05, 10.90it/s][A
 79%|███████▉  | 204/259 [00:21<00:05, 10.50it/s][A
 80%|███████▉  | 206/259 [00:22<00:04, 10.71it/s][A
 80%|████████  | 208/259 [00:22<00

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:00<00:01,  2.53it/s][A
 60%|██████    | 3/5 [00:00<00:00,  5.90it/s][A
100%|██████████| 5/5 [00:00<00:00,  6.72it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:00<00:12,  1.29it/s][A
 12%|█▏        | 2/17 [00:01<00:09,  1.55it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.36it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  4.99it/s][A
 41%|████      | 7/17 [00:01<00:01,  5.67it/s][A
 47%|████▋     | 8/17 [00:01<00:01,  6.29it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.88it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  8.32it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  9.12it/s][A
 88%|████████▊ | 15/17 [00:02<00:00,  9.64it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.24it/s]
 10%|█         | 51/500 [34:14<8:10:29, 65.54s/it]

Dimension 192 works! fid score: 81.79802159127567



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:31,  2.81it/s][A
  1%|          | 2/259 [00:00<01:01,  4.17it/s][A
  1%|          | 3/259 [00:00<00:48,  5.24it/s][A
  2%|▏         | 4/259 [00:00<00:44,  5.79it/s][A
  2%|▏         | 5/259 [00:00<00:40,  6.31it/s][A
  2%|▏         | 6/259 [00:01<00:39,  6.44it/s][A
  3%|▎         | 7/259 [00:01<00:35,  7.11it/s][A
  3%|▎         | 8/259 [00:01<00:36,  6.86it/s][A
  3%|▎         | 9/259 [00:01<00:37,  6.62it/s][A
  4%|▍         | 11/259 [00:01<00:31,  7.76it/s][A
  5%|▍         | 12/259 [00:01<00:33,  7.41it/s][A
  5%|▌         | 13/259 [00:01<00:32,  7.53it/s][A
  5%|▌         | 14/259 [00:02<00:35,  6.88it/s][A
  6%|▌         | 15/259 [00:02<00:35,  6.83it/s][A
  6%|▌         | 16/259 [00:02<00:34,  6.99it/s][A
  7%|▋         | 17/259 [00:02<00:32,  7.34it/s][A
  7%|▋         | 18/259 [00:02<00:31,  7.58it/s][A
  7%|▋         | 19/259 [00:02<00:30,  7.96it/s][A
  8%|▊         | 20/259 [00:0

KeyboardInterrupt: 

In [None]:
sample

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)