In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.models

import numpy as np


import skimage
import skimage.io as sio
import skimage.transform

from srnca.nca import NCA

import matplotlib.pyplot as plt
import matplotlib

import matplotlib.animation
import IPython

In [None]:
vgg16 = torchvision.models.vgg16(pretrained=True).features

identity = torch.tensor([[0., 0., 0.],
                         [0., 1., 0.],
                         [0., 0., 0.]])
sobel_h = torch.tensor([[-1., -1., -1.],
                        [0., 0., 0.],
                        [1., 1., 1.]])
sobel_w = torch.tensor([[-1., 0., 1.],
                        [-1., 0., 1.],
                        [-1., 0., 1.]])
moore = torch.tensor([[1., 1., 1.],
                      [1., 0., 1.] ,
                      [1., 1., 1.]])
laplacian = torch.tensor([[1., 2., 1.], 
                          [2., -12., 2], 
                          [1., 2., 1.]])
def compute_grams(imgs):
    
    style_layers = [1, 6, 11, 18, 25]  
    
    # from https://github.com/google-research/self-organising-systems
    # no idea why
    mean = torch.tensor([0.485, 0.456, 0.406])[:,None,None]
    std = torch.tensor([0.229, 0.224, 0.225])[:,None,None]
    x = (imgs-mean) / std
    
    grams = []
    for i, layer in enumerate(vgg16[:max(style_layers)+1]):
        x = layer(x)
        if i in style_layers:
            
            h, w = x.shape[-2:]
            y = x.clone()  # workaround for pytorch in-place modification bug(?)
            
            gram = torch.einsum('bchw, bdhw -> bcd', y, y) / (h*w)
            grams.append(gram)
            
    return grams

def compute_style_loss(grams_pred, grams_target):
    
    loss = 0.0
    
    for x, y in zip(grams_pred, grams_target):
        loss = loss + (x-y).square().mean()
        
    return loss


def read_image(url, max_size=None):
    
    img = sio.imread(url)
    
    if max_size is not None:
        img = skimage.transform.resize(img, (max_size, max_size))
    
   
    img = np.float32(img)/ img.max()
    
    return img

def image_to_tensor(img):

    if len(img.shape) == 2:
        my_tensor = torch.tensor(img[np.newaxis, np.newaxis, ...])
    elif len(img.shape) == 3:
        my_tensor = torch.tensor(img.transpose(2,0,1)[np.newaxis,...])
    
    return my_tensor

def tensor_to_image(my_tensor, index=0):

    img = my_tensor[index,...].permute(1,2,0).detach().numpy()
    
    return img

def perceive(x, filters):
    
    batch, channels, height, width = x.shape
    
    x = x.reshape(batch*channels, 1, height, width)
    x = F.pad(x, (1,1,1,1), mode="circular")
    
    x = F.conv2d(x, filters[:, np.newaxis, :, :])
    
    perception = x.reshape(batch, -1, height, width)
    
    return perception

def plot_grid(grid):

    global subplot_0
            
    fig, ax = plt.subplots(1,1, figsize=(7.5,7.5), facecolor="white")

    grid_display = tensor_to_image(grid)
    
    subplot_0 = ax.imshow(grid_display, interpolation="nearest")
   
    ax.set_yticklabels('')
    ax.set_xticklabels('')

    plt.tight_layout()

    return fig, ax

def update_fig(i):

    global subplot_0    
    global grid
    
    grid = nca(grid)
    grid_display = tensor_to_image(grid)
    
    subplot_0.set_array(grid_display)
    

In [None]:
def seed_all(my_seed=42):
    
    torch.manual_seed(my_seed)
    np.random.seed(my_seed)
    
