In [None]:
%load_ext autoreload

In [None]:
import warnings
from dataclasses import dataclass

import pytorch_lightning as pl
import torch
import torch.optim as optim
from PIL import ImageFile
from torch.utils.data import Dataset

from modules import NT_Xent, TransformsSimCLR, SimCLR, ImageDataset

warnings.filterwarnings('ignore')
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
@dataclass
class Config:
    seed: int = 116
    batch_size: int = 16
    epochs: int = 10

    projection_dim: int = 256
    img_size: int = 512

    temperature: float = 0.5

    train_path = './data/train'
    valid_path = './data/valid'

In [None]:
class SimCLRModel(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args

        self.model = SimCLR(args.projection_dim)
        self.criterion = NT_Xent(args.batch_size, args.temperature)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x_i, x_j = batch

        h_i, z_i = self.model(x_i)
        h_j, z_j = self.model(x_j)

        loss = self.criterion(z_i, z_j)
        log = {'train_loss': loss}

        return {'loss': loss, 'log': log}

    def validation_step(self, batch, batch_idx):
        x_i, x_j = batch

        h_i, z_i = self.model(x_i)
        h_j, z_j = self.model(x_j)

        loss = self.criterion(z_i, z_j)

        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        log = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': log}

    def configure_optimizers(self):
        return optim.AdamW(self.parameters(), lr=1e-3)

    def train_dataloader(self):
        dataset = ImageDataset(
            self.args.train_path,
            transform=TransformsSimCLR(self.hparams.img_size)
        )
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.args.batch_size,
            drop_last=True,
            shuffle=True
        )
        return train_loader

    def val_dataloader(self):
        dataset = ImageDataset(
            self.args.valid_path,
            transform=TransformsSimCLR(self.hparams.img_size)
        )
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.args.batch_size,
            drop_last=True,
            shuffle=False
        )
        return train_loader

In [None]:
args = Config()

In [None]:
pl.seed_everything(args.seed)
model = SimCLRModel(args)

trainer = pl.Trainer(
    gpus=1,
    max_epochs=args.epochs,
    gradient_clip_val=1.0,
    deterministic=True
)
trainer.fit(model)