In [26]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet18
from tqdm import tqdm
import os
import random
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from collections import OrderedDict

In [3]:
class TwoCropsTransform:
    def __init__(self, t1, t2):
        self.t1 = t1
        self.t2 = t2
    def __call__(self, x):
        return self.t1(x), self.t2(x)

class ConsecutivePairDataset(Dataset):
    def __init__(self, base_dataset):
        self.base = base_dataset
    def __len__(self):
        return len(self.base) - 1
    def __getitem__(self, idx):
        a, _ = self.base[idx]
        b, _ = self.base[idx + 1]
        return a, b

def off_diagonal(x: torch.Tensor) -> torch.Tensor:
    n = x.shape[0]
    assert x.dim() == 2 and x.shape[0] == x.shape[1], "off_diagonal requires square matrix"
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [4]:

class VICRegLoss(nn.Module):
    def __init__(self, sim_coeff=25.0, std_coeff=25.0, cov_coeff=1.0, eps=1e-4):
        super().__init__()
        self.sim_coeff = sim_coeff
        self.std_coeff = std_coeff
        self.cov_coeff = cov_coeff
        self.eps = eps

    def forward(self, z1, z2):
        repr_loss = F.mse_loss(z1, z2)
        z1 = z1 - z1.mean(dim=0)
        z2 = z2 - z2.mean(dim=0)
        std_z1 = torch.sqrt(z1.var(dim=0, unbiased=False) + self.eps)
        std_z2 = torch.sqrt(z2.var(dim=0, unbiased=False) + self.eps)
        std_loss = torch.mean(F.relu(1.0 - std_z1)) + torch.mean(F.relu(1.0 - std_z2))
        B = z1.size(0)
        cov_z1 = (z1.T @ z1) / (B - 1)
        cov_z2 = (z2.T @ z2) / (B - 1)
        cov_loss = (off_diagonal(cov_z1).pow(2).sum() / z1.size(1)) + \
                   (off_diagonal(cov_z2).pow(2).sum() / z2.size(1))
        loss = self.sim_coeff * repr_loss + self.std_coeff * std_loss + self.cov_coeff * cov_loss

        stats = {
            "repr": repr_loss.item(),
            "std": std_loss.item(),
            "cov": cov_loss.item(),
            "total": loss.item()
        }
        return loss, stats

In [5]:
class Projector(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=2048, out_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x):
        return self.net(x)

class Predictor(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=512, out_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x):
        return self.net(x)

In [6]:
def get_encoder(feature_dim=512, pretrained=False):
    m = resnet18(pretrained=pretrained)
    m.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # changed from 3→1 channel
    m.fc = nn.Identity()
    return m

def get_transforms(size=28):
    aug = T.Compose([
        T.RandomResizedCrop(size=size, scale=(0.8, 1.0)),
        T.RandomHorizontalFlip(),
        T.RandomApply([T.ColorJitter(0.2, 0.2, 0.2, 0.1)], p=0.5),
        T.RandomGrayscale(p=0.1),
        T.ToTensor(),
        T.Normalize((0.5,), (0.5,))
    ])
    return aug, aug

