In [None]:
import os
os.chdir('..')

In [None]:
import random
import torch.nn.functional as F
import numpy as np
# load package
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator,FormatStrFormatter,MaxNLocator
from tqdm.auto import tqdm
import argparse
import builtins
from datetime import datetime
import math
import os
import random
import shutil
import wandb
import time
import warnings
import numpy as np
from functools import partial
from utils import utils
# from sklearn.neighbors import KNeighborsClassifier

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision import models
# import torchvision.models as torchvision_models
# from torch.utils.tensorboard import SummaryWriter

import sogclr.builder
import sogclr.loader
import sogclr.optimizer
import sogclr.folder # imagenet

In [None]:
torch.set_grad_enabled(False)

In [None]:
model_idx = 1

In [None]:
model_path = f"glofnd/{model_idx}/checkpoint_0199.pth.tar"

In [None]:
torch.cuda.set_device(0)
data = "imagenet"
arch = "resnet50"
batch_size = 4096

In [None]:
# Data loading code
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

image_size = 224
normalize = transforms.Normalize(mean=mean, std=std)

# simclr
augmentation1 = [
    transforms.RandomResizedCrop(image_size, scale=(0.08, 1.)),
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)  # not strengthened
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([sogclr.loader.GaussianBlur([.1, 2.])], p=1.0),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
]

traindir = os.path.join(data, 'train')
train_dataset = sogclr.folder.ImageFolder(
    traindir,
    sogclr.loader.TwoCropsTransform(transforms.Compose(augmentation1), 
                                    transforms.Compose(augmentation1)))

train_sampler = None

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=False,
    num_workers=6, pin_memory=True, sampler=train_sampler, drop_last=False)

In [None]:
model = models.__dict__[arch](pretrained=False)

def _build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
    mlp = []
    for l in range(num_layers):
        dim1 = input_dim if l == 0 else mlp_dim
        dim2 = output_dim if l == num_layers - 1 else mlp_dim

        mlp.append(nn.Linear(dim1, dim2, bias=False))

        if l < num_layers - 1:
            mlp.append(nn.BatchNorm1d(dim2))
            mlp.append(nn.ReLU(inplace=True))
        elif last_bn:
            # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
            # for simplicity, we further removed gamma in BN
            mlp.append(nn.BatchNorm1d(dim2, affine=False))

    return nn.Sequential(*mlp)

# Remove the final fully connected layer and add an identity layer
if hasattr(model, 'fc'):
    linear_keyword = 'fc'
    hidden_dim = model.fc.weight.shape[1]
    # model.fc = nn.Identity()
    model.fc = _build_mlp(2, hidden_dim, 2048, 128)
else:
    raise ValueError(f"Unsupported model architecture: {arch}")

# Load checkpoint
if os.path.isfile(model_path):
    print(f"=> Loading checkpoint '{model_path}'")
    checkpoint = torch.load(model_path, map_location="cpu")
    state_dict = checkpoint.get('state_dict', checkpoint)
    lda = state_dict['module.lambda_threshold.lda']
    # get base_encoder state_dict
    state_dict = {k.replace('module.base_encoder.', ''): v for k, v in state_dict.items() if 'module.base_encoder.' in k}
    msg = model.load_state_dict(state_dict, strict=True)
    print(f"=> Loaded checkpoint with missing keys: {msg.missing_keys}")
    # assert all([linear_keyword in k for k in msg.missing_keys]), "Missing keys should be linear layers"
    model.eval()
else:
    print(f"Warning: Checkpoint {model_path} not found!")
    raise FileNotFoundError

In [None]:
fig = plt.figure(figsize=(10, 6), dpi=300)
# mpl.rcParams['pdf.fonttype'] = 42
# mpl.rcParams['ps.fonttype'] = 42
# mpl.rcParams['font.family'] = 'Arial'
            
