# Isotropic and Steerable NCA (Single seed experiments)

*Copyright 2023 Google LLC*

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
#@title Imports and Notebook Utilities

import os
import io
import PIL.Image, PIL.ImageDraw
import base64
import zipfile
import json
import requests
import numpy as np
import matplotlib.pylab as pl
import glob

from IPython.display import Image, HTML, clear_output
from tqdm import tqdm_notebook, tnrange

os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter


def imread(url, max_size=None, mode=None):
  if url.startswith(('http:', 'https:')):
    # wikimedia requires a user agent
    headers = {
      "User-Agent": "Requests in Colab/0.0 (https://colab.research.google.com/; no-reply@google.com) requests/0.0"
    }
    r = requests.get(url, headers=headers)
    f = io.BytesIO(r.content)
  else:
    f = url
  img = PIL.Image.open(f)
  if max_size is not None:
    img.thumbnail((max_size, max_size), PIL.Image.ANTIALIAS)
  if mode is not None:
    img = img.convert(mode)
  img = np.float32(img)/255.0
  return img

def np2pil(a):
  if a.dtype in [np.float32, np.float64]:
    a = np.uint8(np.clip(a, 0, 1)*255)
  return PIL.Image.fromarray(a)

def imwrite(f, a, fmt=None):
  a = np.asarray(a)
  if isinstance(f, str):
    fmt = f.rsplit('.', 1)[-1].lower()
    if fmt == 'jpg':
      fmt = 'jpeg'
    f = open(f, 'wb')
  np2pil(a).save(f, fmt, quality=95)

def imencode(a, fmt='jpeg'):
  a = np.asarray(a)
  if len(a.shape) == 3 and a.shape[-1] == 4:
    fmt = 'png'
  f = io.BytesIO()
  imwrite(f, a, fmt)
  return f.getvalue()

def im2url(a, fmt='jpeg'):
  encoded = imencode(a, fmt)
  base64_byte_string = base64.b64encode(encoded).decode('ascii')
  return 'data:image/' + fmt.upper() + ';base64,' + base64_byte_string

def imshow(a, fmt='jpeg'):
  display(Image(data=imencode(a, fmt)))

def tile2d(a, w=None):
  a = np.asarray(a)
  if w is None:
    w = int(np.ceil(np.sqrt(len(a))))
  th, tw = a.shape[1:3]
  pad = (w-len(a))%w
  a = np.pad(a, [(0, pad)]+[(0, 0)]*(a.ndim-1), 'constant')
  h = len(a)//w
  a = a.reshape([h, w]+list(a.shape[1:]))
  a = np.rollaxis(a, 2, 1).reshape([th*h, tw*w]+list(a.shape[4:]))
  return a

def zoom(img, scale=4):
  img = np.repeat(img, scale, 0)
  img = np.repeat(img, scale, 1)
  return img

class VideoWriter:
  def __init__(self, filename='_autoplay.mp4', fps=30.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()
    if self.params['filename'] == '_autoplay.mp4':
      self.show()

  def show(self, **kw):
      self.close()
      fn = self.params['filename']
      display(mvp.ipython_display(fn, **kw))

#!nvidia-smi -L

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
!pip install einops

In [None]:
import torch
import torch.nn.functional as F
import torchvision.models as models

from functools import partial

from einops import rearrange

torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [None]:
ident = torch.tensor([[0.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,0.0]])
sobel_x = torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]])
lap = torch.tensor([[1.0,2.0,1.0],[2.0,-12.0,2.0],[1.0,2.0,1.0]])
lap6 = torch.tensor([[0.0,2.0,2.0],[2.0,-12.0,2.0],[2.0,2.0,0.0]])
gauss = torch.tensor([[1.0,2.0,1.0],[2.0,4.0,2.0],[1.0,2.0,1.0]])/16.0

