In [1]:
# General library imports
import pandas as pd

from torchvision import transforms

import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
from torch.utils import data as t_data
from typing import Iterable


# Custom code imports
from DataSet.lightning_cifar import CIFARDataModule
from lightning_trainer import LitModelTrainer
from generate_dreams.render_engine import generate_dream


In [2]:
# Set global params

# Set default (and available) device
torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# First batch is used to generate dreams (so #batch_size). 
# subset size will be generated, inferred, saved to txt json file. 

# Set dream subset size and batch size:
subset_size = 256
batch_size = 64


# Model to use for dream gen location:
model_loc = "../trained_models/CIFAR/base_model/version_0/checkpoints/model-epoch=15.ckpt"

# Location of dream dataset:
# dream_loc = '../../data/cifar/base_tanh_final_32it_1e2/'
cifar_base_loc = '../../../data/cifar'

# Transforms to use before dreaming:
transform_dreams = transforms.Compose([
    transforms.ToTensor()
])

In [3]:
# Load model and data(sub)set
_ModelCheckpoint = LitModelTrainer.load_from_checkpoint(model_loc)

model : nn.Module = _ModelCheckpoint.model

model = model.to(torch_device)




In [4]:

class_labels =  ["airplanes", "cars", "birds", "cats", "deer", "dogs", "frogs", "horses", "ships", "trucks"]

# Load the two datasets as used in training:
# _data_module = CIFARDataModule(num_workers=0, batch_size=batch_size, shuffle_data=False, )
# _data_module.setup()
_clean_data = torchvision.datasets.CIFAR10(root=cifar_base_loc, train=True,transform=transform_dreams, download=True  )

# _clean_data = _data_module.cifar_full
# _dream_data = _data_module.dream_dataset

_data_idxs = list(range(subset_size))
_clean_subset = t_data.Subset(_clean_data, _data_idxs)


parallell_dataloader = DataLoader(dataset=_clean_subset, batch_size=batch_size,drop_last=True, num_workers=0, shuffle=False)


Files already downloaded and verified


In [5]:
def batch_logits_to_df(df: pd.DataFrame, normal_logits_batch, dream_logits_batch, batch_idx, labels_batch):

    for (_img_idx, normal_logits, dream_logits) in zip(range(len(normal_logits_batch)), normal_logits_batch, dream_logits_batch):
        img_idx = _img_idx + batch_idx * batch_size
        for (logit_idx, normal_logit, dream_logit) in zip(range(len(normal_logits)), normal_logits, dream_logits):

            logit_categorical = class_labels[logit_idx]
            
            _row_list = []
            _row_list.append({
                "img_index": img_idx,
                "img_label": labels_batch[img_idx % batch_size].item(),
                "logit_id": logit_categorical,
                'value_type': "normal",
                "logit_value": normal_logit.item() })
            _row_list.append({
                "img_index": img_idx,
                "img_label": labels_batch[img_idx % batch_size].item(),
                "logit_id": logit_categorical,
                'value_type': "dream",
                "logit_value": dream_logit.item() })
            _row_list.append({
                "img_index": img_idx,
                "img_label": labels_batch[img_idx % batch_size].item(),
                "logit_id": logit_categorical,
                'value_type': "diff",
                "logit_value": dream_logit.item() - normal_logit.item() })

            logit_row_df = pd.DataFrame(data=_row_list)
            df = pd.concat([df, logit_row_df], axis=0, ignore_index=True)
    return df

In [17]:

model.eval()
logit_dataframe = pd.DataFrame(columns=["img_index", "img_label", "logit_id", "value_type", "logit_value"])
optimizer_l = lambda p: torch.optim.Adam(p, lr=1e-2)

for batch_idx,batch in enumerate(parallell_dataloader):
    x, y = batch

    x, y = x.to(torch_device), y.to(torch_device)
    x_dream = generate_dream(optimizer_l=optimizer_l, iterations=32, eps_limit=1, l2_weight=0.0, verbose=True, model=model, batch=batch)
    with torch.no_grad():
        orig_logits = model(x)
        dream_logits = model(x_dream)

    logit_dataframe = batch_logits_to_df(logit_dataframe, orig_logits, dream_logits, batch_idx, y )
    

mean change: 0.04646193981170654, max change: 0.3159600794315338, min change: 0.0
it 32, loss: -32.59113311767578, average l2 other targets: 430.7247009277344, weight factor of L2 dist: 0.0
mean change: 0.048030786216259, max change: 0.3094111382961273, min change: 0.0
it 32, loss: -35.06561279296875, average l2 other targets: 434.37506103515625, weight factor of L2 dist: 0.0
mean change: 0.047301582992076874, max change: 0.3106195628643036, min change: 0.0
it 32, loss: -33.59252166748047, average l2 other targets: 434.1109924316406, weight factor of L2 dist: 0.0
mean change: 0.0476551316678524, max change: 0.3114229440689087, min change: 0.0
it 32, loss: -35.602996826171875, average l2 other targets: 410.9549560546875, weight factor of L2 dist: 0.0


In [18]:
logit_dataframe.to_pickle('32it_lr1e-2.pkl')