In [None]:
# to run this notebook in colab, you'll need to install the repo:
# uncomment the code below to install srnca

#! git clone https://github.com/rivesunder/SRNCA 
#%cd SRNCA
# #! pip install -e .

In [None]:
# optional switchs

crop_image = False
use_cuda = True
img_dim = 128

In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.models

import numpy as np


import skimage
import skimage.io as sio
import skimage.transform

from srnca.nca import NCA
from srnca.utils import image_to_tensor, tensor_to_image, read_image, seed_all

import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams["animation.embed_limit"] = 256

import matplotlib.animation
import IPython

In [None]:
def plot_grid(grid):

    global subplot_0
            
    fig, ax = plt.subplots(1,1, figsize=(7.5,7.5), facecolor="white")

    grid_display = tensor_to_image(grid)
    
    subplot_0 = ax.imshow(grid_display, interpolation="nearest")
   
    ax.set_yticklabels('')
    ax.set_xticklabels('')

    plt.tight_layout()

    return fig, ax

def update_fig(i):

    global subplot_0    
    global grid
    
    grid = nca(grid)
    grid_display = tensor_to_image(grid)
    
    subplot_0.set_array(grid_display)


In [None]:
# seed for repeatability
exp_counter = 0
my_seed = 42
seed_all(my_seed)

In [None]:
# sample hyperparameters

channel_choices = [6,9,12,15]
hidden_choices = [16, 32, 64, 96]
ca_step_choices = [20, 30, 40]
batch_size_choices = [2, 4]
filter_choices = [4, 5, 6]

lr_exponent = np.random.randint(3,6)                        
number_channels = np.random.choice(channel_choices)
number_hidden = np.random.choice(hidden_choices)
max_ca_steps = np.random.choice(ca_step_choices)
batch_size = np.random.choice(batch_size_choices)
number_filters = np.random.choice(filter_choices)

lr = 10.**(-lr_exponent)

exp_counter += 1
exp_tag = f"exp_{exp_counter:04}"

In [None]:
# default values 
lr = 1e-3
number_channels = 9
number_hidden = 96
number_filters = 4
batch_size = 2
max_ca_steps = 20
update_rate = 0.75

In [None]:
url = "https://www.nasa.gov/centers/ames/images/content/72511main_cellstructure8.jpeg"

#url = "https://www.nasa.gov/sites/default/files/thumbnails/image/telescope_alignment_evaluation_image_labeled.png"
#url = "https://spaceplace.nasa.gov/jupiter/en/jupiter5.en.jpg"
#url = "../data/images/jwst_segment_alignment.jpg"

img = read_image(url, max_size=img_dim)[:,:,:3]

target = image_to_tensor(img)
img = tensor_to_image(target)

nca = NCA(number_channels=number_channels, number_hidden=number_hidden,\
              number_filters=number_filters, update_rate=update_rate)

# view the training image
plt.figure()
plt.imshow(img)
plt.show()

hyperparam_msg = f"hyperparameters: \n    lr = {lr}, \n"\
        f"    number_channels = {nca.number_channels}\n"\
        f"    number_hidden   = {nca.number_hidden}\n"\
        f"    max_ca_steps    = {max_ca_steps}\n"\
        f"    batch_size      = {batch_size}\n"\
        f"    number_filters  = {nca.number_filters}\n"\
        f"    update_rate  = {nca.update_rate}\n"\

print(f"{exp_tag}, nca parameter count: {nca.count_parameters()}")
print(hyperparam_msg)

In [None]:
# optional, crop image

if crop_image:
    dim = 64
    crop_x, crop_y = 256, 256
    img = img[crop_x:crop_x+dim, crop_y:crop_y+dim, :]

    # view the training image
    plt.figure()
    plt.imshow(img)
    plt.show()

In [None]:
# optional, move model to gpu

if use_cuda:
    nca.to_device("cuda")

In [None]:
# train for textures
print("begin training")
exp_log = nca.fit(target, max_steps=4096, max_ca_steps=max_ca_steps, lr = lr, exp_tag=exp_tag, batch_size=batch_size)

In [None]:
num_frames = 200
nca.to_device("cpu")

grid = nca.get_init_grid(batch_size=1, dim=128)

fig, ax = plot_grid(grid)

plt.close("all")
IPython.display.HTML(matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=100).to_jshtml())

In [None]:
#optional: save animation
matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=100).save("texture_ca.gif")

In [None]:
# visualize training curves

for my_exp_log in os.listdir("./"):
    if my_exp_log.endswith("dict.npy"):
        my_data = np.load(my_exp_log, allow_pickle=True).reshape(1)[0]

        for my_key in my_data.keys():
            if my_key != "step" and my_key != "loss":
                print(f"hyperparam {my_key}: {my_data[my_key]:.4f}")

        plt.figure(figsize=(10,7))
        plt.plot(my_data["step"], my_data["loss"], "o")

        plt.title(f"Training curve {exp_log}", fontsize=22)
        plt.ylabel("style loss", fontsize=18)
        plt.xlabel("training step", fontsize=18)
        plt.show()