In [5]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet18
from tqdm import tqdm
import os
import random
import numpy as np

In [4]:
class TwoCropsTransform:
    def __init__(self, t1, t2):
        self.t1 = t1
        self.t2 = t2
    def __call__(self, x):
        return self.t1(x), self.t2(x)

class ConsecutivePairDataset(Dataset):

    def __init__(self, base_dataset):
        self.base = base_dataset
    def __len__(self):
        return len(self.base) - 1
    def __getitem__(self, idx):
        a, _ = self.base[idx]
        b, _ = self.base[idx + 1]
        return a, b

def off_diagonal(x: torch.Tensor) -> torch.Tensor:

    n = x.shape[0]
    assert x.dim() == 2 and x.shape[0] == x.shape[1], "off_diagonal requires square matrix"
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

class VICRegLoss(nn.Module):
    def __init__(self, sim_coeff=25.0, std_coeff=25.0, cov_coeff=1.0, eps=1e-4):
        super().__init__()
        self.sim_coeff = sim_coeff
        self.std_coeff = std_coeff
        self.cov_coeff = cov_coeff
        self.eps = eps

    def forward(self, z1, z2):

        repr_loss = F.mse_loss(z1, z2)

        z1 = z1 - z1.mean(dim=0)
        z2 = z2 - z2.mean(dim=0)

        std_z1 = torch.sqrt(z1.var(dim=0, unbiased=False) + self.eps)
        std_z2 = torch.sqrt(z2.var(dim=0, unbiased=False) + self.eps)
        std_loss = torch.mean(F.relu(1.0 - std_z1)) + torch.mean(F.relu(1.0 - std_z2))

        B = z1.size(0)
        cov_z1 = (z1.T @ z1) / (B - 1)
        cov_z2 = (z2.T @ z2) / (B - 1)
        cov_loss = (off_diagonal(cov_z1).pow(2).sum() / z1.size(1)) + \
                   (off_diagonal(cov_z2).pow(2).sum() / z2.size(1))

        loss = self.sim_coeff * repr_loss + self.std_coeff * std_loss + self.cov_coeff * cov_loss

        stats = {
            "repr": repr_loss.item() if isinstance(repr_loss, torch.Tensor) else float(repr_loss),
            "std": std_loss.item() if isinstance(std_loss, torch.Tensor) else float(std_loss),
            "cov": cov_loss.item() if isinstance(cov_loss, torch.Tensor) else float(cov_loss),
            "total": loss.item() if isinstance(loss, torch.Tensor) else float(loss)
        }
        return loss, stats

class Projector(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=2048, out_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x):
        return self.net(x)

class Predictor(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=512, out_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x):
        return self.net(x)

In [5]:
def get_encoder(feature_dim=512, pretrained=False):
    m = resnet18(pretrained=pretrained)
    m.fc = nn.Identity() 
    return m


def get_transforms(size=32):
    aug = T.Compose([
        T.RandomResizedCrop(size=size, scale=(0.2, 1.0)),
        T.RandomHorizontalFlip(),
        T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        T.RandomGrayscale(p=0.2),
        T.GaussianBlur(kernel_size=3),
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])

    return aug, aug

