##### Copyright 2020 The TensorFlow Authors. [Licensed under the Apache License, Version 2.0](#scrollTo=ByZjmtFgB_Y5).

In [None]:
%install-location $cwd/swift-install
%install '.package(url: "https://github.com/tensorflow/swift-models", .branch("main"))' ModelSupport
print("\u{001B}[2J")

In [None]:
// #@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

<table class="tfo-notebook-buttons" align="left">
 <td>
  <a target="_blank" href="https://colab.research.google.com/github/tensorflow/swift-models/blob/main/Examples/GrowingNeuralCellularAutomata/GrowingNeuralCellularAutomata.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
 </td>
 <td>
  <a target="_blank" href="https://github.com/tensorflow/swift-models/blob/main/Examples/GrowingNeuralCellularAutomata/GrowingNeuralCellularAutomata.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
 </td>
</table>

# Growing Neural Cellular Automata

This is an implementation in Swift for TensorFlow of the experiments described in ["Growing Neural Cellular Automata"](https://distill.pub/2020/growing-ca/) by Alexander Mordvintsev, Ettore Randazzo, Eyvind Niklasson, and Michael Levin. Currently, only Experiment 1 has been completed.

In this publication, cellular automata have a rule that is trained via gradient descent to cause a single cell to grow into a larger image, stabilize at a final shape, and repair damage to that image. The rule used for updates on each step is defined by a simple neural network, trained using gradient descent to produce a rule that can grow into a target image.

## Device setup and model parameters

We'll start by importing the appropriate modules:

In [None]:
import Foundation
import TensorFlow
import ModelSupport

Next, we'll configure the accelerator the tensor operations will run on. For best compatibility (TPU + GPU), we'll use XLA through Swift for TensorFlow's X10 backend. The eager-mode runtime can also be used, and may provide better performance on GPUs at present:

In [None]:
let device = Device.defaultTFEager
// let device = Device.defaultXLA
// device

To aid us in displaying images within the notebook, we'll use Swift's Python interoperability to set up an image display function.

In [None]:
import PythonKit

%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")
let display = Python.import("IPython.core.display")

func showImageFile(_ filename: String) {
  display.Image(Python.open(filename, "rb").read()).display()
}

The following contains all model parameters used during training:

In [None]:
// The height and width to use when resizing the input image.
let imageSize = 40
// The padding to add around the input image after resizing.
let padding = 16
// The number of state channels for each cell.
let stateChannels = 16
// The batch size during training.
let batchSize = 4
// The fraction of cells to fire at each update.
let cellFireRate: Float = 0.5
// The number of training iterations.
let iterations = 2000
// The minimum number of steps.
let minimumSteps = 64
// The maximum number of steps.
let maximumSteps = 96
// The number of steps to run through during inference.
let inferenceSteps = 200

## Configuring the cell update rule

The cell update rule is computed by a neural network that takes in the current state (batch size x height x width x state channels) and outputs a new state to use for the next time step. At its first stage, horizontal and vertical Sobel kernels are applied to the 3x3 neighborhood around a cell, and those results, as well as the cell's current state, are passed into the network. By default, a cell's state consists of red, green, blue, and alpha color components along with 12 hidden parameters.

The network itself has two 1x1 convolutional layers, with a ReLU activation after the first. Only a fraction of the cells are updated at a given time step, and any cell with an alpha value less than 10% is considered "dead" and ignored.

The following diagram from ["Growing Neural Cellular Automata"](https://distill.pub/2020/growing-ca/) explains the model structure and function:

![Cell rule model](https://distill.pub/2020/growing-ca/figures/model.svg)


The first component we'll implement will be the perception function. We'll cache the horizontal and vertical Sobel kernel tensors so that they can be reused on each iteration and save on host -> device transfers.

In [None]:
let horizontalSobelKernel = Tensor<Float>(
  shape: [3, 3, 1, 1], scalars: [-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0], on: device) / 8.0
let horizontalSobelFilter = horizontalSobelKernel.broadcasted(to: [3, 3, stateChannels, 1])
let verticalSobelKernel = Tensor<Float>(
  shape: [3, 3, 1, 1], scalars: [-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0], on: device) / 8.0
let verticalSobelFilter = verticalSobelKernel.broadcasted(to: [3, 3, stateChannels, 1])
let identityKernel = Tensor<Float>(
  shape: [3, 3, 1, 1], scalars: [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], on: device)
let identityFilter = identityKernel.broadcasted(to: [3, 3, stateChannels, 1])
let perceptionFilter = Tensor(
  concatenating: [horizontalSobelFilter, verticalSobelFilter, identityFilter], alongAxis: 3)

@differentiable
func perceive(_ input: Tensor<Float>) -> Tensor<Float> {
  return depthwiseConv2D(
    input, filter: perceptionFilter, strides: (1, 1, 1, 1), padding: .same)
}

As a convenience, we'll implement extensions to Tensor that separate out just the RGBA color channels from the larger cell state, as well as mask active cells:

In [None]:
extension Tensor where Scalar: Numeric {
  @differentiable(where Scalar: TensorFlowFloatingPoint)
  var colorComponents: Tensor {
    precondition(self.rank == 3 || self.rank == 4)
    if self.rank == 3 {
      return self.slice(
        lowerBounds: [0, 0, 0], sizes: [self.shape[0], self.shape[1], 4])
    } else {
      return self.slice(
        lowerBounds: [0, 0, 0, 0], sizes: [self.shape[0], self.shape[1], self.shape[2], 4])
    }
  }

  func mask(condition: (Tensor) -> Tensor<Bool>) -> Tensor {
    let satisfied = condition(self)
    return Tensor(zerosLike: self)
      .replacing(with: Tensor(onesLike: self), where: satisfied)
  }
}

Next, we need the ability to mask off only the "living" cells  (those with an alpha channel above 0.1) and their neighbors:

In [None]:
@differentiable
func livingMask(_ input: Tensor<Float>) -> Tensor<Float> {
  let alphaChannel = input.slice(
    lowerBounds: [0, 0, 0, 3], sizes: [input.shape[0], input.shape[1], input.shape[2], 1])
  let localMaximum =
    maxPool2D(alphaChannel, filterSize: (1, 3, 3, 1), strides: (1, 1, 1, 1), padding: .same)
  return withoutDerivative(at: input) { _ in localMaximum.mask { $0 .> 0.1 } }
}

The cell update rule itself is encapsulated in a custom Layer, and can be called like a function. The steps in the rule are applied within `callAsFunction()`, and they follow the diagram above: 

In [None]:
struct CellRule: Layer {
  @noDerivative let fireRate: Float
  var conv1: Conv2D<Float>
  var conv2: Conv2D<Float>

  init(stateChannels: Int, fireRate: Float) {
    self.fireRate = fireRate
    self.conv1 = Conv2D<Float>(filterShape: (1, 1, stateChannels * 3, 128))
    self.conv2 = Conv2D<Float>(
      filterShape: (1, 1, 128, stateChannels), useBias: false, filterInitializer: zeros())
  }

  @differentiable
  func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
    // Perform the update and determine the change to be applied to the cell state.
    let perception = perceive(input)
    let dx = conv2(relu(conv1(perception)))

    // Only fire a certain percentage of cells at each time step.
    let updateDistribution = Tensor<Float>(
      randomUniform: [input.shape[0], input.shape[1], input.shape[2], 1], on: input.device)
    let updateMask = withoutDerivative(at: input) { _ in
      updateDistribution.mask { $0 .< fireRate }
    }

    let updatedState = input + (dx * updateMask)

    // Mask off "dead" cells in the new state and use the combined mask to zero out "dead" cells.
    let combinedLivingMask = livingMask(input) * livingMask(updatedState)
    return updatedState * combinedLivingMask
  }
}

## Training the cell rule

The training loop starts from a single black cell in the center of the grid and applies the cell rule for between `minimumSteps` and `maximumSteps`. The resulting state is then compared along the red, green, and blue channels against a target image and loss calculated via mean squared error. A gradient is determined from this and the Adam optimizer updates the cell rules.

The first step in this process is initializing our cell rule model and Adam optimizer, then moving both onto the appropriate accelerator device:

In [None]:
var cellRule = CellRule(stateChannels: stateChannels, fireRate: cellFireRate)
cellRule.move(to: device)
var optimizer = Adam(for: cellRule, learningRate: 2e-3)
optimizer = Adam(copying: optimizer, to: device)

We'll load our target image into a Tensor, pad it, and convert that to a batch:

In [None]:
let imageData = try! Data(contentsOf: URL(string: "https://github.com/googlefonts/noto-emoji/raw/master/png/128/emoji_u1f98e.png")!)
try! imageData.write(to: URL(fileURLWithPath: "lizard.png"))

let hostInputImage = Image(contentsOf: URL(fileURLWithPath: "lizard.png")).premultipliedAlpha()
let resizedHostInputImage = hostInputImage.resized(to: (imageSize, imageSize))
let inputImage = Tensor(copying: resizedHostInputImage.tensor, to: device) / 255.0
let paddedImage = inputImage.padded(forSizes: [
  (before: padding, after: padding), (before: padding, after: padding), (before: 0, after: 0),
])
let paddedImageBatch = paddedImage.broadcasted(to: [
  batchSize, paddedImage.shape[0], paddedImage.shape[1], paddedImage.shape[2],
])

try paddedImage.scaled(by: 255.0).overlaidOnWhite()
  .saveImage(directory: "output", name: "targetimage", format: .png)

showImageFile("lizard.png")
showImageFile("output/targetimage.png")

The initial cell state is set up once and then re-used:

In [None]:
var initialState = Tensor(zerosLike: paddedImage).padded(forSizes: [
  (before: 0, after: 0), (before: 0, after: 0), (before: 0, after: stateChannels - 4),
])
initialState[initialState.shape[0] / 2][initialState.shape[1] / 2][3] = Tensor<Float>(1.0, on: device)
let initialBatch = initialState.broadcasted(to: [
  batchSize, initialState.shape[0], initialState.shape[1], initialState.shape[2],
])
LazyTensorBarrier()

We normalize gradients to stabilize training:

In [None]:
func normalizeGradient(_ gradient: CellRule.TangentVector) -> CellRule.TangentVector {
  var outputGradient = gradient
  for kp in gradient.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
    let norm = sqrt(gradient[keyPath: kp].squared().sum())
    outputGradient[keyPath: kp] = gradient[keyPath: kp] / (norm + 1e-8)
  }
  
  return outputGradient
}

Due to the way that the X10 backend traces out subgraphs, we need to prevent the iterated cell computation from being fully unrolled on the backward pass. To do this, we'll introduce a passthrough function that has a custom derivative which stops the trace at that point:

In [None]:
@inlinable
@differentiable
func clipBackwardsTrace(_ input: Tensor<Float>) -> Tensor<Float> {
  return input
}

@inlinable
@derivative(of: clipBackwardsTrace)
func _vjpClipBackwardsTrace(
  _ input: Tensor<Float>
) -> (value: Tensor<Float>, pullback: (Tensor<Float>) -> Tensor<Float>) {
  return (input, { 
    LazyTensorBarrier()
    return $0
    }
  )
}

In order to visualize the evolution of the cell states over time with a given rule, we'll set up a function that can perform inference for a series of time steps and record the result directly as an animated GIF:

In [None]:
func recordGrowth(
  seed: Tensor<Float>, rule: CellRule, steps: Int, directory: String, filename: String
) throws {
  var state = seed
  var states: [Tensor<Float>] = []
  LazyTensorBarrier()
  for _ in 0..<steps {
    state = rule(state)
    let sampledState = state[0]
    LazyTensorBarrier()
    states.append(sampledState.colorComponents * 255.0)
  }
  try states.saveAnimatedImage(directory: directory, name: filename, delay: 1)
}

Finally, we train the model in a loop for `iterations`, and capture and display a representative end state every 10 iterations.

Note: currently Swift Jupyter kernels do not support inline display of images during training, so the animated results will be displayed once the calculation has finished. The intermediate results are present as GIF in the output/ directory (on Colab, accessible via the file browser on the left-hand side of the screen).

In [None]:
for iteration in 0..<iterations {
  let startTime = Date()
  let steps = Int.random(in: minimumSteps...maximumSteps)
  var loggingState = initialState
  let (loss, ruleGradient) = valueWithGradient(at: cellRule) { model -> Tensor<Float> in
    var state = initialBatch
    for _ in 0..<steps {
      state = clipBackwardsTrace(state)
      state = model(state)
      LazyTensorBarrier()
    }

    loggingState = state[0]
    return meanSquaredError(predicted: state.colorComponents, expected: paddedImageBatch)
  }
  optimizer.update(&cellRule, along: normalizeGradient(ruleGradient))
  LazyTensorBarrier()

  let lossScalar = loss.scalarized()
  print(
    "Iteration: \(iteration), loss: \(lossScalar), log loss: \(log10(lossScalar)), time: \(Date().timeIntervalSince(startTime)) s")

  // Note: currently Swift Jupyter kernels cannot display images while a calculation is ongoing.
  /*
  if (iteration % 10) == 0 {
    let filename = String(format: "iteration%03d", iteration)
    try recordGrowth(
      seed: initialState.expandingShape(at: 0), rule: cellRule, steps: inferenceSteps,
      directory: "output", filename: filename)
    display.clear_output(wait: true)
    showImageFile("output/\(filename).gif")
  }
  */

  if ((iteration + 1) % 2000) == 0 {
    optimizer.learningRate = optimizer.learningRate * 0.1
  }
}

Once the model has trained, we can perform inference and capture the evolution of the cell state over time:

In [None]:
try recordGrowth(
  seed: initialState.expandingShape(at: 0), rule: cellRule, steps: inferenceSteps,
  directory: "output", filename: "growth")
showImageFile("output/growth.gif")