def perchannel_conv(x, filters):
  '''filters: [filter_n, h, w]'''
  b, ch, h, w = x.shape
  y = x.reshape(b*ch, 1, h, w)
  y = F.pad(y, [1, 1, 1, 1], 'circular')
  y = F.conv2d(y, filters[:,None])
  return y.reshape(b, -1, h, w)

# Models legend

Select the model of interest from the list below.

*   laplacian: Isotropic NCA model.
*   lap6: Istropic NCA, (trained and/or evaluated) on an hexagonal grid.
*   lap_gradnorm: Isotropic NCA variant discussed in the blogpost.
*   steerable: Angle-based Steerable NCA.
*   gradient: Gradient-based Steerable NCA.
*   steerable_nolap: Angle-based Sterable NCA ablation without the laplacian filter.

In [None]:
#@title Minimalistic Neural {vertical-output: true, run: "auto"}

model_type = "steerable_nolap" #@param ['laplacian', 'lap6', 'lap_gradnorm', 'steerable', 'gradient', 'steerable_nolap']

ANGLE_CHN = 0
nhood_kernel = (lap != 0.0).to(torch.float32)
if model_type == 'steerable':
  ANGLE_CHN = 1  # last state channel is angle and should be treated
                 # differently
  def perception(state):
    state, angle = state[:,:-1], state[:,-1:]
    c, s = angle.cos(), angle.sin()
    
    # cells can also feel the average direction of their neightbours
    #alpha = state[:,3:4].clip(0.0, 1.0)
    #dir = torch.cat([c, s], 1)*alpha  # only 
    #avg_dir = perchannel_conv(dir, gauss[None,:])

    grad = perchannel_conv(state, torch.stack([sobel_x, sobel_x.T]))
    #grad = torch.cat([grad, avg_dir], 1)
    # transform percieved vectors into local coords
    gx, gy = grad[:,::2], grad[:,1::2]
    rot_grad = torch.cat([gx*c+gy*s, gy*c-gx*s], 1)
    state_lap = perchannel_conv(state, lap[None,:])
    return torch.cat([state, rot_grad, state_lap], 1)

elif model_type == 'steerable_nolap':
  ANGLE_CHN = 1  # last state channel is angle and should be treated
                 # differently
  def perception(state):
    state, angle = state[:,:-1], state[:,-1:]
    c, s = angle.cos(), angle.sin()
    
    # cells can also feel the average direction of their neightbours
    #alpha = state[:,3:4].clip(0.0, 1.0)
    #dir = torch.cat([c, s], 1)*alpha  # only 
    #avg_dir = perchannel_conv(dir, gauss[None,:])

    grad = perchannel_conv(state, torch.stack([sobel_x, sobel_x.T]))
    #grad = torch.cat([grad, avg_dir], 1)
    # transform percieved vectors into local coords
    gx, gy = grad[:,::2], grad[:,1::2]
    rot_grad = torch.cat([gx*c+gy*s, gy*c-gx*s], 1)
    return torch.cat([state, rot_grad], 1)

elif model_type == 'gradient':
  def perception(state):
    grad = perchannel_conv(state, torch.stack([sobel_x, sobel_x.T]))
    # gradient of the last channel determines the cell direction
    grad, dir = grad[:,:-2], grad[:,-2:]
    dir = dir/dir.norm(dim=1, keepdim=True).clip(1.0)
    c, s = dir[:,:1], dir[:,1:2]
    # transform percieved vectors into local coords
    gx, gy = grad[:,::2], grad[:,1::2]
    rot_grad = torch.cat([gx*c+gy*s, gy*c-gx*s], 1)
    state_lap = perchannel_conv(state, lap[None,:])
    return torch.cat([state, state_lap, rot_grad], 1)

elif model_type == 'lap_gradnorm':
  def perception(state):
    grad = perchannel_conv(state, torch.stack([sobel_x, sobel_x.T]))
    gx, gy = grad[:,::2], grad[:,1::2]
    state_lap = perchannel_conv(state, lap[None,:])
    return torch.cat([state, state_lap, (gx*gx+gy*gy+1e-8).sqrt()], 1)