In [6]:
def train_snd_vic(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}")

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    if device == 'cuda':
        torch.cuda.manual_seed_all(args.seed)

    t1, t2 = get_transforms(size=args.size)
    base_trainset = torchvision.datasets.CIFAR10(root=args.data_root, train=True, download=True, transform=T.ToTensor())

    if args.use_consecutive:

        base_raw = torchvision.datasets.CIFAR10(root=args.data_root, train=True, download=True, transform=None)
        pair_dataset = ConsecutivePairDataset(base_raw)

        class AugmentedPairDataset(Dataset):
            def __init__(self, pair_ds, t1, t2):
                self.pair = pair_ds
                self.t1 = t1
                self.t2 = t2
            def __len__(self):
                return len(self.pair)
            def __getitem__(self, idx):
                a, b = self.pair[idx]

                return (self.t1(a), self.t2(b))
        dataset = AugmentedPairDataset(pair_dataset, t1, t2)
    else:

        aug_dataset = torchvision.datasets.CIFAR10(root=args.data_root, train=True, download=True,
                                                   transform=TwoCropsTransform(t1, t2))
        dataset = aug_dataset

    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=not args.use_consecutive,
                        num_workers=0, drop_last=True)

    encoder = get_encoder(pretrained=args.pretrained).to(device)
    projector = Projector(in_dim=512, hidden_dim=args.hidden_dim, out_dim=args.out_dim).to(device)
    predictor = Predictor(in_dim=args.out_dim, hidden_dim=args.pred_hidden, out_dim=args.out_dim).to(device)


    vicreg_loss = VICRegLoss(sim_coeff=args.sim_coeff, std_coeff=args.std_coeff, cov_coeff=args.cov_coeff)


    target_params = list(encoder.parameters()) + list(projector.parameters())
    optimizer_target = torch.optim.Adam(target_params, lr=args.lr, weight_decay=args.wd)
    optimizer_pred = torch.optim.Adam(predictor.parameters(), lr=args.pred_lr, weight_decay=args.pred_wd)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_target, T_max=args.epochs)

    os.makedirs(args.save_dir, exist_ok=True)

    for epoch in range(args.epochs):
        encoder.train(); projector.train(); predictor.train()
        running_vic = 0.0
        running_distil = 0.0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{args.epochs}", leave=False)
        for batch_idx, batch in enumerate(pbar):

            if args.use_consecutive:
                x1, x2 = batch
            else:
                (x1, x2), _ = batch

            x1 = x1.to(device)
            x2 = x2.to(device)

            h1 = encoder(x1)
            h2 = encoder(x2)
            z1 = projector(h1)
            z2 = projector(h2)

            with torch.no_grad():

                z1_det = z1.detach()
                z2_det = z2.detach()

            p1 = predictor(z1_det)
            p2 = predictor(z2_det)

            distil_loss = F.mse_loss(p1, z1_det) + F.mse_loss(p2, z2_det)
            optimizer_pred.zero_grad()
            distil_loss.backward()
            optimizer_pred.step()

            loss_vic, stats_vic = vicreg_loss(z1, z2)
            optimizer_target.zero_grad()
            loss_vic.backward()
            optimizer_target.step()

            running_vic += stats_vic["total"]
            running_distil += distil_loss.item()

            pbar.set_postfix({
                "vic": f"{running_vic/(batch_idx+1):.4f}",
                "distil": f"{running_distil/(batch_idx+1):.4f}"
            })

        scheduler.step()

        if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1:
            ckpt = {
                'encoder': encoder.state_dict(),
                'projector': projector.state_dict(),
                'predictor': predictor.state_dict(),
                'epoch': epoch
            }
            torch.save(ckpt, os.path.join(args.save_dir, f"snd_vic_epoch{epoch+1}.pth"))
            print(f"Saved checkpoint epoch {epoch+1}")

        print(f"Epoch {epoch+1:03d} | VIC {running_vic/len(loader):.4f} | DIST {running_distil/len(loader):.4f}")

    print("Training finished.")

def parse_args(args=None):
    parser = argparse.ArgumentParser(description="SND-VIC (CIFAR10 example)")
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--wd', type=float, default=1e-6)
    parser.add_argument('--hidden-dim', type=int, default=2048)
    parser.add_argument('--out-dim', type=int, default=512)
    parser.add_argument('--pred-hidden', type=int, default=512)
    parser.add_argument('--pred-lr', type=float, default=1e-3)
    parser.add_argument('--pred-wd', type=float, default=1e-6)
    parser.add_argument('--sim-coeff', type=float, default=25.0)
    parser.add_argument('--std-coeff', type=float, default=25.0)
    parser.add_argument('--cov-coeff', type=float, default=1.0)
    parser.add_argument('--save-every', type=int, default=50)
    parser.add_argument('--save-dir', type=str, default='./checkpoints')
    parser.add_argument('--batch-norm-sync', action='store_true', help='(unused) placeholder')
    parser.add_argument('--num-workers', type=int, default=4)
    parser.add_argument('--data-root', type=str, default='./data')
    parser.add_argument('--size', type=int, default=32)
    parser.add_argument('--pretrained', action='store_true')
    parser.add_argument('--use-consecutive', action='store_true',
                        help='Use consecutive dataset pairs (s_t, s_{t+1}) if available; otherwise use two augmentations')
    parser.add_argument('--seed', type=int, default=42)
    return parser.parse_args(args)

