In [None]:
import scipy as sp
import numpy as np
import matplotlib.pylab as pl
from mpl_toolkits.mplot3d import Axes3D  # noqa

from sklearn.manifold import TSNE

import metric_learn
import numpy as np
from sklearn.datasets import make_classification, make_regression

# For optimal transport operations:
import ot
# For computing graph distances:
from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix
from sklearn.neighbors import kneighbors_graph

# For pre-processing, normalization
from sklearn.preprocessing import StandardScaler, normalize

In [None]:
scatac=np.load("../data/scatac_feat.npy") 
scrna=np.load("../data/scrna_feat.npy")
print("Dimensions of input datasets are: ", "X= ", scatac.shape, " y= ", scrna.shape)

In [None]:
X1, X2  = normalize(scatac, norm="l2"), normalize(scrna, norm="l2")

In [None]:
y1=np.loadtxt("../data/SNAREseq_atac_types.txt")
y2=np.loadtxt("../data/SNAREseq_rna_types.txt")

In [None]:
# def calcCov(x, y):
#     mean_x, mean_y = x.mean(), y.mean()
#     n = len(x)
#     return sum((x - mean_x) * (y - mean_y)) / n
# def cov(data):
#     rows, cols = data.shape
#     cov_mat = np.zeros((cols, cols))
 
#     for i in range(cols):
 
#         for j in range(cols):
#             # store the value in the matrix
#             cov_mat[i][j] = calcCov(data[:, i], data[:, j])
 
#     return cov_mat

In [None]:
# def mahalanobis_distance(p1,p2,X): #p1 is model, p2 is the test point
#     # X is inverse cov matrix
#     distance = np.dot(np.dot(np.subtract(p2,p1).T,np.array(X)),np.subtract(p2,p1))
#     return distance

In [None]:
# X1cov = cov(X1)
# X1cov = np.linalg.inv(X1cov)
# distance = mahalanobis_distance(X1[0],X1[1],X1cov)
# C1 = np.zeros((X1.shape[0],X1.shape[0]))

# for i in range(X1.shape[0]):
#     for j in range(X1.shape[0]):
#         C1[i][j] = mahalanobis_distance(X1[i],X1[j],X1cov)

In [None]:
# X2cov = cov(X2) 
# C2 = np.zeros((X2.shape[0],X2.shape[0]))

# for i in range(X2.shape[0]):
#     for j in range(X2.shape[0]):
#         C2[i][j] = mahalanobis_distance(X2[i],X2[j],X2cov)

In [None]:

# setting up LMNN
X1_metric_learn = metric_learn.RCA()

# fit the data!
X1_metriclearn = X1_metric_learn.fit_transform(X1, y1)


In [None]:
metric_X1  = X1_metric_learn.get_metric()

In [None]:
C1 = np.zeros((X1.shape[0],X1.shape[0]))

for i in range(X1.shape[0]):
    for j in range(X1.shape[0]):
        C1[i][j] = metric_X1(X1[i],X1[j])

In [None]:

# setting up LMNN
X2_metric_learn = metric_learn.RCA()

# fit the data!
X2_metriclearn = X2_metric_learn.fit_transform(X2, y2)


In [None]:
metric_X2  = X2_metric_learn.get_metric()

In [None]:
C2 = np.zeros((X2.shape[0],X2.shape[0]))

for i in range(X2.shape[0]):
    for j in range(X2.shape[0]):
        C2[i][j] = metric_X2(X2[i],X2[j])

In [None]:
from sklearn.manifold import MDS
import numpy as np
import matplotlib.pylab as pl
import torch

import ot
from ot.gromov import gromov_wasserstein2
rng = np.random.RandomState(42)

def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2):
    """ solve min_a GW(C1,C2,a, a2) by gradient descent"""

    # use pyTorch for our data
    C1_torch = torch.tensor(C1)
    C2_torch = torch.tensor(C2)

    a0 = rng.rand(C1.shape[0])  # random_init
    a0 /= a0.sum()  # on simplex
    a1_torch = torch.tensor(a0).requires_grad_(True)
    a2_torch = torch.tensor(a2)

    loss_iter = []

    for i in range(nb_iter_max):

        loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)

        loss_iter.append(loss.clone().detach().cpu().numpy())
        loss.backward()

        #print("{:03d} | {}".format(i, loss_iter[-1]))

        # performs a step of projected gradient descent
        with torch.no_grad():
            grad = a1_torch.grad
            a1_torch -= grad * lr   # step
            a1_torch.grad.zero_()
            a1_torch.data = ot.utils.proj_simplex(a1_torch)

    a1 = a1_torch.clone().detach().cpu().numpy()

    return a1, loss_iter


