Preliminary

In [1]:
import sys
import torch

sys.path.append("..")

from sslearn.archs import ResNet
from sslearn.models.pretraining import SwAV
from sslearn.models.finetuning import Classifier
from sslearn.training.validators import TopKNN, Accuracy
from sslearn.training.schedulers import CosineAnnealingLinearWarmup
from sslearn.training import Trainer
from utils import load_cifar10, plot_results

device = "cuda" if torch.cuda.is_available() else "cpu"
data_root = "static/datasets"
save_path = "weights"
plot_path = "static/plots"

  warn(f"Failed to load image Python extension: {e}")


Encoder definition

In [2]:
encoder = ResNet(channels_in=3, model_name="resnet-18", cifar10=True)

Pretraining

In [4]:
def pretrain(encoder):
    
    model = SwAV(encoder, hidden_dim=2048, head_dim=128, temperature=0.1, num_prototypes=3000, freeze_iters=300,
                 global_crop_info=[(2, 32)], local_crop_info=[])

    dataloaders = {
        "train" : load_cifar10(data_root, train=True, batch_size=512, shuffle=True),
        "valid" : load_cifar10(data_root, train=False, batch_size=1024),
        "index" : load_cifar10(data_root, train=True, batch_size=1024),
    }

    epochs = 50
    warmup_steps = 5 * len(dataloaders["train"])
    total_iters = epochs * len(dataloaders["train"])

    validator = TopKNN(dataloaders, device=device)
    optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
    scheduler = CosineAnnealingLinearWarmup(optim, warmup_steps, total_iters, min_lr=1e-6)

    trainer = Trainer(optim, scheduler, validator)
    losses, valid_metrics = trainer.train(model, dataloaders["train"], epochs, save_path, device=device)

    plot_results(losses, epochs, title="Pretraining", y_str="Loss", path=plot_path)
    plot_results(valid_metrics, epochs, title="Pretraining validation", 
                y_str=validator.metric_str.capitalize(), color="orange", path=plot_path)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [None]:
pretrain(encoder)

Finetuning

In [None]:
def finetune(encoder):
    
    model = Classifier(encoder, hidden_dim=2048, num_classes=10)

    dataloaders = {
        "train" : load_cifar10(data_root, train=True, batch_size=1024, shuffle=True, device=device),
        "valid" : load_cifar10(data_root, train=False, batch_size=1024),
    }

    epochs = 50
    total_iters = epochs * len(dataloaders["train"])
    
    validator = Accuracy(dataloaders, device=device)
    optim = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_iters)

    trainer = Trainer(optim, scheduler, validator)
    losses, valid_metrics = trainer.train(model, dataloaders["train"], epochs, save_path)

    plot_results(losses, epochs, title="Finetuning", y_str="Loss", path=plot_path)
    plot_results(valid_metrics, epochs, title="Finetuning validation", 
                y_str=validator.metric_str.capitalize(), color="orange", path=plot_path)

In [None]:
finetune(encoder)