Preliminary

In [None]:
import sys
import torch

sys.path.append("..")

from sslearn.archs import ResNet
from sslearn.models.pretraining import BYOL
from sslearn.models.finetuning import Classifier
from sslearn.training.validators import TopKNN, Accuracy
from sslearn.training.schedulers import CosineAnnealingLinearWarmup
from sslearn.training import Trainer
from utils import load_cifar10, plot_results

device = "cuda" if torch.cuda.is_available() else "cpu"
data_root = "static/datasets"
save_path = "weights"
plot_path = "static/plots"

Encoder definition

In [None]:
encoder = ResNet(channels_in=3, model_name="resnet-18", cifar10=True)

Pretraining

In [None]:
def pretrain(encoder):

    dataloaders = {
        "train" : load_cifar10(data_root, train=True, batch_size=512, shuffle=True),
        "valid" : load_cifar10(data_root, train=False, batch_size=1024),
        "index" : load_cifar10(data_root, train=True, batch_size=1024),
    }

    epochs = 100
    total_iters = epochs * len(dataloaders["train"])
    warmup_steps = 10 * len(dataloaders["train"])

    model = BYOL(encoder, total_iters=total_iters, hidden_dim=2048, 
                 head_dim=256, decay_base=0.99)

    validator = TopKNN(dataloaders, device=device)
    optim = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
    scheduler = CosineAnnealingLinearWarmup(optim, warmup_steps, total_iters, min_lr=3e-7)

    trainer = Trainer(optim, scheduler, validator)
    losses, valid_metrics = trainer.train(model, dataloaders["train"], epochs, save_path, device=device)

    plot_results(losses, epochs, title="Pretraining", y_str="Loss", path=plot_path)
    plot_results(valid_metrics, epochs, title="Pretraining validation", 
                 y_str=validator.metric_str.capitalize(), color="orange", path=plot_path)

In [None]:
pretrain(encoder)

Finetuning

In [None]:
def finetune(encoder):
    
    model = Classifier(encoder, num_classes=10)

    dataloaders = {
        "train" : load_cifar10(data_root, train=True, batch_size=512, shuffle=True),
        "valid" : load_cifar10(data_root, train=False, batch_size=1024),
    }

    epochs = 50
    total_iters = epochs * len(dataloaders["train"])
    warmup_steps = 5 * len(dataloaders["train"])

    validator = Accuracy(dataloaders, device=device)
    optim = torch.optim.SGD(model.parameters(), lr=1e-4, weight_decay=1e-4, momentum=0.9)
    scheduler = CosineAnnealingLinearWarmup(optim, warmup_steps, total_iters, min_lr=1e-7)

    trainer = Trainer(optim, scheduler, validator)
    losses, valid_metrics = trainer.train(model, dataloaders["train"], epochs, save_path)

    plot_results(losses, epochs, title="Finetuning", y_str="Loss", path=plot_path)
    plot_results(valid_metrics, epochs, title="Finetuning validation", 
                 y_str=validator.metric_str.capitalize(), color="orange", path=plot_path)

In [None]:
finetune(encoder)