In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.models

import numpy as np

import matplotlib.pyplot as plt

import skimage
import skimage.io as sio
import skimage.transform

In [None]:
identity = torch.tensor([[0., 0., 0.],
                         [0., 1., 0.],
                         [0., 0., 0.]])
sobel_h = torch.tensor([[-1., -1., -1.],
                        [0., 0., 0.],
                        [1., 1., 1.]])
sobel_w = torch.tensor([[-1., 0., 1.],
                        [-1., 0., 1.],
                        [-1., 0., 1.]])
moore = torch.tensor([[1., 1., 1.],
                      [1., 0., 1.] ,
                      [1., 1., 1.]])
laplacian = torch.tensor([[1., 2., 1.], 
                          [2., -12., 2], 
                          [1., 2., 1.]])

vgg16 = torchvision.models.vgg16(pretrained=True).features

def compute_grams(imgs):
    
    style_layers = [1, 6, 11, 18, 25]  
    
    # from https://github.com/google-research/self-organising-systems
    # no idea why
    mean = torch.tensor([0.485, 0.456, 0.406])[:,None,None]
    std = torch.tensor([0.229, 0.224, 0.225])[:,None,None]
    x = (imgs-mean) / std
    
    grams = []
    for i, layer in enumerate(vgg16[:max(style_layers)+1]):
        x = layer(x)
        if i in style_layers:
            
            h, w = x.shape[-2:]
            y = x.clone()  # workaround for pytorch in-place modification bug(?)
            
            gram = torch.einsum('bchw, bdhw -> bcd', y, y) / (h*w)
            grams.append(gram)
            
    return grams

def compute_style_loss(grams_pred, grams_target):
    
    loss = 0.0
    
    for x, y in zip(grams_pred, grams_target):
        loss = loss + (x-y).square().mean()
        
    return loss


def read_image(url, max_size=None):
    
    img = sio.imread(url)
    
    if max_size is not None:
        img = skimage.transform.resize(img, (max_size, max_size))
    
   
    img = np.float32(img)/ img.max()
    
    return img

def image_to_tensor(img):

    if len(img.shape) == 2:
        my_tensor = torch.tensor(img[np.newaxis, np.newaxis, ...])
    elif len(img.shape) == 3:
        my_tensor = torch.tensor(img.transpose(2,0,1)[np.newaxis,...])
    
    return my_tensor

def tensor_to_image(my_tensor, index=0):

    img = my_tensor[index,...].permute(1,2,0).detach().numpy()
    
    return img

def perceive(x, filters):
    
    batch, channels, height, width = x.shape
    
    x = x.reshape(batch*channels, 1, height, width)
    x = F.pad(x, (1,1,1,1), mode="circular")
    
    x = F.conv2d(x, filters[:, np.newaxis, :, :])
    
    perception = x.reshape(batch, -1, height, width)
    
    return perception
    
class NCA(nn.Module):
    
    def __init__(self, number_channels=1, number_filters=5, number_hidden=32):
        super().__init__()
        
        self.number_channels = number_channels
        self.conv_0 = nn.Conv2d(number_channels * number_filters, number_hidden, \
                kernel_size=1)
        self.conv_1 = nn.Conv2d(number_hidden, number_channels, \
                kernel_size=1, bias=False)
        
        self.filters = torch.stack([identity, sobel_h, sobel_w, moore, laplacian])
        
        # Trick from Mordvintsev
        self.conv_1.weight.data.zero_()
        
        # update step size
        self.dt = 1.0
        self.max_value = 1.0
        self.min_value = 0.0
    
    def forward(self, grid, update_rate=0.5):
        
        update_mask = (torch.rand_like(grid) < update_rate) * 1.0
        perception = perceive(grid, self.filters)
        
        
        new_grid = self.conv_0(perception)
        new_grid = self.conv_1(new_grid)
        
        return torch.clamp(grid + self.dt * new_grid * update_mask, \
                self.min_value, self.max_value)
    
    def get_init_grid(self, batch_size=8, dim=128):
        
        temp = torch.zeros(batch_size, self.number_channels, dim, dim)
        
        #temp[:,:,dim//2:dim//2+10, dim//2:dim//2+10] = self.max_value
        
        return temp

In [None]:
def seed_all(my_seed=42):
    
    torch.manual_seed(my_seed)
    np.random.seed(my_seed)
    
def train(target, nca, max_steps, lr=1e-3, max_ca_steps=8):
    
    display_every = max_steps // 8 + 1
    
    optimizer = torch.optim.Adam(nca.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [max_steps//3], 0.3)
    
    grids = nca.get_init_grid(batch_size=64, dim=target.shape[-2])
    
    grams_target = compute_grams(target)
    
    for step in range(max_steps):
        
        with torch.no_grad():
            batch_index = np.random.choice(len(grids), 4, replace=False)
            x = grids #[batch_index]

            if step % 8 == 0:
                x[:1] = nca.get_init_grid(batch_size=1, dim=x.shape[-2])
        
        optimizer.zero_grad()
        
        for ca_step in range(max_ca_steps):
            x = nca(x)
            
        grams_pred  = compute_grams(x)
        grams_target  = compute_grams(target)
        
        loss = compute_style_loss(grams_pred, grams_target)
        
        loss.backward()
        
        optimizer.step()
        lr_scheduler.step()
        
        #grids[batch_index] = x
    
        if step % display_every == 0:
            print(f"loss at step {step} = {loss:.4e}")
        

In [None]:
url = "https://www.nasa.gov/centers/ames/images/content/72511main_cellstructure8.jpeg"

img = read_image(url, max_size=128)
target = image_to_tensor(img)

seed_all(1337)

nca = NCA(number_channels=3)

# view the training image
plt.figure()
plt.imshow(img)
plt.show()

In [None]:
# train for textures
train(target, nca, max_steps=256)

In [None]:
# visualize what the nca model puts out

grid = nca.get_init_grid(batch_size=1, dim=256)

img = tensor_to_image(grid)

plt.figure()
plt.imshow(img)
plt.show()

for step in range(2):
    grid = nca(grid)

img = tensor_to_image(grid)

plt.figure()
plt.imshow(img)
plt.show()

for step in range(32):
    grid = nca(grid)

img = tensor_to_image(grid)

plt.figure()
plt.imshow(img)
plt.show()

In [None]:
# what do the perception convolutions look like?
filters = torch.stack([identity, sobel_h, sobel_w, moore, laplacian])
perception = perceive(target, filters)

#image/filter dimensions
print(perception.shape, img.shape)

#visualize perception 
for ii in range(perception.shape[1]):
    
    plt.figure()
    plt.imshow(perception[0,ii,:,:].detach())
    plt.title(f"channel {ii}")
    plt.show()