# Plot the kernel density plot for lda
v_plot = lda.squeeze().numpy()
ax = sns.kdeplot(v_plot, color='blue', legend=False, linewidth=12)
ax.yaxis.set_visible(False)
plt.tick_params(axis='x', labelsize=40)
plt.xlabel("\u03BB", fontsize=45)
plt.xlim(0.1, 0.6)
# plt.xticks(np.arange(0.1, 0.7, 0.1))
plt.tight_layout()

# plt.savefig(f"plots/lda_199.pdf", format="pdf")

In [None]:
from tqdm.auto import tqdm
hidden_list1 = []
hidden_list2 = []
indices = []
labels_list = []

with torch.inference_mode():
    model.cuda()

    for images, labels, index in tqdm(train_loader):
        images[0] = images[0].cuda(non_blocking=True)
        images[1] = images[1].cuda(non_blocking=True)

        # compute output
        with torch.amp.autocast('cuda'):
            hidden1, hidden2 = model(images[0]), model(images[1])
        hidden1, hidden2 = F.normalize(hidden1, p=2, dim=1), F.normalize(hidden2, p=2, dim=1)
        hidden_list1.append(hidden1.cpu())
        hidden_list2.append(hidden2.cpu())
        indices.append(index.cpu())
        labels_list.append(labels.cpu())

del model

hidden1 = torch.cat(hidden_list1, dim=0)
hidden2 = torch.cat(hidden_list2, dim=0)
indices = torch.cat(indices, dim=0)
labels = torch.cat(labels_list, dim=0)



In [None]:
alpha = 0.01
small_batch_size = 4096
large_batch_size = 128
max_count = 100000
rng = np.random.default_rng(0)
seen = torch.ones_like(lda, dtype=torch.bool)

# approx_lda = torch.ones_like(lda)
approx_lda = torch.ones((len(lda), 2), dtype=torch.float32)
batch_lda = []
lda_fn = []
batch_fn = []
approx_fn = []
oracle_fn = []

with torch.inference_mode():
    for i in tqdm(range(0, len(hidden1), small_batch_size)):
        i_lda = lda[indices[i:i+small_batch_size]].cuda()
        i_label = labels[i:model_idx+small_batch_size].cuda()
        b1 = hidden1[model_idx:i+small_batch_size].cuda()
        b2 = hidden2[i:i+small_batch_size].cuda()
        sampled_indices = torch.from_numpy(rng.choice(indices, size=min(max_count, len(indices)), replace=False))
        # for loop in batches
        sim_list = []
        curr_batch_lda = []
        curr_lda_fn = []
        curr_batch_fn = []
        curr_approx_fn = []
        curr_oracle_fn = []
        for j in range(0, len(sampled_indices), large_batch_size):
            b_indices = sampled_indices[j:j+large_batch_size]
            b1_large = hidden1[b_indices].cuda()
            b2_large = hidden2[b_indices].cuda()
            labels_large = labels[b_indices].cuda().repeat(2)
            assert labels_large.shape[0] == 2 * b1_large.shape[0]
            # compute output
            sim1 = torch.cat([torch.mm(b1, b1_large.t()), torch.mm(b1, b2_large.t())], dim=1)
            sim2 = torch.cat([torch.mm(b2, b1_large.t()), torch.mm(b2, b2_large.t())], dim=1)

            sim = torch.cat([sim1, sim2], dim=1)

            bl = sim.float().quantile(1-alpha, dim=1, keepdim=True).type_as(i_lda).cuda()
            curr_batch_fn.append(
                (sim > bl).float().cpu())
            curr_lda_fn.append(
                (sim > i_lda).float().cpu())
            curr_oracle_fn.append(
                (labels_large.unsqueeze(0) == i_label.unsqueeze(1)).float().cpu())
            curr_batch_lda.append(bl.cpu())
            sim_list.append(sim.cpu())

        
        sim_list = torch.cat(sim_list, dim=1)
        curr_batch_lda = torch.cat(curr_batch_lda, dim=1)
        curr_batch_fn = torch.cat(curr_batch_fn, dim=1)
        curr_lda_fn = torch.cat(curr_lda_fn, dim=1)
        curr_oracle_fn = torch.cat(curr_oracle_fn, dim=1)

        q_low = sim_list.float().quantile(1-alpha, dim=1, keepdim=True, interpolation='lower').type_as(lda).cpu()
        q_high = sim_list.float().quantile(1-alpha, dim=1, keepdim=True, interpolation='higher').type_as(lda).cpu()

        curr_approx_fn = (sim_list > q_low).float().cpu()

        del sim_list

        batch_lda.append(curr_batch_lda)
        lda_fn.append(curr_lda_fn)
        batch_fn.append(curr_batch_fn)
        approx_fn.append(curr_approx_fn)
        oracle_fn.append(curr_oracle_fn)
        
        approx_lda[indices[i:i+small_batch_size],0] = q_low.squeeze(1)
        approx_lda[indices[i:i+small_batch_size],1] = q_high.squeeze(1)
        seen[indices[i:i+small_batch_size]] = False