elif model_type == 'laplacian':
  def perception(state):
    state_lap = perchannel_conv(state, lap[None,:])
    return torch.cat([state, state_lap], 1)

# add norm of gradients

elif model_type == 'lap6':
  nhood_kernel = (lap6 != 0.0).to(torch.float32)
  def perception(state):
    state_lap = perchannel_conv(state, lap6[None,:])
    return torch.cat([state, state_lap], 1)

else:
  assert False, "unknown model_type"


CHN = 16
SCALAR_CHN = CHN-ANGLE_CHN

# if you want to try experiments with synchronous NCA, you can set the value 
# below to 1.0
DEFAULT_UPDATE_RATE = 0.5

def get_alive_mask(x):
  mature = (x[:,3:4]>0.1).to(torch.float32)
  return perchannel_conv(mature, nhood_kernel[None,:])>0.5

class CA(torch.nn.Module):
  def __init__(self, chn=CHN, hidden_n=128):
    super().__init__()
    self.chn = chn
    # determene the number of perceived channels
    perc_n = perception(torch.zeros([1, chn, 8, 8])).shape[1]
    # approximately equalize the param number btw model variants
    hidden_n = 8*1024//(perc_n+chn)
    hidden_n = (hidden_n+31)//32*32
    print('perc_n:', perc_n, 'hidden_n:', hidden_n)

    self.w1 = torch.nn.Conv2d(perc_n, hidden_n, 1)
    self.w2 = torch.nn.Conv2d(hidden_n, chn, 1, bias=False)
    self.w2.weight.data.zero_()

  def forward(self, x, update_rate=DEFAULT_UPDATE_RATE):
    alive = get_alive_mask(x)
    y = perception(x)
    y = self.w2(torch.relu(self.w1(y)))
    b, c, h, w = y.shape
    update_mask = (torch.rand(b, 1, h, w)+update_rate).floor()
    x = x + y*update_mask
    if SCALAR_CHN==CHN:
      x = x*alive
    else:
      x = torch.cat([x[:,:SCALAR_CHN]*alive, x[:,SCALAR_CHN:]%(np.pi*2.0)], 1)
    return x

  def seed(self, n, sz=128, angle=None, seed_size=1):
    x = torch.zeros(n, self.chn, sz, sz)
    if SCALAR_CHN != CHN:
      x[:,-1] = torch.rand(n, sz, sz)*np.pi*2.0
    r, s = sz//2, seed_size
    x[:,3:SCALAR_CHN,r:r+s, r:r+s] = 1.0
    if angle is not None:
      x[:,-1,r:r+s, r:r+s] = angle
    return x

def to_rgb(x):
  rgb, a = x[:,:3], x[:,3:4]
  return 1.0-a+rgb

param_n = sum(p.numel() for p in CA().parameters())
print('CA param count:', param_n)

In [None]:

def make_concentric_discrete(h, w, n):
  # reference https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
  x = np.linspace(-1.0, 1.0, w)[None, :]
  y = np.linspace(-1.0, 1.0, h)[:, None]
  center = np.zeros([2,1,1])
  #r = np.array([0.8])[:,None]
  x, y = (x-center[0]), (y-center[1])
  act = np.sin if n % 2 == 0 else np.cos
  mask = np.sign(act(np.sqrt(x*x+y*y)*n*np.pi))
  return mask.astype(np.float32) * 0.5

def make_concentric(h, w, n):
  # version with blurriness
  # reference https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
  x = np.linspace(-1.0, 1.0, w)[None, :]
  y = np.linspace(-1.0, 1.0, h)[:, None]
  center = np.zeros([2,1,1])
  #r = np.array([0.8])[:,None]
  x, y = (x-center[0]), (y-center[1])
  grad_start = 0.2
  act = np.sin if n % 2 == 0 else np.cos
  period = act(np.sqrt(x*x+y*y)*n*np.pi)
  grad = (np.abs(period) < grad_start).astype(np.float32)
  mask = np.sign(period) * (1-grad) + period  *grad / grad_start
  return mask.astype(np.float32) * 0.5

