In [None]:
from collections.abc import Callable
import hashlib
import inspect
import itertools
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
import torchvision
from torchvision import datasets
from torchvision.utils import make_grid
from torchvision.transforms import ToTensor, Compose, Resize

from philosofool.torch.nn_models import (
    ResidualBlock, ResidualNetwork, NeuralNetwork,
    Generator, Discriminator,
    compute_convolution_dims, conv_dims_1d
)
from philosofool.torch.nn_loop import (
    TrainingLoop, JSONLogger, CompositeLogger, StandardOutputLogger, GANLoop, TrainingLoop,
    EndOnBatchCallback, SnapshotCallback, VerboseTrainingCallback
)
from philosofool.torch.visualize import show_image

## NN Models


In [None]:
def read_text(path) -> str:
    with open(path, 'r') as f:
        return f.read()

def get_fashionMINST_data() -> tuple[DataLoader, DataLoader, dict]:

    training_data = datasets.FashionMNIST(
        root='data', train=True, download=True, transform=ToTensor()
    )
    test_data = datasets.FashionMNIST(
        root='data', train=False, download=True, transform=ToTensor()
    )
    batch_size = 64
    train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    classes = {
        0: 'T-shirt/top',
        1: 'Trouser',
        2: 'Pullover',
        3: 'Dress',
        4: 'Coat',
        5: 'Sandal',
        6: 'Shirt',
        7: 'Sneaker',
        8: 'Bag',
        9: 'Ankle'}
    return train_dataloader, test_dataloader, classes

def get_food100_data():
    transform = Compose([Resize((256, 256)), ToTensor()])
    training_data = datasets.Food101(
        root='data', split='train', download=True, transform=transform
    )
    test_data = datasets.Food101(
        root='data', split='test', download=True, transform=transform
    )
    batch_size = 64
    train_dataloader = DataLoader(
        training_data, batch_size=batch_size, shuffle=True, num_workers=8)
    test_dataloader = DataLoader(
        test_data, batch_size=batch_size, shuffle=True, num_workers=8)

    classes_path = os.path.join(os.getcwd(), 'data/food-101/meta/classes.txt')
    classes = dict(enumerate(read_text(classes_path).split('\n')))

    return train_dataloader, test_dataloader, classes


train_dataloader, test_dataloader, classes = get_fashionMINST_data()
# train_dataloader, test_dataloader, classes = get_food100_data()

In [None]:
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

input_dims = tuple(int(x) for x in X.shape[2:])
in_channels = int(X.shape[1])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
loss_fn = nn.CrossEntropyLoss()
configurations = [
    ('baseline', {'n_hidden_layers': 1, 'width': 256}),
    ('h2_w128', {'n_hidden_layers': 2, 'width': 128}),
    ('h3_w94', {'n_hidden_layers': 3, 'width': 96}),
    ('h4_w64', {'n_hidden_layers': 4, 'width': 64})
]

In [None]:
def model_experiment(configurations):
    X, _ = next(iter(test_dataloader))
    input_dims = int(X.shape[2] * X.shape[3])
    input_channels = int(X.shape[1])
    for model_name, config in configurations:
        model = NeuralNetwork(input_dims * input_channels, len(classes), **config)
        model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        logger = CompositeLogger(JSONLogger('data/logs/food100_' + model_name + '.json'), StandardOutputLogger(500))
        training_loop = TrainingLoop(model, optimizer, loss_fn, logger)
        print(model_name)
        training_loop.fit(train_dataloader, test_dataloader, epochs=1)

# model_experiment(configurations)

In [None]:
def _make_model_path(model: nn.Module, model_name: str, directory: str | None) -> str:
    string_hash = hashlib.sha1(bytes(str(model).encode())).hexdigest()
    full_path = model_name + '_' + string_hash+ '.model'
    if directory:
        full_path = os.path.join(directory, full_path)
    return os.path.normpath(full_path)


def save_model(model: nn.Module, model_name: str, directory: str | None):
    """Save model, using string repr to assure that models are not duplicated once trained."""
    full_path = _make_model_path(model, model_name, directory)
    if not os.path.exists(os.path.dirname(full_path)):
        os.makedirs(os.path.dirname(full_path))
    torch.save(model, full_path)

def load_model(model: nn.Module, model_name: str, directory: str | None):
    full_path = _make_model_path(model, model_name, directory)
    torch.load(full_path)


In [None]:
resid_model = ResidualNetwork((256, 256), in_channels, 101).to(device)


