In [None]:
!pip install 'git+https://github.com/rphln/Cardboard4.git@code'

from os import cpu_count

In [None]:
from pathlib import Path

import torch
from ignite.contrib.handlers import ProgressBar, WandBLogger
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
from ignite.handlers import (
    Checkpoint,
    DiskSaver,
    EarlyStopping,
    global_step_from_engine,
)
from ignite.metrics import Loss, RunningAverage
from sklearn.model_selection import train_test_split
from torch.nn import Conv2d, MSELoss, PixelShuffle, Sequential, Tanh
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.transforms import Normalize

from cardboard4 import TensorPairsDataset, mean_psnr, mean_ssim

In [None]:
try:
    from google.colab import drive

    drive.mount("/content/drive/", force_remount=True)
except ImportError:
    ROOT = Path("var/")
else:
    ROOT = Path("/content/drive/MyDrive/")
finally:
    ROOT.mkdir(parents=True, exist_ok=True)

In [None]:
!rsync --archive --ignore-existing --human-readable --info progress2 '/content/drive/MyDrive/rphln-danbooru2020-small' 'var/'

In [None]:
!cp -r '/content/drive/MyDrive/.netrc' '/root/.netrc'

In [None]:
class ESPCN(Sequential):
    SCALE = 4
    N0 = 3

    F1 = 5
    N1 = 64

    F2 = 3
    N2 = 32

    F3 = 3
    N3 = N0 * (SCALE ** 2)

    def __init__(self):
        super().__init__()

        self.normalize = Normalize(
            std=[0.2931, 0.2985, 0.2946], mean=[0.7026, 0.6407, 0.6265]
        )

        self.stem = Sequential(
            Conv2d(self.N0, self.N1, self.F1, padding="same"),
            Tanh(),
            Conv2d(self.N1, self.N2, self.F2, padding="same"),
            Tanh(),
        )
        self.head = Sequential(
            Conv2d(self.N2, self.N3, self.F3, padding="same"),
            PixelShuffle(4),
        )

In [None]:
config = {
    "BATCH_SIZE": 2048,
    "LEARNING_RATE": 0.001,
    "PATIENCE": 100,
}

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

test_with = TensorPairsDataset(Path("var") / "rphln-danbooru2020-small" / "test")
train_with = TensorPairsDataset(Path("var") / "rphln-danbooru2020-small" / "train")

train_with, validate_with = train_test_split(train_with, test_size=0.2, shuffle=False)

training_data_loader = DataLoader(
    train_with,
    config["BATCH_SIZE"],
    drop_last=True,
    pin_memory=True,
    num_workers=cpu_count(),
    persistent_workers=True,
)
validation_data_loader = DataLoader(
    validate_with,
    config["BATCH_SIZE"],
    drop_last=True,
    pin_memory=True,
    num_workers=cpu_count(),
    persistent_workers=True,
)
testing_data_loader = DataLoader(
    test_with,
    config["BATCH_SIZE"],
    drop_last=True,
    pin_memory=True,
    num_workers=cpu_count(),
    persistent_workers=True,
)

criterion = MSELoss()

model = ESPCN()
model.to(device)

parameters = [
    {"params": model.stem.parameters(), "lr": config["LEARNING_RATE"]},
    {"params": model.head.parameters(), "lr": config["LEARNING_RATE"] * 0.1},
]

optimizer = Adam(parameters)

metrics = {
    "loss": Loss(criterion),
    "psnr": Loss(mean_psnr),
    "ssim": Loss(mean_ssim),
}

trainer = create_supervised_trainer(model, optimizer, criterion, device)

validation = create_supervised_evaluator(model, metrics, device)
testing = create_supervised_evaluator(model, metrics, device)

average = RunningAverage(output_transform=lambda loss: loss)
average.attach(trainer, "loss")

ProgressBar(persist=True).attach(trainer, metric_names="all")
ProgressBar(persist=True).attach(validation, metric_names="all")
ProgressBar(persist=True).attach(testing, metric_names="all")


@trainer.on(Events.EPOCH_COMPLETED)
def compute_metrics():
    validation.run(validation_data_loader)


@trainer.on(Events.COMPLETED)
def compute_testing_metrics():
    testing.run(testing_data_loader)


logger = WandBLogger(
    project="Cardboard4",
    config=config,
    save_code=True,
    dir=ROOT / "wandb",
)
logger.watch(model, criterion, log="all")

logger.attach_output_handler(
    engine=trainer,
    event_name=Events.EPOCH_COMPLETED,
    tag="training",
    metric_names="all",
    global_step_transform=global_step_from_engine(trainer),
)

logger.attach_output_handler(
    engine=validation,
    event_name=Events.EPOCH_COMPLETED,
    tag="validation",
    metric_names="all",
    global_step_transform=global_step_from_engine(trainer),
)

logger.attach_output_handler(
    engine=testing,
    event_name=Events.EPOCH_COMPLETED,
    tag="testing",
    metric_names="all",
    global_step_transform=global_step_from_engine(trainer),
)

neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0)

halt = EarlyStopping(
    trainer=trainer,
    patience=config["PATIENCE"],
    score_function=neg_loss_score,
)
validation.add_event_handler(Events.COMPLETED, halt)

validation.add_event_handler(
    Events.COMPLETED,
    Checkpoint(
        n_saved=1,
        to_save={
            "halt": halt,
            "model": model,
            "optimizer": optimizer,
            "trainer": trainer,
            "validator": validation,
        },
        include_self=True,
        save_handler=DiskSaver(dirname=logger.run.dir, require_empty=False),
        score_name="loss",
        score_function=neg_loss_score,
        global_step_transform=global_step_from_engine(trainer),
    ),
)

trainer.run(training_data_loader, 3_000)