In [4]:
import pytorch_lightning as pl
import torch
from torch import nn
import torch.nn.functional as F
import einops
from pathlib import Path
from pprint import pprint
from torch import optim

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
import shutil
shutil.rmtree("./lightning_logs")

In [7]:
from utils import VideoWriter, transform
from display import NCAGrid # Resolve
from lightning_module import NCALightningModule
from dataset import NCADataModule # Resolve
from model import Updater, Perceiver, NCA

In [8]:
SEED_CACHE = Path("./seed") # For storing all seed caches
SEED_CACHE_SIZE = 64 # Must be at least the batch_size, to avoid the drop_last, Might not be neccessary tho
BATCH_SIZE = 32
TRAIN_STEP = 32

SEED_CACHE.mkdir(exist_ok=True, parents=True)
GRID_SIZE = (40, 40)
CELL_FIRE_RATE = 0.5
CLIP_VALUE = [-10, 10]
ALIVE_THRESHOLD = 0.1
USE_ALIVE_CHANNEL = True # If False, all cells are assume alive
THUMBNAIL_SIZE = 32 # This controls the size of the target image
NUM_HIDDEN_CHANNELS = 5
NUM_STATIC_CHANNELS = 0
NUM_TARGET_CHANNELS = 3
TOTAL_CHANNELS = NUM_HIDDEN_CHANNELS + NUM_STATIC_CHANNELS + NUM_TARGET_CHANNELS + 1

In [9]:
lit_dm = NCADataModule(
    seed_cache_dir=SEED_CACHE, 
    grid_size=GRID_SIZE, 
    num_hidden_channels=NUM_HIDDEN_CHANNELS, 
    num_target_channels=NUM_TARGET_CHANNELS, 
    num_static_channels=NUM_STATIC_CHANNELS, 
    target_image_path="./pic/32/emoji_u0037_20e3.png",
    batch_size=BATCH_SIZE,
    thumbnail_size=THUMBNAIL_SIZE,
    clear_cache=True
)

In [10]:
updater_net = Updater(in_channels=10, out_channels=TOTAL_CHANNELS)
updater_net

Updater(
  (out): Sequential(
    (0): Conv2d(10, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
)

In [11]:
perceiver_net = Perceiver(in_channels=TOTAL_CHANNELS, out_channels=10, groups=1)
perceiver_net

Perceiver(
  (model): Conv2d(9, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)

In [12]:
nca_2d = NCA(
    num_hidden_channels = NUM_HIDDEN_CHANNELS,
    num_target_channels = NUM_TARGET_CHANNELS,
    num_static_channels = NUM_STATIC_CHANNELS,
    use_alive_channel = USE_ALIVE_CHANNEL,
    perceiver = perceiver_net,
    updater= updater_net,
    cell_fire_rate = CELL_FIRE_RATE,
    clip_value = CLIP_VALUE,
    alive_threshold = ALIVE_THRESHOLD,
)

In [13]:
lit_model = NCALightningModule(
    model = nca_2d,
    train_step = TRAIN_STEP,
    seed_cache_dir = SEED_CACHE,
    seed_cache_size = SEED_CACHE_SIZE
)

In [19]:
from lightning_module.callback import get_num_generator, CacheBestSeed, CacheCorruptedSeed, VisualizeBestSeed, VisualizeRun

In [20]:
corrupt_func = transform.create_corrupt_2d_circular(h=GRID_SIZE[0], w=GRID_SIZE[1], radius=5)

In [22]:
import pytorch_lightning as pl

num_gen = get_num_generator(SEED_CACHE_SIZE)

trainer = pl.Trainer(
    max_epochs=1000,
    reload_dataloaders_every_n_epochs=1,
    callbacks=[
        CacheBestSeed(cache_dir=SEED_CACHE, num_generator=num_gen),
        VisualizeBestSeed(),
        VisualizeRun(interval=3, simulate_step=TRAIN_STEP),
        CacheCorruptedSeed(cache_dir=SEED_CACHE, num_generator=num_gen, loss_threshold=0.15, corrupt_func=corrupt_func)
    ]
)

trainer.fit(lit_model, lit_dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type    | Params
----------------------------------
0 | model | NCA     | 6.2 K 
1 | loss  | MSELoss | 0     
----------------------------------
6.2 K     Trainable params
0         Non-trainable params
6.2 K     Total params
0.025     Total estimated model params size (MB)


Epoch 89:   0%|                                                             | 0/2 [00:00<?, ?it/s, loss=0.189, v_num=1]