H = W = 48

mask = make_concentric(H, W, 1)
imshow(mask+0.5)

mask = make_concentric(H, W, 2)
imshow(mask+0.5)

mask = make_concentric(H, W, 3)
imshow(mask+0.5)

mask = make_concentric(H, W, 4)
imshow(mask+0.5)
#imshow(aux_target[1].cpu()/2.0+0.5)

In [None]:

# Not used currently, since we only need to make one stripe.
def make_stripes(h, w, n):
  x = np.linspace(-1.0, 1.0, w)[None, :]
  y = np.linspace(-1.0, 1.0, h)[:, None]
  center = np.zeros([2,1,1])
  #r = np.array([0.8])[:,None]
  x, y = (x-center[0])*np.ones_like(y), (y-center[1])*np.ones_like(x)
  act = np.sin if n % 2 == 0 else np.cos
  grad_start = 0.2
  x_period = act(x*n*np.pi)
  x_grad = (np.abs(x_period) < grad_start).astype(np.float32)
  x_mask = np.sign(x_period) * (1-x_grad) + x_period *x_grad / grad_start
  y_period = act(y*n*np.pi)
  y_grad = (np.abs(y_period) < grad_start).astype(np.float32)
  y_mask = np.sign(y_period) * (1-y_grad) + y_period  *y_grad / grad_start
  return x_mask.astype(np.float32) * 0.5, y_mask.astype(np.float32) * 0.5

H = W = 48

x_mask, y_mask = make_stripes(H, W, 1)
imshow(x_mask+0.5)
imshow(y_mask+0.5)

x_mask, y_mask = make_stripes(H, W, 2)
imshow(x_mask+0.5)
imshow(y_mask+0.5)

x_mask, y_mask = make_stripes(H, W, 3)
imshow(x_mask+0.5)
imshow(y_mask+0.5)

x_mask, y_mask = make_stripes(H, W, 4)
imshow(x_mask+0.5)
imshow(y_mask+0.5)

# Aux L Type legend

What kind of auxiliary channels are going to be present in the target image:

*   noaux: No auxiliary channels. The 'spiderweb' target was trained with it.
*   binary: One stripe mask as auxiliary channel. The 'lizard' target was trained with it.
*   minimal: binary + a concentric auxiliary channel. The 'heart' target was trained with it.
*   extended: more stripe and concentric modes for auxiliary channels. Not used in any publication.

Select the target image and what auxiliary channels are of interest below.


In [None]:

TARGET_P = "lizard" #@param ['circle','lizard', 'heart', 'smiley', 'lollipop', 'unicorn', 'spiderweb']
AUX_L_TYPE = "binary" #@param ['noaux', 'binary', 'minimal', 'extended']

if  TARGET_P == 'circle':
  def make_circle(h, w):
    x = np.linspace(-1.0, 1.0, w)[None, :]
    y = np.linspace(-1.0, 1.0, h)[:, None]
    center = np.zeros([2,1,1])
    r = np.array([0.9])[:,None]
    x, y = (x-center[0])/r, (y-center[1])/r
    mask = (x*x+y*y < 1.0).astype(np.float32)
    return mask

  H = W = 48
  mask = make_circle(H, W)
  IS_COLORED = False
  if IS_COLORED:
    r = np.linspace(0,1,H)[:,None]*mask
    g = np.linspace(0,1,W)[None,:]*mask
  else:
    r = g = np.zeros(mask.shape)
  target_colors = np.stack([r, g, np.zeros(mask.shape)], -1)
  #target = np.zeros([H, W, 4], dtype=np.float32)

  #target[..., 3] = mask

  target = np.concatenate([target_colors, mask[...,None]],-1).astype(np.float32)
  imshow(target)
  target[:,:,:3] *= target[:,:,3:]
