<a href="https://colab.research.google.com/github/pittner8/DeepDreamColab/blob/main/DeepDream.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import matplotlib as mpl
from PIL import Image, ExifTags
import imageio
from tensorflow.python.ops.math_ops import _OverrideBinaryOperatorHelper
from tensorflow.python.training.tracking import base

#===============
#imports for UI
import ipywidgets as widgets
from IPython.display import display, clear_output
from ipywidgets import HBox, Label, GridBox, Layout
#===============

def open_img(img_path, max_size = None):
  try:
    img = Image.open(img_path)
    #some images e.g. smartphone pictures are not rotatet properly by default
    img = rotate_img(img)
    if max_size:
      img.thumbnail(max_size)

    if img.mode != "RGB":
      img = img.convert("RGB")

    img = np.array(img)
    return img
  except:
    raise SystemExit("{0} Bild konnte nicht geöffnet werden".format(u"\u274C"))

def save_img(img, save_path : str, img_name="dreamImg"):
  try:
    img_name = img_name.lstrip()
    if(len(img_name) == 0):
      raise Exception()
    if(not save_path.endswith("/")):
      save_path += "/"
    imageio.imwrite(save_path+img_name+".jpg", img)
    out.append_stdout("{0} Bild wurde erfolgreich gespeichert".format(u"\u2705"))
  except Exception as e:
    out.append_stdout("{0} Bild konnte nicht gespeichert werden".format(u"\u274C"))

def deprocess(img):
  #revert the preprocess steps
  img = 255*(img + 1.0)/2.0
  return tf.cast(img, tf.uint8)

# rotate_img rotates the image so that it is upright
def rotate_img(image):
  try:
    for orientation in ExifTags.TAGS.keys():
        if ExifTags.TAGS[orientation]=='Orientation':
            break
    
    exif = image._getexif()

    if exif[orientation] == 3:
        image=image.rotate(180, expand=True)
    elif exif[orientation] == 6:
        image=image.rotate(270, expand=True)
    elif exif[orientation] == 8:
        image=image.rotate(90, expand=True)

    return image

  except (AttributeError, KeyError, IndexError, TypeError):
    # cases: image don't have getexif
    return image



def show(img):
    plt.imshow(img)
    plt.show()


def get_dream_model(layer):
  base_model = tf.keras.applications.InceptionV3(include_top=False, weights='imagenet')

  # Maximize the activations of these layers
  layers = [base_model.get_layer(ind).output for ind in layer]

  # Create the feature extraction model
  return tf.keras.Model(inputs=base_model.input, outputs=layers)


def calc_loss(img, model):
  # Pass forward the image through the model to retrieve the activations.
  # Converts the image into a batch of size 1.
  img_batch = tf.expand_dims(img, axis=0)
  layer_activations = model(img_batch)
  if len(layer_activations) == 1:
    layer_activations = [layer_activations]

  losses = []
  for act in layer_activations:
    loss = tf.math.reduce_mean(act)
    losses.append(loss)

  return  tf.reduce_sum(losses)

def random_roll(img, maxroll):
  # Randomly shift the image to avoid tiled boundaries.
  shift = tf.random.uniform(shape=[2], minval=-maxroll, maxval=maxroll, dtype=tf.int32)
  img_rolled = tf.roll(img, shift=shift, axis=[0,1])
  return shift, img_rolled



#============================================================
#============================================================
class TiledGradients(tf.Module):
  def __init__(self, model):
    self.model = model

  @tf.function(
      input_signature=(
        tf.TensorSpec(shape=[None,None,3], dtype=tf.float32),
        tf.TensorSpec(shape=[], dtype=tf.int32),)
  )
  def __call__(self, img, tile_size=512):
    shift, img_rolled = random_roll(img, tile_size)

    # Initialize the image gradients to zero.
    gradients = tf.zeros_like(img_rolled)
    
    # Skip the last tile, unless there's only one tile.
    xs = tf.range(0, img_rolled.shape[0], tile_size)[:-1]
    if not tf.cast(len(xs), bool):
      xs = tf.constant([0])
    ys = tf.range(0, img_rolled.shape[1], tile_size)[:-1]
    if not tf.cast(len(ys), bool):
      ys = tf.constant([0])

    for x in xs:
      for y in ys:
        # Calculate the gradients for this tile.
        with tf.GradientTape() as tape:
          # This needs gradients relative to `img_rolled`.
          # `GradientTape` only watches `tf.Variable`s by default.
          tape.watch(img_rolled)

          # Extract a tile out of the image.
          img_tile = img_rolled[x:x+tile_size, y:y+tile_size]
          loss = calc_loss(img_tile, self.model)

        # Update the image gradients for this tile.
        gradients = gradients + tape.gradient(loss, img_rolled)

    # Undo the random shift applied to the image and its gradients.
    gradients = tf.roll(gradients, shift=-shift, axis=[0,1])

    # Normalize the gradients.
    gradients /= tf.math.reduce_std(gradients) + 1e-8 

    return gradients 