In [7]:
args = parse_args([
    '--epochs', '100',
    '--batch-size', '128',
    '--lr', '1e-3',
    '--wd', '1e-6',
    '--hidden-dim', '2048',
    '--out-dim', '512',
    '--pred-hidden', '512',
    '--pred-lr', '1e-3',
    '--sim-coeff', '25.0',
    '--std-coeff', '25.0',
    '--cov-coeff', '1.0',
    '--save-every', '50'
])

train_snd_vic(args)


Device: cuda
Files already downloaded and verified
Files already downloaded and verified


                                                                                          

Epoch 001 | VIC 35.0601 | DIST 0.1425


                                                                                          

Epoch 002 | VIC 32.7457 | DIST 0.0634


                                                                                          

Epoch 003 | VIC 31.8982 | DIST 0.0540


                                                                                          

Epoch 004 | VIC 31.4026 | DIST 0.0455


                                                                                          

Epoch 005 | VIC 30.9576 | DIST 0.0382


                                                                                          

Epoch 006 | VIC 30.5554 | DIST 0.0311


                                                                                          

Epoch 007 | VIC 30.2546 | DIST 0.0255


                                                                                          

Epoch 008 | VIC 29.8735 | DIST 0.0204


                                                                                          

Epoch 009 | VIC 29.5996 | DIST 0.0164


                                                                                           

Epoch 010 | VIC 29.3588 | DIST 0.0138


                                                                                           

Epoch 011 | VIC 29.0790 | DIST 0.0120


                                                                                           

Epoch 012 | VIC 28.8282 | DIST 0.0105


                                                                                           

Epoch 013 | VIC 28.7243 | DIST 0.0093


                                                                                           

Epoch 014 | VIC 28.4232 | DIST 0.0086


                                                                                           

Epoch 015 | VIC 28.2403 | DIST 0.0079


                                                                                           

Epoch 016 | VIC 28.0461 | DIST 0.0075


                                                                                           

Epoch 017 | VIC 27.8727 | DIST 0.0070


                                                                                           

Epoch 018 | VIC 27.7345 | DIST 0.0065


                                                                                           

Epoch 019 | VIC 27.6046 | DIST 0.0063


                                                                                           

Epoch 020 | VIC 27.4261 | DIST 0.0061


                                                                                           

Epoch 021 | VIC 27.2905 | DIST 0.0058


                                                                                           

Epoch 022 | VIC 27.1904 | DIST 0.0057


                                                                                           

Epoch 023 | VIC 26.9754 | DIST 0.0056


                                                                                           

Epoch 024 | VIC 26.8993 | DIST 0.0056


                                                                                           

Epoch 025 | VIC 26.7607 | DIST 0.0054


                                                                                           

Epoch 026 | VIC 26.6955 | DIST 0.0053


                                                                                           

Epoch 027 | VIC 26.5978 | DIST 0.0052


                                                                                           

Epoch 028 | VIC 26.4050 | DIST 0.0053


                                                                                           

Epoch 029 | VIC 26.3074 | DIST 0.0052


                                                                                           

Epoch 030 | VIC 26.1553 | DIST 0.0052


                                                                                           

Epoch 031 | VIC 26.1058 | DIST 0.0052


                                                                                           

Epoch 032 | VIC 25.9867 | DIST 0.0051


                                                                                           

Epoch 033 | VIC 25.9227 | DIST 0.0051


                                                                                           

Epoch 034 | VIC 25.8498 | DIST 0.0051


                                                                                           

Epoch 035 | VIC 25.7634 | DIST 0.0051


                                                                                           

Epoch 036 | VIC 25.6082 | DIST 0.0051


                                                                                           

