In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from datetime import datetime

import torch

from utils.dataset import ConditioningDataset
from nca import ConditionedNCA
from conditioned_trainer import ConditionedNCATrainer
from utils.utils import load_target_style_image, dotdict

import os
import shutil
import json

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running training on", device)

##### Set training arguments

In [None]:
args_dict = {
         'conditioning_dataset': '../../../data/random_faces/',
         'target_style_image': '../../../data/style_images/picasso.jpg',
         'exp_name': 'img_size_128_appearance_weight_5_content_weight_01_overflow_weight_10',
         'log_dir': 'logs',
         'num_hidden_channels': 16,
         'img_size': 128,
         'batch_size': 8,
         'learning_rate': 0.001,
         'epochs': 100000,
         'cell_fire_rate': 0.5,
         'pool_size': 1024,
         'damage_radius': 3,
         'appearance_loss_type': 'OT',
         'appearance_loss_weight': 5.0,
         'content_loss_weight': 0.1,
         'overflow_loss_weight': 10.0
}

args = dotdict(args_dict)

##### Create logging dir

In [None]:
overwrite = False

outdir = os.path.join(args.log_dir, args.exp_name)

if os.path.exists(outdir):
    assert overwrite, (f"Results for experiment '{args.exp_name}' " +
                            "already exist. Please set the overwrite " +
                            "variable to True if you wish to " +
                            "overwrite these results.")
    print("Overwriting previous results")
    shutil.rmtree(outdir)

os.makedirs(outdir)
print(f"Created results directory: {outdir}")

args_dict['overwrite'] = overwrite
args_log_file = os.path.join(outdir, 'args.json')
with open(args_log_file, 'w', encoding='utf-8') as f:
    json.dump(args_dict, f, ensure_ascii=False, indent=4)
print(f"Saved arguments log file to {args_log_file}")

##### Load training data

In [None]:
dataset = ConditioningDataset(args.conditioning_dataset, image_size=args.img_size)
target_style_image = load_target_style_image(args.target_style_image, size=args.img_size)

dataset.to(device)

In [None]:
NUM_HIDDEN_CHANNELS = args.num_hidden_channels

nca = ConditionedNCA(
        target_shape = dataset.target_size,
        num_hidden_channels = NUM_HIDDEN_CHANNELS,
        living_channel_dim = 3,
        cell_fire_rate = args.cell_fire_rate
)

nca = nca.to(device)

print(nca)

In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output, Markdown
from collections import defaultdict 
import numpy as np

class Visualiser():
    
    def __init__(self, show_every, curr_epoch=0):
        self.show_every = show_every
        self.curr_epoch = 0
        self.metrics = defaultdict(list)
    
    def step(self, curr_epoch, batch, outputs, targets, metrics):
        self.curr_epoch = curr_epoch
        for metric in metrics:
            if 'grad' in metric:  # Fixed 'grad' should be used as string not variable
                continue
            self.metrics[metric].append(metrics[metric])

        if self.curr_epoch % self.show_every == 0 and self.curr_epoch > 0:
            clear_output(True)
            self._plot_metrics()
            self._plot_images(batch, outputs, targets)
            

        if self.curr_epoch % 5 == 0 and self.curr_epoch > 0:
            display_str = "\n\n".join([f"**{k} loss** = {self.metrics[k][-1]}" for k in self.metrics])
            display_str += f"\n**Step**: {self.curr_epoch} / {args.epochs} \n\n"
            display(Markdown(display_str), display_id='stats')
    
    def _plot_metrics(self):
        plt.figure(figsize=(10, 6))
        for i, (metric, values) in enumerate(self.metrics.items()):
            plt.subplot(2, (len(self.metrics) + 1)//2, i+1)
            plt.scatter(range(len(values)), values, label=metric)
            plt.xlabel('Epoch')
            plt.ylabel(metric)
            plt.legend()
        plt.show()
    
    def _plot_images(self, batch, outputs, targets):
        num_images = batch.shape[0]
        fig, axes = plt.subplots(3, num_images, figsize=(15, 6))

        for i in range(num_images):
            axes[0, i].imshow(np.transpose(batch[i], (1, 2, 0)))
            axes[0, i].set_title(f'Batch {i+1}')
            axes[0, i].axis('off')

            axes[1, i].imshow(np.transpose(outputs[i], (1, 2, 0)))
            axes[1, i].set_title(f'Output {i+1}')
            axes[1, i].axis('off')

            axes[2, i].imshow(np.transpose(targets[i], (1, 2, 0)))
            axes[2, i].set_title(f'Target {i+1}')
            axes[2, i].axis('off')

        plt.tight_layout()
        plt.show()


In [None]:
visualiser = Visualiser(show_every=5)

trainer = ConditionedNCATrainer(
    nca,
    dataset,
    target_style_image,
    nca_steps=[48, 96],
    lr = args.learning_rate,
    pool_size = args.pool_size,
    num_damaged = 0,
    log_base_path = outdir,
    damage_radius = args.damage_radius,
    appearance_loss_type = args.appearance_loss_type,
    appearance_loss_weight = args.appearance_loss_weight,
    content_loss_weight = args.content_loss_weight,
    overflow_loss_weight = args.overflow_loss_weight,
    device = device,
    visualiser = visualiser
)

In [None]:
try:
    trainer.train(batch_size=args.batch_size, epochs=args.epochs)
except (KeyboardInterrupt, torch.cuda.OutOfMemoryError) as e:
    print(e)
    print('Saving latest model checkpoint...')

nca.save(f"{outdir}/ConditionedNCA.pt")