In [1]:
import os
os.chdir('..')

In [2]:
import random
from sklearn.manifold import TSNE
import torch.nn.functional as F
import numpy as np
# load package
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator,FormatStrFormatter,MaxNLocator
from tqdm.auto import tqdm
import argparse
import builtins
from datetime import datetime
import math
import os
import random
import shutil
import wandb
import time
import warnings
import numpy as np
from functools import partial
from utils import utils
# from sklearn.neighbors import KNeighborsClassifier

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision import models as torchvision_models
# import torchvision.models as torchvision_models
# from torch.utils.tensorboard import SummaryWriter

import sogclr.builder
import sogclr.loader
import sogclr.optimizer
import sogclr.folder # imagenet

In [None]:
torch.cuda.set_device(9)
model_paths = [
    "baselines/tuning_20241028_153626_dcl_0_0_0_0_-1_0/3/checkpoint_0199.pth.tar",
]
data = "/data/datasets/imagenet100"
arch = "resnet50"
batch_size = 2048

In [None]:
# Data loading code
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

image_size = 224
normalize = transforms.Normalize(mean=mean, std=std)

augmentation1 = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    normalize,
])

traindir = os.path.join(data, 'train')
train_dataset = sogclr.folder.ImageFolder(
    traindir,
    augmentation1)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=False,
    num_workers=6, pin_memory=True, drop_last=False)

In [None]:
def _build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
    mlp = []
    for l in range(num_layers):
        dim1 = input_dim if l == 0 else mlp_dim
        dim2 = output_dim if l == num_layers - 1 else mlp_dim

        mlp.append(nn.Linear(dim1, dim2, bias=False))

        if l < num_layers - 1:
            mlp.append(nn.BatchNorm1d(dim2))
            mlp.append(nn.ReLU(inplace=True))
        elif last_bn:
            # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
            # for simplicity, we further removed gamma in BN
            mlp.append(nn.BatchNorm1d(dim2, affine=False))

    return nn.Sequential(*mlp)

models = []
for model_path in model_paths:
    model = torchvision_models.__dict__[arch](pretrained=False)
    # Remove the final fully connected layer and add an identity layer
    if hasattr(model, 'fc'):
        linear_keyword = 'fc'
        hidden_dim = model.fc.weight.shape[1]
        # model.fc = nn.Identity()
        model.fc = _build_mlp(2, hidden_dim, 2048, 128)
    else:
        raise ValueError(f"Unsupported model architecture: {arch}")

    # Load checkpoint
    if os.path.isfile(model_path):
        print(f"=> Loading checkpoint '{model_path}'")
        checkpoint = torch.load(model_path, map_location="cpu")
        state_dict = checkpoint.get('state_dict', checkpoint)
        lda = state_dict['module.lambda_threshold.lda']
        # get base_encoder state_dict
        state_dict = {k.replace('module.base_encoder.', ''): v for k, v in state_dict.items() if 'module.base_encoder.' in k}
        msg = model.load_state_dict(state_dict, strict=True)
        print(f"=> Loaded checkpoint with missing keys: {msg.missing_keys}")
        # assert all([linear_keyword in k for k in msg.missing_keys]), "Missing keys should be linear layers"
        model.eval()
    else:
        print(f"Warning: Checkpoint {model_path} not found!")
        raise FileNotFoundError

    models.append(model)

In [None]:
from tqdm.auto import tqdm

h_list = []
i_list = []
for model in models:
    hidden_list1 = []
    indices = []

    with torch.inference_mode():
        model.cuda()

        tqdm_progress = tqdm(total=len(train_loader), leave=False)
        for images, labels, index in train_loader:
            images = images.cuda(non_blocking=True)

            # compute output
            with torch.cuda.amp.autocast(True):
                hidden1 = model(images)
            hidden1 = F.normalize(hidden1, p=2, dim=1)
            hidden_list1.append(hidden1.cpu())
            indices.append(index.cpu())

            tqdm_progress.update(images.shape[0])

        model.cpu()

    h_list.append(torch.cat(hidden_list1, dim=0))
    i_list.append(torch.cat(indices, dim=0))

h_list = torch.stack(h_list, dim=0)
i_list = torch.stack(i_list, dim=0)


In [61]:
import matplotlib.pyplot as plt

def plot(projections, meta, kl, alpha):
    # Get unique labels
    unique_labels = np.unique(meta)

    # Plot the projections with different colors and symbols for each label
    plt.figure(figsize=(10, 8))
    # plt.title(f"KL divergence: {kl:0.2f}", fontsize=60)

    cmap = plt.cm.get_cmap("jet", len(unique_labels))
    # cmap = generate_colormap(len(unique_labels))

    for idx,label in enumerate(unique_labels):
        mask = (meta == label)
        
        # Use a different color and marker for each label
        # color = plt.cm.jet(idx / float(len(unique_labels)))
        color = cmap(idx / float(len(unique_labels)))
        # marker = 'o' if label in BIRD_LABELS else '^'
        marker = 'o'
        
        plt.scatter(
            projections[mask, 0], 
            projections[mask, 1], 
            color=[color], 
            marker=marker, 
            label=f'Label {label}', 
            alpha=0.8,
            s=25,
        )
    plt.axis('off')  # Turn off the axes
    plt.tight_layout()

    # plt.show()
    plt.savefig(f"results/embeddings/alpha_{alpha}.png", format="png", dpi=300)
    plt.close()


In [None]:
from pathlib import Path

def get_pth_files(folder_path):
    # Get all .pth files in the folder
    folder = Path(folder_path)
    pth_files = list(folder.glob("*.pth"))
    return pth_files

# Example usage
folder = "results/embeddings"
pth_files = get_pth_files(folder)
print(pth_files)

In [None]:
pth_files

In [None]:
for f in pth_files:
    data = torch.load(f)
    plot(data["proj"], data["meta"].numpy(), data["kl"], data["name"])
    # plot(p, m.numpy(), k, n)