In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
import pytorch_lightning as pl
import torch
from torch import nn
import torch.nn.functional as F
import einops
from pathlib import Path
from pprint import pprint
from torch import optim

  from .autonotebook import tqdm as notebook_tqdm


In [61]:
from utils import VideoWriter, transform
from display import NCAGrid
from lightning_module import NCALightningModule
from dataset import NCADataModule, GoalNCADataModule
from model import Updater, Perceiver, NCA, GoalNCA

In [70]:
SEED_CACHE = Path("./seed_goal")
SEED_CACHE_SIZE = 64 # Must be at least the batch_size, to avoid the drop_last, Might not be neccessary tho
BATCH_SIZE = 16
TRAIN_STEP = 50

SEED_CACHE.mkdir(exist_ok=True, parents=True)
GRID_SIZE = (40, 40)
CELL_FIRE_RATE = 0.5
CLIP_VALUE = [-10, 10]
ALIVE_THRESHOLD = 0.1
USE_ALIVE_CHANNEL = True
THUMBNAIL_SIZE = 32 # This controls the size of the target image
NUM_HIDDEN_CHANNELS = 20
NUM_STATIC_CHANNELS = 1
NUM_TARGET_CHANNELS = 3
TOTAL_CHANNELS = NUM_HIDDEN_CHANNELS + NUM_STATIC_CHANNELS + NUM_TARGET_CHANNELS + 1
OUTPUT_CHANNELS = NUM_HIDDEN_CHANNELS + NUM_TARGET_CHANNELS + 1

In [13]:
# import random
# from pathlib import Path

# pop = Path("./pic/32")
# target_image_list = random.choices(list(pop.iterdir()), k=10)

target_image_list = ['pic/32/emoji_u1f9d4_1f3fe_200d_2642.png',
 'pic/32/emoji_u1f469_1f3fe.png',
 'pic/32/emoji_u1f468_200d_1f469_200d_1f467_200d_1f467.png',
 'pic/32/emoji_u1f9c8.png',
 'pic/32/emoji_u1fab5.png',
 'pic/32/emoji_u1f31c.png',
 'pic/32/emoji_u1f19a.png',
 'pic/32/emoji_u1f508.png',
 'pic/32/emoji_u1f469_1f3ff_200d_1f373.png',
 'pic/32/emoji_u1f469_200d_1f469_200d_1f466_200d_1f466.png']

In [128]:
NUM_GOALS = len(target_image_list)

Similar to `growing_nca`, we will define our grid, target emojis and NCA model architecture. `goal_nca` is different from `growing_nca` in that it is able to form into different target emojis based on the goal signal emitted by the environment.

In [127]:
lit_dm = GoalNCADataModule(
    seed_cache_dir=SEED_CACHE, 
    grid_size=GRID_SIZE, 
    num_hidden_channels=NUM_HIDDEN_CHANNELS, 
    num_target_channels=NUM_TARGET_CHANNELS, 
    num_static_channels=NUM_STATIC_CHANNELS, 
    target_image_path=target_image_list,
    batch_size=BATCH_SIZE,
    THUMBNAIL_SIZE = THUMBNAIL_SIZE, # This controls the size of the target image
    clear_cache=True
)

In [None]:
NET_HIDDEN_STATE = 63

In [129]:
perceiver_net = Perceiver(in_channels=OUTPUT_CHANNELS, out_channels=NET_HIDDEN_STATE, groups=1)
perceiver_net

Perceiver(
  (model): Conv2d(9, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)

In [130]:
updater_net = Updater(in_channels=NET_HIDDEN_STATE, out_channels=OUTPUT_CHANNELS)
updater_net

Updater(
  (out): Sequential(
    (0): Conv2d(10, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
)

In [131]:
nca_2d = GoalNCA(
    num_hidden_channels = NUM_HIDDEN_CHANNELS,
    num_target_channels = NUM_TARGET_CHANNELS,
    num_static_channels = NUM_STATIC_CHANNELS,
    use_alive_channel = USE_ALIVE_CHANNEL,
    perceiver = perceiver_net,
    updater= updater_net,
    num_goals=NUM_GOALS,
    cell_fire_rate = CELL_FIRE_RATE,
    clip_value = CLIP_VALUE,
    alive_threshold = ALIVE_THRESHOLD,
)

In [132]:
lit_model = NCALightningModule(
    model = nca_2d,
    train_step = TRAIN_STEP,
    seed_cache_dir = SEED_CACHE,
    seed_cache_size = SEED_CACHE_SIZE
)

Training our model

In [150]:
from lightning_module.callback import get_num_generator, VisualizeBestSeed, VisualizeRun, GoalCacheBestSeed

In [155]:
corrupt_func = transform.create_corrupt_2d_circular(h=GRID_SIZE[0], w=GRID_SIZE[1], radius=3)

In [None]:
import pytorch_lightning as pl

num_gen = get_num_generator(SEED_CACHE_SIZE)

trainer = pl.Trainer(
    max_epochs=1000,
    reload_dataloaders_every_n_epochs=1,
    callbacks=[
        GoalCacheBestSeed(cache_dir=SEED_CACHE, num_generator=num_gen),
        VisualizeBestSeed(),
        VisualizeRun(interval=3, simulate_step=TRAIN_STEP),
        # CacheCorruptedSeed(cache_dir=SEED_CACHE, num_generator=num_gen, loss_threshold=0.15, corrupt_func=corrupt_func)
    ]
)

trainer.fit(lit_model, lit_dm)