a0_est, loss_iter0 = min_weight_gw(C1, C2, ot.unif(C1.shape[0]), nb_iter_max=10, lr=1e-2)


pl.figure(2)
pl.plot(loss_iter0)
pl.title("Loss along iterations")

print("Estimated weights : ", a0_est)

In [None]:
a0_est.shape

In [None]:

def plot_graph(x, C, color='C0', s=None):
    for j in range(C.shape[0]):
        for i in range(j):
            if C[i, j] > 0:
                pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k')
    pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9)

    
T_unif = ot.gromov_wasserstein(C2, C1, ot.unif(C1.shape[0]), ot.unif(C2.shape[0]))
label_unif = T_unif.argmax(1)

T_est = ot.gromov_wasserstein(C2, C1, ot.unif(C1.shape[0]), a0_est)
label_est = T_est.argmax(1)

# get 2d position for nodes
x1 = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C1)

# pl.figure(3, (10, 5))
# pl.clf()
# pl.subplot(1, 2, 1)
# plot_graph(x1, C1, color=label_unif)
# pl.title("Graph clustering unif. weights")
# pl.axis("off")
# pl.subplot(1, 2, 2)
# plot_graph(x1, C1, color=label_est)
# pl.title("Graph clustering est. weights")
# pl.axis("off")

In [None]:
T_unif.shape

In [None]:
T_est.shape

In [None]:
label_unif

In [None]:
def graph_compession_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2):
    """ solve min_a GW(C1,C2,a, a2) by gradient descent"""

    # use pyTorch for our data

    C2_torch = torch.tensor(C2)
    a2_torch = torch.tensor(a2)

    a0 = rng.rand(nb_nodes)  # random_init
    a0 /= a0.sum()  # on simplex
    a1_torch = torch.tensor(a0).requires_grad_(True)
    C0 = np.eye(nb_nodes)
    C1_torch = torch.tensor(C0).requires_grad_(True)

    loss_iter = []

    for i in range(nb_iter_max):

        loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)

        loss_iter.append(loss.clone().detach().cpu().numpy())
        loss.backward()

        #print("{:03d} | {}".format(i, loss_iter[-1]))

        # performs a step of projected gradient descent
        with torch.no_grad():
            grad = a1_torch.grad
            a1_torch -= grad * lr   # step
            a1_torch.grad.zero_()
            a1_torch.data = ot.utils.proj_simplex(a1_torch)

            grad = C1_torch.grad
            C1_torch -= grad * lr   # step
            C1_torch.grad.zero_()
            C1_torch.data = torch.clamp(C1_torch, 0, 1)

    a1 = a1_torch.clone().detach().cpu().numpy()
    C1 = C1_torch.clone().detach().cpu().numpy()

    return a1, C1, loss_iter


# nb_nodes = C2.shape[0]
# a0_est2, C0_est2, loss_iter2 = graph_compession_gw(nb_nodes, C2, ot.unif(C2.shape[0]),
#                                                    nb_iter_max=300, lr=5e-2)

# pl.figure(4)
# pl.plot(loss_iter2)
# pl.title("Loss along iterations")


# print("Estimated weights : ", a0_est2)

# pl.figure(6, (10, 3.5))
# pl.clf()
# pl.imshow(C0_est2, vmin=0, vmax=1)
# pl.title('Estimated C0 matrix')
# pl.colorbar()

In [None]:
# a0_est2.shape

In [None]:
# C0_est2.shape

In [None]:
# eval function

import numpy as np
import random, math, os, sys
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
from sklearn.metrics import roc_auc_score, silhouette_samples
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier

def calc_frac_idx(x1_mat,x2_mat):
    """
    Returns fraction closer than true match for each sample (as an array)
    """
    fracs = []
    x = []
    nsamp = x1_mat.shape[0]
    rank=0
    for row_idx in range(nsamp):
        euc_dist = np.sqrt(np.sum(np.square(np.subtract(x1_mat[row_idx,:], x2_mat)), axis=1))
        true_nbr = euc_dist[row_idx]
        sort_euc_dist = sorted(euc_dist)
        rank =sort_euc_dist.index(true_nbr)
        frac = float(rank)/(nsamp -1)

        fracs.append(frac)
        x.append(row_idx+1)

    return fracs,x

