In [None]:
!nvidia-smi

In [None]:
!pip install 'git+https://github.com/rphln/Cardboard4.git@code' 'kornia==0.5.11'

In [None]:
from os import environ
from pathlib import Path

import torch
from ignite.contrib.handlers import ProgressBar, WandBLogger
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
from ignite.handlers import (
    Checkpoint,
    DiskSaver,
    EarlyStopping,
    LRScheduler,
    global_step_from_engine,
)
from ignite.metrics import Loss, RunningAverage
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.nn.functional import mse_loss
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

from cardboard4 import TensorPairsDataset, mean_psnr, mean_ssim, MultiScaleSSIM
from cardboard4.models import ResidualNetwork

from kornia.color import rgb_to_grayscale
from kornia.filters import gaussian_blur2d

In [None]:
try:
    from google.colab import drive

    drive.mount("/content/drive/", force_remount=True)
except ImportError:
    ROOT = Path("var/")
else:
    ROOT = Path("/content/drive/MyDrive/")
finally:
    ROOT.mkdir(parents=True, exist_ok=True)

In [None]:
!rsync --archive --ignore-existing --human-readable --info progress2 '/content/drive/MyDrive/rphln-danbooru2020-small/' '/dev/shm/'

In [None]:
!cp -r '/content/drive/MyDrive/.netrc' '/root/.netrc'
!wandb sync '/content/drive/MyDrive/wandb/latest-run/'

In [None]:
config = {
    "BATCH_SIZE": 16,
    "LEARNING_RATE": 1e-3,
    "PATIENCE": 50,
}

device = torch.device("cuda")

if True:
    checkpoint = None
else:
    environ["WANDB_RESUME"] = "allow"
    environ["WANDB_RUN_ID"] = "365ogza8"

    checkpoint = torch.load(
        ROOT
        / "wandb"
        / f"run-20210722_175244-{environ['WANDB_RUN_ID']}"
        / "files"
        / "checkpoint_192_loss=-0.0061.pt"
    )

# test_with = TensorPairsDataset(ROOT / "rphln-danbooru2020-small" / "test")
# train_with = TensorPairsDataset(ROOT / "rphln-danbooru2020-small" / "train")

test_with = TensorPairsDataset(Path("/dev/shm/test"))
train_with = TensorPairsDataset(Path("/dev/shm/train"))

def pencil_sketch(image: Tensor, eps: float = 1e-24) -> Tensor:
    greyscale = rgb_to_grayscale(image)
    blurred = gaussian_blur2d(greyscale, kernel_size=(21, 21), sigma=(1.0, 1.0))

    return torch.clamp((greyscale + eps) / (blurred + eps), 0.0, 1.0)


def criterion(x, y):
    return mse_loss(x, y) + mse_loss(pencil_sketch(x), pencil_sketch(y))

model = ResidualNetwork().to(device)

optimizer = Adam(model.parameters(), lr=config["LEARNING_RATE"])

trainer = create_supervised_trainer(model, optimizer, criterion, device)

scheduler = StepLR(optimizer, step_size=50, gamma=0.1, verbose=True)
trainer.add_event_handler(Events.EPOCH_COMPLETED, LRScheduler(scheduler))

metrics = {
    "loss": Loss(criterion),
    "psnr": Loss(mean_psnr),
    "ssim": Loss(mean_ssim),
}

validation = create_supervised_evaluator(model, metrics, device)
testing = create_supervised_evaluator(model, metrics, device)

average = RunningAverage(output_transform=lambda loss: loss)
average.attach(trainer, "loss")

progress = ProgressBar()
progress.attach(trainer, metric_names="all")

train_with, validate_with = train_test_split(train_with, test_size=0.2, shuffle=False)

training_data_loader = DataLoader(
    train_with,
    config["BATCH_SIZE"],
    drop_last=True,
    pin_memory=True,
)
validation_data_loader = DataLoader(
    validate_with,
    config["BATCH_SIZE"],
    drop_last=True,
    pin_memory=True,
)
testing_data_loader = DataLoader(
    test_with,
    config["BATCH_SIZE"],
    drop_last=True,
    pin_memory=True,
)


@trainer.on(Events.EPOCH_COMPLETED)
def compute_metrics():
    validation.run(validation_data_loader)


@trainer.on(Events.COMPLETED)
def compute_testing_metrics():
    testing.run(testing_data_loader)


logger = WandBLogger(
    project="Cardboard4",
    config=config,
    save_code=True,
    dir=ROOT,
    name="Residual@PencilSketch",
    tags=[model.__class__.__name__, "PencilSketch", "Check"],
)
logger.watch(model, criterion, log="all")

logger.attach_output_handler(
    engine=trainer,
    event_name=Events.EPOCH_COMPLETED,
    tag="training",
    metric_names="all",
    global_step_transform=global_step_from_engine(trainer),
)

logger.attach_output_handler(
    engine=validation,
    event_name=Events.EPOCH_COMPLETED,
    tag="validation",
    metric_names="all",
    global_step_transform=global_step_from_engine(trainer),
)

logger.attach_output_handler(
    engine=testing,
    event_name=Events.EPOCH_COMPLETED,
    tag="testing",
    metric_names="all",
    global_step_transform=global_step_from_engine(trainer),
)

neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0)

halt = EarlyStopping(
    trainer=trainer,
    patience=config["PATIENCE"],
    score_function=neg_loss_score,
)
validation.add_event_handler(Events.COMPLETED, halt)

validation.add_event_handler(
    Events.COMPLETED,
    Checkpoint(
        n_saved=10,
        to_save={
            "scheduler": scheduler,
            "halt": halt,
            "model": model,
            "optimizer": optimizer,
            "trainer": trainer,
            "validator": validation,
        },
        include_self=True,
        save_handler=DiskSaver(dirname=logger.run.dir, require_empty=False),
        score_name="loss",
        score_function=neg_loss_score,
        global_step_transform=global_step_from_engine(trainer),
    ),
)

In [None]:
if checkpoint:
    scheduler.load_state_dict(checkpoint["scheduler"])
    halt.load_state_dict(checkpoint["halt"])
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    trainer.load_state_dict(checkpoint["trainer"])
    validation.load_state_dict(checkpoint["validator"])

trainer.run(training_data_loader, 1_500)