Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 116 additions & 113 deletions vignettes/examples/deep_dream.R
Original file line number Diff line number Diff line change
@@ -1,32 +1,57 @@


# Setup

library(keras)
library(tensorflow)


base_image_path = get_file('paris.jpg', 'https://i.imgur.com/aGBdQyK.jpg')
result_prefix = 'sky_dream'

# These are the names of the layers
# for which we try to maximize activation,
# as well as their weight in the final loss
# we try to maximize.
# You can tweak these setting to obtain new visual effects.
layer_settings = list(
'mixed4' = 1.0,
'mixed5' = 1.5,
'mixed6' = 2.0,
'mixed7' = 2.5
)

# Playing with these hyperparameters will also allow you to achieve new effects
step = 0.01 # Gradient ascent step size
num_octave = 3 # Number of scales at which to run gradient ascent
octave_scale = 1.4 # Size ratio between scales
iterations = 20 # Number of ascent steps per scale
max_loss = 15.

# Utility functions -------------------------------------------------------
# This is our base image:
plot(magick::image_read(base_image_path))

# Util function to open, resize, and format pictures into tensors that Inception V3 can process
# Let's set up some image preprocessing/deprocessing utilities:
preprocess_image <- function(image_path) {
image_load(image_path) %>%
image_to_array() %>%
array_reshape(dim = c(1, dim(.))) %>%
inception_v3_preprocess_input()
}

# Util function to convert a tensor into a valid image
deprocess_image <- function(img) {
img <- array_reshape(img, dim = c(dim(img)[[2]], dim(img)[[3]], 3))
# Undoes preprocessing that was performed by `imagenet_preprocess_input`
img <- img / 2
img <- img + 0.5
img <- img * 255

dims <- dim(img)
img <- pmax(0, pmin(img, 255))
dim(img) <- dims
# Util function to open, resize and format pictures
# into appropriate arrays.
img = tf$keras$preprocessing$image$load_img(image_path)
img = tf$keras$preprocessing$image$img_to_array(img)
img = tf$expand_dims(img, axis=0L)
img = inception_v3_preprocess_input(img)
img
}

