In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from utils.dataset import ConditioningDataset
from nca import ConditionedNCA
from conditioned_trainer import ConditionedNCATrainer

import torch

import numpy as np

import matplotlib.pyplot as plt
import torch
from einops import rearrange

from utils.utils import load_image
from utils.utils import load_target_style_image


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running training on", device)

In [None]:
dataset = ConditioningDataset('../../../data/random_faces/', image_size=64)

NUM_HIDDEN_CHANNELS = 16
nca = ConditionedNCA(
        target_shape = dataset.target_size,
        num_hidden_channels = NUM_HIDDEN_CHANNELS,
        living_channel_dim = 3,
        cell_fire_rate = 0.5
)

target_style_image = load_target_style_image('../../../data/style_images/picasso.jpg', size=64)


trainer = ConditionedNCATrainer(
        nca,
        dataset,
        target_style_image,
        nca_steps=[48, 96],
        lr = 1e-3,
        pool_size = 1024,
        num_damaged = 0,
        log_base_path = "tensorboard_logs",
        damage_radius = 3,
        device = device,
)

nca.load("models/ConditionedNCA_2024-04-12 14:09:22.097166.pt")

In [None]:
nca = nca.to(device)
dataset.to(device)

In [None]:
from threading import Event, Thread

import cv2
import numpy as np
import torch
from einops import rearrange
from ipycanvas import Canvas, hold_canvas  # noqa
from ipywidgets import Button, HBox, VBox

from utils.utils import create_2d_circular_mask, rgb


def to_numpy_rgb(x):
    return rearrange(
        np.squeeze(rgb(x).detach().cpu().numpy()), "c x y -> x y c"
    )


class ConditionedNCAVisualizer:
    def __init__(
        self,
        trainer,
        image_size,
        canvas_scale=5,
        damage_radius: int = 5,
    ):
        self.trainer = trainer
        self.current_state = None
        self.current_goal = None

        self.image_size = image_size
        self.canvas_scale = canvas_scale
        self.canvas_size = self.image_size * self.canvas_scale

        self.canvas = Canvas(width=self.canvas_size, height=self.canvas_size)
        self.canvas.on_mouse_down(self.handle_mouse_down)
        self.stopped = Event()

        self.current_goal = self.trainer.target_dataset[0].to(self.trainer.device)

        self.device = self.trainer.device
        self.damage_radius = damage_radius
        self.current_state = self.trainer.nca.generate_seed(1).to(self.device)

        def button_fn(class_num):
            def start(btn):
                self.current_goal = self.trainer.target_dataset[class_num].to(self.trainer.device)
                if self.stopped.isSet():
                    self.stopped.clear()
                    Thread(target=self.loop).start()

            return start

        button_list = []
        for i in range(len(self.trainer.target_dataset)):
            button_list.append(Button(description=str(i)))
            button_list[-1].on_click(button_fn(i))

        self.vbox = VBox(button_list)

        self.stop_btn = Button(description="Stop")

        def stop(btn):
            if not self.stopped.isSet():
                self.stopped.set()

        self.stop_btn.on_click(stop)

    def handle_mouse_down(self, xpos, ypos):
        in_x = int(xpos / self.canvas_scale)
        in_y = int(ypos / self.canvas_scale)

        mask = create_2d_circular_mask(
            self.image_size,
            self.image_size,
            (in_x, in_y),
            radius=self.damage_radius,
        )
        self.current_state[0][:, mask] *= 0.0

    def draw_image(self, rgb):
        with hold_canvas(self.canvas):
            rgb = np.squeeze(rearrange(rgb, "b c w h -> b w h c"))
            self.canvas.clear()  # Clear the old animation step
            self.canvas.put_image_data(
                cv2.resize(
                    rgb * 255.0,
                    (self.canvas_size, self.canvas_size),
                    interpolation=cv2.INTER_NEAREST,
                ),
                0,
                0,
            )

    def loop(self):
        with torch.no_grad():
            self.current_state = self.trainer.nca.generate_seed(1).to(self.device)
            while not self.stopped.wait(0.02):  # the first call is in `interval` secs
                # update_particle_locations()
                self.draw_image(self.trainer.to_rgb(self.current_state))
                self.current_state = self.trainer.nca.grow(
                    self.current_state, 1, self.current_goal
                )

    def visualize(self):
        Thread(target=self.loop).start()
        display(self.canvas, HBox([self.stop_btn, self.vbox]))  # noqa


In [None]:
viz = ConditionedNCAVisualizer(trainer, 64)

In [None]:
viz.visualize()