In [None]:
import dataclasses
from collections.abc import Iterator, Mapping
from types import MappingProxyType
from typing import Any, Literal, Optional

import jax
import jax.numpy as jnp
import sklearn
import sklearn.datasets

import optax

import matplotlib.pyplot as plt
from IPython.display import clear_output, display

from ott import datasets
from ott.geometry import costs, pointcloud

from ott.tools import sinkhorn_divergence

import jax
import jax.numpy as jnp
from ott.geometry.geometry import Geometry
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
import scipy
import numpy as np

from typing import Any, Optional

import matplotlib.pyplot as plt
from matplotlib import colors

from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem, potentials
from ott.solvers import linear
from ott.tools import progot
import scipy

import pandas as pd
import scanpy as sc
import numpy as np
import torch
import sys
sys.path.insert(0, '../src/')
import importlib
import FRLC
from FRLC import FRLC_opt
import HR_OT
importlib.reload(HR_OT)

import torch.multiprocessing as mp


In [2]:
@jax.jit
def sinkhorn_loss(
    x: jnp.ndarray, y: jnp.ndarray, epsilon: float = 0.001
) -> float:
    """Computes transport between (x, a) and (y, b) via Sinkhorn algorithm."""
    a = jnp.ones(len(x)) / len(x)
    b = jnp.ones(len(y)) / len(y)
    
    _, out = sinkhorn_divergence.sinkhorn_divergence(
        pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b
    )
    
    return out.divergence


def run_progot(
    x: jnp.ndarray, y: jnp.ndarray, cost_fn: costs.TICost, **kwargs: Any
) -> progot.ProgOTOutput:
    geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn)
    prob = linear_problem.LinearProblem(geom)
    estim = progot.ProgOT(**kwargs)
    out = estim(prob)
    return out

K = 4
cost_fn = costs.SqEuclidean()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'On device: {device}')

dtype = torch.float64


On device: cpu


In [3]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset
import torch
import torchvision.models as models
import os

In [4]:
# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images for CNN input
    transforms.ToTensor(),
])

# Load ImageNet dataset from extracted path # /ILSVRC/Data/CLS-LOC/test
imagenet_dataset = datasets.ImageFolder(root="/scratch/gpfs/DATASETS/imagenet/ilsvrc_2012_classification_localization/train",
                                        transform=transform)

# Create DataLoader for batching
imagenet_loader = DataLoader(imagenet_dataset, batch_size=32, shuffle=True, num_workers=4)

print(f"Loaded {len(imagenet_dataset)} images from ImageNet!")


Loaded 1281167 images from ImageNet!


In [None]:
!mkdir -p ~/.cache/torch/hub/checkpoints
!mv /home/ph3641/HierarchicalRefinement/HR_OT/HR_OT/notebooks/resnet50-0676ba61.pth ~/.cache/torch/hub/checkpoints/

In [None]:
model_path = os.path.expanduser("~/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth")

# Load pretrained ResNet model
model = models.resnet50()
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.fc = torch.nn.Identity()  # Remove classification layer to extract features
model.eval()  # Set to evaluation mode

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Compute embeddings
def extract_features(dataloader, model):
    embeddings = []
    num_img = 0
    with torch.no_grad():
        for idx, (images, _) in enumerate(dataloader):
            num_img += len(images)
            if idx % 100 == 0:
                print(f'image idx {idx}, images: {num_img}')
            images = images.to(device)
            features = model(images)
            embeddings.append(features.cpu().numpy())
    return np.vstack(embeddings)  # Stack all embeddings

print('extracting embeddings!')
embeddings = extract_features(imagenet_loader, model)

print(embeddings.shape)

In [None]:
import pickle
save_dir = "/scratch/gpfs/ph3641/hr_ot/"

os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "embeddings.pkl")

with open(save_path, "wb") as f:
    pickle.dump(embeddings, f)

print(f"Embeddings saved successfully to {save_path}")

In [4]:
import pickle
emb_dir = '/scratch/gpfs/ph3641/hr_ot/embeddings.pkl'

# Load embeddings from the pickle file
with open(emb_dir, "rb") as f:
    embeddings = pickle.load(f)

print(f"Embeddings loaded successfully! Shape: {embeddings.shape}")

Embeddings loaded successfully! Shape: (1281167, 2048)


In [5]:
import rank_annealing

# 1. Making it even, remove 1 image
embeddings = embeddings[1:,:]

# 2. Get to a close even number which when divided is non-prime
print(f'num embeddings: {embeddings.shape[0]}')
N = embeddings.shape[0] // 2
q = 640500
k = (N-q)*2

embed_sliced = embeddings[:-k]
n = embed_sliced.shape[0] // 2

rank_schedule = rank_annealing.optimal_rank_schedule( n , hierarchy_depth = 6, max_Q = int(2**11), max_rank = 64 )

num embeddings: 1281166
Optimized rank-annealing schedule: [7, 50, 1830]


In [6]:
import torch

num_samples = embed_sliced.shape[0]

# Shuffle indices
indices = torch.randperm(num_samples)

# Split into two tensors
X = embeddings[indices[:n]]  # First 50%
Y = embeddings[indices[n:]]  # Second 50%

del embeddings, indices, embed_sliced

print(f"X shape: {X.shape}, Y shape: {Y.shape}")

X shape: (640500, 2048), Y shape: (640500, 2048)


In [8]:
import math
import functools
import operator
import rank_annealing
import validation
importlib.reload(rank_annealing)
importlib.reload(HR_OT)

# Squared Euclidean cost p=2 or Euclidean if p=1
p = 1
K = 2

# Initialize dictionaries to store costs and sample sizes
costs = {
    'HROT_LR': {'samples': [], 'costs': []},
    'Sinkhorn': {'samples': [], 'costs': []},
    'ProgOT': {'samples': [], 'costs': []}
}

X, Y = np.array(X).astype(np.float32), np.array(Y).astype(np.float32)

batch_sizes = [128, 256, 512, 1024]

for B in batch_sizes:
    cost = validation.minibatch_sinkhorn_ot_without_replacement(X, Y, B)
    print(f'-----Mini-batch cost for batch size B = {B}: <C,P> = {cost}-----')


Mini-batch Sinkhorn: 100%|██████████| 5004/5004 [03:20<00:00, 24.96it/s]


-----Mini-batch cost for batch size B = 128: <C,P> = 21.889437173100873-----


Mini-batch Sinkhorn: 100%|██████████| 2502/2502 [03:48<00:00, 10.95it/s]


-----Mini-batch cost for batch size B = 256: <C,P> = 21.11325375353404-----


Mini-batch Sinkhorn: 100%|██████████| 1251/1251 [07:04<00:00,  2.94it/s]


-----Mini-batch cost for batch size B = 512: <C,P> = 20.335602532378395-----


Mini-batch Sinkhorn: 100%|██████████| 626/626 [18:20<00:00,  1.76s/it]

-----Mini-batch cost for batch size B = 1024: <C,P> = 19.585924760983012-----