def train(target, nca, max_steps, lr=1e-3, max_ca_steps=8):
    
    display_every = max_steps // 8 + 1
    
    optimizer = torch.optim.Adam(nca.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [max_steps//3], 0.3)
    
    grids = nca.get_init_grid(batch_size=64, dim=target.shape[-2])
    
    grams_target = compute_grams(target)
    
    for step in range(max_steps):
        
        with torch.no_grad():
            batch_index = np.random.choice(len(grids), 4, replace=False)
            x = grids #[batch_index]

            if step % 8 == 0:
                x[:1] = nca.get_init_grid(batch_size=1, dim=x.shape[-2])
        
        optimizer.zero_grad()
        
        for ca_step in range(np.random.randint(1,12) + max_ca_steps):
            x = nca(x)
            
        grams_pred  = compute_grams(x)
        grams_target  = compute_grams(target)
        
        loss = compute_style_loss(grams_pred, grams_target)
        
        loss.backward()
        
        optimizer.step()
        lr_scheduler.step()
        
        #grids[batch_index] = x
    
        if step % display_every == 0:
            print(f"loss at step {step} = {loss:.4e}")
        

In [None]:
soft_clamp = lambda x: 1.0 / (1.0 + torch.exp(-4.0 * (x-0.5)))                  

class NCA(nn.Module):

    def __init__(self, number_channels=1, number_filters=5, number_hidden=32):
        super().__init__()

        self.number_channels = number_channels
        self.number_filters = number_filters
        self.number_hidden = number_hidden


        self.conv_0 = nn.Conv2d(self.number_channels * self.number_filters, \
                self.number_hidden, kernel_size=1)
        self.conv_1 = nn.Conv2d(self.number_hidden, self.number_channels, \
                kernel_size=1, bias=False)
        self.filters = torch.stack([identity, sobel_h, sobel_w, \
                moore, laplacian])

        self.conv_1.weight.data.zero_()

        self.dt = 0.5 #nn.Parameter(torch.Tensor([[[[1.0]]]]))
        #self.add_module("dt", dt)
        self.max_value = 1.0
        self.min_value = 0.0

        self.squash = soft_clamp


    def forward(self, grid, update_rate=0.75):
    

        update_mask = (torch.rand_like(grid) < update_rate) * 1.0
        perception = perceive(grid, self.filters)

        new_grid = self.conv_0(perception)
        new_grid = self.conv_1(new_grid)
        
        new_grid = grid + self.dt * new_grid * update_mask

        return self.squash(new_grid)

    def get_init_grid(self, batch_size=8, dim=128):
        
        temp = torch.zeros(batch_size, self.number_channels, dim, dim)

        return temp

    def initialize_optimizer(self, lr, max_steps):

        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(\
                self.optimizer, [max_steps // 3], 0.3)

    def fit(self, target, max_steps=10, lr=1e-3, max_ca_steps=16, batch_size=8):

        self.batch_size = batch_size
        display_every = max_steps // 8 + 1

        self.initialize_optimizer(lr, max_steps)

        grids = self.get_init_grid(batch_size=self.batch_size, dim = target.shape[-2])

        for step in range(max_steps):

            with torch.no_grad():
                batch_index = np.random.choice(len(grids), 4, replace=False)

                x = grids

                if step % 8 == 0:
                    x[:1] = self.get_init_grid(batch_size=1, dim=x.shape[-2])


            self.optimizer.zero_grad()

            for ca_step in range(np.random.randint(1,16) + max_ca_steps):
                x = self.forward(x)
            
            grams_pred = compute_grams(x)
            grams_target = compute_grams(target)# + torch.rand_like(target)*0.05)

            loss = compute_style_loss(grams_pred, grams_target)
            loss.backward()

            self.optimizer.step()
            self.lr_scheduler.step()

            if step % display_every == 0:
                print(f"loss at step {step} = {loss:.4e}")

    def count_parameters(self):

        number_parameters = 0

        for param in self.parameters():
            number_parameters += param.numel()

        return number_parameters


    def save_parameters(self, save_path):

        torch.save(self.state_dict(), save_path)

    def load_parameters(self, load_path):

        state_dict = torch.load(load_path)

        self.load_state_dict(state_dict)

In [None]:
url = "https://www.nasa.gov/centers/ames/images/content/72511main_cellstructure8.jpeg"

url = "../data/images/orbia_magma.png"

#url = "../data/images/jwst_segment_alignment.jpg"

img = read_image(url, max_size=128)[:,:,:3]


target = image_to_tensor(img)
img = tensor_to_image(target)
print(target.shape)
seed_all(13)

nca = NCA(number_channels=3, number_hidden=256)

# view the training image
plt.figure()
plt.imshow(img, cmap="gray")
plt.show()

In [None]:
# train for textureso
print(target.mean(), target.max(), target.min())
nca.fit(target, max_steps=128, max_ca_steps=20, lr = 1e-3)

In [None]:
num_frames = 128

grid = nca.get_init_grid(batch_size=1, dim=128)

fig, ax = plot_grid(grid)

plt.close("all")
IPython.display.HTML(matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=100).to_jshtml())