# assert torch.all(seen), "Some indices are not seen"

In [None]:
del hidden1
del hidden2
del indices
del labels
torch.cuda.empty_cache()

In [None]:
folder = "results/lambdas"
torch.save(approx_lda, f"{folder}/approx_lda_{model_idx}.pth")
torch.save(batch_lda, f"{folder}/batch_lda_{model_idx}.pth")

In [None]:
folder = "results/lambdas"
approx_lda = torch.load(f"{folder}/approx_lda_{model_idx}.pth")
batch_lda = torch.load(f"{folder}/batch_lda_{model_idx}.pth")

In [None]:
batch_lda = torch.cat(batch_lda, dim=0)
lda_fn = torch.cat(lda_fn, dim=0)
batch_fn = torch.cat(batch_fn, dim=0)
approx_fn = torch.cat(approx_fn, dim=0)
oracle_fn = torch.cat(oracle_fn, dim=0)

In [None]:
lda = lda.cuda().squeeze()[:-1]
approx_lda = approx_lda.cuda().squeeze()[:-1]
batch_lda = batch_lda.cuda()

In [None]:
batch_lda.shape

In [None]:
with torch.inference_mode():
    q_low = approx_lda[:,0]
    q_high = approx_lda[:,1]
    a_lda = torch.where(torch.abs(lda - q_low) < torch.abs(q_high - lda), q_low, q_high)

In [None]:
error = torch.abs(lda.squeeze() - a_lda.squeeze())
# mae, rmse, mae_scaled (to 0-100 scale)
mae = error.mean().item()
rmse = torch.sqrt((error ** 2).mean()).item()
mae_scaled = mae * 50
print(f"MAE: {mae:.4f}, RMSE: {rmse:.4f}")
print(f"MAE scaled: {mae_scaled:.4f}")

In [None]:
with torch.inference_mode():
    q_low = approx_lda[:,0].unsqueeze(1)
    q_high = approx_lda[:,1].unsqueeze(1)
    a_lda = torch.where(torch.abs(batch_lda - q_low) < torch.abs(q_high - batch_lda), q_low, q_high)
batch_error = torch.abs(batch_lda - a_lda)
# mae, rmse, mae_scaled (to 0-100 scale)
mae = batch_error.mean().item()
rmse = torch.sqrt((batch_error ** 2).mean()).item()
mae_scaled = mae * 50
print(f"MAE: {mae:.4f}, RMSE: {rmse:.4f}")
print(f"MAE scaled: {mae_scaled:.4f}")

In [None]:
error = error.cpu().numpy()
lda = lda.cpu().numpy()
a_lda = a_lda.cpu().numpy()

In [None]:
plt.figure(figsize=(8, 6))  # Optional: Specify figure size
plt.plot(error, label='y = x^2', color='blue', linewidth=2)  # Plotting the line

# Adding labels and title
plt.xlabel('X-axis', fontsize=14)
plt.ylabel('Y-axis', fontsize=14)
plt.title('Line Plot Example', fontsize=16)

In [None]:
(lda.squeeze() - a_lda.squeeze()).mean()

