In [None]:
!pip install 'git+https://github.com/rphln/Cardboard4.git@code'

In [None]:
from pathlib import Path

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from torch import Tensor
from torch.nn import Conv2d, Module, MSELoss, PixelShuffle, Sequential, Tanh
from torch.optim import Adam
from torchvision.transforms import Normalize

from cardboard4 import TensorPairsDataModule, mean_psnr, mean_ssim

In [None]:
try:
    from google.colab import drive

    drive.mount("/content/drive/", force_remount=True)
except ImportError:
    ROOT = Path("var/")
else:
    ROOT = Path("/content/drive/MyDrive/")
finally:
    ROOT.mkdir(parents=True, exist_ok=True)

In [None]:
!cp '/content/drive/MyDrive/.netrc' '/root/.netrc'

In [None]:
class ESPCN(Sequential):
    SCALE = 4
    N0 = 3

    F1 = 5
    N1 = 64

    F2 = 3
    N2 = 32

    F3 = 3
    N3 = N0 * (SCALE ** 2)

    def __init__(self):
        super().__init__()

        self.normalize = Normalize(
            std=[0.2931, 0.2985, 0.2946], mean=[0.7026, 0.6407, 0.6265]
        )

        self.stem = Sequential(
            Conv2d(self.N0, self.N1, self.F1, padding="same"),
            Tanh(),
            Conv2d(self.N1, self.N2, self.F2, padding="same"),
            Tanh(),
        )
        self.head = Sequential(
            Conv2d(self.N2, self.N3, self.F3, padding="same"),
            PixelShuffle(4),
        )

In [None]:
class LitModel(LightningModule):
    def __init__(self, model: Module, criterion: Module, learning_rate: float = 0.001):
        super().__init__()

        self.model = model
        self.criterion = criterion

        self.learning_rate = learning_rate

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def training_step(self, batch, index):
        x, y = batch
        z = self(x)

        loss = self.criterion(z, y)
        self.log("Training/Loss", loss)

        return loss

    def validation_step(self, batch, index):
        x, y = batch
        z = self(x)

        loss = self.criterion(z, y)
        self.log("Validation/Loss", loss)

        psnr = mean_psnr(z, y)
        self.log("Validation/PSNR", psnr)

        ssim = mean_ssim(z, y)
        self.log("Validation/SSIM", ssim)

    def test_step(self, batch, index):
        x, y = batch
        z = self(x)

        psnr = mean_psnr(z, y)
        self.log("Testing/PSNR", psnr)

        ssim = mean_ssim(z, y)
        self.log("Testing/SSIM", ssim)

    def configure_optimizers(self):
        parameters = [
            {"params": self.model.stem.parameters(), "lr": self.learning_rate},
            {"params": self.model.head.parameters(), "lr": self.learning_rate * 0.1},
        ]

        return Adam(parameters)

In [None]:
config = {
    "BATCH_SIZE": 1024,
    "LEARNING_RATE": 0.001,
}

model = LitModel(
    model=ESPCN(),
    criterion=MSELoss(),
    learning_rate=config["LEARNING_RATE"],
)

data = TensorPairsDataModule(
    batch_size=config["BATCH_SIZE"],
    test_with=ROOT / "test",
    train_with=ROOT / "train",
)

wandb = WandbLogger(project="Cardboard4", log_model=True, config=config, save_dir=ROOT)
wandb.watch(model)

trainer = Trainer(
    max_epochs=1_000,
    gpus=1,
    precision=16,
    logger=wandb,
    weights_save_path=wandb.experiment.dir,
)

trainer.fit(model, datamodule=data)
trainer.test(datamodule=data)