# Optimal transport-based machine learning to match specific expression patterns in omics data

In this notebook, we will show how to use (a) optimal transport and (b) matching or co-clustering procedures to match two data sets.

The methodology is presented in the paper [Optimal transport-based machine learning to match specific expression patterns in omics data](https://arxiv.org/pdf/2107.11192.pdf) by 
T. T. Y. Nguyen, O. Bouaziz, W. Harchauoui, C. Neri and A. Chambaz. 


## Imports and installs

In [1]:
# Ignore this cell if the corresponding packages are already installed

#!pip install coclust
#!pip install scikit-learn

In [2]:
import numpy as np
from wtot import wtot
from match_coclust import matching, SCC1_star, SCC1, SCC2_star, SCC2
import pandas as pd




# Example 1: Real data 

The real data set has been kindly made public by Langfelder et al. (see their articles published in [Nature Neuroscience](https://europepmc.org/article/med/26900923) and [Plos One](https://pubmed.ncbi.nlm.nih.gov/29324753/)).

## Data loading

Load the real data then convert to matrices.

In [3]:
data_micro = pd.read_csv('./datasets/LFC_Cortex_mirna.txt', sep = ' ', delimiter = '\t') 
data_mess = pd.read_csv('./datasets/LFC_Cortex_mrna.txt', sep = ' ', delimiter = '\t')

x = data_micro.values
y = data_mess.values

x = x[:200,:3]
y = y[: 200,:3]

In [4]:
## Hyperparameter
m = 1
n = 3

## Algorithm WTOT_matching and WTOT_coclust
### First step ( WTOT_...)

We compute the optimal transport matrix, optimal transformation and an estimator of the "weights" (see paper).

In [4]:
results = wtot(x, y, m = 1, n = 3, batch_size_x = 64 , batch_size_y = 64)

# value of the optimal transport matrix, the optimal transformation, and the "weights".
pi_np = results['P'] 
theta = results['theta']
w = results['w']

### Second step 
#### Matching

In [6]:
results_match = matching(pi_np)

In [7]:
# the collection calM
N_m = results_match['N_m']
print('The indices of the miRNAs associated to the first mRNA of the list:', N_m[0], '.\n')
# the collection calN
M_n = results_match['M_n']
print('The indices of the mRNAs associated to the first miRNA of the list:', M_n[0], '.\n')

The set of columns is associated to the first row: {0, 71, 72, 169, 12, 177, 122, 156}
The set of rows is associated to the first column [0, 2, 13, 45, 47, 64, 66, 70, 78, 79, 101, 119, 155, 167, 173, 176, 195]


#### Co-clustering

In [8]:
### WTOT-SCC1
SCC1_res = SCC1(pi_np)

### WTOT-SCC1*
SCC1_star_res = SCC1_star(pi_np, 4 )

### WTOT-SCC2
SCC2_res = SCC2(pi_np)

### WTOT-SCC*
SCC2_star_res = SCC2_star(pi_np, 4)


# Example 2: synthesic data

We now present an illustration based on simulated data.

## Data simulation

In [9]:
### an example of synthesic data
datas = np.load('./datasets/sample_A4.npz', allow_pickle = True) # the configuration A4 of the first simulation study
datas = datas['dats']

id_sample = 1
x         = datas[id_sample]['x']
y         = datas[id_sample]['y']
labels_x  = datas[id_sample]['labels_x']
labels_y  = datas[id_sample]['labels_y']

### First step ( WTOT_...)
We compute the optimal transport matrix, optimal transformation and an estimator of the "weights".

In [10]:
results = wtot(x, y, m = 2, n=1, batch_size_x = 64, batch_size_y = 64)

# value of the optimal transport matrix, the optimal transformation, and the "weights"
pi_np = results['P'] 
theta = results['theta']
w = results['w']

### Second step 
#### Matching

In [11]:
results_match = matching(pi_np)

In [12]:
# the collection of calM
N_m = results_match['N_m']
print('The indices of the columns associated to the first row:', N_m[0], '.\n')
# the collection of calN
M_n = results_match['M_n']
print('The indices of rows associated to the first column', M_n[0], '.\n')

The set of columns is associated to the first row: {290, 228, 294, 166, 200, 170, 204, 239, 112, 50, 243, 20, 52, 277, 151}
The set of rows is associated to the first column [99, 103, 111, 145, 178, 211, 215, 220, 286]


#### Co-clustering

In [13]:

### WTOT-SCC1
SCC1_res = SCC1(pi_np)
### WTOT-SCC1*
SCC1_star_res =SCC1_star(pi_np, 4 )
### WTOT-SCC2
SCC2_res = SCC2(pi_np)
### WTOT-SCC*
SCC2_star_res = SCC2_star(pi_np, 4)