def calc_domainAveraged_FOSCTTM(x1_mat, x2_mat):
    """
    Outputs average FOSCTTM measure (averaged over both domains)
    Get the fraction matched for all data points in both directions
    Averages the fractions in both directions for each data point
    """
    
    fracs1,xs = calc_frac_idx(x1_mat, x2_mat)
    fracs2,xs = calc_frac_idx(x2_mat, x1_mat)
    fracs = []
    for i in range(len(fracs1)):
        fracs.append((fracs1[i]+fracs2[i])/2)  
    return fracs

def calc_sil(x1_mat,x2_mat,x1_lab,x2_lab):
    """
    Returns silhouette score for datasets with cell clusters
    """
    sil = []
    sil_d0 = []
    sil_d3 = []
    sil_d7 = []
    sil_d11 = []
    sil_npc = []

    x = np.concatenate((x1_mat,x2_mat))
    lab = np.concatenate((x1_lab,x2_lab))

    sil_score = silhouette_samples(x,lab)

    nsamp = x.shape[0]
    for i in range(nsamp):
        if(lab[i]==1):
            sil_d0.append(sil_score[i])
        elif(lab[i]==2):
            sil_d3.append(sil_score[i])
        elif(lab[i]==3):
            sil_d7.append(sil_score[i])
        elif(lab[i]==4):
            sil_d11.append(sil_score[i])
        elif(lab[i]==5):
            sil_npc.append(sil_score[i])

    avg = np.mean(sil_score)
    d0 = sum(sil_d0)/len(sil_d0)
    d3 = sum(sil_d3)/len(sil_d3)
    d7 = sum(sil_d7)/len(sil_d7)
    d11 = sum(sil_d11)/len(sil_d11)
    npc = sum(sil_npc)/len(sil_npc)

    return avg,d0,d3,d7,d11,npc

def binarize_labels(label,x):
    """
    Helper function for calc_auc
    """
    bin_lab = np.array([1] * len(x))
    idx = np.where(x == label)

    bin_lab[idx] = 0
    return bin_lab


def calc_auc(x1_mat, x2_mat, x1_lab, x2_lab):
    """
    calculate avg. ROC AUC scores for transformed data when there are >=2 number of clusters.
    """

    nsamp = x1_mat.shape[0]

    auc = []
    auc_d0 = []
    auc_d3 = []
    auc_d7 = []
    auc_d11 = []
    auc_npc = []

    for row_idx in range(nsamp):
        euc_dist = np.sqrt(np.sum(np.square(np.subtract(x1_mat[row_idx,:], x2_mat)), axis=1))
        y_scores = euc_dist
        y_true = binarize_labels(x1_lab[row_idx],x2_lab)

        auc_score = roc_auc_score(y_true, y_scores)
        auc.append(auc_score)
        
        if(x1_lab[row_idx]==0):
            auc_d0.append(auc_score)
        elif(x1_lab[row_idx]==1):
            auc_d3.append(auc_score)
        elif(x1_lab[row_idx]==2):
            auc_d7.append(auc_score)
        elif(x1_lab[row_idx]==3):
            auc_d11.append(auc_score)
        elif(x1_lab[row_idx]==4):
            auc_npc.append(auc_score)

    avg = sum(auc)/len(auc)
    d0 = sum(auc_d0)/len(auc_d0)
    d3 = sum(auc_d3)/len(auc_d3)
    d7 = sum(auc_d7)/len(auc_d7)
    d11 = sum(auc_d11)/len(auc_d11)
    npc = sum(auc_npc)/len(auc_npc)

    return avg,d0,d3,d7,d11,npc

def transfer_accuracy(domain1, domain2, type1, type2, n):
    """
    Metric from UnionCom: "Label Transfer Accuracy"
    """
    knn = KNeighborsClassifier(n_neighbors=n)
    knn.fit(domain2, type2)
    type1_predict = knn.predict(domain1)
    np.savetxt("type1_predict.txt", type1_predict)
    count = 0
    for label1, label2 in zip(type1_predict, type1):
        if label1 == label2:
            count += 1
    return count / len(type1)

In [None]:
# projection

#Projecting the first domain onto the second domain
y_aligned_from_normalized=X2
weights_from_normalized=np.sum(T_unif,axis = 0)
X_aligned_from_normalized=np.matmul(T_unif, X2) / weights_from_normalized[:, None]

