In [None]:
import tensorflow as tf
import numpy as np
from scipy.ndimage.filters import gaussian_filter

In [None]:
from utils import *
import inception

In [None]:
inception.download_model()

In [None]:
model = inception.Inception5H()

In [None]:
def get_optimal_tile_size(image_dim, tile_size):
    # Get the number of tiles
    num_tiles = int(round(image_dim / tile_size))
    
    # Ensure that there must be atleat one tile
    max_tiles = max(1, num_tiles)
    optim_tile_size = image_dim // max_tiles
    
    return optim_tile_size

In [None]:
def calculate_gradient(image, gradient, tile_size):
    
    # Initialize the grads to zero
    grads = np.zeros_like(image)

    # Max dimensions
    x_max = image.shape[1]
    y_max = image.shape[2]
    
    # Get optimal tile sizes
    x_tile_size = get_optimal_tile_size(x_max, tile_size)
    y_tile_size = get_optimal_tile_size(y_max, tile_size)
    
    # Init random x_start
    x_start = np.random.randint(-3*x_tile_size//4, -x_tile_size//4)
        
    while x_start < x_max:
        x_end = x_start + x_tile_size
        lim_x_start = max(0, x_start)
        lim_x_end = min(x_end, x_max)
        
        y_start = np.random.randint(-3*y_tile_size//4, -y_tile_size//4)
        
        while y_start < y_max:
            y_end = y_start + y_tile_size
            lim_y_start = max(0, y_start)
            lim_y_end = min(y_end, y_max)
            
            img = image[:, lim_x_start:lim_x_end, lim_y_start:lim_y_end, :]
            
            feed_dict = model.get_feed_dict(img)
            g = sess.run(gradient, feed_dict)
            g /= (np.std(g) + 1e-8) 
            grads[:, lim_x_start:lim_x_end, lim_y_start:lim_y_end, :] += g
            y_start = y_end
        
        x_start = x_end
    
    return grads

In [None]:
def optimize_image(image, layer, num_iterations, step_size, tile_size, plot_gradient = False):
    
    # Plot the image to be optimized
    print('Before Optimization')
    plot_image(image)
    
    image = model.preprocess_image(image)
    
    # Get the gradient
    gradient = model.get_gradient(layer)
    for it in range(num_iterations):
        grad = calculate_gradient(image, gradient, tile_size)
        
        # Make the gradient smooth
        sigma = (it * 4.0) / num_iterations + 0.5
        grad_smooth1 = gaussian_filter(grad, sigma=sigma)
        grad_smooth2 = gaussian_filter(grad, sigma=sigma*2)
        grad_smooth3 = gaussian_filter(grad, sigma=sigma*0.5)
        grad = (grad_smooth1 + grad_smooth2 + grad_smooth3)
                         
        # Update the image
        image += step_size*grad
        
        # Plot the gradients
        if plot_gradient:
          print(f'Iteration: {it+1}')
          plot_gradients(grad)
    
    image = model.depreprocess_image(image)
    
    # Plot the optimized image
    print('After optimization')
    plot_image(image)
       
    return image

In [None]:
def recursive_optimize(image, layer, n_octave=3, num_iterations=10, step_size=3, tile_size=400, size_factor=0.7, blend = 0.2, plot_gradient = False):
    if n_octave > 0:
        
        # Downscale dthe image
        image_downscaled = resize_image(image, factor = size_factor)
        
        # Optimize the downscaled image recursively
        image_optimized = recursive_optimize(image_downscaled, layer, n_octave-1, num_iterations, step_size, tile_size, size_factor, blend)
        
        # Upscale the optimized image
        image_upscaled = resize_image(image_optimized, shape = image.shape)
        
        # Add both the images
        image = blend*image + (1-blend)*image_upscaled
    
    print(f'Recursive step: {n_octave + 1}')
    
    # Then again optimize the image
    new_image = optimize_image(image, layer, num_iterations, step_size, tile_size, plot_gradient)
    
    return new_image

In [None]:
image = load_image('./images/content/content8.jpg', factor = 1.0)
image.shape

In [None]:
layer = model.layers[3]
model.features[3]

In [None]:
sess = tf.Session(graph=model.graph)
new_image = recursive_optimize(image, layer)

In [None]:
sess.close()

In [None]:
save_image('./images/generated/gen26.jpg', new_image)