In [7]:
def train_snd_vic(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}")

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    if device == 'cuda':
        torch.cuda.manual_seed_all(args.seed)

    t1, t2 = get_transforms(size=args.size)

    base_trainset = torchvision.datasets.FashionMNIST(root=args.data_root, train=True, download=True, transform=T.ToTensor())

    if args.use_consecutive:
        base_raw = torchvision.datasets.FashionMNIST(root=args.data_root, train=True, download=True, transform=None)
        pair_dataset = ConsecutivePairDataset(base_raw)

        class AugmentedPairDataset(Dataset):
            def __init__(self, pair_ds, t1, t2):
                self.pair = pair_ds
                self.t1 = t1
                self.t2 = t2
            def __len__(self):
                return len(self.pair)
            def __getitem__(self, idx):
                a, b = self.pair[idx]
                return (self.t1(a), self.t2(b))
        dataset = AugmentedPairDataset(pair_dataset, t1, t2)
    else:
        aug_dataset = torchvision.datasets.FashionMNIST(
            root=args.data_root, train=True, download=True,
            transform=TwoCropsTransform(t1, t2)
        )
        dataset = aug_dataset

    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=not args.use_consecutive,
                        num_workers=args.num_workers, drop_last=True)

    encoder = get_encoder(pretrained=args.pretrained).to(device)
    projector = Projector(in_dim=512, hidden_dim=args.hidden_dim, out_dim=args.out_dim).to(device)
    predictor = Predictor(in_dim=args.out_dim, hidden_dim=args.pred_hidden, out_dim=args.out_dim).to(device)

    vicreg_loss = VICRegLoss(sim_coeff=args.sim_coeff, std_coeff=args.std_coeff, cov_coeff=args.cov_coeff)

    target_params = list(encoder.parameters()) + list(projector.parameters())
    optimizer_target = torch.optim.Adam(target_params, lr=args.lr, weight_decay=args.wd)
    optimizer_pred = torch.optim.Adam(predictor.parameters(), lr=args.pred_lr, weight_decay=args.pred_wd)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_target, T_max=args.epochs)
    os.makedirs(args.save_dir, exist_ok=True)

    for epoch in range(args.epochs):
        encoder.train(); projector.train(); predictor.train()
        running_vic, running_distil = 0.0, 0.0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{args.epochs}", leave=False)
        for batch_idx, batch in enumerate(pbar):
            if args.use_consecutive:
                x1, x2 = batch
            else:
                (x1, x2), _ = batch
            x1, x2 = x1.to(device), x2.to(device)
            h1, h2 = encoder(x1), encoder(x2)
            z1, z2 = projector(h1), projector(h2)

            with torch.no_grad():
                z1_det, z2_det = z1.detach(), z2.detach()
            p1, p2 = predictor(z1_det), predictor(z2_det)

            distil_loss = F.mse_loss(p1, z1_det) + F.mse_loss(p2, z2_det)
            optimizer_pred.zero_grad()
            distil_loss.backward()
            optimizer_pred.step()

            loss_vic, stats_vic = vicreg_loss(z1, z2)
            optimizer_target.zero_grad()
            loss_vic.backward()
            optimizer_target.step()

            running_vic += stats_vic["total"]
            running_distil += distil_loss.item()

            pbar.set_postfix({"vic": f"{running_vic/(batch_idx+1):.4f}", "distil": f"{running_distil/(batch_idx+1):.4f}"})

        scheduler.step()
        if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1:
            ckpt = {
                'encoder': encoder.state_dict(),
                'projector': projector.state_dict(),
                'predictor': predictor.state_dict(),
                'epoch': epoch
            }
            torch.save(ckpt, os.path.join(args.save_dir, f"snd_vic_epoch{epoch+1}.pth"))
            print(f"Saved checkpoint epoch {epoch+1}")

        print(f"Epoch {epoch+1:03d} | VIC {running_vic/len(loader):.4f} | DIST {running_distil/len(loader):.4f}")

    print("Training finished on FashionMNIST.")

def parse_args(args=None):
    parser = argparse.ArgumentParser(description="SND-VIC (FashionMNIST example)")
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--wd', type=float, default=1e-6)
    parser.add_argument('--hidden-dim', type=int, default=2048)
    parser.add_argument('--out-dim', type=int, default=512)
    parser.add_argument('--pred-hidden', type=int, default=512)
    parser.add_argument('--pred-lr', type=float, default=1e-3)
    parser.add_argument('--pred-wd', type=float, default=1e-6)
    parser.add_argument('--sim-coeff', type=float, default=25.0)
    parser.add_argument('--std-coeff', type=float, default=25.0)
    parser.add_argument('--cov-coeff', type=float, default=1.0)
    parser.add_argument('--save-every', type=int, default=50)
    parser.add_argument('--save-dir', type=str, default='./checkpoints_fmnist')
    parser.add_argument('--num-workers', type=int, default=4)
    parser.add_argument('--data-root', type=str, default='./data')
    parser.add_argument('--size', type=int, default=28)
    parser.add_argument('--pretrained', action='store_true')
    parser.add_argument('--use-consecutive', action='store_true')
    parser.add_argument('--seed', type=int, default=42)
    return parser.parse_args(args)

