In [None]:
import hashlib
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
import torchvision
from torchvision import datasets
from torchvision.utils import make_grid
from torchvision.transforms import ToTensor, Compose, Resize

from philosofool.torch.nn_models import (
    ResidualBlock, ResidualNetwork, NeuralNetwork,
    Generator, Discriminator,
    compute_convolution_dims, conv_dims_1d
)
from philosofool.torch.nn_loop import TrainingLoop, JSONLogger, CompositeLogger, StandardOutputLogger, GANLoop, TrainingLoop

In [None]:
def show_image(img):
    """Show image implied by input tensor. The input is assumed to be normalized."""
    img = img / 2 + 0.5     # unnormalize
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

## NN Models


In [None]:
def read_text(path) -> str:
    with open(path, 'r') as f:
        return f.read()

def get_fashionMINST_data() -> tuple[DataLoader, DataLoader, dict]:

    training_data = datasets.FashionMNIST(
        root='data', train=True, download=True, transform=ToTensor()
    )
    test_data = datasets.FashionMNIST(
        root='data', train=False, download=True, transform=ToTensor()
    )
    batch_size = 64
    train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    classes = {
        0: 'T-shirt/top',
        1: 'Trouser',
        2: 'Pullover',
        3: 'Dress',
        4: 'Coat',
        5: 'Sandal',
        6: 'Shirt',
        7: 'Sneaker',
        8: 'Bag',
        9: 'Ankle'}
    return train_dataloader, test_dataloader, classes

def get_food100_data():
    transform = Compose([Resize((256, 256)), ToTensor()])
    training_data = datasets.Food101(
        root='data', split='train', download=True, transform=transform
    )
    test_data = datasets.Food101(
        root='data', split='test', download=True, transform=transform
    )
    batch_size = 64
    train_dataloader = DataLoader(
        training_data, batch_size=batch_size, shuffle=True, num_workers=8)
    test_dataloader = DataLoader(
        test_data, batch_size=batch_size, shuffle=True, num_workers=8)

    classes_path = os.path.join(os.getcwd(), 'data/food-101/meta/classes.txt')
    classes = dict(enumerate(read_text(classes_path).split('\n')))

    return train_dataloader, test_dataloader, classes


train_dataloader, test_dataloader, classes = get_fashionMINST_data()
# train_dataloader, test_dataloader, classes = get_food100_data()

In [None]:
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

input_dims = tuple(int(x) for x in X.shape[2:])
in_channels = int(X.shape[1])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
loss_fn = nn.CrossEntropyLoss()
configurations = [
    ('baseline', {'n_hidden_layers': 1, 'width': 256}),
    ('h2_w128', {'n_hidden_layers': 2, 'width': 128}),
    ('h3_w94', {'n_hidden_layers': 3, 'width': 96}),
    ('h4_w64', {'n_hidden_layers': 4, 'width': 64})
]

In [None]:
def model_experiment(configurations):
    X, _ = next(iter(test_dataloader))
    input_dims = int(X.shape[2] * X.shape[3])
    input_channels = int(X.shape[1])
    for model_name, config in configurations:
        model = NeuralNetwork(input_dims * input_channels, len(classes), **config)
        model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        logger = CompositeLogger(JSONLogger('data/logs/food100_' + model_name + '.json'), StandardOutputLogger(500))
        training_loop = TrainingLoop(model, optimizer, loss_fn, logger)
        print(model_name)
        training_loop.fit(train_dataloader, test_dataloader, epochs=1)

# model_experiment(configurations)

In [None]:
def _make_model_path(model: nn.Module, model_name: str, directory: str | None) -> str:
    string_hash = hashlib.sha1(bytes(str(model).encode())).hexdigest()
    full_path = model_name + '_' + string_hash+ '.model'
    if directory:
        full_path = os.path.join(directory, full_path)
    return os.path.normpath(full_path)


def save_model(model: nn.Module, model_name: str, directory: str | None):
    """Save model, using string repr to assure that models are not duplicated once trained."""
    full_path = _make_model_path(model, model_name, directory)
    if not os.path.exists(os.path.dirname(full_path)):
        os.makedirs(os.path.dirname(full_path))
    torch.save(model, full_path)

def load_model(model: nn.Module, model_name: str, directory: str | None):
    full_path = _make_model_path(model, model_name, directory)
    torch.load(full_path)


In [None]:
resid_model = ResidualNetwork((256, 256), in_channels, 101).to(device)


In [None]:
optimizer = torch.optim.Adam(resid_model.parameters(), lr=1e-3)
logger = CompositeLogger(JSONLogger('food100_conv_baseline.json'), StandardOutputLogger(100))

training_loop = TrainingLoop(resid_model, optimizer, loss_fn, logger)
# training_loop.fit(train_dataloader, test_dataloader, 5)
# save_model(resid_model, 'food100_conv_baseline', 'data/models')

## GAN

