In [None]:
import sys
import tensorflow as tf
from IPython.display import clear_output
from utils import load_im, save_im, preview
from model_tf import StyleContentModel, StyleContentLoss, init_images
from model_tf import get_train_step

import matplotlib.pyplot as plt

In [None]:
def plot(im):
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    ax.imshow(im[0])
    plt.show()
    
def get_file_name(i):
    path2im_name = lambda x: os.path.split(x)[-1].split(".")[0]
    file_name = f"{path2im_name(CONTENT_PATH)}_{path2im_name(STYLE_PATH)}"
    return os.path.join("generated", file_name + f"_{i}.jpg")

def style_transfer(content_im_np, style_im_np, style_layers, content_layers, backbone_name, n_iter=1000):
    ### Prepare Model
    content_im, style_im, im = init_images(content_im_np, style_im_np)
    model = StyleContentModel(style_layers, content_layers, backbone_name=backbone_name)
    style_targets = model(style_im)["style"]
    content_targets = model(content_im)["content"]

    styleContentLoss = StyleContentLoss(
        content_targets, style_targets, STYLE_WEIGHT, CONTENT_WEIGHT
    )

    optim = tf.optimizers.Adam(learning_rate=LEARNING_RATE)

    train_step = get_train_step(model, optim, styleContentLoss)

    for i in range(n_iter):
        train_step(im)
        sys.stdout.write(".")
        sys.stdout.flush()
        if i %100 == 0:
            clear_output()
            print(f"iter: {i}")
            plot(im)

        if i % 1000 == 0: # and i != 0:
            path = get_file_name(i)
            save_im(im, path)

    path = get_file_name("final")        
    save_im(im, path)

In [None]:
MAX_DIM = 128
STYLE_WEIGHT = 2e10
CONTENT_WEIGHT = 1e2
LEARNING_RATE = 0.02 #0.02

GEN_PATH = "generated/test.jpg"
STYLE_PATH = "../guernica.jpg"
# STYLE_PATH = "girl_pearl.jpg"
CONTENT_PATH = "../beaver.jpg"

In [None]:
content_im_np = load_im(CONTENT_PATH, max_dim=MAX_DIM)
style_im_np = load_im(STYLE_PATH, max_dim=MAX_DIM)

plt.imshow(style_im_np)
plt.show()

In [None]:
CONTENT_LAYERS = ["conv2_block3_out"]

STYLE_LAYERS = [
    "conv2_block1_out",
    "conv2_block2_out",
    "conv2_block3_out",
    "conv3_block1_out",
#     "conv3_block2_out",
#     "conv3_block3_out",
#     "conv3_block4_out",
#     "conv4_block1_out",
#     "conv4_block2_out",
#     "conv4_block3_out",
#     "conv4_block4_out",
#     "conv4_block5_out",
#     "conv4_block6_out",
#     "conv5_block1_out",
#     "conv5_block2_out",
#     "conv5_block3_out",
] 

  

backbone_name = "ResNet50"

style_transfer(content_im_np, style_im_np, STYLE_LAYERS, 
               CONTENT_LAYERS, backbone_name, 10000)