In [8]:
args = parse_args([
    '--epochs', '100',
    '--batch-size', '128',
    '--lr', '1e-3',
    '--wd', '1e-6',
    '--hidden-dim', '2048',
    '--out-dim', '512',
    '--pred-hidden', '512',
    '--pred-lr', '1e-3',
    '--sim-coeff', '25.0',
    '--std-coeff', '25.0',
    '--cov-coeff', '1.0',
    '--save-every', '50',
    '--num-workers', '0'
])

train_snd_vic(args)

Device: cuda


                                                                                          

Epoch 001 | VIC 25.9375 | DIST 0.1005


                                                                                          

Epoch 002 | VIC 18.9964 | DIST 0.0673


                                                                                          

Epoch 003 | VIC 16.5539 | DIST 0.0841


                                                                                          

Epoch 004 | VIC 15.1649 | DIST 0.0880


                                                                                          

Epoch 005 | VIC 14.3200 | DIST 0.0759


                                                                                          

Epoch 006 | VIC 13.7456 | DIST 0.0616


                                                                                          

Epoch 007 | VIC 13.2666 | DIST 0.0495


                                                                                          

Epoch 008 | VIC 12.8585 | DIST 0.0403


                                                                                          

Epoch 009 | VIC 12.5742 | DIST 0.0341


                                                                                           

Epoch 010 | VIC 12.2926 | DIST 0.0294


                                                                                           

Epoch 011 | VIC 12.0719 | DIST 0.0261


                                                                                           

Epoch 012 | VIC 11.9061 | DIST 0.0237


                                                                                           

Epoch 013 | VIC 11.7341 | DIST 0.0218


                                                                                           

Epoch 014 | VIC 11.5947 | DIST 0.0208


                                                                                           

Epoch 015 | VIC 11.4274 | DIST 0.0196


                                                                                           

Epoch 016 | VIC 11.2933 | DIST 0.0188


                                                                                           

Epoch 017 | VIC 11.1912 | DIST 0.0186


                                                                                           

Epoch 018 | VIC 11.1220 | DIST 0.0182


                                                                                           

Epoch 019 | VIC 10.9650 | DIST 0.0174


                                                                                           

Epoch 020 | VIC 10.8828 | DIST 0.0173


                                                                                           

Epoch 021 | VIC 10.7918 | DIST 0.0167


                                                                                           

Epoch 022 | VIC 10.7038 | DIST 0.0166


                                                                                           

Epoch 023 | VIC 10.6728 | DIST 0.0162


                                                                                           

Epoch 024 | VIC 10.5606 | DIST 0.0162


                                                                                           

Epoch 025 | VIC 10.4849 | DIST 0.0156


                                                                                           

Epoch 026 | VIC 10.4201 | DIST 0.0157


                                                                                           

Epoch 027 | VIC 10.3597 | DIST 0.0155


                                                                                           

Epoch 028 | VIC 10.3059 | DIST 0.0153


                                                                                           

Epoch 029 | VIC 10.2056 | DIST 0.0150


                                                                                           

Epoch 030 | VIC 10.1614 | DIST 0.0147


                                                                                           

Epoch 031 | VIC 10.1439 | DIST 0.0149


                                                                                           

Epoch 032 | VIC 10.0832 | DIST 0.0147


                                                                                           

Epoch 033 | VIC 10.0468 | DIST 0.0144


                                                                                           

Epoch 034 | VIC 10.0014 | DIST 0.0144


                                                                                          