Epoch 037 | VIC 25.5633 | DIST 0.0051


                                                                                           

Epoch 038 | VIC 25.4124 | DIST 0.0052


                                                                                           

Epoch 039 | VIC 25.4168 | DIST 0.0051


                                                                                           

Epoch 040 | VIC 25.2818 | DIST 0.0051


                                                                                           

Epoch 041 | VIC 25.2505 | DIST 0.0051


                                                                                           

Epoch 042 | VIC 25.1225 | DIST 0.0051


                                                                                           

Epoch 043 | VIC 25.1312 | DIST 0.0051


                                                                                           

Epoch 044 | VIC 25.0270 | DIST 0.0052


                                                                                           

Epoch 045 | VIC 24.9056 | DIST 0.0052


                                                                                           

Epoch 046 | VIC 24.8283 | DIST 0.0052


                                                                                           

Epoch 047 | VIC 24.8317 | DIST 0.0052


                                                                                           

Epoch 048 | VIC 24.7115 | DIST 0.0052


                                                                                           

Epoch 049 | VIC 24.6363 | DIST 0.0053


                                                                                           

Saved checkpoint epoch 50
Epoch 050 | VIC 24.5474 | DIST 0.0052


                                                                                           

Epoch 051 | VIC 24.5912 | DIST 0.0053


                                                                                           

Epoch 052 | VIC 24.4691 | DIST 0.0052


                                                                                           

Epoch 053 | VIC 24.4024 | DIST 0.0053


                                                                                           

Epoch 054 | VIC 24.3063 | DIST 0.0053


                                                                                           

Epoch 055 | VIC 24.2959 | DIST 0.0053


                                                                                           

Epoch 056 | VIC 24.2312 | DIST 0.0053


                                                                                           

Epoch 057 | VIC 24.1276 | DIST 0.0054


                                                                                           

Epoch 058 | VIC 24.0846 | DIST 0.0053


                                                                                           

Epoch 059 | VIC 24.0363 | DIST 0.0055


                                                                                           

Epoch 060 | VIC 23.9919 | DIST 0.0054


                                                                                           

Epoch 061 | VIC 23.9472 | DIST 0.0055


                                                                                           

Epoch 062 | VIC 23.8840 | DIST 0.0055


                                                                                           

Epoch 063 | VIC 23.8370 | DIST 0.0056


                                                                                           

Epoch 064 | VIC 23.8194 | DIST 0.0054


                                                                                           

Epoch 065 | VIC 23.7389 | DIST 0.0055


                                                                                           

Epoch 066 | VIC 23.6632 | DIST 0.0055


                                                                                           

Epoch 067 | VIC 23.6680 | DIST 0.0055


                                                                                           

Epoch 068 | VIC 23.6194 | DIST 0.0055


                                                                                           

Epoch 069 | VIC 23.5449 | DIST 0.0056


                                                                                           

Epoch 070 | VIC 23.5481 | DIST 0.0055


                                                                                           

Epoch 071 | VIC 23.4380 | DIST 0.0055


                                                                                           

Epoch 072 | VIC 23.4668 | DIST 0.0055


                                                                                           

Epoch 073 | VIC 23.4312 | DIST 0.0055


                                                                                           

Epoch 074 | VIC 23.3347 | DIST 0.0055


                                                                                           

Epoch 075 | VIC 23.2773 | DIST 0.0056


                                                                                           

Epoch 076 | VIC 23.3397 | DIST 0.0055


                                                                                           

Epoch 077 | VIC 23.2797 | DIST 0.0056


                                                                                           

Epoch 078 | VIC 23.2521 | DIST 0.0055


                                                                                           

Epoch 079 | VIC 23.1676 | DIST 0.0055


                                                                                           

Epoch 080 | VIC 23.2091 | DIST 0.0055


                                                                                           

Epoch 081 | VIC 23.1318 | DIST 0.0055


                                                                                           

Epoch 082 | VIC 23.0979 | DIST 0.0055


                                                                                           

Epoch 083 | VIC 23.1163 | DIST 0.0056


                                                                                           