resize_img <- function(img, size) {
image_array_resize(img, size[[1]], size[[2]])

deprocess_image <- function(x) {
x = array_reshape(x, dim = c(dim(img)[[2]], dim(img)[[3]], 3))
# Undo inception v3 preprocession
x = x / 2.
x = x + 0.5
x = x * 255.
# Convert to uint8 and clip to the valid range [0, 255]
x = tf$clip_by_value(x, 0L, 255L) %>% tf$cast(dtype = 'uint8')
x
}

save_img <- function(img, fname) {
Expand All @@ -35,131 +60,109 @@ save_img <- function(img, fname) {
}


# Model ----------------------------------------------

# You won't be training the model, so this command disables all training-specific operations.
k_set_learning_phase(0)

# Builds the Inception V3 network, without its convolutional base. The model will be loaded with pretrained ImageNet weights.
# Build an InceptionV3 model loaded with pre-trained ImageNet weights
model <- application_inception_v3(weights = "imagenet",
include_top = FALSE)

# Named list mapping layer names to a coefficient quantifying how much the layer's activation contributes to the loss you'll seek to maximize. Note that the layer names are hardcoded in the built-in Inception V3 application. You can list all layer names using `summary(model)`.
layer_contributions <- list(
mixed2 = 0.2,
mixed3 = 3,
mixed4 = 2,
mixed5 = 1.5
)

# You'll define the loss by adding layer contributions to this scalar variable
loss <- k_variable(0)
for (layer_name in names(layer_contributions)) {
coeff <- layer_contributions[[layer_name]]
# Get the symbolic outputs of each "key" layer (we gave them unique names).
outputs_dict = list()
for (layer_name in names(layer_settings)) {
coeff <- layer_settings[[layer_name]]
# Retrieves the layer's output
activation <- get_layer(model, layer_name)$output
scaling <- k_prod(k_cast(k_shape(activation), "float32"))
# Retrieves the layer's output
loss <- loss + (coeff * k_sum(k_square(activation)) / scaling)
outputs_dict[[layer_name]] <- activation
}

# Retrieves the layer's output
dream <- model$input

# Computes the gradients of the dream with regard to the loss
grads <- k_gradients(loss, dream)[[1]]

# Normalizes the gradients (important trick)
grads <- grads / k_maximum(k_mean(k_abs(grads)), 1e-7)

outputs <- list(loss, grads)

# Sets up a Keras function to retrieve the value of the loss and gradients, given an input image
fetch_loss_and_grads <- k_function(list(dream), outputs)

eval_loss_and_grads <- function(x) {
outs <- fetch_loss_and_grads(list(x))
loss_value <- outs[[1]]
grad_values <- outs[[2]]
list(loss_value, grad_values)
# Set up a model that returns the activation values for every target layer
# (as a named list)
feature_extractor = keras_model(inputs = model$inputs,
outputs = outputs_dict)

compute_loss <- function(input_image) {
features = feature_extractor(input_image)
names(features) = names(layer_settings)
loss = tf$zeros(shape=list())
for (names in names(layer_settings)) {
coeff = layer_settings[[names]]
activation = features[[names]]
# We avoid border artifacts by only involving non-border pixels in the loss.
scaling = tf$reduce_prod(tf$cast(tf$shape(activation), 'float32'))
loss = loss + coeff * tf$reduce_sum(tf$square(activation)) / scaling
}
loss
}

# Set up the gradient ascent loop for one octave
gradient_ascent_step <- function(img, learning_rate) {
with(tf$GradientTape() %as% tape, {
tape$watch(img)
loss = compute_loss(img)
})
# Compute gradients.
grads = tape$gradient(loss, img)
# Normalize gradients.
grads = grads / tf$maximum(tf$reduce_mean(tf$abs(grads)), 1e-6)
img = img + learning_rate * grads
list(loss, img)
}

# Run gradient ascent -----------------------------------------------------

# This function runs gradient ascent for a number of iterations.
gradient_ascent <-
function(x, iterations, step, max_loss = NULL) {
for (i in 1:iterations) {
c(loss_value, grad_values) %<-% eval_loss_and_grads(x)
if (!is.null(max_loss) && loss_value > max_loss)
break
cat("...Loss value at", i, ":", loss_value, "\n")
x <- x + (step * grad_values)
}
x
gradient_ascent_loop <- function(img, iterations, learning_rate, max_loss = NULL) {
for (i in 1:iterations) {
c(loss, img) %<-% gradient_ascent_step(img, learning_rate)
if (!is.null(max_loss) && as.array(loss) > max_loss)
break
cat("...Loss value at step", i, ":", as.array(loss), "\n")
}
img
}

# Playing with these hyperparameters will let you achieve new effects.
# Gradient ascent step size
step <- 0.01
# Number of scales at which to run gradient ascent
num_octave <- 3
# Size ratio between scales
octave_scale <- 1.4
# Number of ascent steps to run at each scale
iterations <- 20
# If the loss grows larger than 10, we will interrupt the gradient-ascent process to avoid ugly artifacts.
max_loss <- 10

# Fill this with the path to the image you want to use.
base_image_path <- "/tmp/mypic.jpg"

# Loads the base image into an array
img <-
preprocess_image(base_image_path)
# Run the training loop, iterating over different octaves
original_img = preprocess_image(base_image_path)

# Prepares a list of shape tuples defining the different scales at which to run gradient ascent
original_shape <- dim(img)[-1]
successive_shapes <-
list(original_shape)
original_shape <- dim(original_img)[2:3]

successive_shapes <- list(original_shape)

for (i in 1:num_octave) {
shape <- as.integer(original_shape / (octave_scale ^ i))
successive_shapes[[length(successive_shapes) + 1]] <-
shape
successive_shapes[[length(successive_shapes) + 1]] <- shape
}
# Reverses the list of shapes so they're in increasing order
successive_shapes <-
rev(successive_shapes)

original_img <- img
# Reverses the list of shapes so they're in increasing order
successive_shapes <- rev(successive_shapes[1:3])
# Resizes the array of the image to the smallest scale
shrunk_original_img <-
resize_img(img, successive_shapes[[1]])
shrunk_original_img <- tf$image$resize(original_img, successive_shapes[[1]])

img = tf$identity(original_img) # Make a copy

for (shape in successive_shapes) {
cat("Processing image shape", shape, "\n")
for (i in 1:length(successive_shapes)) {
shape = successive_shapes[[i]]
cat("Processing octave", i, "with shape", shape, "\n")
# Scales up the dream image
img <- resize_img(img, shape)
img <- tf$image$resize(img, shape)
# Runs gradient ascent, altering the dream
img <- gradient_ascent(img,
iterations = iterations,
step = step,
max_loss = max_loss)
img <- gradient_ascent_loop(img,
iterations = iterations,
learning_rate = step,
max_loss = max_loss)
# Scales up the smaller version of the original image: it will be pixellated
upscaled_shrunk_original_img <-
resize_img(shrunk_original_img, shape)
tf$image$resize(shrunk_original_img, shape)
# Computes the high-quality version of the original image at this size
same_size_original <-
resize_img(original_img, shape)
tf$image$resize(original_img, shape)
# The difference between the two is the detail that was lost when scaling up
lost_detail <-
same_size_original - upscaled_shrunk_original_img
# Reinjects lost detail into the dream
img <- img + lost_detail
shrunk_original_img <-
resize_img(original_img, shape)
save_img(img, fname = sprintf("dream_at_scale_%s.png",
paste(shape, collapse = "x")))
tf$image$resize(original_img, shape)
tf$keras$preprocessing$image$save_img(paste(result_prefix,'.png',sep = ''), deprocess_image(img$numpy()))
}

# Plot result
plot(magick::image_read(paste(result_prefix,'.png',sep = '')))