In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from google.colab import drive

drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [None]:
import os
import sys

# change this to your corresponding folder
GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = 'stat-453'
GOOGLE_DRIVE_PATH = os.path.join('drive', 'MyDrive', GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)

print(os.listdir(GOOGLE_DRIVE_PATH))


# Add to sys so we can import .py files.
sys.path.append(GOOGLE_DRIVE_PATH)

['data', '__pycache__', 'models', 'model.py', 'noise_testing', 'videos', 'diffusion.py', 'stat_453', 'experiment.json', 'STAT 453.gslides', 'stat_453_test']


In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data as data_utils

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('Training on',DEVICE)

Training on cuda


# Prepare dataset

In [None]:
DATASET_PATH = os.path.join(GOOGLE_DRIVE_PATH, 'data/images')

DATASET_PATH

'drive/MyDrive/stat-453/data/images'

In [None]:
image_size = 64
batch_size = 48

dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([
                               transforms.Resize((image_size, image_size)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                           ]))

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose([
                               transforms.Resize((image_size, image_size)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                           ]))

print("Total number of classes in the dataset: ", len(dataset))
print("Number of different samples in the dataset: ", len(dataset.classes))

test_dataloader = data_utils.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers = 2,
                                        )

train_dataloader = data_utils.DataLoader(train_dataset,
                                         batch_size=28,
                                         shuffle=True,
                                         num_workers = 2,
                                        )

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 36705177.90it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Total number of classes in the dataset:  10000
Number of different samples in the dataset:  10


In [None]:
# uncomment to plot the samples in the dataset

# import matplotlib.pyplot as plt
# import numpy as np
# import torchvision.utils as vutils

# real_batch, _ = next(iter(train_dataloader))
# plt.figure(figsize=(8,8))
# plt.axis("off")
# plt.title("Training Images")
# plt.imshow(np.transpose(vutils.make_grid(real_batch[:16],
#                                          nrow = 4,
#                                          padding=2,
#                                          normalize=True),(1,2,0)))

# Test

In [None]:
!pip install timm
!pip install detectors

Collecting timm
  Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->timm)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->timm)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->timm)
  Using cach

In [None]:
import detectors
import timm

model = timm.create_model("resnet34_cifar10", pretrained=True)
model1 = timm.create_model("resnet34_cifar10", pretrained=True)

Downloading: "https://huggingface.co/edadaltocg/resnet34_cifar10/resolve/main/pytorch_model.bin" to /root/.cache/torch/hub/checkpoints/resnet34_cifar10.pth
100%|██████████| 81.3M/81.3M [00:00<00:00, 345MB/s]


In [None]:
import torch.nn as nn

model.fc = nn.Identity()
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): Identity()
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_

In [None]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/841.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.2/841.5 kB[0m [31m7.9 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━[0m [32m778.2/841.5 kB[0m [31m11.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.2-py3-none-any.whl (26 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.2 torchmetrics-1.3.2


In [None]:
# from utils import *
from diffusion import *
from model import *
import torch.optim as optim

cifar10_dict = dataset.class_to_idx
dtype = dataset.__getitem__(0)[0].dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(dataset.classes)
beta = 0.995

# load diffusion
unet = UNet(num_classes = num_classes, device = DEVICE)
unet.load_state_dict(torch.load(os.path.join(GOOGLE_DRIVE_PATH, "models", f"cifar_ckpt.pt")))
unet = unet.to(device)

ema_model = optim.swa_utils.AveragedModel(unet, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(beta))
ema_model.load_state_dict(torch.load(os.path.join(GOOGLE_DRIVE_PATH, "models", f"cifar_ema_ckpt.pt")))
ema_model = ema_model.to(DEVICE)

optim.swa_utils.update_bn(train_dataloader, ema_model)

diffusion = Improved_CFG_Diffusion(1024, (image_size, image_size), dtype = dtype, device = DEVICE)

fid_score = 0
inception_score = 0
num = 0

import torch
from torch import nn, optim

from torchmetrics.image.fid import FrechetInceptionDistance
model = model.to(DEVICE)
fid = FrechetInceptionDistance(feature=model, normalize = False)
fid = fid.to(DEVICE)

from torchmetrics.image.inception import InceptionScore
inception = InceptionScore(feature=model1, normalize = False)
inception = inception.to(DEVICE)

import json
f = open(os.path.join(GOOGLE_DRIVE_PATH, 'experiment.json'))
data = json.load(f)


for guidance_strength_raw in data.keys():

  if (data[guidance_strength_raw]["fid"] > -1) and (data[guidance_strength_raw]["inception"] > -1):
    continue

  fid_score = 0
  inception_score = 0
  num = 0
  guidance_strength = float(guidance_strength_raw)

  for i, (x, y) in enumerate(test_dataloader):
    x = x.to(DEVICE)
    y = y.to(DEVICE)

    ema_model = ema_model.to(DEVICE)
    ema_model.eval()
    with torch.no_grad():
      # create a random noise to generate image
      noise = torch.randn((x.shape[0], 3, image_size, image_size), dtype = dtype, device = DEVICE)

      # get the label
      fake = diffusion.inference(ema_model, noise, label = y, guidance_strength = guidance_strength)

      fid.update(x, real=True)
      fid.update((fake - 0.5) / 0.5, real=False)
      fid_score += fid.compute()
      fid.reset()

      inception.update((fake - 0.5) / 0.5)
      inception_score += inception.compute()[0]
      inception.reset()
      num += 1

    if (i % 50 == 0):
      print(f"At epoch {i} - Inception Score with Guidance Strength = {guidance_strength}: {inception_score}")
      print(f"At epoch {i} - FID Score with Guidance Strength = {guidance_strength}: {fid_score}")


  fid_score /= num
  inception_score /= num
  print(f"Inception Score with Guidance Strength = {guidance_strength}: {inception_score}")
  print(f"FID Score with Guidance Strength = {guidance_strength}: {fid_score}")
  data[guidance_strength_raw]["fid"] = float(fid_score.cpu().numpy())
  data[guidance_strength_raw]["inception"] = float(inception_score.cpu().numpy())
  with open(os.path.join(GOOGLE_DRIVE_PATH, 'experiment.json'), 'w') as f:
    json.dump(data, f)





  self.pid = os.fork()


At epoch 0 - Inception Score with Guidance Strength = 0.0: 2.3016562461853027
At epoch 0 - FID Score with Guidance Strength = 0.0: 1.3937277793884277


In [None]:
  print(fid_score.cpu().detach().numpy())

  data[guidance_strength_raw]["fid"] = float(fid_score.cpu().numpy())
  data[guidance_strength_raw]["inception"] = float(inception_score.cpu().numpy())
  with open(os.path.join(GOOGLE_DRIVE_PATH, 'experiment.json'), 'w') as f:
    json.dump(data, f)