In [46]:
import sys
import time
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from ot.gromov import gwloss, init_matrix
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import seaborn as sn
from IPython import display
from matplotlib import animation

import ott
from ott.problems.quadratic import quadratic_problem
from ott.solvers.quadratic import gromov_wasserstein

In [5]:
from SCOT.src import evals
from SCOT.src.scotv1 import SCOT

In [7]:
X = np.load("SCOT/data/SNARE/SNAREseq_atac_feat.npy")
y = np.load("SCOT/data/SNARE/SNAREseq_rna_feat.npy")

print("Dimensions of input datasets:")
print(
    "X =",
    X.shape,
    "=> ie",
    X.shape[0],
    "samples belonging to a chromatin accessibility feature space of dimension",
    X.shape[1],
)
print(
    "y =",
    y.shape,
    "=> ie",
    y.shape[0],
    "samples belonging to a gene expression feature space of dimension",
    y.shape[1],
)

Dimensions of input datasets:
X = (1047, 19) => ie 1047 samples belonging to a chromatin accessibility feature space of dimension 19
y = (1047, 10) => ie 1047 samples belonging to a gene expression feature space of dimension 10


In [9]:
k = 40
epsilon = 1e-3

In [10]:
potscot = SCOT(X, y)

start = time.time()
X_shifted_pot, y_shifted_pot = potscot.align(
    k=k, e=epsilon, normalize=True, norm="l2", verbose=False
)  # POT
end = time.time()

print("Execution time: ", round(end - start, 2), "s")



Execution time:  31.17 s


In [11]:
potscot.gwdist

0.016519975128480156

In [14]:
class OTTSCOT(SCOT):
    def find_correspondences(self, e: float, verbose: bool = True) -> None:
        geom_xx = ott.geometry.geometry.Geometry(self.Cx)
        geom_yy = ott.geometry.geometry.Geometry(self.Cy)
        prob = quadratic_problem.QuadraticProblem(
            geom_xx, geom_yy, a=self.p, b=self.q
        )

        solver = gromov_wasserstein.GromovWasserstein(
            epsilon=e, threshold=1e-9, max_iterations=1000
        )

        T = solver(prob).matrix

        constC, hC1, hC2 = init_matrix(
            self.Cx, self.Cy, self.p, self.q, loss_fun="square_loss"
        )
        self.gwdist = gwloss(constC, hC1, hC2, np.array(T))
        self.coupling = T

        if (
            np.isnan(self.coupling).any()
            or np.any(~self.coupling.any(axis=1))
            or np.any(~self.coupling.any(axis=0))
            or sum(sum(self.coupling)) < 0.95
        ):
            self.flag = False
        else:
            self.flag = True

In [15]:
ottscot = OTTSCOT(X, y)

start = time.time()
X_shifted, y_shifted = ottscot.align(
    k=k, e=epsilon, normalize=True, norm="l2", verbose=False
)  # OTT
end = time.time()

print("Execution time: ", round(end - start, 2), "s")

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Execution time:  214.05 s


In [16]:
print(potscot.gwdist, ottscot.gwdist)

0.016519975128480156 0.017811801691579823


In [18]:
from SGW.lib.sgw_numpy import sgw_cpu

In [24]:
print(Xs.shape, Xt.shape, X.shape, y.shape)

(300, 2) (300, 1) (1047, 19) (1047, 10)


In [28]:
%%time
outx = sgw_cpu(X, y, nproj=2000, tolog=True)
print(outx)

(7.955944738298696e+18, {'time_sink_': 0.04156494140625, 'time_gw_1D': 0.8067829608917236, 'gw_1d_details': {'g1d': 0.8064517974853516, 't1': 0.26371216773986816, 't2': 0.26117467880249023}})
CPU times: user 2.82 s, sys: 51.1 ms, total: 2.88 s
Wall time: 849 ms


In [40]:
n_samples=300
Xs=np.random.rand(n_samples,5)
Xt=np.random.rand(n_samples,5)
P=np.random.randn(5,500) # 500 projections are used

In [41]:
%%time
sgw_cpu(Xs,Xt,P=P)

CPU times: user 1.16 s, sys: 49.4 ms, total: 1.21 s
Wall time: 145 ms


0.0019520239196300262

In [43]:
%%time
def ott_gw(Xs, Xt, p, q, e):
    geom_xx = ott.geometry.geometry.Geometry(Xs)
    geom_yy = ott.geometry.geometry.Geometry(Xt)
    prob = quadratic_problem.QuadraticProblem(
        geom_xx, geom_yy, a=p, b=q
    )
    solver = gromov_wasserstein.GromovWasserstein(
       epsilon=e, threshold=1e-9, max_iterations=1000
    )
    T = solver(prob).matrix
    constC, hC1, hC2 = init_matrix(
            Xs, Xt, p, q, loss_fun="square_loss"
        )
    gwdist = gwloss(constC, hC1, hC2, np.array(T))
    return gw, T
ra, rb = ott_gw(Xs,Xt,None,None, epsilon)
rb

TypeError: Incompatible shapes for dot: got (300, 5) and (300,).