In [1]:
import jax
import jax.numpy as jnp
import sklearn
import sklearn.datasets

import optax

import matplotlib.pyplot as plt
from IPython.display import clear_output, display

from ott import datasets
from ott.geometry import costs, pointcloud

from ott.tools import sinkhorn_divergence

import jax
import jax.numpy as jnp
from ott.geometry.geometry import Geometry
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
import scipy
import numpy as np

from typing import Any, Optional

import matplotlib.pyplot as plt
from matplotlib import colors

from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem, potentials
from ott.solvers import linear
from ott.tools import progot
import scipy

In [2]:
from scipy.stats import pearsonr, spearmanr
from scipy.spatial import distance
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics.cluster import adjusted_rand_score, adjusted_mutual_info_score
import numpy as np
import json
import torch
from scipy.spatial import distance
import sys
import importlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
filehandle_embryo = '/scratch/gpfs/ph3641/mouse_embryo_preprocessed/E12.5_E13.5/'

sys.path.insert(0, filehandle_embryo)
sys.path.insert(0, '/home/ph3641/WDM/WDM/')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'On device: {device}')
dtype = torch.float64

def plot_slice_value(value_vec, spatial, vmax=None, vmin=None, center=False, center_val=None, s=6):
    """
    Parameters: 
    slice: AnnData object of the slice
    value_vec: growth vector

    Returns:
    Plots the slice with each spot colored according to its value in value_vec
    """
    if center:
        if center_val is None:
            cmin = vmin - (vmax-vmin)/2
            cmax = vmax + (vmax-vmin)/2
        else:
            cmin = vmin - center_val
            cmax = vmax + center_val
    else:
        cmin = vmin
        cmax = vmax
    plt.figure()
    sc = plt.scatter(spatial[:, 0], spatial[:, 1], c=value_vec, cmap='YlGn', s=s, vmin=cmin, vmax=cmax) # or 'RdYlGn'
    cbar = plt.colorbar(sc)
    cbar.ax.tick_params(labelsize=20)
    plt.gca().invert_yaxis()
    plt.axis('off')

    fig = plt.gcf()
    fig_size = fig.get_size_inches()
    new_width = 20.0
    new_height = new_width * (fig_size[1] / fig_size[0])
    fig.set_size_inches(new_width, new_height)
    plt.axis('equal')
    plt.show()
    return


@jax.jit
def sinkhorn_loss(
    x: jnp.ndarray, y: jnp.ndarray, epsilon: float = 0.001
) -> float:
    """Computes transport between (x, a) and (y, b) via Sinkhorn algorithm."""
    a = jnp.ones(len(x)) / len(x)
    b = jnp.ones(len(y)) / len(y)
    
    _, out = sinkhorn_divergence.sinkhorn_divergence(
        pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b
    )
    
    return out.divergence


def run_progot(
    x: jnp.ndarray, y: jnp.ndarray, cost_fn: costs.TICost, **kwargs: Any
) -> progot.ProgOTOutput:
    geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn)
    prob = linear_problem.LinearProblem(geom)
    estim = progot.ProgOT(**kwargs)
    out = estim(prob)
    return out

K = 4
cost_fn = costs.SqEuclidean()

On device: cpu


In [5]:
# Filehandles for the Mouse Embryo datasets (Chen et al 2022)

filehandles_embryo = ['/scratch/gpfs/ph3641/mouse_embryo_preprocessed/E9.5_E10.5/', \
                        '/scratch/gpfs/ph3641/mouse_embryo_preprocessed/E10.5_E11.5/', \
                        '/scratch/gpfs/ph3641/mouse_embryo_preprocessed/E11.5_E12.5/', \
                        '/scratch/gpfs/ph3641/mouse_embryo_preprocessed/E12.5_E13.5/', \
                          '/scratch/gpfs/ph3641/mouse_embryo_preprocessed/E13.5_E14.5/', \
                          '/scratch/gpfs/ph3641/mouse_embryo_preprocessed/E14.5_E15.5/', \
                         '/scratch/gpfs/ph3641/mouse_embryo_preprocessed/E15.5_E16.5/']

spatial_list = []
expr_list = []

for i in range(len(filehandles_embryo)):
    
    if i == 0:
        spatial = np.load(filehandles_embryo[i] + f'slice{1}_coordinates.npy')
        spatial_list.append(spatial)
    
    spatial = np.load(filehandles_embryo[i] + f'slice{2}_coordinates.npy')
    spatial_list.append(spatial)

    if i == 0:
        expr = np.load(filehandles_embryo[i] + f'slice{1}_feature.npy')
        expr_list.append(expr)
    
    expr = np.load(filehandles_embryo[i] + f'slice{2}_feature.npy')
    expr_list.append(expr)



In [None]:
sys.path.insert(0, '../src/')

import ott
from ott.geometry import pointcloud, geometry
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
import rank_annealing
import HR_OT
import importlib
import FRLC
import util
importlib.reload(HR_OT)
importlib.reload(FRLC)

np.random.seed(seed=1)
torch.manual_seed(1)

p = 1

# Looping/ iterating over datasets; some of these break the kernel,
# so in practice need to comment out ones (Sinkhorn, ProgOT, etc) which cannot scale
# past a certain number of points.

