# Look into GW performance

The pure Gromow-Wasserstein (GW) solver converges to a bad local minimum when we use precise lineage information on the 330/390 min pair of time points. In this notebook, we illustrate this point by showing that initializing GW from a better starting point, in this case the solution of a lineage OT problem, encourages the algorithm to converge to a better local minima. Thus, we highlight that pure GW, which only relies on lineage information, sometimes convergeds to bad local minima, even for perfect lineage information, which can be avoided in moslin by including some gene expression information. 

## Preliminaries

### Import libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Any, Dict, Literal, Optional, Tuple
import scanpy as sc
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt

import jax

jax.config.update("jax_enable_x64", True)

import pandas as pd
import moslin_utils as mu
from moslin_utils.constants import DATA_DIR, TIME_KEY

import ott

# import solvers
from ott.solvers.quadratic import gromov_wasserstein
from ott.solvers.linear import sinkhorn

# import initializers
from ott.initializers.quadratic import initializers

# import problems
from ott.problems.quadratic import quadratic_problem
from ott.problems.linear import linear_problem

### Set up paths

In [3]:
DATA_DIR = DATA_DIR / "packer_c_elegans"

### Define global settings

### Import data

In [4]:
adata = sc.read(DATA_DIR / "c_elegans.h5ad")
full_reference_tree = nx.read_gml(DATA_DIR / "ML_2023-11-06_packer_lineage_tree.gml")

### Define utility functions

In [5]:
def print_mean_error(tmat):
    """Compute the mean over ancestor and descendant error for a given oupling and print."""

    # make sure this is a numpy array
    tmat = np.array(tmat)

    # get early, late, and mean errors
    early_error, late_error = gt_coupling.cost(tmat, late=False), gt_coupling.cost(tmat, late=True)
    mean_error = (early_error + late_error) / 2

    print(f"Early error = {early_error:.3f}, late error = {late_error:.3f} and mean_error = {mean_error:.3f}.")

## Look into GW performance

### Define parameters and prepare data

Define parameters

In [6]:
tp = [330, 390]

lineage_info = "precise"
store_inner_errors = True

max_outer_iterations = 50
max_inner_iterations = 30000
threshold = 1e-3
scale_cost = "mean"

# we'll define two different values for epsilon here
epsilon_linear = 0.05
epsilon_quad = 0.001

Let's look into the original gridsearch results for GW. 

In [7]:
grid_df = pd.read_csv(DATA_DIR / "ML_2024-03-11_celegans_precise_and_abpxp.csv", index_col=0)

In [8]:
hyper_df = mu.ul.get_best_runs(df=grid_df, lineage_info=lineage_info, group_key="kind", group="GW", converged=True)
hyper_df

Removing 0/30 not converged runs.


Unnamed: 0,_wandb,late_cost,early_cost,deviation_from_balanced,_step,_runtime,converged,_timestamp,mean_error,tp,kind,alpha,tau_a,epsilon,scale_cost,lineage_info,max_inner_iterations,name
47,{'runtime': 26},0.109178,0.083764,0.0,0,27.500503,True,1699616000.0,0.096471,170-210,GW,1.0,1,0.01,mean,precise,30000.0,sandy-sweep-553
58,{'runtime': 121},0.16329,0.150787,2.220446e-16,0,121.807805,True,1699616000.0,0.157038,210-270,GW,1.0,1,0.001,mean,precise,30000.0,ethereal-sweep-542
57,{'runtime': 305},0.229756,0.197582,2.220446e-16,0,306.088845,True,1699616000.0,0.213669,270-330,GW,1.0,1,0.001,mean,precise,30000.0,laced-sweep-543
56,{'runtime': 126},0.613913,0.559957,2.220446e-16,0,126.642962,True,1699616000.0,0.586935,330-390,GW,1.0,1,0.001,mean,precise,30000.0,wise-sweep-544
55,{'runtime': 36},0.042624,0.047701,2.220446e-16,0,37.496566,True,1699616000.0,0.045162,390-450,GW,1.0,1,0.001,mean,precise,30000.0,fluent-sweep-545
54,{'runtime': 25},0.010022,0.009853,8.881784e-16,0,26.347877,True,1699616000.0,0.009937,450-510,GW,1.0,1,0.001,mean,precise,30000.0,rose-sweep-546


Start by preprocessing the data

In [9]:
# pre-proces the data
adata = mu.pp.preprocess(adata, full_reference_tree, lineage_info=lineage_info)

# extract early and late time point information
early_time, late_time = tp

Get the ground truth coupling, and lineage distance matrices. 

In [10]:
gt_coupling, early_dist, late_dist, bdata = mu.tl.prepare_moscot(
    adata=adata,
    early_time=early_time,
    late_time=late_time,
)

# from the ground truth coupling, get the marginals
a, b = gt_coupling.early_marginal, gt_coupling.late_marginal

Let's make sure that dimensions make sense. 

In [11]:
assert early_dist.shape[0] == np.sum(bdata.obs[TIME_KEY] == early_time) == a.shape[0], "Shape mismatch for early cells"
assert late_dist.shape[0] == np.sum(bdata.obs[TIME_KEY] == late_time) == b.shape[0], "Shape mismatch for late cells"

### Solve a linear problem

Using OTT directly, initialize a geometry for the linear problem 

In [12]:
# initialize a Geometry in PCA space for the linear problm
X = bdata[bdata.obs[TIME_KEY] == early_time].obsm["X_pca"].copy()
Y = bdata[bdata.obs[TIME_KEY] == late_time].obsm["X_pca"].copy()