# =======================================================
def run_deep_dream_with_octaves(img, steps_per_octave=100, step_size=0.01, octaves=range(-2,3), octave_scale=1.3, get_tiled_gradients=None):
  base_shape = tf.shape(img)
  img = tf.keras.preprocessing.image.img_to_array(img)
  img = tf.keras.applications.inception_v3.preprocess_input(img)

  initial_shape = img.shape[:-1]
  img = tf.image.resize(img, initial_shape)
  for octave in octaves:
    # Scale the image based on the octave
    new_size = tf.cast(tf.convert_to_tensor(base_shape[:-1]), tf.float32)*(octave_scale**octave)
    img = tf.image.resize(img, tf.cast(new_size, tf.int32))

    for step in range(steps_per_octave):
      temp = (initial_shape[1] // 3) if initial_shape[1] >= initial_shape[0] else initial_shape[0]
      gradients = get_tiled_gradients(img, tile_size=temp)
      img = img + gradients*step_size
      img = tf.clip_by_value(img, -1, 1)
    
    out.append_stdout("{0} Bin noch am Träumen ... {1}".format(u"\U0001F4A4", u"\U0001F4A4"))

  out.append_stdout("{0} Gleich fertig {1}".format(u"\U0001F3C1", u"\U0001F3C1"))

  result = deprocess(img)
  result = tf.image.resize(result, base_shape[:-1])
  result = tf.cast(result, tf.uint8)

  return result

#=============================================================================
# Colab UI
#============================================================================= 

out = widgets.Output()

image_label = Label("Bild samt Pfad:", layout=Layout(width="auto", height="auto"))
image = widgets.Text(value='/home/bild.jpg',
                     layout=Layout(width='auto', height='auto'),
)

original_size_label = Label("Original größe Beibehalten?", layout=Layout(width="auto", height="auto"))
check_keep_original_size = widgets.Checkbox(value=True, description="beibehalten")

resolution_label = Label("Neue Bildauflösung angeben:   (Nicht größer als die Originalaulösung wählen)", layout=Layout(width="auto", height="auto"))
resolution = widgets.Dropdown(
                              options = [("3000x1500", (3000, 1500)), ("1920x1080", (1920, 1080)), ("1280x720", (1280, 720))], 
                              layout=Layout(width="auto", height="auto"), 
                              disabled = check_keep_original_size.value)


save_path_label = Label("Speichern unter:", layout=Layout(width="auto", height="auto"))
save_path = widgets.Text(value='/home/',
                    layout=Layout(width='auto', height='auto'),
)

new_img_name_label = Label("Name des DreamBildes:", layout=Layout(width="auto", height="auto"))
new_img_name = widgets.Text(value="DreamImg",
                          layout = Layout(width="auto", height="auto"))

dream_layer_label = Label("Wähle eine Schicht zum träumen aus:", layout=Layout(width="auto", height="auto"))
dream_layer = widgets.Dropdown(
    options=['mixed{0}'.format(i) for i in range(0,11)],
    value='mixed3',
   layout=Layout(width='auto', height='auto'),
)

oktave_von_label = Label("Oktave von:",layout=Layout(width="auto", height="auto"))
oktave_von = widgets.IntText(value = -2, layout=Layout(width="auto", height="auto"))

oktave_bis_label = Label("Oktave bis:",layout=Layout(width="auto", height="auto"))
oktave_bis = widgets.IntText(value = 3, layout=Layout(width="auto", height="auto"))

steps_per_octave_label = Label("Schritte pro Oktave:", layout=Layout(width="auto", heights="auto"))
steps_per_ovtave = widgets.IntText(value = 100, min = 1, layout=Layout(width="auto", heights="auto"))

enlargment_per_octave_label = Label("Vergrösserung pro Oktave:", layout=Layout(width="auto", heights="auto"))
enlargment_per_octave = widgets.FloatSlider(value=1.3, step = 0.1 , min = 1., max = 2.)

step_size_label = Label("Schrittgröße/länge:", layout=Layout(width="auto", height="auto"))
step_size = widgets.FloatSlider(value=0.01, min=0.001, max=0.1, step=0.001, layout=Layout(width="auto", height="auto"), readout_format=".3f")

ui_children = [image_label, image, original_size_label, check_keep_original_size, resolution_label, resolution, 
              save_path_label, save_path, new_img_name_label, new_img_name, dream_layer_label, dream_layer, oktave_von_label, oktave_von,
              oktave_bis_label, oktave_bis, steps_per_octave_label, steps_per_ovtave,
              enlargment_per_octave_label, enlargment_per_octave, step_size_label, step_size]

grid = GridBox(children=ui_children,
        layout=Layout(
            width='50%',
            grid_template_rows='auto',
            grid_template_columns='50% 50%'
       ))



button = widgets.Button(description="Träumen")
display(grid, button, out)

# Functions for the Widget Events

def check_box_changed(b):
  if check_keep_original_size.value == True:
     check_keep_original_size.description = "beibehalten"
     resolution.disabled = True
  else:
    check_keep_original_size.description = "nicht beibehalten"
    resolution.disabled = False

def on_button_clicked(b):
  clear_screen()
  out.append_stdout("===== {0} Das Träumen beginnt {1} =====".format(u"\U0001F6CC", u"\U0001F971"))
  dream_model = get_dream_model([dream_layer.value])
  get_tiled_gradients = TiledGradients(dream_model)

  if check_keep_original_size.value:
    max_size = None
  else:
    max_size = resolution.value
  img = open_img(image.value,  max_size)

  l_oktave_von, l_oktave_bis = oktave_von.value, oktave_bis.value
  if l_oktave_von > l_oktave_bis:
    l_oktave_von, l_oktave_bis = l_oktave_bis, l_oktave_von

  img = run_deep_dream_with_octaves(img, steps_per_ovtave.value, step_size.value, range(l_oktave_von, l_oktave_bis), enlargment_per_octave.value, get_tiled_gradients)
  save_img(img=img.numpy(), save_path=save_path.value, img_name=new_img_name.value)


# helper functions

def clear_screen():
  clear_output(wait=True)
  global out
  out = widgets.Output()
  display(grid, button, out)

# Assign the above functions to the events
button.on_click(on_button_clicked)
check_keep_original_size.observe(check_box_changed)

GridBox(children=(Label(value='Bild samt Pfad:', layout=Layout(height='auto', width='auto')), Text(value='/hom…

Button(description='Träumen', style=ButtonStyle())

Output()