In [None]:
fig = plt.figure(figsize=(10, 6), dpi=300)
# mpl.rcParams['pdf.fonttype'] = 42
# mpl.rcParams['ps.fonttype'] = 42
# mpl.rcParams['font.family'] = 'Arial'
            
# Plot the kernel density plot for lda
v_plot = (lda.squeeze() - a_lda.squeeze())
ax = sns.kdeplot(v_plot, color='blue', legend=False, linewidth=12)
ax.yaxis.set_visible(False)
plt.tick_params(axis='x', labelsize=40)
plt.xlabel("\u03BB", fontsize=45)
plt.xlim(-0.5, 0.5)
# plt.xticks(np.arange(0.1, 0.7, 0.1))
plt.tight_layout()

# plt.savefig(f"plots/lda_199.pdf", format="pdf")

In [None]:
fig = plt.figure(figsize=(9, 6), dpi=300)
# mpl.rcParams['pdf.fonttype'] = 42
# mpl.rcParams['ps.fonttype'] = 42
# mpl.rcParams['font.family'] = 'Arial'
            
# Plot the kernel density plot for lda
v_plot = error
ax = sns.kdeplot(v_plot, color='blue', legend=False, linewidth=12)
ax.yaxis.set_visible(False)
plt.tick_params(axis='x', labelsize=40)
plt.xlabel("\u03BB", fontsize=45)
plt.xlim(0, 0.5)
# plt.xticks(np.arange(0.1, 0.7, 0.1))
plt.tight_layout()

# plt.savefig(f"plots/lda_199.pdf", format="pdf")

In [None]:
fig = plt.figure(figsize=(10, 6), dpi=300)
# mpl.rcParams['pdf.fonttype'] = 42
# mpl.rcParams['ps.fonttype'] = 42
# mpl.rcParams['font.family'] = 'Arial'
            
# Plot the kernel density plot for lda
v_plot = lda.squeeze()
v2_plot = a_lda.squeeze()
ax = sns.kdeplot(v_plot, color='blue', legend=False, linewidth=12)
sns.kdeplot(v2_plot, color='red', legend=False, linewidth=12, ax=ax)
ax.yaxis.set_visible(False)
plt.tick_params(axis='x', labelsize=40)
plt.xlabel("\u03BB", fontsize=45)
plt.xlim(0.1, 0.6)
# plt.xticks(np.arange(0.1, 0.7, 0.1))
plt.tight_layout()

# plt.savefig(f"plots/lda_199.pdf", format="pdf")

In [None]:
import random

fig = plt.figure(figsize=(10, 6), dpi=300)
# mpl.rcParams['pdf.fonttype'] = 42
# mpl.rcParams['ps.fonttype'] = 42
# mpl.rcParams['font.family'] = 'Arial'
            
# Plot the kernel density plot for lda
v_plot = lda.squeeze()
v2_plot = a_lda.squeeze()
v3_plot = batch_lda.cpu().numpy()
randomlist = random.sample(range(0, v3_plot.shape[1]), 50)
ax = sns.kdeplot(v_plot, color='blue', legend=True, linewidth=12, label='Real')
sns.kdeplot(v2_plot, color='red', legend=True, linewidth=12, ax=ax, label=r'$\lambda$')
sns.kdeplot(v3_plot[:,randomlist[0]],legend=True, linewidth=12, ax=ax, alpha=0.5, label="FNC")
lines = ax.lines
for k in randomlist[1:]:
    sns.kdeplot(v3_plot[:,k],legend=True, linewidth=12, ax=ax, alpha=0.5)

ax.yaxis.set_visible(False)
plt.tick_params(axis='x', labelsize=40)
plt.xlabel("\u03BB", fontsize=45)
plt.xlim(0.1, 0.6)
# plt.xticks(np.arange(0.1, 0.7, 0.1))
plt.legend(handles=ax.lines, loc='upper right')
plt.tight_layout()

# plt.savefig(f"plots/lda_199.pdf", format="pdf")