geom_xy = ott.geometry.pointcloud.PointCloud(x=X, y=Y, scale_cost=scale_cost, epsilon=epsilon_linear)

Make sure setting epsilon worked as expected. 

In [13]:
assert geom_xy.epsilon == epsilon_linear, "Not using the specific epsilon value"

Initialize a Linear problem. 

In [14]:
prob_lin = linear_problem.LinearProblem(geom=geom_xy, a=a, b=b)

Initialize a solver for this linear problem. 

In [15]:
linear_solver_kwargs = {"max_iterations": max_inner_iterations}

# Instantiate a jitt'ed linear solver
sinkhorn_solver = jax.jit(
    sinkhorn.Sinkhorn(
        max_iterations=max_inner_iterations,
    )
)

Solve the Linear problem using the Sinkhorn solver.

In [16]:
sinkhorn_out = sinkhorn_solver(prob_lin)

Let's look into this solution a bit more. 

In [17]:
has_converged = bool(sinkhorn_out.converged)

print(f"{sinkhorn_out.n_iters} outer iterations were needed.")

print(f"The algorithm converged: {has_converged}")
print(f"The final regularized Sinkhorn cost is: {sinkhorn_out.reg_ot_cost:.3f}")

190 outer iterations were needed.
The algorithm converged: True
The final regularized Sinkhorn cost is: 0.565


We can also use this solution to compute a mean error. 

In [18]:
# extract the transition matrix
print_mean_error(sinkhorn_out.matrix)

If total mass - 1 is small, this may not significantly affect downstream results.
Total mass - 1: -0.0001719895810036176
If total mass - 1 is small, this may not significantly affect downstream results.
Total mass - 1: -0.00017198958100372863


Early error = 0.268, late error = 0.267 and mean_error = 0.267.


### Solve a quadratic problem

Initialize the geometries and a quadratic problem. 

In [19]:
# initialize Geometries from pre-computed lineage distance matrices
geom_xx = ott.geometry.geometry.Geometry(early_dist, scale_cost=scale_cost)
geom_yy = ott.geometry.geometry.Geometry(late_dist, scale_cost=scale_cost)

# initialize the Quadratic problem
prob_quad = quadratic_problem.QuadraticProblem(geom_xx=geom_xx, geom_yy=geom_yy, a=a, b=b)

Initialize a solver, and use it to solve this problem. 

In [20]:
# control some parameters of the inner Sinkhorn solver
linear_solver_kwargs = {"max_iterations": max_inner_iterations}
linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)

# Instantiate a jitt'ed Gromov-Wasserstein solver
gw_solver = jax.jit(
    gromov_wasserstein.GromovWasserstein(
        linear_ot_solver=linear_ot_solver,
        epsilon=epsilon_quad,
        store_inner_errors=True,
        threshold=threshold,
        max_iterations=max_outer_iterations,
    )
)

gw_out = gw_solver(prob_quad)

Inspect the outpute we got from this. 

In [21]:
has_converged = bool(gw_out.linear_convergence[gw_out.n_iters - 1])

print(f"{gw_out.n_iters} outer iterations were needed.")
print(f"The last Sinkhorn iteration has converged: {has_converged}")
print(f"The outer loop of Gromov Wasserstein has converged: {gw_out.converged}")
print(f"The final regularized GW cost is: {gw_out.reg_gw_cost:.3f}")

9 outer iterations were needed.
The last Sinkhorn iteration has converged: True
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 0.006


Evaluate using the ground-truth coupling. 

In [22]:
print_mean_error(gw_out.matrix)

Early error = 0.560, late error = 0.614 and mean_error = 0.587.


The poor performance of pure GW reproduces what we saw in the gridsearch above. 

### Initialize quadratic problem with linear solution

Finally, let's use the OT solution to initialize the GW problem. 

In [23]:
# define the inner linear solver
linear_solver_kwargs = {"max_iterations": max_inner_iterations}
linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)

# Instantiate a jitt'ed Gromov-Wasserstein solver with custom initialization
ot_gw_solver = jax.jit(
    gromov_wasserstein.GromovWasserstein(
        linear_ot_solver=linear_ot_solver,
        epsilon=epsilon_quad,
        store_inner_errors=True,
        threshold=threshold,
        max_iterations=max_outer_iterations,
        quad_initializer=initializers.QuadraticInitializer(
            init_coupling=sinkhorn_out.matrix,
        ),
    )
)

Let's call this solver

In [24]:
ot_gw_out = ot_gw_solver(prob_quad)

Let's look into this solution

In [26]:
has_converged = bool(ot_gw_out.linear_convergence[ot_gw_out.n_iters - 1])

print(f"{ot_gw_out.n_iters} outer iterations were needed.")
print(f"The last Sinkhorn iteration has converged: {has_converged}")
print(f"The outer loop of Gromov Wasserstein has converged: {ot_gw_out.converged}")
print(f"The final regularized GW cost is: {ot_gw_out.reg_gw_cost:.3f}")

5 outer iterations were needed.
The last Sinkhorn iteration has converged: True
The outer loop of Gromov Wasserstein has converged: True
The final regularized GW cost is: 0.006


Evaluate using the ground-truth coupling. 

In [27]:
print_mean_error(ot_gw_out.matrix)

Early error = 0.108, late error = 0.118 and mean_error = 0.113.


Indeed, the mean error has substantially decreased compared to the default GW initalization we tested before. 