else:
  emoji = {'lizard': '🦎',
           'heart': '❤️',
           'smiley': '😁',
           'lollipop': '🍭',
           'unicorn': '🦄', # overfits to the grid
           'spiderweb': '🕸️'
           }[TARGET_P][0]
  
  code = hex(ord(emoji))[2:].lower()
  url = 'https://github.com/googlefonts/noto-emoji/blob/main/png/128/emoji_u%s.png?raw=true'%code
  target = imread(url, 48)
  imshow(target)
  target[:,:,:3] *= target[:,:,3:]

p = 12
#target = F.pad(torch.tensor(target).permute(2, 0, 1), [p, p, p, p, 0, 2])
target = F.pad(torch.tensor(target).permute(2, 0, 1), [p, p, p, p, 0, 0])
W = target.shape[1]

Wp = Hp = target.shape[1]
if AUX_L_TYPE != 'noaux':
  aux_target_l = []
  #x_mask, y_mask = make_stripes(Hp, Wp, 2)
  #aux_target_l += [torch.tensor(x_mask), torch.tensor(y_mask)]
  #x_mask, y_mask = make_stripes(Hp, Wp, 2)
  #aux_target_l += [torch.tensor(y_mask)]

  print(target[3:].shape)
  y_mask = torch.linspace(-1,1,W)[:,None].sign()*target[3]*0.5
  aux_target_l += [y_mask]


  if AUX_L_TYPE == "extended":
    x_mask = torch.linspace(-1,1,W)[None,:].sign()*target[3]*0.5
    aux_target_l += [torch.tensor(x_mask)]
    aux_target_l += [torch.tensor(make_concentric(Hp, Wp, 2)),
                    torch.tensor(make_concentric(Hp, Wp, 3)),
                    torch.tensor(make_concentric(Hp, Wp, 4))]
  if AUX_L_TYPE == "minimal":
    aux_target_l += [torch.tensor(make_concentric(Hp, Wp, 4))]
  aux_target = torch.stack(aux_target_l)*target[3:4]

  imshow(aux_target[0].cpu()+0.5)

  for at in aux_target:
    imshow((1. - target[3] + target[3]*(at+0.5)).cpu())

  print(target.shape, aux_target.shape)
  target = torch.cat([target, aux_target])

model_suffix = model_type + "_" + TARGET_P + "_" + AUX_L_TYPE
print(model_suffix)

In [None]:
from torchvision.transforms.functional_tensor import gaussian_blur
def sharpen_filter(img):
  blured = gaussian_blur(img, [5, 5], [1, 1])
  return img + (img-blured)*2.0


# separating the logic, since xy_grid is not inherent of the Loss.
hex_grid = model_type == 'lap6'
if hex_grid:
  s = np.sqrt(3)/2.0
  hex2xy = np.float32([[1.0, 0.0], 
                      [0.5, s]])
  xy2hex = torch.tensor(np.linalg.inv(hex2xy))

  x = torch.linspace(-1, 1, W)
  y, x = torch.meshgrid(x, x)
  xy_grid = torch.stack([x, y], -1)
  # This grid will be needed later on, in the step functions.
  xy_grid = (xy_grid@xy2hex+1.0)%2.0-1.0