In [None]:
# We will use the average FOSCTTM measure implemented in evals.py for evaluation (metric used in the publication Demetci et al 2021)
# This measure reports the fraction of samples closer to a sample than its true match (FOSCTTM), averaged over all samples. 
fracs_normalized=calc_domainAveraged_FOSCTTM(X_aligned_from_normalized, y_aligned_from_normalized)
print("Average FOSCTTM score for this alignment with X onto Y is: ", np.mean(fracs_normalized))

In [None]:
#Plotting sorted FOSCTTM to show the distributions of FOSCTTM across cells:

import matplotlib.pyplot as plt
legend_label="SCOT alignment FOSCTTM \n average value: "+str(np.mean(fracs_normalized)) #Put average FOSCTTM in the legend
plt.plot(np.arange(len(fracs_normalized)), np.sort(fracs_normalized), "r--", label=legend_label)
plt.legend()
plt.xlabel("Cells")
plt.ylabel("Sorted FOSCTTM")
plt.show()

In [None]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

#Reduce the dimensionality of the aligned domains to two (2D) via PCA for the sake of visualization:
pca=PCA(n_components=2)
Xy_pca=pca.fit_transform(np.concatenate((X_aligned_from_normalized, y_aligned_from_normalized), axis=0))
X_pca=Xy_pca[0: 1047,]
y_pca=Xy_pca[1047:,]

#Plot aligned domains, samples colored by domain identity:
plt.scatter(X_pca[:,0], X_pca[:,1], c="k", s=15, label="Chromatin Accessibility")
plt.scatter(y_pca[:,0], y_pca[:,1], c="r", s=15, label="Gene Expression")
plt.legend()
plt.title("Colored based on domains")
plt.show()

plt.figure
#Plot aligned domains, samples colored by domain identity:
plt.scatter(X_pca[:,0], X_pca[:,1], c="k", s=15, label="Chromatin Accessibility")
plt.legend()
plt.title("Colored based on domains")
plt.show()

plt.figure
#Plot aligned domains, samples colored by domain identity:
plt.scatter(y_pca[:,0], y_pca[:,1], c="r", s=15, label="Gene Expression")
plt.legend()
plt.title("Colored based on domains")
plt.show()



In [None]:
#Plot aligned domains, samples colored by cell types:
cellTypes_atac=np.loadtxt("../data/SNAREseq_atac_types.txt")
cellTypes_rna=np.loadtxt("../data/SNAREseq_rna_types.txt")

colormap = plt.get_cmap('rainbow', 4) 
plt.scatter(X_pca[:,0], X_pca[:,1], c=cellTypes_atac, s=15, cmap=colormap)
plt.scatter(y_pca[:,0], y_pca[:,1], c=cellTypes_rna, s=15, cmap=colormap)
# plt.colorbar()
cbar=plt.colorbar()

# approximately center the colors on the colorbar when adding cell type labels
tick_locs = (np.arange(1,5)+0.75) *3/4 
cbar.set_ticks(tick_locs)
cbar.set_ticklabels(["H1", "GM", "BJ", "K562"]) #cell-type labels
plt.title("Colored based on cell type identity")
plt.show()

In [None]:
originalX_pca=pca.fit_transform(X1)
originaly_pca=pca.fit_transform(X2)

#Visualization of the global geometry
fig, (ax1, ax2)= plt.subplots(1,2)
ax1.scatter(originalX_pca[:,0], originalX_pca[:,1], c="k", s=15)
ax1.set_title("Chromatin Accessibiliy Domain \n *before* Alignment")
ax2.scatter(originaly_pca[:,0], originaly_pca[:,1], c="r", s=15)
ax2.set_title("Gene Expression Domain \n *before* Alignment")
plt.show()

In [None]:
#Visualization of the cell type clusters in original domains *before* alignment
fig, (ax1, ax2)= plt.subplots(1,2)

fig1= ax1.scatter(originalX_pca[:,0], originalX_pca[:,1], c=cellTypes_atac, s=15, cmap=colormap)
ax1.set_title("Chromatin Accessibiliy \n *before* Alignment")

fig2= ax2.scatter(originaly_pca[:,0], originaly_pca[:,1],  c=cellTypes_rna, s=15, cmap=colormap)
ax2.set_title("Gene Expression Domain \n *before* Alignment")

