In [1]:
%matplotlib inline

In [2]:
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 300

import numpy as np

import json

import torch
from torch.utils.data import DataLoader

import PIL
from PIL import GifImagePlugin, Image, ImageDraw, ImageFont

import torchvision
from torchvision.datasets import CIFAR10, MNIST
from torchvision import transforms

from generate_dreams.render_engine import generate_dream


In [3]:

torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
models = {}
use_dataset:str = "CIFAR"
# use_dataset:str = "MNIST"
# 
_model_loc_list_cifar = {
    "basemodel": './modelfolder/cifar_base.pkl',
    # "PGD5": './modelfolder/cifar_pgd5.pkl',
    # "PGDDream-16itlr1e-3": 'modelfolder/cifar_PGD_dream16_1e-3.pkl',
    # "CW20-0.1": 'modelfolder/cifar_CW20.pkl'
}

_model_loc_list_mnist = {
    "basemodel": './modelfolder/mnist_base.pkl',
    # "PGD5": './modelfolder/mnist_pgd5.pkl',
}

if "MNIST" in use_dataset:
    _model_loc_list = _model_loc_list_mnist
else:
    _model_loc_list = _model_loc_list_cifar

for _modelname, _modelloc in _model_loc_list.items():
    _model = torch.load(_modelloc, map_location=torch_device)
    _model.eval()
    models[_modelname] = _model


# save_data = True
save_data = False
save_loc = 'imgs/test'
save_suffix = '_tt'


bsize = 256

In [4]:
_transforms = transforms.Compose([transforms.ToTensor(),])
_to_pil = transforms.ToPILImage()

_CIFAR_data = CIFAR10(root='../../data/cifar/', train=True, download=True, transform=_transforms)
cifar_dataloader = DataLoader(dataset=_CIFAR_data, batch_size=bsize, shuffle=False, num_workers=0)


_MNIST_data = MNIST(root='../../data/mnist/', train=True, download=True, transform=_transforms)
mnist_dataloader = DataLoader(dataset=_MNIST_data, batch_size=bsize, shuffle=False, num_workers=0)

if "MNIST" in use_dataset:
    dataloader = mnist_dataloader
else:
    dataloader = cifar_dataloader



Files already downloaded and verified


In [5]:
# img_b = next(iter(dataloader))
# _x, _y = img_b


In [6]:
img_b = next(iter(dataloader))
_x, _y = img_b

_unique, _inverse = torch.unique(_y, return_inverse=True, dim=0)

_perm = torch.arange(_inverse.size(0), dtype=_inverse.dtype, device=_inverse.device)

_inverse, _perm = _inverse.flip([0]), _perm.flip([0])

_first_class_idxs = _inverse.new_empty(_unique.size(0)).scatter_(0, _inverse, _perm)
_dream_imgs = _x[_first_class_idxs]
_dream_lbl = torch.arange(0, 10, 1, device=torch_device)

boatimg, boatlbl = _dream_imgs[8].unsqueeze(0), _dream_lbl[8].unsqueeze(0)
# boatimg, boatlbl = _dream_imgs[5].unsqueeze(0), _dream_lbl[5].unsqueeze(0)

img_b = (boatimg, boatlbl)


In [7]:
# boatimg = _to_pil(boatimg)
# with open("imgs/test/verboat.jpeg", 'wb') as f:
#     boatimg.save(fp=f, format="jpeg", quality=100)

In [8]:
# boatimg = _x[5]
# boatimg = _to_pil(boatimg)
# with open("imgs/5.jpeg", 'wb') as f:
#     boatimg.save(fp=f , format="jpeg", quality=100)

In [9]:
# if save_gif:
#     pil_img_array[0].save(gif_save_loc + 'lr1e3_smooth_cat_pgd5.gif', format='GIF',
#         append_images=pil_img_array[1:],
#         duration=50,
#         interlace=False,
#         save_all=True,
#         loop=0)

In [10]:
def calculate_lpdist(im1:torch.Tensor, im2:torch.Tensor, ord=2):
    diff = torch.sub(im1, im2)
    return torch.linalg.vector_norm(diff, ord=ord, dim=(-3, -2, -1))


In [11]:
def print_stats(t1: torch.Tensor, t2: torch.Tensor, idx):
    t1, t2 = t1.cpu().clone(), t2.cpu().clone()
    diffs = torch.sub(t1, t2)
    l2 = torch.linalg.vector_norm(diffs, ord=2, dim=(-3, -2, -1))
    li = torch.linalg.vector_norm(diffs, ord=float('inf'), dim=(-3, -2, -1))
    
    avg_l2 = torch.mean(l2)
    max_l2 = torch.max(l2)
    min_l2 = torch.min(l2)

    avg_li = torch.mean(li)
    max_li = torch.max(li)
    min_li = torch.min(li)

    l2_std = torch.std(l2)
    li_std = torch.std(li)

    return {
        "idx": idx,
        "shape":l2.shape[0],
        "avg_l2": avg_l2.item(),
        "max_l2": max_l2.item(),
        "min_l2": min_l2.item(),

        "avg_li": avg_li.item(),
        "max_li": max_li.item(),
        "min_li": min_li.item(),

        # "l2_std": l2_std.item(),
        # "li_std": li_std.item(),
       
        }

In [12]:
# iterlist = np.arange(0, 256, 32)
iterlist = np.arange(0, 129, 1)
# iterlist = np.arange(0, 6, 2)
# iterlist = np.arange(0, 512, 64)

origs, lbls = img_b
origs = origs.detach().clone().to(torch_device)
lbls = lbls.detach().clone().to(torch_device)

lr =1e-1

dreams = generate_dream(model=models["basemodel"], batch=img_b, device=torch_device, opt_lr=lr, iterations=iterlist, parametrization="tanh")
statlist= []
totablel2 = []
totableli = []
for idx, dream in enumerate(dreams):
    stats = print_stats(origs, dream, idx)
    statlist.append(stats)
    totablel2.append(stats["avg_l2"])
    totableli.append(stats["avg_li"])
    # print(f"stats for iter {iterlist[idx]}: {json.dumps(stats, indent=2)}")
_save_name = f"single_{use_dataset}_{bsize}_{lr}"

In [13]:
intervalrange = np.arange(2, 17, 2)
print(intervalrange)

[ 2  4  6  8 10 12 14 16]


In [14]:

l2listfloats = [num for idx, num in enumerate(totablel2) if idx in intervalrange]
lilistfloats = [num for idx, num in enumerate(totableli) if idx in intervalrange]

l2str = ""
for val in l2listfloats:
    l2str += f"& {val:.3f} " 
print(l2str  + "\\\\")
    
listr = ""
for val in lilistfloats:
    listr += f"& {val:.3f} " 
print(listr + "\\\\")

& 3.272 & 4.891 & 6.228 & 7.411 & 8.495 & 9.494 & 10.404 & 11.241 \\
& 0.109 & 0.202 & 0.290 & 0.377 & 0.447 & 0.509 & 0.592 & 0.663 \\


In [15]:
with open((save_loc + "/" + _save_name + ".json"), 'w') as f:
    json.dump(statlist, f, indent=4)

In [16]:

# if save_data:
#     for name, dream_img in dreams_to_save.items():
#         with open(name + ".jpeg", 'wb') as f:
#             # dream_img.save(fp=f , format="bmp", resolution=50)
#             dream_img.save(fp=f , format="jpeg", quality=100)


# display(dreams[0.01][0][8])