In [5]:
# from skimage import io
from torchvision.datasets import LFWPairs

lfw_dataset = LFWPairs(
    root='/scratch/sj4020/lfw/',
    transform=None,
    target_transform=None,
    download=True,
    split="10fold"
)


Files already downloaded and verified


In [6]:
print(f"Number of images: {len(lfw_dataset)}")

Number of images: 6000


In [7]:
image1, image2,target = lfw_dataset[0]
print(f"Image size: {image1.size}")
print(f"Target name: {target}")

Image size: (250, 250)
Target name: 1


In [8]:
import numpy as np
targets = np.asarray(lfw_dataset.targets)
num_target_1 = (targets == 1).sum()
print("Number of target 1 in LFW Pairs dataset: ", num_target_1)

Number of target 1 in LFW Pairs dataset:  3000


In [9]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

# Define the transforms
transform = transforms.Compose([    transforms.ToTensor()])

# Load the dataset
lfw_pairs_dataset = ImageFolder('/scratch/sj4020/lfw/', transform=transform)

# Calculate the mean and standard deviation
loader = torch.utils.data.DataLoader(lfw_pairs_dataset, batch_size=1, shuffle=False)
mean = 0.
std = 0.
for images, _ in loader:
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(loader.dataset)
std /= len(loader.dataset)

print("Mean:", mean)
print("Standard Deviation:", std)

Mean: tensor([0.4332, 0.3757, 0.3340])
Standard Deviation: tensor([0.2711, 0.2446, 0.2346])


