In [None]:
#@title Installation of gpu-cpm

!wget https://computational-immunology.org/gpu-cpm.zip && unzip -q gpu-cpm.zip && pip3 install ./gpu-cpm

In [None]:
#@title Imports
import gpucpm
import random
from skimage.segmentation import mark_boundaries
from PIL import Image
import numpy as np
from timeit import default_timer as timer
import matplotlib.pyplot as plt
import ipywidgets as widgets
from google.colab import output
from IPython.display import Javascript
import matplotlib.cm as cm
import pandas as pd


In [None]:
#@title Sorting simulation code
def make_pic(types, ids, dimension):
    pic = np.ones((dimension,dimension,3)) * 255
    pic[types==1] = [200, 200, 200]
    pic[types==2] = [80,80,80]

    pic2 = mark_boundaries(pic, ids, color=(10,10,10), mode="inner")
    pic = (pic + pic2)/2
    pic = mark_boundaries(pic, types, color=(0,0,0), mode="inner")
    pic = mark_boundaries(pic, types!=0, color=(0,0,0), mode="outer")
    return pic.astype("uint8")

def create_sorting_simulation(dimension):
    number_of_types = 3
    temperature = 10
    cellPerType = dimension**2//80
    nr_of_cells = cellPerType*2 + 1
    simulation = gpucpm.Cpm(dimension, 2, nr_of_cells, temperature, False)

    simulation.set_constraints(cell_type = 1, lambda_area = 10, target_area = 40)
    simulation.set_constraints(cell_type = 1, other_cell_type = 1, adhesion = 14)
    simulation.set_constraints(cell_type = 0, other_cell_type = 1, adhesion = 16)

    simulation.set_constraints(cell_type = 2, lambda_area = 10, target_area = 40)
    simulation.set_constraints(cell_type = 2, other_cell_type = 2, adhesion = 2)
    simulation.set_constraints(cell_type = 2, other_cell_type = 1, adhesion = 11)
    simulation.set_constraints(cell_type = 2, other_cell_type = 0, adhesion = 16)

    radius = dimension/3
    center = dimension/2

    cellPerType = (dimension//256)*(dimension//256) * 500
    state = simulation.get_state()
    ref_state = np.zeros((dimension,dimension))
    for t in [1,2]:
        for i in range(cellPerType):
            while True:
                x = random.randint(0,dimension-1)
                y = random.randint(0,dimension-1)
                #if ref_state[x,y] == 0:
                distance_squared = (x-center)**2 + (y-center)**2
                if distance_squared < radius**2 and ref_state[x,y] == 0:
                    simulation.add_cell(t, x,y)
                    ref_state[x,y] = 1
                    break

    return simulation



In [None]:
#@title Create cell sorting simulation
dimension = 256
sim = create_sorting_simulation(dimension)

sim.push_to_gpu()


current_mcs = 0

positions_per_thread = 8
threads_per_block = 16

block_sync = 0
global_sync = 1

if dimension == 128:
    positions_per_thread = 4
    threads_per_block = 16

if dimension == 256:
    positions_per_thread = 4
    threads_per_block = 16

if dimension == 512:
    positions_per_thread = 4
    threads_per_block = 16

if dimension == 1024:
    positions_per_thread = 4
    threads_per_block = 16

if dimension == 2048:
    threads_per_block = 16
    positions_per_thread = 8
    block_sync = 1
    global_sync = 0

if dimension == 4096:
    threads_per_block = 16
    positions_per_thread = 16
    block_sync = 1
    global_sync = 0

if dimension == 8192:
    threads_per_block = 16
    positions_per_thread = 32
    block_sync = 1
    global_sync = 0


threads_per_block = 8
positions_per_thread = 4

threads = dimension // positions_per_thread
blocks = threads // threads_per_block

print("total threads: {} threads/block: {} blocks: {}".format(threads**2, threads_per_block**2,  blocks**2))

sim.run(cell_sync=0,block_sync=0,global_sync=0,
    threads_per_block = threads_per_block,
    positions_per_thread = positions_per_thread,
    positions_per_checkerboard = 2,
    updates_per_checkerboard_switch = 1,
    updates_per_barrier = 1,
    iterations=100,
    inner_iterations=1)#, shared = 0, partial_dispatch = 1)
sim.synchronize()

sim.set_constraints(cell_type = 1, lambda_area = 1, target_area = 40)
sim.set_constraints(cell_type = 2, lambda_area = 1, target_area = 40)


In [None]:
#@title Define interactive simulation widget

class InteractiveSim:
  def __init__(self, sim):
    self.sim = sim
    self.out = widgets.Output(layout=widgets.Layout(width='400px', height='400px'))
    self.start_button = widgets.Button(description='start',disabled=False,tooltip='Start',icon='play')
    self.stop_button = widgets.Button(description='stop',disabled=False,tooltip='stop',icon='stop')

    self.slider = widgets.IntSlider(min=10, max=1000, step=10, value=100, description="MCSs/frame:")
    self.slider.observe(self.on_slider_change, names='value')

    self.start_button.on_click(self.start)
    self.stop_button.on_click(self.end)

    self.interface = widgets.VBox([widgets.HBox([self.start_button, self.stop_button]), self.slider, self.out])


    self.running = False
    self.ticks = 100

    self.current_mcs = 0

    self.running = False
    self.register_js()
    self.visualise = None

    self.extra_interface_elements = []
    self.extra_visualisation_elements = []
    self.extra_visualisations = []

    self.stat_trackers = []
    self.stats = []

  def add_visualisation(self, visualizer):
    self.extra_visualisations.append(visualizer)
    self.extra_visualisation_elements.append(widgets.Output(layout=widgets.Layout(width='400px', height='400px')))


    elements = [widgets.HBox([self.start_button, self.stop_button]), self.slider]+self.extra_interface_elements+[widgets.VBox([self.out,]+ self.extra_visualisation_elements),]
    self.interface.children = elements

  def on_slider_change(self, change):
    self.ticks = change['new']

  def set_visualisation(self, vis):
    self.visualise = vis
    self.update_img()

  def start(self, a):
    self.running = True
    output.eval_js("x.start()")

  def end(self, a):
    self.running = False
    output.eval_js("x.end()")


  def register_js(self):
    output.register_callback("update", self.update)

    display(Javascript("""
    class Updater {
      constructor() {
        this._running = false;
      }
      start() {
        if (this._running)
          return;
        this._running = true;
        this.loop();
      }
      async loop() {
        if (!this._running)
          return;
        await google.colab.kernel.invokeFunction('update', [], {});
        setTimeout(this.loop.bind(this), 50);
      }
      end() {
        this._running = false;
      }
    }

    window.x = new Updater();
    """))


  def visualise_default(self, types, ids, dimension):
    i = cm.rainbow(ids%255)*255
    i = i[:,:,:3]
    pic2 = mark_boundaries(i, ids, color=(10,10,10), mode="inner")
    i = (i + pic2)/2
    i[ids==0] = 255
    return Image.fromarray(i.astype(np.uint8))

  def update(self):
    if not self.running:
      return
    self.update_sim()
    self.update_vis()
    self.update_img()
    self.update_stats()

  def update_vis(self):
    df = self.get_dataframe()
    for out, vis in zip(self.extra_visualisation_elements, self.extra_visualisations):
      with out:
        out.clear_output(wait=True)
        vis(df)

  def update_sim(self):
    self.sim.run(cell_sync=0,block_sync=0,global_sync=0,
    threads_per_block = threads_per_block,
    positions_per_thread = positions_per_thread*2,
    positions_per_checkerboard = 4,
    updates_per_checkerboard_switch = 1,
    updates_per_barrier = 1,
    iterations=self.ticks,
    inner_iterations=1, shared = 0, partial_dispatch = 1)
    sim.synchronize()

    self.current_mcs += self.ticks

  def update_img(self):
    self.sim.pull_from_gpu()
    state = sim.get_state()
    types = state // 2**24
    ids = state % 2**24

    #pic = make_pic(types, ids, dimension)
    if self.visualise:
      pic = self.visualise(types, ids, dimension)
    else:
      pic = self.visualise_default(types, ids, dimension)

    i = pic.resize((300,300), resample=Image.NEAREST)
    with self.out:
        self.out.clear_output(wait=True)
        display(i)

  def update_stats(self):
    stats = [{'mcs': self.current_mcs}]
    for tracker in self.stat_trackers:
      stat = tracker(sim)
      stats.append(stat)
    stats = {k: v for d in stats for k, v in d.items()}
    self.stats.append(stats)

  def get_dataframe(self):
    return pd.DataFrame(self.stats)


  def add_interface(self, description, min, max, default, updater):
    slider = widgets.IntSlider(min=min, max=max, step=1, value=default, description=description)
    def on_slider_change(change):
      value = change['new']
      updater(sim, value)
    slider.observe(on_slider_change, names='value')
    self.extra_interface_elements.append(slider)

    elements = [widgets.HBox([self.start_button, self.stop_button]), self.slider]+self.extra_interface_elements+[widgets.VBox([self.out,] + self.extra_visualisation_elements),]
    self.interface.children = elements

  def add_stat_tracker(self, tracker):
    self.stat_trackers.append(tracker)


In [None]:
#@title Run interactive simulation
s = InteractiveSim(sim)
s.update_img()
s.interface