# Deepdream

## Setup

In [None]:
import random
import math
import torch
import torch.nn.functional as F
import PIL
import copy
from tqdm.auto import tqdm
import torchvision.models as models
import transformers
from transformers import AutoImageProcessor, ResNetForImageClassification
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# load from torchvision pre-trained VGG19
vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
# load things from huggingface
transformers.logging.set_verbosity_error() # suppress warnings
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
resnet = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") # for imagenet labels
do_normalize = False # keep images human readable
# labels
id2label = resnet.config.id2label

## Helpers

In [None]:
# plot the image and optinal labels in a grid
def plot(*args, n_col=None, figsize=None, axis_off=True):
    '''Plot the images in a grid with optional labels'''
    # prepare the elements to plot as (image, label)
    elems = []
    for el in args:
        if isinstance(el, torch.Tensor):
            if el.dim() == 4: el = el[0] # if picture is a batch, take the first one
            assert el.dim() == 3, f"Wrong image shape {el.shape}"
            if el.shape[0] == 3: el = el.permute(1, 2, 0)
            if el.requires_grad: el = el.detach()
            elems.append([el])
        elif isinstance(el, str): elems[-1].append(el)
        else: raise ValueError(f"Wrong element ({type(el)}) {repr(el)}")
    elems = [(el[0], el[1] if len(el) == 2 else None) for el in elems]

    # plot the images
    n_col = n_col or len(elems)
    n_row = math.ceil(len(elems) / n_col)
    figsize = figsize or (n_col * 4, n_row * 4)
    fig, ax = plt.subplots(n_row, n_col, figsize=figsize)
    if n_row == 1 and n_col == 1: ax = [ax]
    for i, (img, label) in enumerate(elems):
        idx = (i,) if n_row == 1 or n_col == 1 else (i // n_col, i % n_col)
        if label: ax.__getitem__(*idx).set_title(label) # use __getitem__() because [] mess up the * operator somehow
        ax.__getitem__(*idx).imshow(img)
    # remove the axis
    for i in range(n_row * n_col):
        idx = (i,) if n_row == 1 or n_col == 1 else (i // n_col, i % n_col)
        if axis_off: ax.__getitem__(*idx).axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
# classify things
@torch.no_grad()
def classify(img, model):
    logits = model(img)
    if not isinstance(logits, torch.Tensor): logits = logits.logits
    confidence = logits.softmax(dim=-1).max()
    label = id2label[logits.argmax(dim=-1).item()]
    return label, confidence

def all_classify(img):
    res = []
    models = [
        ('VGG19', vgg),
        ('ResNet', resnet)
    ]
    for name, model in models:
        label, confidence = classify(img, model)
        if len(label) > 20: label = label[:20] + '...'
        res.append(f'{name:6}: {label:23} ({confidence*100:.2f}%)')
    return '\n'.join(res)

def plotl(*args):
    '''expecting tuples of (image, label)'''
    av = []
    for img, label in args:
        av.append(img)
        av.append(f'{label}\n{all_classify(img)}')
    plot(*av)

## Load images

In [None]:
original_vangogh = PIL.Image.open("vangogh.jpg")
original_sky = PIL.Image.open("sky.jpg")
original_kelpie = PIL.Image.open("kelpie.jpg")
vangogh = processor(original_vangogh, return_tensors="pt", do_normalize=do_normalize)['pixel_values']
sky = processor(original_sky, return_tensors="pt", do_normalize=do_normalize)['pixel_values']
kelpie = processor(original_kelpie, return_tensors="pt", do_normalize=do_normalize)['pixel_values']
plot(vangogh, sky, kelpie)

## VGG from torchvision

In [None]:
# hyperparameters
epochs = 1000
log_every = 100
learning_rate = 0.01
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# the VGG model on torchvision is much easier to look into than ResNet from huggingface
vgg.features

In [None]:
# hook some activation layer
def hook(layer, k, mem=None):
    if mem is None: mem = {}
    def f(module, input, output):
        mem[k] = output
    layer.register_forward_hook(f)
    return mem

In [None]:
def shallowdream(start, layer=35, m=vgg, learning_rate=learning_rate, epochs=epochs):
    start = copy.deepcopy(start.detach())
    # move to device
    dream = start.to(device).requires_grad_()
    m = m.to(device)
    # hook some layer
  
    mem = hook(m.features[layer], layer)
    # dream on

    for epoch in tqdm(range(epochs)):
        m(dream)
        # loss = mem[35].mean()
        loss = mem[layer].norm()
        dream.grad = None
        loss.backward()
        dream.data = torch.clip((dream + dream.grad * learning_rate), 0., 1.).data # jumping through hoops to please pytorch
        # logging
        if epoch % log_every == log_every - 1:
            print(f"{epoch: 4} {loss=}")
            # plot(dream.to('cpu').detach())
    return dream

dream = shallowdream(vangogh, layer=28, learning_rate=0.02, epochs=20).to('cpu')
plot(vangogh, 'starting point',
     dream, 'dream')

## dream by layer

In [None]:
start = kelpie
for layer in range(1, 36):
    dream = shallowdream(start, layer=layer, learning_rate=0.02, epochs=30).to('cpu')
    plot(start, 'starting point',
        dream, f'dream {layer=}')

## dream of creatures?

In [None]:
start = kelpie
# layer = 25 # building?
layer = 27 # creatures?
# layer = 30 # mountain?
# layer = 34 # bigger creatures?
for duration in range(10, 300, 50):
    dream = shallowdream(start, layer=layer, learning_rate=0.02, epochs=duration).to('cpu')
    plot(start, 'starting point',
         dream, f'dream {duration=}')

## random demos

In [None]:
for start in [kelpie, sky, vangogh]:
    for layer in range(1, 36):
        dream = shallowdream(start, layer=layer, learning_rate=0.02, epochs=30).to('cpu')
        plot(start, 'starting point',
            dream, f'dream {layer=}')

## with noise

In [None]:
start = torch.randn_like(vangogh) * 0.2
for layer in range(1, 36):
    dream = shallowdream(start, layer=layer, learning_rate=0.02, epochs=80).to('cpu')
    plot(start, 'starting point',
        dream, f'dream {layer=}')