class InvariantLoss:
  def __init__(self, target, mirror=False, sharpen=True, hex_grid=False):
    self.sharpen = sharpen
    self.mirror = mirror
    self.channel_n = target.shape[0]
    W = target.shape[-1]
    self.r = r = torch.linspace(0.5/W, 1, W//2)[:,None]
    self.angle = a = torch.range(0, W*np.pi)/(W/2)
    self.polar_xy = torch.stack([r*a.cos(), r*a.sin()], -1)[None,:]
    if hex_grid:
      self.polar_xy = (self.polar_xy@xy2hex+1.0)%2.0-1.0

      # also make an x
    target = target[None,:]
    if self.sharpen:
      target = sharpen_filter(target)
    self.polar_target = F.grid_sample(target, self.polar_xy)
    self.fft_target = torch.fft.rfft(self.polar_target).conj()
    self.polar_target_sqnorm = self.polar_target.square().sum(-1, keepdim=True)

  def calc_losses(self, batch, extra_outputs=False):
    batch = batch[:, :self.channel_n]
    if self.sharpen:
      batch = sharpen_filter(batch)
    polar_batch = F.grid_sample(batch, self.polar_xy.repeat(len(batch), 1, 1, 1))
    X = torch.fft.rfft(polar_batch)
    n = polar_batch.shape[-1]
    xy = torch.fft.irfft(X*self.fft_target, n)
    if self.mirror:
      xy = torch.cat([xy, torch.fft.irfft(X*self.fft_target.conj(), n)], -1)
    xx = polar_batch.square().sum(-1, keepdim=True)
    yy = self.polar_target_sqnorm
    sqdiff = (xx+yy-2.0*xy)
    losses = sqdiff.mean([1, 2])
    if extra_outputs:
      return losses, batch, polar_batch
    else:
      return losses

  def __call__(self, batch):
    return self.calc_losses(batch).min(-1)[0].mean()

  def plot_losses(self, x):
    losses = self.calc_losses(x[None,:])[0].cpu()
    fig = pl.figure(figsize=(10, 10))
    ax0 = fig.add_subplot(111)
    vis = to_rgb(x[None,:4])[0].permute(1, 2, 0).cpu().clip(0, 1)
    ax0.imshow(vis, alpha=0.5)
    ax0.axis("off")
    ax = fig.add_subplot(111, polar=True, label='polar')
    ax.set_theta_zero_location('N')
    ax.set_theta_direction(-1)
    ax.set_facecolor("None")
    ang = self.angle.cpu()
    if not self.mirror:
      ax.plot(ang, losses, linewidth=3.0)
    else:
      ax.plot(ang, losses[:len(ang)], linewidth=3.0)
      ax.plot(ang, losses[len(ang):], linewidth=3.0)
    min_i = losses.argmin()
    pl.plot(ang[min_i%len(ang)], losses[min_i], 'or', markersize=12)


mirror = model_type in ['gradnorm','laplacian','lap6']
target_loss_f = InvariantLoss(target, mirror=mirror, hex_grid=hex_grid)
vis = to_rgb(target_loss_f.polar_target)[0].permute(1, 2, 0).cpu()
imshow(zoom(vis))

target_loss_f.plot_losses(target)



In [None]:
#@title setup training
ca = CA() 
loss_log = []
with torch.no_grad():
  pool = ca.seed(256, W)
opt = torch.optim.Adam(ca.parameters(), 1e-3)
#lr_sched = torch.optim.lr_scheduler.MultiStepLR(opt, [1000, 3000, 20000], 0.3)
# for the experiment with auxiliary loss
#lr_sched = torch.optim.lr_scheduler.MultiStepLR(opt, [3000, 10000], 0.3)
lr_sched = torch.optim.lr_scheduler.CyclicLR(
    opt, 1e-5, 1e-3, step_size_up=2000, mode='triangular2', cycle_momentum=False)


In [None]:
#@title training loop {vertical-output: true}
def make_circle_masks(n, h, w):
  x = np.linspace(-1.0, 1.0, w)[None, None, :]
  y = np.linspace(-1.0, 1.0, h)[None, :, None]
  center = np.random.uniform(-0.5, 0.5, [2, n, 1, 1])
  r = np.random.uniform(0.1, 0.4, [n, 1, 1])
  x, y = (x-center[0])/r, (y-center[1])/r
  mask = (x*x+y*y < 1.0).astype(np.float32)
  return mask


for i in range(50000):
  with torch.no_grad():
    batch_idx = np.random.choice(len(pool), 8, replace=False)
    x = pool[batch_idx]

    if len(loss_log) < 4000:
      seed_rate = 1
    else:
      # exp because of decrease of step_n
      #seed_rate = 3
      seed_rate = 6
    if i%seed_rate==0:
      x[:1] = ca.seed(1, W)
      
    #damage_rate = 3 # for spiderweb and heart
    damage_rate = 6  # for lizard?
    if i%damage_rate==0:
      mask = torch.from_numpy(make_circle_masks(1, W, W)[:,None]).to("cuda")
      if hex_grid:
        mask = F.grid_sample(mask, xy_grid[None,:].repeat([len(mask), 1, 1, 1]), mode='bicubic')

      x[-1:] *= (1.0 - mask)
  
    # EXTRA:
    # if all the cells have died, reset the sample.
    if len(loss_log) % 10 == 0:
      all_cells_dead_mask = (torch.sum(x[1:, 3:4],(1,2,3)) < 1e-6).float()[:,None,None,None]
      if all_cells_dead_mask.sum() > 1e-6:
        print("got here.")
        x[1:] = all_cells_dead_mask * ca.seed(7, W) + (1. - all_cells_dead_mask) * x[1:]


  #step_n = np.random.randint(32, 128)
  # new!  
  # everything worked but the unicorn pattern was constantly imploding with this.
  #step_n = np.random.randint(96, 128)
  step_n = np.random.randint(64, 96)
  overflow_loss = 0.0
  diff_loss = 0.0
  target_loss = 0.0
  for k in range(step_n):
    px = x
    x = ca(x)
    diff_loss += (x-px).abs().mean()
    overflow_loss += (x-x.clamp(-2.0, 2.0))[:,:SCALAR_CHN].square().sum()

    # experimenting to address implosions:
    if k == 0:
      target_loss += target_loss_f(x[:,:target.shape[0]])
      """
      if AUX_L_TYPE != "noaux":
        aux_target_loss += aux_target_loss_f(
            x[:,4:4+aux_target.shape[-3]]) * 2e-1
      """

  target_loss += target_loss_f(x[:,:target.shape[0]])
  """
  if AUX_L_TYPE != "noaux":
    aux_target_loss += aux_target_loss_f(
        x[:,4:4+aux_target.shape[-3]]) * 2e-1
  """
  target_loss /= 2.
  #aux_target_loss /= 2.
  diff_loss = diff_loss*10.0
  loss = target_loss+overflow_loss+diff_loss# + aux_target_loss

  with torch.no_grad():
    loss.backward()
    for p in ca.parameters():
      p.grad /= (p.grad.norm()+1e-8)   # normalize gradients 
    opt.step()
    opt.zero_grad()
    lr_sched.step()

    pool[batch_idx] = x                # update pool
    
    loss_log.append(loss.item())
    if i%32==0:
      clear_output(True)
      pl.plot(loss_log, '.', alpha=0.1)
      pl.yscale('log')
      pl.ylim(np.min(loss_log), loss_log[0])
      pl.show()
      imgs = to_rgb(x)
      if hex_grid:
        imgs = F.grid_sample(imgs, xy_grid[None,:].repeat([len(imgs), 1, 1, 1]), mode='bicubic')
      imgs = imgs.permute([0, 2, 3, 1]).cpu()

      imshow(zoom(tile2d(imgs, 4), 2))

      if AUX_L_TYPE != "noaux":
        alphas = x[:,3].cpu()
        for extra_i in range(aux_target.shape[-3]):
          imgs = 1. - alphas + alphas*(x[:,4+extra_i].cpu() + 0.5)

          if hex_grid:
            imgs = F.grid_sample(
                imgs[:,None], xy_grid[None,:].repeat([len(imgs), 1, 1, 1]).cpu(), 
                mode='bicubic')[:,0]
          imshow(zoom(tile2d(imgs, 8), 1))



    if i%10 == 0:
      print('\rstep_n:', len(loss_log),
        ' loss:', loss.item(), 
        ' lr:', lr_sched.get_lr()[0], end='')
    if len(loss_log) % 500 == 0:
      model_name = model_suffix + "_{:07d}.pt".format(len(loss_log))
      print(model_name)
      torch.save(ca.state_dict(), model_name)


In [None]:
model_name = model_suffix + "_{:07d}.pt".format(len(loss_log))
print(model_name)

In [None]:
# how to save a model
torch.save(ca, model_name)
from google.colab import files
files.download(model_name) 

In [None]:
# how to save all checkpoints
print(model_suffix)
shell_command = f"zip {model_suffix}.zip {model_suffix}_*"
!$shell_command
from google.colab import files
files.download(f"{model_suffix}.zip") 

In [None]:
# optionally, load a model from a checkpoint.
model_ckpt = "FILL.pt"
ca = torch.load(model_ckpt)

In [None]:
with torch.no_grad():
  x = ca.seed(8, 96)
  #x = ca.seed(9, 72)
  count = 0
  for k in tnrange(300, leave=False): # was 1000
    step_n = min(2**(k//30), 32)
    for i in range(step_n):
      x = ca(x)
    count += step_n
  print(count)

  imgs = to_rgb(x)
  if hex_grid:
    imgs = F.grid_sample(imgs, xy_grid[None,:].repeat([len(imgs), 1, 1, 1]), mode='bicubic')
  imgs = imgs.permute([0, 2, 3, 1]).cpu()

  imshow(zoom(tile2d(imgs, 4), 2))

  alphas = x[:,3].cpu()
  for extra_i in range(aux_target.shape[-3]):
    imgs = 1. - alphas + alphas*(x[:,4+extra_i].cpu() + 0.5)

    if hex_grid:
      imgs = F.grid_sample(
          imgs[:,None], xy_grid[None,:].repeat([len(imgs), 1, 1, 1]).cpu(), 
          mode='bicubic')[:,0]
    imshow(zoom(tile2d(imgs, 8), 1))



## Snapshot of a loss for an intermediate step

In [None]:
with torch.no_grad():
  x = ca.seed(1, 72)
  count = 0
  for k in tnrange(500, leave=False):
    x = ca(x)

  target_loss_f.plot_losses(x[0])


In [None]:
#@title Video: All in the same grid {vertical-output: true}
with VideoWriter() as vid, torch.no_grad():
  #x = ca.seed(1, 128)
  sz = 256
  x = torch.zeros([1, 16, sz, sz])
  # this is with steerables!
  if model_type in ['steerable', 'steerable_nolap']:
    x[:,-1] = torch.rand(sz, sz)*(2.0*np.pi)
  for i in range(5):
    i, j = np.random.randint(sz-40, size=2)+20
    x[:,3:SCALAR_CHN,i:i+1,j:j+1] = 1.0
  count = 0
  for k in tnrange(400, leave=False):
    step_n = min(2**(k//30), 128)
    for i in range(step_n):
      x = ca(x)
    count += step_n
    img = to_rgb(x)[0].permute(1, 2, 0).cpu()
    vid.add(zoom(img, 2))
print(count)

In [None]:
#@title Video: different grids {vertical-output: true}
eval_grid_size = 72
#eval_grid_size = 96
video_max_speed = 128
#video_max_speed = 32
# Long run
# n_frames = 500
# fast run
n_frames = 300
with VideoWriter() as vid, torch.no_grad():
  #x = ca.seed(16, 96)
  x = ca.seed(16, eval_grid_size)
  count = 0
  for k in tnrange(n_frames, leave=False): # was 1000
    step_n = min(2**(k//30), video_max_speed)
    for i in range(step_n):
      x = ca(x)
    count += step_n
    img = to_rgb(x).permute(0, 2, 3, 1).cpu()
    vid.add(zoom(tile2d(img), 2))
print(count)