-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathdeepdream.py
73 lines (62 loc) · 2.42 KB
/
deepdream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import torch
from util import showtensor
import scipy.ndimage as nd
from torch.autograd import Variable
def objective_L2(dst, guide_features):
return dst.data
def make_step(img, model, control=None, distance=objective_L2):
mean = np.array([0.485, 0.456, 0.406]).reshape([3, 1, 1])
std = np.array([0.229, 0.224, 0.225]).reshape([3, 1, 1])
learning_rate = 2e-2
max_jitter = 32
num_iterations = 20
show_every = 10
end_layer = 3
guide_features = control
for i in range(num_iterations):
shift_x, shift_y = np.random.randint(-max_jitter, max_jitter + 1, 2)
img = np.roll(np.roll(img, shift_x, -1), shift_y, -2)
# apply jitter shift
model.zero_grad()
img_tensor = torch.Tensor(img)
if torch.cuda.is_available():
img_variable = Variable(img_tensor.cuda(), requires_grad=True)
else:
img_variable = Variable(img_tensor, requires_grad=True)
act_value = model.forward(img_variable, end_layer)
diff_out = distance(act_value, guide_features)
act_value.backward(diff_out)
ratio = np.abs(img_variable.grad.data.cpu().numpy()).mean()
learning_rate_use = learning_rate / ratio
img_variable.data.add_(img_variable.grad.data * learning_rate_use)
img = img_variable.data.cpu().numpy() # b, c, h, w
img = np.roll(np.roll(img, -shift_x, -1), -shift_y, -2)
img[0, :, :, :] = np.clip(img[0, :, :, :], -mean / std,
(1 - mean) / std)
if i == 0 or (i + 1) % show_every == 0:
showtensor(img)
return img
def dream(model,
base_img,
octave_n=6,
octave_scale=1.4,
control=None,
distance=objective_L2):
octaves = [base_img]
for i in range(octave_n - 1):
octaves.append(
nd.zoom(
octaves[-1], (1, 1, 1.0 / octave_scale, 1.0 / octave_scale),
order=1))
detail = np.zeros_like(octaves[-1])
for octave, octave_base in enumerate(octaves[::-1]):
h, w = octave_base.shape[-2:]
if octave > 0:
h1, w1 = detail.shape[-2:]
detail = nd.zoom(
detail, (1, 1, 1.0 * h / h1, 1.0 * w / w1), order=1)
input_oct = octave_base + detail
print(input_oct.shape)
out = make_step(input_oct, model, control, distance=distance)
detail = out - octave_base