Epoch 035 | VIC 9.9567 | DIST 0.0142


                                                                                          

Epoch 036 | VIC 9.8889 | DIST 0.0139


                                                                                          

Epoch 037 | VIC 9.8371 | DIST 0.0140


                                                                                          

Epoch 038 | VIC 9.8150 | DIST 0.0139


                                                                                          

Epoch 039 | VIC 9.7516 | DIST 0.0137


                                                                                          

Epoch 040 | VIC 9.7474 | DIST 0.0137


                                                                                          

Epoch 041 | VIC 9.7002 | DIST 0.0137


                                                                                          

Epoch 042 | VIC 9.6255 | DIST 0.0134


                                                                                          

Epoch 043 | VIC 9.6153 | DIST 0.0134


                                                                                          

Epoch 044 | VIC 9.5706 | DIST 0.0135


                                                                                          

Epoch 045 | VIC 9.5583 | DIST 0.0130


                                                                                          

Epoch 046 | VIC 9.4953 | DIST 0.0131


                                                                                          

Epoch 047 | VIC 9.4985 | DIST 0.0131


                                                                                          

Epoch 048 | VIC 9.4525 | DIST 0.0132


                                                                                          

Epoch 049 | VIC 9.4225 | DIST 0.0130


                                                                                          

Saved checkpoint epoch 50
Epoch 050 | VIC 9.3845 | DIST 0.0126


                                                                                          

Epoch 051 | VIC 9.3742 | DIST 0.0125


                                                                                          

Epoch 052 | VIC 9.3412 | DIST 0.0128


                                                                                          

Epoch 053 | VIC 9.3041 | DIST 0.0124


                                                                                          

Epoch 054 | VIC 9.2456 | DIST 0.0125


                                                                                          

Epoch 055 | VIC 9.2556 | DIST 0.0123


                                                                                          

Epoch 056 | VIC 9.2396 | DIST 0.0123


                                                                                          

Epoch 057 | VIC 9.1926 | DIST 0.0123


                                                                                          

Epoch 058 | VIC 9.1806 | DIST 0.0122


                                                                                          

Epoch 059 | VIC 9.1648 | DIST 0.0120


                                                                                          

Epoch 060 | VIC 9.1126 | DIST 0.0120


                                                                                          

Epoch 061 | VIC 9.0889 | DIST 0.0119


                                                                                          

Epoch 062 | VIC 9.0475 | DIST 0.0118


                                                                                          

Epoch 063 | VIC 9.0500 | DIST 0.0119


                                                                                          

Epoch 064 | VIC 9.0503 | DIST 0.0118


                                                                                          

Epoch 065 | VIC 8.9901 | DIST 0.0117


                                                                                          

Epoch 066 | VIC 8.9729 | DIST 0.0114


                                                                                          

Epoch 067 | VIC 8.9876 | DIST 0.0116


                                                                                          

Epoch 068 | VIC 8.9547 | DIST 0.0115


                                                                                          

Epoch 069 | VIC 8.9028 | DIST 0.0115


                                                                                          

Epoch 070 | VIC 8.8971 | DIST 0.0115


                                                                                          

Epoch 071 | VIC 8.9073 | DIST 0.0113


                                                                                          

Epoch 072 | VIC 8.8643 | DIST 0.0111


                                                                                          

Epoch 073 | VIC 8.8704 | DIST 0.0113


                                                                                          

Epoch 074 | VIC 8.8647 | DIST 0.0113


                                                                                          

Epoch 075 | VIC 8.8359 | DIST 0.0111


                                                                                          

Epoch 076 | VIC 8.8112 | DIST 0.0110


                                                                                          

Epoch 077 | VIC 8.7843 | DIST 0.0112


                                                                                          

Epoch 078 | VIC 8.7947 | DIST 0.0110


                                                                                          

Epoch 079 | VIC 8.7989 | DIST 0.0111


                                                                                          