In [None]:
optimizer = torch.optim.Adam(resid_model.parameters(), lr=1e-3)
logger = CompositeLogger(JSONLogger('food100_conv_baseline.json'), StandardOutputLogger(100))

training_loop = TrainingLoop(resid_model, optimizer, loss_fn, logger)
# training_loop.fit(train_dataloader, test_dataloader, 5)
# save_model(resid_model, 'food100_conv_baseline', 'data/models')

## GAN

Let's make anime faces.

This was a project in working on GAN models. 
"GAN" stands for geneterative adversarial network, and it's a process for training models
that can simulate the distributions of known examples.
A GAN works by having a generator and discriminator learn simultaneously.
If you're not familiar, search it on the internet. There are a lot of good explanations.

### Issues

Let me mention a bit of what I learned while working with GANs.
First, issues I ran into:
1. Mode collapse: the generator produces the same image for all inputs.
Obviously, the goal is for the generator to produce diverse outputs
from diverse inputs.
Mode collapse is a situation in which the generator learns a single output
and won't leave that spot. 
2. Stalling. (Not sure if there's a technical term for this) The generator
reaches a state of clear improvement and then stops making progress, or regresses.
3. Lack of metric measurements. 
This is simiple: there's not an easy number to associate with progress.
In short, it's not possible to tell when the model has converged.

### Lessons

The discriminator and generator need to be suited adversaries, so to speak. 
When one has significantly greater capacity than the other, it lead to mode collapse.
Additionally, both need a lot of capacity for this problem, generating 64x64 anime faces.
I tried some things that didn't work very well. 
I thought maybe just getting the models learning rates right, relative to one another,
might solve the problem. 
It did, sort of, but without enough capacity, there was stalling.
By the end, I cranked up capacity a lot in order to produce somewhat convincing images.

Stalling was pretty persistent. It's difficult, without metrics, to say when you have stalled.

This all lead to a lot of babysitting. That's the pejorative word for having to watch over models manually.
The lack of a metric was a major cause. Without a metric, it wasn't possible to implement early stopping,
programmatically select tuned hyperparameters which lead to progress, or assess the effects of hyperparameters.

I ended up literally looking at the generated outputs and selecting based on apparent diversity,
then training for several loops. Usually this stalled and I had to start over. 
I plan to add capacity and come back to this to see if I can get a very good model
in the future. 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
class AnimeDataset:
    def __init__(self, image_dir, transform=None, max_items=None):
        self.img_labels = list(os.listdir(image_dir))
        self.img_dir = image_dir
        self.transform = transform
        self._max_items = max_items

    def __len__(self):
        if self._max_items is None:
            return len(self.img_labels)
        return self._max_items

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        image = torchvision.io.decode_image(img_path)
        if self.transform:
            image = self.transform(image)
        return image

anime_dataset = AnimeDataset(
    'data/anime_faces/images',
    transform=Compose([
        torchvision.transforms.ConvertImageDtype(torch.float32),
        torchvision.transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
        Resize((64, 64))]
    ),
    max_items=None)

anime_loader = DataLoader(anime_dataset, 32, shuffle=True, num_workers=6)


In [None]:
from time import time

def optimize_loader(dataset):
    """Find fast data itertor to save time training."""
    best_time = np.inf
    for batch_size in [16, 32]:
        for n_workers in range(3, 12):
            model = Discriminator(30).to(device)
            data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=n_workers)
            start = time()
            for i, e in enumerate(data_loader):
                e = e.to(device)

                if i * batch_size >= 1280:
                    break
                model(e)
            end = time()
            current = end - start
            if current < best_time:
                best = (batch_size, n_workers)
                best_time = current
            print(f"for {batch_size} with {n_workers} took {end - start}")
    print(f"Best time was {best} with {best_time}")

# optimize_loader(dataset=anime_dataset)


In [None]:
# custom weights initialization https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
def build_gan_loop(learning_rate: float, n_generator_features: int, n_discriminator_features: int, generator_input_size: int = 100, dropout: float = 0.) -> GANLoop:
    """Take a specification of hyperparameters and return a configured GANModel."""
    generator = Generator(generator_input_size, n_generator_features, dropout=dropout)
    discriminator = Discriminator(n_discriminator_features, dropout=dropout)
    generator.apply(weights_init)
    discriminator.apply(weights_init)

    loss = nn.BCEWithLogitsLoss()
    loop = GANLoop(
        generator,
        discriminator,
        torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(.5, .999)),
        torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(.5, .999)),
        loss
    )
    return loop


