In [None]:
# to run this notebook in colab, you'll need to install the repo:
# uncomment the code below to install srnca

#! git clone https://github.com/rivesunder/SRNCA 
#%cd SRNCA
#! pip install -e .

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.models

import numpy as np


import skimage
import skimage.io as sio
import skimage.transform

from srnca.nca import NCA
from srnca.utils import image_to_tensor, tensor_to_image, read_image, seed_all

import matplotlib.pyplot as plt
import matplotlib

import matplotlib.animation
import IPython

In [None]:
def plot_grid(grid):

    global subplot_0
            
    fig, ax = plt.subplots(1,1, figsize=(7.5,7.5), facecolor="white")

    grid_display = tensor_to_image(grid)
    
    subplot_0 = ax.imshow(grid_display, interpolation="nearest")
   
    ax.set_yticklabels('')
    ax.set_xticklabels('')

    plt.tight_layout()

    return fig, ax

def update_fig(i):

    global subplot_0    
    global grid
    
    grid = nca(grid)
    grid_display = tensor_to_image(grid)
    
    subplot_0.set_array(grid_display)
    

In [None]:
url = "https://www.nasa.gov/centers/ames/images/content/72511main_cellstructure8.jpeg"

#url = "../data/images/orbia_magma.png"

#url = "../data/images/jwst_segment_alignment.jpg"

img = read_image(url, max_size=128)[:,:,:3]

target = image_to_tensor(img)
img = tensor_to_image(target)
print(target.shape, img.shape)
seed_all(13)

nca = NCA(number_channels=3, number_hidden=96)

# view the training image
plt.figure()
plt.imshow(img, cmap="gray")
plt.show()

In [None]:
# optional, move model to gpu
nca.to_device("cuda")

In [None]:
# train for textureso
print(target.mean(), target.max(), target.min())
nca.fit(target, max_steps=10000, max_ca_steps=20, lr = 3e-5)

In [None]:
num_frames = 360

grid = nca.get_init_grid(batch_size=1, dim=128)

fig, ax = plot_grid(grid)

plt.close("all")
IPython.display.HTML(matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=100).to_jshtml())