Let's make anime faces.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
class AnimeDataset:
    def __init__(self, image_dir, transform=None, max_items=None):
        self.img_labels = list(os.listdir(image_dir))
        self.img_dir = image_dir
        self.transform = transform
        self._max_items = max_items

    def __len__(self):
        if self._max_items is None:
            return len(self.img_labels)
        return self._max_items

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        image = torchvision.io.decode_image(img_path)
        if self.transform:
            image = self.transform(image)
        return image

anime_dataset = AnimeDataset(
    'data/anime_faces/images',
    transform=Compose([
        torchvision.transforms.ConvertImageDtype(torch.float32),
        torchvision.transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
        Resize((64, 64))]
    ),
    max_items=None)

anime_loader = DataLoader(anime_dataset, 64, shuffle=True)


In [None]:
# custom weights initialization https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
def make_adversaries(n_features: int = 16):
    generator = Generator(100, n_features, dropout=0.)
    discriminator = Discriminator(n_features, dropout=0.)
    generator.apply(weights_init)
    discriminator.apply(weights_init)
    return generator, discriminator

In [None]:
class CaptureGeneratedImages:
    def __init__(self, n_images: int):
        self.n_images = n_images
        self._captured = []

    def capture(self, generator: Generator):
        random_inputs = torch.randn((self.n_images, generator.input_size, 1, 1)).to(device)
        tensors = generator(random_inputs)
        self._captured.append(tensors)

    def save(self, path):
        torch.save(self._captured, path)


In [None]:
class EndOnBatch:
    """After a number of batches, proceed to the next epoch."""
    def __init__(self, end_on: int):
        self.end_on = end_on
        self.on = 'batch_end'

    def __call__(self, loop, batch: int, **kwargs):
        if batch == self.end_on:
            return 'end_batch'

def test_end_on_batch():
    end_on_batch = EndOnBatch(1)
    assert end_on_batch.on == 'batch_end'
    assert end_on_batch(None, batch=1, gen_loss=.1) == 'end_batch'
    assert end_on_batch(None, batch=2) is None


class SnapshotCallback:
    """Collect snapshots on an interval."""
    def __init__(self, n_images: int, on: str, interval: int):
        self.interval = interval
        self.n_images = n_images
        self.on = on
        self.snapshots = []
        self._random_inputs = {}

    def __call__(self, loop, **kwargs):
        interval = self._get_interval(kwargs)
        if interval % self.interval != 0:
            return
        random_input = self._random_input(loop.generator.input_size).to(loop._device)
        result = loop.generator(random_input)
        self.snapshots.append(result)

    def _get_interval(self, kwargs):
        if self.on == 'batch_end':
            return kwargs['batch']
        if self.on == 'epoch_end':
            return kwargs['epoch']
        # for example, to call on 'fit_end': assume interval is one.
        return 1

    def _random_input(self, size: int):
        if size in self._random_inputs:
            return self._random_inputs[size]
        random_input = torch.randn((self.n_images, size, 1, 1))
        self._random_inputs[size] = random_input
        return random_input


def test_snapshot_callback():
    generator, discriminator = make_adversaries(6)

    loss = nn.BCEWithLogitsLoss()
    loop = GANLoop(
        generator,
        discriminator,
        torch.optim.SGD(generator.parameters(), lr=.01, momentum=0),
        torch.optim.SGD(discriminator.parameters(), lr=.01, momentum=0),
        loss
    )
    loop.generator.eval()
    snapshot_callback = SnapshotCallback(n_images=1, interval=2, on='batch_end')
    assert snapshot_callback.on == 'batch_end'
    snapshot_callback(loop, batch=2, gen_loss=.1, dis_loss=.1)
    assert len(snapshot_callback.snapshots ) == 1, "Snapshots should update when called on batch interval."
    snapshot_callback(loop, batch=3, gen_loss=.1, dis_loss=.1)
    assert len(snapshot_callback.snapshots ) == 1, "Snapshots should update only on interval."
    snapshot_callback(loop, batch=2, gen_loss=.1, dis_loss=.1)
    assert len(snapshot_callback.snapshots) == 2, "Snapshots should update when called on batch interval."
    snapshots = snapshot_callback.snapshots
    assert torch.allclose(snapshots[0].detach(), snapshots[1].detach())

    snapshot_callback = SnapshotCallback(n_images=8, interval=1, on='epoch_end')
    snapshot_callback(loop, epoch=1)
    assert snapshot_callback.snapshots[0].shape[0] == snapshot_callback.n_images == 8

class StandardOutputCallback:
    def __init__(self, interval: int):
        self.on = 'batch_end'
        self.interval = interval

    def __call__(self, loop, **kwargs):
        batch = kwargs['batch']
        if batch % self.interval != 0:
            return
        strings = []
        for key, value in kwargs.items():
            strings.append(f"{key}: {value}")
        print(', '.join(strings))

def test_standard_ouput_callback():

    callback = StandardOutputCallback(interval=2)
    assert callback.on == 'batch_end'
    callback(None, batch=1)
    callback(None, batch=2, gen_loss=.1, dis_loss=.2)

test_end_on_batch()
test_snapshot_callback()
test_standard_ouput_callback()