class TuneGANModel:
# TODO: TuneGANModel is close to a generic HP tuner.
#       Not implemented is a validation set for learning on
#       labeled data. Add that, then move this to nn_loop.
    def __init__(self, build_loop: Callable[..., GANLoop], paramgrid: dict):
        self.build = build_loop
        self.paramgrid = paramgrid

    def select_parameters(self, n_combinations: int) -> list[dict]:
        """Return n randomly selected hyperparameter combinations from the parameter grid."""
        parameters_array = np.array(list((itertools.product(*self.paramgrid.values()))))
        parameters_indexes = np.random.choice(
            np.arange(0, parameters_array.shape[0], step=1, dtype=np.int64),
            min(n_combinations, parameters_array.shape[0]),
            replace=False)
        selected_parameters = parameters_array[parameters_indexes]
        return [self._to_parameter_dict(params) for params in selected_parameters]

    def _to_parameter_dict(self, parameters):
        """Map parameters to parameter names and conform type to self.build annotations."""
        build_parameters = inspect.signature(self.build).parameters
        type_dict = {name: param.annotation for name, param in build_parameters.items() if type(param.annotation) is type}
        out = {}
        for param_name, param_value in zip(self.paramgrid, parameters):
            param_type = type_dict.get(param_name, type(param_value))
            out[param_name] = param_type(param_value)
        return out

    def tune_model(self, data: DataLoader, n_models: int, epochs: int, max_steps: int, callbacks: list | None):
        selected_parameters = self.select_parameters(n_models)
        for parameters in selected_parameters:
            print(parameters)
            loop = self.build(**parameters)

            callbacks = callbacks if callbacks is not None else []

            loop.fit(data, epochs, callbacks)


In [None]:
class TestTuneGANMOdel:
    def test_select_params(self):
        model = TuneGANModel(build_gan_loop, {'dropout': [.01, .2], 'learning_rates': [.01, .0001], 'generator_input_size': [1, 2]})
        selected = model.select_parameters(2)
        assert len(selected) == 2
        assert 'dropout' in selected[0] and 'learning_rates' in selected[0]
        gen_size_type = type(selected[0]['generator_input_size'])
        assert gen_size_type is int, f"Expected int, but it is {gen_size_type}"

TestTuneGANMOdel().test_select_params()


In [None]:
learning_rates = np.logspace(-3, -4.5, num=20)

parameter_grid = {
    'learning_rate': learning_rates.tolist(),
    'n_generator_features': [24, 36, 40, 44, 52, 60],
    'n_discriminator_features': [24, 36, 40, 44, 52, 60],
    'generator_input_size': [150, 300]
}

gan_tuner = TuneGANModel(build_gan_loop, parameter_grid)

In [None]:
input_size = 100
features_size = 16
nn.Sequential(
    nn.ConvTranspose2d(input_size, features_size * 8, 4, 2, 1),
    nn.ConvTranspose2d(features_size * 8, features_size * 4, 8, 4, 0),
    nn.ConvTranspose2d(features_size * 4, features_size * 2, 8, 2, 0),
    nn.ConvTranspose2d(features_size * 2, features_size * 2, 8, 2, 1),
    nn.Conv2d(features_size * 2, 3, 1, 1, 0)
)(torch.randn(1, input_size, 1, 1)).shape

In [None]:
# Below, we train several models and select hyperparameters which were successful
# for further iterations.

end_batch_early = False
if end_batch_early:
    max_steps = int(2**13 / anime_loader.batch_size)
else:
    max_steps = len(anime_loader.dataset) // anime_loader.batch_size
callbacks = [
    VerboseTrainingCallback(max_steps // 4),
    SnapshotCallback(8, interval=max_steps // 4)
]
if end_batch_early:
    callbacks.append(EndOnBatchCallback(max_steps))

gan_tuner.tune_model(anime_loader, 10, 1, max_steps, callbacks)

In [None]:
# Use good parameters from above
# NOTE: if we had a metric, this could be programmatic...

params = {'learning_rate': 0.00013538761800225446, 'n_generator_features': 52, 'n_discriminator_features': 44, 'generator_input_size': 150}
loop = build_gan_loop(**params)
n_batches = len(anime_dataset) // anime_loader.batch_size
callbacks = [
    VerboseTrainingCallback(n_batches // 4),
    SnapshotCallback(8, interval=n_batches // 4)
]


In [None]:
loop.fit(anime_loader, 10, callbacks)