for i in range(2,len(filehandles_embryo)):

    print(f'Pair: {filehandles_embryo[i]} ')
    
    # Expression data
    x0, x1 = expr_list[i], expr_list[i+1]
    print(x1.shape)
    # Spatial data
    s0, s1 = spatial_list[i], spatial_list[i+1]
    
    n = min(x0.shape[0], x1.shape[0])
    if x1.shape[0] > n:
        selected_indices = np.random.choice(x1.shape[0], size=n, replace=False)
        x1 = x1[selected_indices]
        s1 = s1[selected_indices]
    
    assert x1.shape[0] == x0.shape[0], "Expression dataset sizes do not match!"
    assert s1.shape[0] == s0.shape[0], "Expression dataset sizes do not match!"

    print(f'Sample size = {n}')
    
    # Wasserstein-cost only
    # Define pairwise Dist Mat
    _X = torch.tensor(x0).float().to(device)
    _Y = torch.tensor(x1).float().to(device)

    # n, hierarchy_depth = 6, max_Q = int(2**10), max_rank = 16 up to 18k
    # n, hierarchy_depth = 6, max_Q = int(2**10), max_rank = 32 for 30k
    # n, hierarchy_depth = 6, max_Q = int(2**10), max_rank = 64 for up to 70k
    # n, hierarchy_depth = 6, max_Q = int(2**10), max_rank = 128 for 102k-113k

    """
    HierarchicalRefinement
    """
    try:
        # some might be prime, removing insignificant # points if fails
        # for 50, 70k removing 2 points
        # for 102k removing 4 points
        # for 113k removing 2 points
        # Slight adjustment needed for prime sizes or numbers with few divisors (e.g. apply line n = n - 2 and subsample)
        
        _X, _Y = torch.tensor(x0[2:] ).float().to(device), torch.tensor(x1[2:] ).float().to(device) 
        
        rank_schedule = rank_annealing.optimal_rank_schedule( n, hierarchy_depth = 6, max_Q = int(2**10), max_rank = 128 )
    
        hrot_lr = HR_OT.HierarchicalRefinementOT.init_from_point_clouds(_X, _Y, rank_schedule, base_rank=1, device=device)
        F = hrot_lr.run(return_as_coupling=False)
        cost_hrot_lr = hrot_lr.compute_OT_cost()
        print(cost_hrot_lr)
        
    except Exception as e:
        print(f'Refinement failed for sample size {n}: {e}')
    
    try:
        C = torch.cdist(_X, _Y) ** p
        
    except Exception as e:
        print(f'Failed to load cost mat for sample size {n}: {e}')
        continue
    """
    LOT
    """
    try: 
        geom_xy = ott.geometry.geometry.Geometry(cost_matrix=C.cpu().numpy()) 
        ot_prob = linear_problem.LinearProblem(geom_xy)
        solver = sinkhorn_lr.LRSinkhorn(rank=40)
        ot_lr = solver(ot_prob)
        cost_lot = ot_lr.primal_cost
        print(f'cost: {cost_lot:.4f}')
        
    except Exception as e:
        print(f'LOT failed for sample size {n}: {e}')
        continue
    """
    Sinkhorn
    """
    try:
        geom = Geometry(C.cpu().numpy())
        ot_problem = linear_problem.LinearProblem(geom)
        solver = sinkhorn.Sinkhorn()
        ot_solution = solver(ot_problem)
        P_sinkhorn = ot_solution.matrix
        cost_sink = (C.cpu().numpy() * P_sinkhorn).sum()
        print(f'Cost of Sinkhorn: {cost_sink}')
    except Exception as e:
        print(f'Sinkhorn failed for sample size {n}: {e}')
    """
    ProgOT
    """
    try:
        x_train, y_train = x0, x1
        alphas = progot.get_alpha_schedule("exp", num_steps=K)
        out = run_progot(x0, x1, cost_fn, alphas=alphas, epsilons=None)
        P_progOT_default = out.get_output(-1).matrix
        cost_progOT = (C.cpu().numpy() * P_progOT_default).sum()
        print(f'Cost of ProgOT: {cost_progOT}')
        
    except Exception as e:
        print(f'ProgOT failed for sample size {n}: {e}')
    """
    FRLC
    """
    try:
        # FRLC
        C1, C2 = util.low_rank_distance_factorization(_X,_Y, r=40, eps=0.04, device=device)
        # Normalize appropriately
        c = ( C1.max()**1/2 ) * ( C2.max()**1/2 )
        C1, C2 = C1/c, C2/c
        C_factors = (C1.to(_X.dtype), C2.to(_X.dtype))
        Q, R, diagG, errs = FRLC.FRLC_LR_opt(C_factors,
                                             A_factors=None, 
                                             B_factors=None,
                                           gamma=30,
                                           r = 40,
                                           max_iter=60,
                                           device=device,
                                           min_iter = 25,
                                           max_inneriters_balanced=100,
                                           max_inneriters_relaxed=40,
                                           diagonalize_return=True,
                                           printCost=True, tau_in=100000,
                                            dtype = _X.dtype)
        
        print(f'FRLC cost: {errs['W_cost'][-1] * c**2}')
        
    except Exception as e:
        print(f'FRLC failed for sample size {n}: {e}')
    
    

In [None]:

costs = {
    'HROT_LR': {'samples': [5913, 18408, 30124, 51363, 77367, 102519, 113350], 
                'costs': [21.8088, 14.8126, 16.1396, 14.5741, 13.7851, 14.2901, 12.7880
                         ]},
    'Sinkhorn': {'samples': [5913, 18408], 
                 'costs': [21.9137, 14.8893]},
    'ProgOT': {'samples': [5913, 18408],
               'costs': [22.5607, 15.3539]},
    'FRLC': {'samples': [5913, 18408, 30124, 51363, 77367, 102519, 113350],
               'costs': [23.1443, 16.0926, 17.7380, 15.4707, 14.6422, 15.5055, 14.0034
               ]}
}