In [None]:
def _model_path(path, index):
    return f'{path}_{index}.pth'

def save_tuned_model(loop: GANLoop, path: str, index: int, meta: dict | None = None):
    directory, name = os.path.split(path)
    os.makedirs(directory, exist_ok=True)
    full_path = _model_path(path, index)
    loop.save_checkpoint(full_path, meta)

def load_tuned_model(path, index) -> tuple[GANLoop, dict]:
    full_path = _model_path(path, index)
    loop, meta = GANLoop.load_checkpoint(full_path)
    return loop, meta

def plot_loop_losses(g_loss, d_loss):
    fig, ax = plt.subplots(1, 1, figsize=(6, 3))
    ax.plot(range(len(g_loss)), g_loss, label='generator loss')
    ax.plot(range(len(d_loss)), d_loss, label='discriminator loss')
    fig.legend()

def tune_learning_rate(rates: list, n_combinations: int, max_steps: int, epochs: int = 1):
    import itertools
    rate_pairs = np.array(list((itertools.product(rates, rates))))

    rate_indexes = np.random.choice(
        np.arange(0, rate_pairs.shape[0], step=1, dtype=np.int64),
        min(n_combinations, rate_pairs.shape[0]),
        replace=False)

    results = {}
    max_steps = min(len(anime_dataset), max_steps)
    snapshot_interval = max_steps // 8 if max_steps is not None else None
    capture = CaptureGeneratedImages(8)
    for idx in rate_indexes:
        gen_rate, dis_rate = rate_pairs[idx]
        print(f"Starting with {idx} generator rate {gen_rate}, discriminator rate {dis_rate}")
        generator, discriminator = make_adversaries(32)

        loss = nn.BCEWithLogitsLoss()
        loop = GANLoop(
            generator,
            discriminator,
            torch.optim.Adam(generator.parameters(), lr=gen_rate, betas=(.5, .999)),
            # torch.optim.SGD(generator.parameters(), lr=gen_rate, momentum=0),
            torch.optim.Adam(discriminator.parameters(), lr=dis_rate, betas=(.5, .999)),
            # torch.optim.SGD(discriminator.parameters(), lr=dis_rate, momentum=.1),
            loss
        )
        random_inputs = torch.randn((8, generator.input_size, 1, 1)).to(device)
        input_one = random_inputs[0].view(1, generator.input_size, 1, 1)

        images = SnapshotCallback(8, 'epoch_end', 1)
        snapshots = SnapshotCallback(1, 'batch_end', max_steps // 8)
        callbacks = [StandardOutputCallback(max_steps // 4), snapshots, images]

        if max_steps:
            callbacks.append(EndOnBatch(max_steps))

        loop.fit(anime_loader, epochs, callbacks)
        results[idx] = {'gen_rate': gen_rate, 'dis_rate': dis_rate, 'images': images.snapshots, 'snapshots': snapshots.snapshots}
        save_tuned_model(loop, 'gan_tuning/model', idx, results[idx])
        show_image(make_grid(images.snapshots[-1].to('cpu'), nrow=4))
        plt.show()
        plot_loop_losses(loop.history['gen_loss'], loop.history['dis_loss'])
        plt.show()

        print()
    # capture.save('gan_tuning/generator_images.pth')
    return results


learning_rates = np.logspace(-3, -4.5, num=20)
results = tune_learning_rate(learning_rates.tolist(), 10, 100)


In [None]:
for k, r in results.items():
    print(k)
    show_image(make_grid(r['images'][0].to('cpu'), nrow=4))
    show_image(make_grid(torch.concat(r['snapshots']).to('cpu')))

In [None]:
{int(k): (float(v['gen_rate']), v['dis_rate']) for k, v in results.items()}

In [None]:
selected_loop = 112
try:
    loop, meta = GANLoop.load_checkpoint((f'well_tuned_{selected_loop}.pth'))
except FileNotFoundError:
    print("Not found.")
    loop, meta = load_tuned_model('gan_tuning/model', selected_loop)
random_inputs = torch.randn((8, loop.generator.input_size, 1, 1)).to(device)


In [None]:
epochs = 20
for epoch in range(epochs):
    show_image(make_grid(loop.generator(random_inputs).cpu()))
    step = 0
    print(f"Epoch: {epoch}")
    for gen_loss, dis_loss in loop.step(anime_loader):
        if step % (len(anime_dataset) // anime_loader.batch_size // 5) == 0:
            print(f"Gen loss: {gen_loss}, dis loss: {dis_loss}")
        step += 1

show_image(make_grid(loop.generator(random_inputs).cpu()))


In [None]:
loop.save_checkpoint(f'well_tuned_{selected_loop}.pth', meta)

In [None]:
show_image(loop.generator(random_inputs.to(device)).cpu()[-4])

In [None]:
def report_result(result: dict):
    x = np.arange(0, len(result['gen_loss']))
    plt.plot(x, result['gen_loss'], label='generator loss')
    plt.plot(x, result['dis_loss'], label='dis_loss')

report_result(meta)