In [10]:
import torch
from torch.utils.data import Dataset, ConcatDataset
import augmentations as aug
transforms = aug.TrainTransform()
class CustomDataset1(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # get the first image from the data point and apply the desired transforms
        image1 = self.data[index][0]
        if self.data[index][2]==1:
            image2 = self.data[index][1]
            img1 =transforms.transform(image1)
            img2 =transforms.transform_prime(image2)

            # return the transformed image1 along with the target label
            return img1, img2,self.data[index][2]
        return None

class CustomDataset2(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # get the second image from the data point and apply the desired transforms
        image2 = self.data[index][1]
        if self.data[index][2]==0:
            img1 =transforms.transform(image2)
            img2 =transforms.transform_prime(image2)
        
            # return the transformed image2 along with the target label
            return img1, img2,self.data[index][2]
        return None
    
class CustomDataset3(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # get the first image from the data point and apply the desired transforms
        image1 = self.data[index][0]
        if self.data[index][2]==0:
            img1 =transforms.simple_transform(image1)
            img2 =transforms.transform_prime(image1)

            # return the transformed image1 along with the target label
            return img1, img2,self.data[index][2]
        return None

# combine the two datasets using ConcatDataset
combined_dataset = ConcatDataset([CustomDataset1(lfw_dataset), CustomDataset2(lfw_dataset),CustomDataset3(lfw_dataset)])
batch_size=256
num_workers=4

from torch.utils.data._utils.collate import default_collate
def custom_collate(batch):
    batch = [sample for sample in batch if sample is not None]
    return default_collate(batch)
# create the dataloader using the combined dataset
dataloader = torch.utils.data.DataLoader(
    combined_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    collate_fn=custom_collate
)


In [11]:
for batch_idx, (data1, data2, target) in enumerate(dataloader):
    print('Batch Index:', batch_idx)
    print('Data 1:', data1.shape)
    print('Data 2:', data2.shape)
    print('Target:', target.shape)
    
    # stop after printing a few batches
    if batch_idx == 2:
        break

Batch Index: 0
Data 1: torch.Size([126, 3, 224, 224])
Data 2: torch.Size([126, 3, 224, 224])
Target: torch.Size([126])
Batch Index: 1
Data 1: torch.Size([135, 3, 224, 224])
Data 2: torch.Size([135, 3, 224, 224])
Target: torch.Size([135])
Batch Index: 2
Data 1: torch.Size([129, 3, 224, 224])
Data 2: torch.Size([129, 3, 224, 224])
Target: torch.Size([129])


In [12]:
def exclude_bias_and_norm(p):
    return p.ndim == 1


def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


In [13]:
import argparse
from pathlib import Path
class Args:
    def __init__(self):
        self.arch = 'resnet50'
        self.mlp = '8192-8192-8192'
        self.sim_coeff=25
        self.std_coeff=25
        self.cov_coeff=1
        self.batch_size = 512
        self.log_freq_time=50
        self.base_lr=0.2
        self.exp_dir=Path("./exp")
        self.epochs=30
        self.wd=1e-6

args = Args()
# namespace = argparse.Namespace(**vars(args))


In [14]:
import resnet
from torch import nn
class VICReg(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.num_features = int(args.mlp.split("-")[-1])
        self.backbone, self.embedding = resnet.__dict__[args.arch](zero_init_residual=True)
        self.projector = Projector(args, self.embedding)

    def forward(self, x, y):
        x = self.projector(self.backbone(x))
        y = self.projector(self.backbone(y))

        repr_loss = F.mse_loss(x, y)

        # x = torch.cat(x, dim=0)
        # y = torch.cat(y, dim=0)
        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)

        std_x = torch.sqrt(x.var(dim=0) + 0.0001)
        std_y = torch.sqrt(y.var(dim=0) + 0.0001)
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2

        cov_x = (x.T @ x) / (self.args.batch_size - 1)
        cov_y = (y.T @ y) / (self.args.batch_size - 1)
        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(
            self.num_features
        ) + off_diagonal(cov_y).pow_(2).sum().div(self.num_features)

        loss = (
            self.args.sim_coeff * repr_loss
            + self.args.std_coeff * std_loss
            + self.args.cov_coeff * cov_loss
        )
        return loss



def Projector(args, embedding):
    mlp_spec = f"{embedding}-{args.mlp}"
    layers = []
    f = list(map(int, mlp_spec.split("-")))
    for i in range(len(f) - 2):
        layers.append(nn.Linear(f[i], f[i + 1]))
        layers.append(nn.BatchNorm1d(f[i + 1]))
        layers.append(nn.ReLU(True))
    layers.append(nn.Linear(f[-2], f[-1], bias=False))
    return nn.Sequential(*layers)


In [15]:
model = VICReg(args).cuda()

In [16]:
# ckpt=torch.load("exp/model.pth")

In [17]:
ckpt=torch.load("exp/resnet50_fullckpt.pth")
state_dict = {k.replace('module.', ''): v for k, v in ckpt["model"].items()}
model.load_state_dict(state_dict)

<All keys matched successfully>

In [18]:
# model.load_state_dict(ckpt["model"])

In [19]:
def adjust_learning_rate(args, optimizer, loader, step):
    max_steps = args.epochs * len(loader)
    warmup_steps = 10 * len(loader)
    base_lr = args.base_lr * args.batch_size / 256
    if step < warmup_steps:
        lr = base_lr * step / warmup_steps
    else:
        step -= warmup_steps
        max_steps -= warmup_steps
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
        end_lr = base_lr * 0.001
        lr = base_lr * q + end_lr * (1 - q)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    return lr

In [20]:
from torch import nn, optim
class LARS(optim.Optimizer):
    def __init__(
        self,
        params,
        lr,
        weight_decay=0,
        momentum=0.9,
        eta=0.001,
        weight_decay_filter=None,
        lars_adaptation_filter=None,
    ):
        defaults = dict(
            lr=lr,
            weight_decay=weight_decay,
            momentum=momentum,
            eta=eta,
            weight_decay_filter=weight_decay_filter,
            lars_adaptation_filter=lars_adaptation_filter,
        )
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for g in self.param_groups:
            for p in g["params"]:
                dp = p.grad

                if dp is None:
                    continue

                if g["weight_decay_filter"] is None or not g["weight_decay_filter"](p):
                    dp = dp.add(p, alpha=g["weight_decay"])

                if g["lars_adaptation_filter"] is None or not g[
                    "lars_adaptation_filter"
                ](p):
                    param_norm = torch.norm(p)
                    update_norm = torch.norm(dp)
                    one = torch.ones_like(param_norm)
                    q = torch.where(
                        param_norm > 0.0,
                        torch.where(
                            update_norm > 0, (g["eta"] * param_norm / update_norm), one
                        ),
                        one,
                    )
                    dp = dp.mul(q)

                param_state = self.state[p]
                if "mu" not in param_state:
                    param_state["mu"] = torch.zeros_like(p)
                mu = param_state["mu"]
                mu.mul_(g["momentum"]).add_(dp)

                p.add_(mu, alpha=-g["lr"])


In [21]:
import time
import json
import torch.nn.functional as F
import math

In [None]:

optimizer = LARS(
    model.parameters(),
    lr=0,
    weight_decay=args.wd,
    weight_decay_filter=exclude_bias_and_norm,
    lars_adaptation_filter=exclude_bias_and_norm,
)
args.exp_dir.mkdir(parents=True, exist_ok=True)
stats_file = open(args.exp_dir / "stats.txt", "a", buffering=1)
# print(" ".join(sys.argv))
# print(" ".join(sys.argv), file=stats_file)

# if (args.exp_dir / "model.pth").is_file():
#     if args.rank == 0:
#         print("resuming from checkpoint")
#     ckpt = torch.load(args.exp_dir / "model.pth", map_location="cpu")
#     start_epoch = ckpt["epoch"]
#     model.load_state_dict(ckpt["model"])
#     optimizer.load_state_dict(ckpt["optimizer"])
# else:
start_epoch = 0

start_time = last_logging = time.time()
scaler = torch.cuda.amp.GradScaler()
for epoch in range(start_epoch, args.epochs):


    for step, (x, y, _) in enumerate(dataloader, start=epoch * len(dataloader)):
        x = x.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True)

        lr = adjust_learning_rate(args, optimizer, dataloader, step)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            loss = model.forward(x, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        current_time = time.time()
        if current_time - last_logging > args.log_freq_time:
            stats = dict(
                epoch=epoch,
                step=step,
                loss=loss.item(),
                time=int(current_time - start_time),
                lr=lr,
            )
            print(json.dumps(stats))
            print(json.dumps(stats), file=stats_file)
            last_logging = current_time
    state = dict(
        epoch=epoch + 1,
        model=model.state_dict(),
        optimizer=optimizer.state_dict(),
    )
    torch.save(state, args.exp_dir / "model.pth")

torch.save(model.backbone.state_dict(), args.exp_dir / "resnet_backbone.pth")

{"epoch": 0, "step": 56, "loss": 22.21867561340332, "time": 51, "lr": 0.03154929577464789}
{"epoch": 1, "step": 99, "loss": 21.072938919067383, "time": 102, "lr": 0.05577464788732395}
{"epoch": 1, "step": 131, "loss": 20.112300872802734, "time": 155, "lr": 0.07380281690140846}
{"epoch": 2, "step": 161, "loss": 20.30962562561035, "time": 205, "lr": 0.09070422535211269}
{"epoch": 2, "step": 190, "loss": 19.579139709472656, "time": 255, "lr": 0.10704225352112676}
{"epoch": 3, "step": 234, "loss": 19.24996566772461, "time": 306, "lr": 0.13183098591549297}
{"epoch": 4, "step": 289, "loss": 18.1663875579834, "time": 356, "lr": 0.1628169014084507}
{"epoch": 4, "step": 348, "loss": 18.557510375976562, "time": 407, "lr": 0.19605633802816905}
{"epoch": 5, "step": 404, "loss": 18.93300437927246, "time": 459, "lr": 0.22760563380281693}
{"epoch": 6, "step": 459, "loss": 17.464313507080078, "time": 510, "lr": 0.25859154929577466}
{"epoch": 7, "step": 516, "loss": 17.816692352294922, "time": 560, "lr

In [21]:
model.eval()

VICReg(
  (backbone): ResNet(
    (padding): ConstantPad2d(padding=(1, 1, 1, 1), value=0.0)
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(2, 2), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): R

In [35]:
x1,y,target=lfw_dataset[400]

In [36]:
x2,y,target=lfw_dataset[502]


In [37]:
target

0

In [38]:
x = transforms.simple_transform(x1).cuda()
y = transforms.simple_transform(x2).cuda()
x=torch.stack([x], dim=0)
y=torch.stack([y], dim=0)


In [39]:
embed_x = model.projector(model.backbone(x))
embed_y = model.projector(model.backbone(y))

In [40]:
 F.mse_loss(embed_x, embed_y)

tensor(0.5709, device='cuda:0', grad_fn=<MseLossBackward0>)