cbar=fig.colorbar(fig2)
cbar.set_ticks(tick_locs)
cbar.set_ticklabels(["H1", "GM", "BJ", "K562"]) #cell-type labels

In [9]:
import torch
import torch.nn.functional as F
from pytorch_metric_learning import miners, losses
import torch.optim as optim

# Define the input tensor
A = torch.randn(100, 100)

# Fix the first 50 rows of A
fixed_rows = A[:50, :]

# Define the variable tensor for the last 50 rows of A
var_rows = torch.randn(50, 100)

# Define the Miner
miner = miners.MultiSimilarityMiner(epsilon=0.1)

# Define the ProxyAnchorLoss
loss_fn = losses.ProxyAnchorLoss(num_classes=50, embedding_size=100)

# Define the optimizer
optimizer = optim.SGD([var_rows], lr=0.01)

# Perform the optimization
num_iterations = 1000
for i in range(num_iterations):
    optimizer.zero_grad()
    
    # Mine the hard negatives
    hard_pairs = miner(var_rows, torch.tensor([i for i in range(50)]))
    
    # Calculate the loss
    loss = loss_fn(var_rows, torch.tensor([i for i in range(50)]), hard_pairs)
    
    # Add a penalty to the loss to encourage the output to be close to the fixed rows
    penalty = F.mse_loss(var_rows, fixed_rows)
    loss += penalty
    
    # Perform backpropagation and optimization step
    loss.backward()
    optimizer.step()

# Print the optimized output
print(var_rows)


tensor([[ 0.7396, -0.0272, -0.2027,  ..., -0.6204,  1.3295,  1.1549],
        [-1.3363, -0.6620,  1.0288,  ...,  0.8893, -1.4013, -2.2258],
        [-1.0180, -1.7433, -0.7344,  ...,  1.5548, -0.1579,  0.0903],
        ...,
        [ 0.7112, -0.5546,  2.0906,  ..., -0.2433,  1.1326,  0.3307],
        [-0.7007, -0.3370,  1.1479,  ...,  0.1687,  2.4663, -0.7770],
        [ 0.8166,  0.6741, -2.1385,  ...,  0.6308,  1.0232, -0.3835]])


In [11]:
import torch
import torch.nn.functional as F
from pytorch_metric_learning import miners, losses
import torch.optim as optim

from torch import nn, optim
import pytorch_metric_learning.losses as loss
import pytorch_metric_learning.miners as miners
import matplotlib.pylab as pl

from sklearn.cluster import KMeans


# Define the input tensor
A = torch.randn(100, 100)
B = torch.randn(100,100)

# perform k-means clustering with k=2
kmeans = KMeans(n_clusters=2)
kmeans.fit(A.numpy())

# Fix the first 50 rows of A
fixed_rows = A[:50, :]

# Define the variable tensor for the last 50 rows of A
var_rows = torch.randn(50, 100)

# Define the Miner
miner = miners.MultiSimilarityMiner(epsilon=0.1)

# Define the ProxyAnchorLoss
loss_fn = losses.ProxyAnchorLoss(num_classes=50, embedding_size=100)

# Define the optimizer
optimizer = optim.SGD([var_rows], lr=0.01)

# Perform the optimization
num_iterations = 1000
for i in range(num_iterations):
    optimizer.zero_grad()
    
    # Mine the hard negatives
    hard_pairs = miner(var_rows, torch.tensor([i for i in range(50)]))
    
    # Calculate the loss
    loss = loss_fn(var_rows, torch.tensor([i for i in range(50)]), hard_pairs)
    
    # Add a penalty to the loss to encourage the output to be close to the fixed rows
    penalty = F.mse_loss(var_rows, fixed_rows)
    loss += penalty
    
    # Perform backpropagation and optimization step
    loss.backward()
    optimizer.step()

# Print the optimized output
print(var_rows)




tensor([[-0.2575, -0.4642,  1.6648,  ...,  0.3470, -0.9651,  0.7426],
        [ 0.9870, -0.2222,  0.3301,  ..., -0.3669,  0.0924,  0.2013],
        [-1.4439, -0.4751, -0.0994,  ...,  0.4002,  2.0583, -1.3915],
        ...,
        [ 0.2864, -1.1707, -0.3838,  ...,  0.3727, -0.7677,  0.0377],
        [-1.1144, -0.7462, -0.2922,  ...,  0.2014,  0.3760,  1.9066],
        [ 1.2016,  3.7003, -0.2942,  ..., -1.2272, -1.2431, -0.2215]])