Epoch 080 | VIC 8.7866 | DIST 0.0109


                                                                                          

Epoch 081 | VIC 8.7539 | DIST 0.0109


                                                                                          

Epoch 082 | VIC 8.7559 | DIST 0.0109


                                                                                          

Epoch 083 | VIC 8.7447 | DIST 0.0108


                                                                                          

Epoch 084 | VIC 8.7330 | DIST 0.0108


                                                                                          

Epoch 085 | VIC 8.7245 | DIST 0.0106


                                                                                          

Epoch 086 | VIC 8.7130 | DIST 0.0106


                                                                                          

Epoch 087 | VIC 8.7040 | DIST 0.0106


                                                                                          

Epoch 088 | VIC 8.7089 | DIST 0.0106


                                                                                          

Epoch 089 | VIC 8.7171 | DIST 0.0106


                                                                                          

Epoch 090 | VIC 8.7060 | DIST 0.0105


                                                                                          

Epoch 091 | VIC 8.6980 | DIST 0.0105


                                                                                          

Epoch 092 | VIC 8.6802 | DIST 0.0105


                                                                                          

Epoch 093 | VIC 8.6895 | DIST 0.0104


                                                                                          

Epoch 094 | VIC 8.6600 | DIST 0.0103


                                                                                          

Epoch 095 | VIC 8.6771 | DIST 0.0104


                                                                                          

Epoch 096 | VIC 8.6634 | DIST 0.0103


                                                                                          

Epoch 097 | VIC 8.6661 | DIST 0.0102


                                                                                          

Epoch 098 | VIC 8.6602 | DIST 0.0102


                                                                                          

Epoch 099 | VIC 8.6652 | DIST 0.0103


                                                                                           

Saved checkpoint epoch 100
Epoch 100 | VIC 8.6568 | DIST 0.0103
Training finished on FashionMNIST.




In [28]:
def get_encoder(feature_dim=512, pretrained=False):
    m = resnet18(weights=None if not pretrained else "IMAGENET1K_V1")

    m.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    m.fc = nn.Identity()
    return m

In [29]:
ckpt_path = "./checkpoints_fmnist/snd_vic_epoch100.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"
k = 5 

encoder = get_encoder(pretrained=False).to(device)
ckpt = torch.load(ckpt_path, map_location=device)
encoder.load_state_dict(ckpt['encoder'])
encoder.eval()
print(f"Načítaný checkpoint z epochy {ckpt['epoch']+1}")

transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.FashionMNIST(root=args.data_root, train=True, download=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root=args.data_root, train=False, download=True, transform=transform)

train_loader = DataLoader(trainset, batch_size=256, shuffle=False, num_workers=2)
test_loader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=2)

  ckpt = torch.load(ckpt_path, map_location=device)


Načítaný checkpoint z epochy 100


In [30]:
def extract_features(loader, model):
    features, labels = [], []
    with torch.no_grad():
        for x, y in tqdm(loader, desc="Extracting features"):
            x = x.to(device)
            h = model(x)
            h = h.view(h.size(0), -1)
            features.append(h.cpu())
            labels.append(y)
    return torch.cat(features).numpy(), torch.cat(labels).numpy()

train_feats, train_labels = extract_features(train_loader, encoder)
test_feats, test_labels = extract_features(test_loader, encoder)

print(f"Train feats: {train_feats.shape}, Test feats: {test_feats.shape}")

knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)
knn.fit(train_feats, train_labels)

preds = knn.predict(test_feats)
acc = accuracy_score(test_labels, preds)
print(f"\nk-NN presnosť (k={k}): {acc*100:.2f}%")

Extracting features: 100%|██████████| 235/235 [00:20<00:00, 11.25it/s]
Extracting features: 100%|██████████| 40/40 [00:09<00:00,  4.20it/s]


Train feats: (60000, 512), Test feats: (10000, 512)

k-NN presnosť (k=5): 81.45%
