In [59]:
%%capture

from typing import Dict, List, Optional
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import numpy as np
import scipy.ndimage as nd
import psutil

from utils import print_probs, transform, inverse_transform

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
# initialize net and set to evaluation mode
net = models.resnet50(pretrained=True).to(device);
net.eval()

# Each layer in this list represents the first however many layers of resnet
layers = [3,4,5,8] # For resnet50, layer can range from 0 to 9
children = list(net.children())
for i in range(len(layers)):
    layers[i] = nn.Sequential(*children)[:layers[i]]

In [63]:
learning_rate = 0.05
n_iterations = 20
n_octaves = 4
octave_scale = 1.4

# set mean and std_deviation for imagenet
mean = [.485, .456, .406]
std = [.229, .224, .225]

In [64]:
# preprocess image
img = Image.open("images/rj-flower.jpg")
# img.show()
img = transform(img).to(device)
img = torch.unsqueeze(img, 0)

#img = torch.rand(1, 3, 500, 500)

# deep copy the image
img_copy = img.clone().detach()

# # normalize learning rate
# learning_rate = learning_rate / (len(layers)*n_octaves)

# Each octave_img is a zoomed-in (i.e. lower-res) version of the previous image
# octave_imgs = [img[0].cpu().numpy()]
# for i in range(n_octaves-1):
#     new_octave_img = nd.zoom(octave_imgs[-1], (1, 1.0/octave_scale, 1.0/octave_scale),
#                          order=2)
#     octave_imgs.append(new_octave_img)
# for i in range(len(octave_imgs)):
#     octave_imgs[i] = torch.tensor(octave_imgs[i]).unsqueeze(0).float().to(device)
# # Make the list low to high res
# octave_imgs.reverse()

In [65]:
for i in range(n_iterations):
    for layer in layers:
        # apply jitter
        y_jitter, x_jitter = np.random.randint(-32, 32, size=2)
        img = torch.roll(img, shifts=(y_jitter, x_jitter), dims=(-2, -1))
        img = img.detach()
        img.requires_grad = True
        logits = layer(img)
        #loss = -(logits**2).mean()
        loss = -(logits[0]**2).mean()
        loss.backward()
        
        g = img.grad.data
        g = g/g.abs().mean()
        img = img - learning_rate*g
        
        # Normalize image 
        # from https://github.com/eriklindernoren/PyTorch-Deep-Dream
        for c in range(3):
            m, s = mean[c], std[c]
            img[0][c] = torch.clamp(img[0][c], -m/s, (1-m)/s)

        # undo jitter
        img = torch.roll(img, shifts=(-y_jitter, -x_jitter), dims=(-2, -1))

# Display the difference between the two images
diff = img_copy[0]-img[0].cpu()
diff = inverse_transform(diff)
diff.show()

# Display dreamed image
img = img[0].cpu()
img = inverse_transform(img)
img.show()
img.save("temp.jpg")