Epoch 084 | VIC 23.0041 | DIST 0.0055


                                                                                           

Epoch 085 | VIC 22.9893 | DIST 0.0055


                                                                                           

Epoch 086 | VIC 23.0323 | DIST 0.0056


                                                                                           

Epoch 087 | VIC 22.9818 | DIST 0.0055


                                                                                           

Epoch 088 | VIC 23.0082 | DIST 0.0054


                                                                                           

Epoch 089 | VIC 22.9596 | DIST 0.0055


                                                                                           

Epoch 090 | VIC 22.9880 | DIST 0.0055


                                                                                           

Epoch 091 | VIC 22.9347 | DIST 0.0055


                                                                                           

Epoch 092 | VIC 22.9845 | DIST 0.0055


                                                                                           

Epoch 093 | VIC 22.9507 | DIST 0.0054


                                                                                           

Epoch 094 | VIC 22.9703 | DIST 0.0054


                                                                                           

Epoch 095 | VIC 22.9837 | DIST 0.0054


                                                                                           

Epoch 096 | VIC 22.9531 | DIST 0.0054


                                                                                           

Epoch 097 | VIC 22.9243 | DIST 0.0053


                                                                                           

Epoch 098 | VIC 22.8852 | DIST 0.0054


                                                                                           

Epoch 099 | VIC 22.8738 | DIST 0.0053


                                                                                            

Saved checkpoint epoch 100
Epoch 100 | VIC 22.9141 | DIST 0.0053
Training finished.


In [3]:
import torch
from torchvision.models import resnet18
import torch.nn as nn

def get_encoder(feature_dim=512, pretrained=False):
    m = resnet18(pretrained=pretrained)
    m.fc = nn.Identity()
    return m

ckpt = torch.load("./checkpoints/snd_vic_epoch100.pth", map_location="cpu")

encoder = get_encoder()
encoder.load_state_dict(ckpt['encoder'])
encoder.eval()
print("Encoder načítaný a pripravený na testovanie.")


  ckpt = torch.load("./checkpoints/snd_vic_epoch100.pth", map_location="cpu")


Encoder načítaný a pripravený na testovanie.


In [7]:
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=4)

class LinearClassifier(nn.Module):
    def __init__(self, encoder, num_classes=10):
        super().__init__()
        self.encoder = encoder
        for p in self.encoder.parameters():
            p.requires_grad = False
        self.fc = nn.Linear(512, num_classes)
    def forward(self, x):
        with torch.no_grad():
            feats = self.encoder(x)
        return self.fc(feats)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = LinearClassifier(encoder).to(device)
optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
for epoch in range(10):
    model.train()
    total_loss = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        loss = F.cross_entropy(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} | Loss {total_loss/len(train_loader):.4f}")


Epoch 1 | Loss 0.9186
Epoch 2 | Loss 0.8219
Epoch 3 | Loss 0.8048
Epoch 4 | Loss 0.7936
Epoch 5 | Loss 0.7873
Epoch 6 | Loss 0.7807
Epoch 7 | Loss 0.7744
Epoch 8 | Loss 0.7724
Epoch 9 | Loss 0.7696
Epoch 10 | Loss 0.7667


In [4]:
model.eval()
correct = total = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = model(imgs).argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
print(f"Test accuracy: {100 * correct / total:.2f}%")

Test accuracy: 72.05%


In [8]:
train_features, train_labels = [], []
with torch.no_grad():
    for imgs, labels in train_loader:
        feats = encoder(imgs.to(device))
        train_features.append(feats.cpu())
        train_labels.append(labels)
train_features = torch.cat(train_features)
train_labels = torch.cat(train_labels)


k = 10
correct = total = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        feats = encoder(imgs.to(device)).cpu()
        sims = feats @ train_features.T
        topk = sims.topk(k, dim=1).indices
        preds = torch.mode(train_labels[topk], dim=1).values
        correct += (preds == labels).sum().item()
        total += labels.size(0)
print(f"k-NN accuracy (k={k}): {100 * correct / total:.2f}%")

k-NN accuracy (k=10): 65.32%
