# FUGW Barycenter

In [14]:
import sys

if "google.colab" in sys.modules:
    !pip install -q git+https://github.com/ott-jax/ott@main

In [15]:
import gdist
from nilearn import datasets, image, plotting, surface

import jax
import jax.numpy as jnp
import numpy as np

import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

from ott.geometry import pointcloud
from ott.problems.quadratic import gw_barycenter as gwb
from ott.solvers.quadratic import gw_barycenter as gwb_solver

In [25]:
n_subjects = 4

contrasts = [
    "sentence reading vs checkerboard",
    "sentence listening",
    "calculation vs sentences",
    "left vs right button press",
    "checkerboard",
]


brain_data = datasets.fetch_localizer_contrasts(
    contrasts,
    n_subjects=n_subjects,
    get_anats=True,
)

subjects_ = []
for i in range(n_subjects):
    sub

['/Users/plbar/nilearn_data/brainomics_localizer/brainomics_data/S01/cmaps_VisualSentencesVsCheckerboard.nii.gz',
 '/Users/plbar/nilearn_data/brainomics_localizer/brainomics_data/S01/cmaps_AuditorySentences.nii.gz',
 '/Users/plbar/nilearn_data/brainomics_localizer/brainomics_data/S01/cmaps_Auditory&VisualCalculationVsSentences.nii.gz',
 '/Users/plbar/nilearn_data/brainomics_localizer/brainomics_data/S01/cmaps_LeftAuditory&VisualClickVsRightAuditory&VisualClick.nii.gz',
 '/Users/plbar/nilearn_data/brainomics_localizer/brainomics_data/S01/cmaps_Checkerboard.nii.gz',
 '/Users/plbar/nilearn_data/brainomics_localizer/brainomics_data/S02/cmaps_VisualSentencesVsCheckerboard.nii.gz',
 '/Users/plbar/nilearn_data/brainomics_localizer/brainomics_data/S02/cmaps_AuditorySentences.nii.gz',
 '/Users/plbar/nilearn_data/brainomics_localizer/brainomics_data/S02/cmaps_Auditory&VisualCalculationVsSentences.nii.gz',
 '/Users/plbar/nilearn_data/brainomics_localizer/brainomics_data/S02/cmaps_LeftAuditory&Vis

In [17]:
fsaverage3 = datasets.fetch_surf_fsaverage(mesh="fsaverage3")


def load_images_and_project_to_surface(image_paths):
    """Util function for loading and projecting volumetric images."""
    images = [image.load_img(img) for img in image_paths]
    surface_images = [
        np.nan_to_num(surface.vol_to_surf(img, fsaverage3.pial_left))
        for img in images
    ]

    return np.stack(surface_images)

In [18]:
def compute_geometry_from_mesh(mesh_path):
    """Util function to compute matrix of geodesic distances of a mesh."""
    (coordinates, triangles) = surface.load_surf_mesh(mesh_path)
    geometry = gdist.local_gdist_matrix(
        coordinates.astype(np.float64), triangles.astype(np.int32)
    ).toarray()
    return geometry


fsaverage3_pial_left_geometry = compute_geometry_from_mesh(fsaverage3.pial_left)
source_geometry = fsaverage3_pial_left_geometry
target_geometry = fsaverage3_pial_left_geometry
source_geometry.shape

(642, 642)

In [19]:
ndim = 3
ndim_f = 4
bar_size = 10
num_per_segment = (7, 12)
epsilon = 1.0
tau_a = 0.75
tau_b = 0.75

rng = jax.random.PRNGKey(0)


def random_pc(
    n: int, d: int, rng: jax.Array, m: int = None
) -> pointcloud.PointCloud:
    rng1, rng2 = jax.random.split(rng, 2)
    x = jax.random.normal(rng1, (n, d))
    y = x if m is None else jax.random.normal(rng2, (m, d))
    return pointcloud.PointCloud(x, y)


rng1, *rngs = jax.random.split(rng, len(num_per_segment) + 1)
y = jnp.concatenate(
    [random_pc(n, d=ndim_f, rng=rng).x for n, rng in zip(num_per_segment, rngs)]
)
rngs = jax.random.split(rng1, len(num_per_segment))
y_fused = jnp.concatenate(
    [random_pc(n, d=ndim_f, rng=rng).x for n, rng in zip(num_per_segment, rngs)]
)

In [20]:
prob = gwb.GWBarycenterProblem(
    y=y,
    y_fused=y_fused,
    num_per_segment=num_per_segment,
    fused_penalty=1.0,
    scale_cost="max_cost",
    tau_a=tau_a,
    tau_b=tau_b,
    gw_unbalanced_correction=True,
)

In [21]:
solver = gwb_solver.GromovWassersteinBarycenter(
    store_inner_errors=True, epsilon=epsilon
)

x_init = jax.random.normal(rng, (bar_size, ndim_f))
cost_init = pointcloud.PointCloud(x_init).cost_matrix

fugw = solver(prob, bar_size=bar_size, bar_init=(cost_init, x_init))

In [22]:
fugw.x.shape

(10, 4)