## Setup

In [None]:
!pip install cmapPy
!pip install umap-learn

In [None]:
import pandas as pd
import numpy as np
import gzip
import re
from cmapPy.pandasGEXpress.parse import parse
from matplotlib import pyplot as plt
import umap
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
from sklearn.metrics.pairwise import cosine_similarity


In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

## The data

### Download
Download expression and necessary metadata

In [None]:
!wget https://ftp.ncbi.nlm.nih.gov/geo/series/GSE70nnn/GSE70138/suppl/GSE70138_Broad_LINCS_inst_info_2017-03-06.txt.gz

In [None]:
!wget https://ftp.ncbi.nlm.nih.gov/geo/series/GSE70nnn/GSE70138/suppl/GSE70138_Broad_LINCS_Level2_GEX_n345976x978_2017-03-06.gctx.gz
!gunzip GSE70138_Broad_LINCS_Level2_GEX_n345976x978_2017-03-06.gctx.gz

### Import the metadata

Please note the columns of particular interest
- inst_id: ID col, used to join the metadata with the data
- cell_id: cell line
- det_plate: can be used as the experimental batch, i.e. all the wells with the same det_plate were ran together

And especially columns relating to the perturbation:
- pert_iname: human readable name of perturbant
- pert_dose: concentration of perturbant
- pert_time: time between perturbation and measurement CHECK!
- pert_type: whether this perturbant is a negative control or treatment, and whether it's a small compound or genetic perturbation

In [None]:
inst_info = pd.read_csv("GSE70138_Broad_LINCS_inst_info_2017-03-06.txt.gz", sep='\t')
inst_info.head()


In case you were wondering, the -666 values are how the lincs l1000 data chose to indicate NA values

### Import the expression data

Here we will work with lincs L1000 data, level 2.

> Note that level 2 is _not_ the standardly recommend starting point as it often rather used with normalization, processing (such as imputation), batch correction (level 4), and aggregation (level 5). However we utilize this because
- it surfaces many of the realistic challenges of working with _any_ expression data
- it is much smaller, which is useful for a quick lab


This data contains a matrix with
"landmark genes" in the rows
and "samples" or "wells" in the columns.
The values are the the measured abundance of
the mRNAs derived from the given
landmark gene in a given sample.

Landmark genes are a subset of ~1000 genes
that are sufficient to predict much of the
variation in all 20k genes. Only these 1000
are measured in the l1000 data for cost reasons. This contrasts with _sequencing_ based transcriptomics, which is not selective for a subset of target genes.

Find more information [in the paper](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5990023/).

In [None]:
gene_abundance = parse("GSE70138_Broad_LINCS_Level2_GEX_n345976x978_2017-03-06.gctx",
                       convert_neg_666=True).data_df

In [None]:
gene_abundance.head()

#### Data distribution
Transcriptomics data is generally not normally distributed, while many of the losses or metrics we use in ML work better with normally distributed data. The l1000 is not count data, but we can still see it's not normal

In [None]:
# note the long tail
_ = plt.hist(gene_abundance.iloc[:,:300].values.flatten(), bins=300)
plt.ylabel("Count")
plt.xlabel("Gene Abundance")
plt.show()

#### Normalization
> Note: the following normalization and scaling differs from the recomendation for l1000 for the sake of simplicity and showing general steps which apply to many transcriptomics types. For l1000 outside of this lab, you can simply take the level 4 or 5 pre-normalized data, or take a look at the author's [write up of "Level 3 - Normalization (NORM)"](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5990023/#S40title). For sequencing data, a few good options are to either use raw counts and adjust loss & metrics for a negative binomial distribution _or_ one could perform variance stabilizing transformation with e.g. [DESeq2](https://bioconductor.org/packages/release/bioc/html/DESeq2.html).

In [None]:
# log transform
gene_normalized = np.log(gene_abundance.values) + 0.001
# scale each sample to same total abundance
gene_normalized = gene_normalized / np.mean(gene_normalized, axis=0) * np.mean(gene_normalized)
# center at 0
gene_normalized = gene_normalized - np.mean(gene_normalized)


In [None]:
_ = plt.hist(gene_normalized[:,:30000].flatten(), bins=300)
plt.ylabel("Count")
plt.xlabel("Log-Centred Gene Abundance")
plt.show()

While the transformations have not resulted in a perfect normal distribution, it is bell shaped and does not have extreme outliers. We'll consider this good enough to work with.

Now let's look at subgroups and structures in the data with a UMAP.

In [None]:
# visualization tends to be clearer (and faster)
# when we don't take the _full_ perturbation data
N = 10000

# setup and run UMAP
reducer = umap.UMAP()
embedding = reducer.fit_transform(gene_normalized[:,:N].T)
embedding.shape

In [None]:
# rearrange and join with metadata for seaborn
dat = pd.DataFrame(embedding)
dat["inst_id"] = gene_abundance.columns[:N]
dat = dat.merge(inst_info, on='inst_id')


In [None]:
dat.head()

#### Visualize normalized, but unaligned

First we see that cell line, unsurprisingly, has a dominant effect

In [None]:
ax = sns.scatterplot(dat, x=0, y=1, hue="cell_id", s=3)
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
plt.xlabel("UMAP 0")
plt.ylabel("UMAP 1")
plt.show()

And within each cell line, the batch (or plate) has a large effect

In [None]:
ax = sns.scatterplot(dat, x=0, y=1, hue="det_plate", s=3, palette="Paired")
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
plt.xlabel("UMAP 0")
plt.ylabel("UMAP 1")
plt.show()

Good to know, we'll come back to the batch effect after embedding the data

### Embedding
#### Why?
Transcriptomics data can be _very_ noisy (especially single cell data, as a large fraction of transcripts are near the sensitivity/detection limit). Moreover, most transcriptomcis assays include 20k genes, often with high gene-gene correlations. Thus, embedding the transcriptomics data into a lower dimensional space can reduce noise, simplify comparisons and increase the utility of the data. We acknowledge that the selection of 978 landmark genes already works _somewhat like an embedding_ and exactly this lab would work without an embedding step; however, for broader applicability to other types of transcriptomics (and especially other, either noisier or unstructured bio-assay types) we wanted to include this step.

#### How?
While the field of transcriptomics foundation models is under
rapid development, e.g. ([scGPT](https://www.nature.com/articles/s41592-024-02201-0), [scFoundation](https://www.biorxiv.org/content/10.1101/2023.05.29.542705v3), [Geneformer](https://www.nature.com/articles/s41586-023-06139-9), [Universal Cell Embeddings](https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1)), the foundation models are generally sequencing specific and few have been independently benchmarked to date. From the benchmarks that
have been done, it is clear that [scVI](https://www.nature.com/articles/s41592-018-0229-2), a variational autoencoder, remains a
strong baseline. None of the above actually are designed for L1000 data, and for today we need something that is fast to train,
so we'll take a vanilla **variational autoencoder** for simplicity.

#### Data loaders

In [None]:
# simple tabular data loader
class L1000Dataset(Dataset):
    def __init__(self, data):
        self.X = data

    def __len__(self):
        return len(self.X.T)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = (self.X[:,idx], )
        return sample

In [None]:
joint_dataset = L1000Dataset(gene_normalized)
total_len = len(joint_dataset)
len_train = int(total_len * .9)
len_val = total_len - len_train

> Note. In this lab we are not concerned with _generalization_.
We're simply using the embedding to reduce noise, and measuring
similarity (within set). Most ML projects with transcriptomics
data _will_ be concerned with generalization and should consider
a split that reserves whole experimental batches
(such as 'gem wells' or 'plates')
of data for the test set.

In [None]:
training_dataset, validation_dataset = random_split(joint_dataset, lengths=[len_train, len_val])

In [None]:
batch_size = 512
training_dataloader = DataLoader(training_dataset, batch_size=batch_size,
                                 shuffle=True, num_workers=0)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size,
                                   shuffle=True, num_workers=0)

#### The Variational Autoencoder Model
This is an autoencoder, which has an encoder that compresses the information
from the input down to a smaller latent space, and then a decoder that expands back to the input dimensions. It is trained
to reconstruct the input as precisely as possible.
For the _variational_ part, a sampling
step is applied, the encoder predicts a mean and standard deviation,
from which a sample is taken and fed to the decoder during training.
This allows VAEs to work as a _generative_ model, for which they are better
known. However VAEs are also useful for embedding as the sampling step
both has a regularizing function and importantly causes a smoother
latent space.

> Our implemntation borrowed from: https://github.com/lyeoni/pytorch-mnist-VAE/blob/master/pytorch-mnist-VAE.ipynb and https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py; both of which are good resources on VAEs generally.

In [None]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.fc_1 = nn.Linear(x_dim, h_dim)
        self.fc_2a = nn.Linear(h_dim, z_dim)
        self.fc_2b = nn.Linear(h_dim, z_dim)
        self.fc_3 = nn.Linear(z_dim, h_dim)
        self.fc_4 = nn.Linear(h_dim, x_dim)

    def encoder(self, x):
        h = F.relu(self.fc_1(x))
        return self.fc_2a(h), self.fc_2b(h) # mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps * std + mu # z sample

    def decoder(self, z):
        h = F.relu(self.fc_3(z))
        return self.fc_4(h)

    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 978))
        z = self.reparameterize(mu, log_var)
        return self.decoder(z), mu, log_var


vae = VAE(x_dim=978, h_dim=512, z_dim=256)
vae.to(device)

In [None]:
optimizer = optim.Adam(vae.parameters(), lr=1e-4)
# return reconstruction error + KL divergence losses
def loss_function(recon_x, x, mu, log_var):
    mse = F.mse_loss(recon_x, x.view(-1, 978))
    kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

    return mse + kld_loss * 0.0002

In [None]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, ) in enumerate(training_dataloader):
        data = data.to(device)
        optimizer.zero_grad()

        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(training_dataloader.dataset),
                100. * batch_idx / len(training_dataloader), loss.item() / len(data)))
    print(f'====> Epoch: {epoch} Average loss: {train_loss / len(training_dataloader.dataset):.4f}')


In [None]:
def val():
    vae.eval()
    val_loss= 0
    with torch.no_grad():
        for (data,) in validation_dataloader:
            data = data.to(device)
            recon, mu, log_var = vae(data)

            # sum up batch loss
            val_loss += loss_function(recon, data, mu, log_var).item()

    val_loss /= len(validation_dataloader.dataset)
    print(f'====> Val set loss: {val_loss:.4f}')

In [None]:
# this will take several minutes, a great time
# to read or catch up on the theory
for epoch in range(1, 3):
    train(epoch)
    val()

Check if output distribution is reasonable

In [None]:
for (data, ) in validation_dataloader:
    break

In [None]:
with torch.no_grad():
    res = vae(data.to(device))

In [None]:
# do we get a bell curve / match to ori dist?
_ = plt.hist(res[0].cpu().numpy().flatten(), bins=50)
plt.ylabel("Count")
plt.xlabel("Predicted Log-Centred Gene Abundance")
plt.show()

In [None]:
# compared to exact input?
plt.scatter(data.cpu().numpy().flatten(), res[0].cpu().numpy().flatten())
plt.ylabel("Predicted")
plt.xlabel("Ground Truth")
plt.show()

We have a notable correlation, especially for more extreme values.

In a performance oriented project we would perform hyperparameter tuning
using downstream metrics to select the best model. But for now,
we'll continue with this one and embed all the data.



In [None]:
out = []
# all the data, unshuffled
joint_dataloader = DataLoader(joint_dataset, batch_size=batch_size,
                                 shuffle=False, num_workers=0)
for (data, ) in joint_dataloader:
    with torch.no_grad():
        out.append(vae(data.to(device))[1].cpu().numpy())  # 1 is the embedding, AKA mu
vae_embedding = np.concatenate(out)

### Experimental batch correction
Even high throughput biological assays have much smaller capacity than the number of samples we wish to measure in drug discovery. Thus, the data is gathered in 'batches', for instance, everything that fits on a 384 well plate at the same time. These batches necessarily share some technical (e.g. sequencing depth) and biological (e.g. exact cell age, response to time of day) co-variation that is often greater than the effect of some individual perturbants.

How best to make downstream modelling tasks robust to the batches
in which experimental data was
collected or even integrate data from disparate studies is an evolving field
(e.g. [Harmony](https://www.nature.com/articles/s41592-019-0619-0),
[TVN](https://www.biorxiv.org/content/10.1101/161422v1.abstract),
[sysVI](https://www.biorxiv.org/content/10.1101/2023.11.03.565463v2),
[InfoCORE](https://arxiv.org/abs/2312.00718) to name only a few).
Here, we simply center each batch on the control mean and scale it by the
control standard deviation.

In [None]:
# first, helper functions
def find_controls(inst_info_batch):
    """parses the metadata to identify which controls can be used for centering"""
    # the plates (batches) have _either_
    # compounds with DMSO ("ctl_vehicle") as a negative, centering control
    comp_control = 'ctl_vehicle'
    comp_treatment = 'trt_cp'
    comp_pert_types = (comp_control, comp_treatment)
    # OR genetic perturbations with a non-targeting guide ("ctl_vector") as a negative, centering control
    genetic_control = 'ctl_vector'
    genetic_treatment = 'trt_xpr'
    genetic_null = 'ctl_untrt'
    genetic_pert_types = (genetic_null, genetic_control, genetic_treatment)
    # what did we actually get?
    batch_pert_types = tuple(np.sort(inst_info_batch.pert_type.unique()))
    if batch_pert_types == comp_pert_types:
        return inst_info_batch.pert_type == comp_control
    elif batch_pert_types == genetic_pert_types:
        return inst_info_batch.pert_type == genetic_control
    else:
        raise ValueError(
            f"unknown perturbation types {tuple(inst_info_batch.pert_type.unique())}"
        )


def center_scale_on_controls(dat_batch_tup, inst_info_batch_tup):
    """centers each batch on the control mean and scales by the control
    standard deviation.

    Thus the (average) control will have a mean of 0 and sd of 1 afterwards,
    while perturbed samples may vary."""

    # drop batch that came from groupby
    _, dat_batch = dat_batch_tup
    _, inst_info_batch = inst_info_batch_tup
    # select controls
    controls = dat_batch.loc[find_controls(inst_info_batch)]
    # calculate mean & sd for each 'feature'
    control_mean = np.mean(controls, axis=0)
    control_sd = np.std(controls, axis=0)
    # normalize all data (subtract mean, divide by sd)
    return (dat_batch - control_mean) / control_sd




In [None]:
# put meta data (inst_info) and normalized gene abundance
# into pandas dfs with matching order & index
ordered_inst_info = inst_info.set_index("inst_id")
ordered_inst_info = ordered_inst_info.loc[gene_abundance.columns, :]

In [None]:
splitable = pd.DataFrame(gene_normalized).T
splitable.index = ordered_inst_info.index

In [None]:
# loop through batches and center
out = []
for dat_batch, inst_info_batch in zip(splitable.groupby(ordered_inst_info['det_plate']),
                                     ordered_inst_info.groupby('det_plate')):
    out.append(center_scale_on_controls(dat_batch, inst_info_batch))
# re-concatenate batches and restore original order
bc_vae_embedding = pd.concat(out).loc[ordered_inst_info.index]


Let's see how the batch normalization changed the groupings in the data

In [None]:
# UMAP, as above
embedding = reducer.fit_transform(bc_vae_embedding[:N])
dat = pd.DataFrame(embedding)
dat["inst_id"] = gene_abundance.columns[:N]
dat = dat.merge(inst_info, on='inst_id')

In [None]:
# cell types colored
ax = sns.scatterplot(dat, x=0, y=1, hue="cell_id", s=3)
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

In [None]:
ax = sns.scatterplot(dat, x=0, y=1, hue="det_plate", s=3)
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

The mixing is imperfect, but nevertheless the data points are far better mixed than before embedding and batch correction.

### Apply embeddings to target deconvolution
We will try and use these embeddings to identify
(or in this case verify) potential targets
(or disrupted pathways) for everolimus, an inhibitor of
MTOR.

We can work with the hypothesis, that a drug inhibiting
a protein target will induce a similar phenotype to a
genetic knock out of said protein target. Whether this
phenotype is measured by phenomics, transcriptomics,
or any other high through put assay, if we can make
centered embeddings, we can simply find similar phenotypes
with cosine similarity and consider these candidate targets.


In [None]:
def mean_cosine_similarity(mat):
    """calculates the mean of non-self pairwise cosine similarities in mat"""
    res = cosine_similarity(mat)
    upper_triangular_mask = np.triu(np.ones(res.shape, dtype=bool), 1)
    return np.mean(res[upper_triangular_mask])

In [None]:
# confirm our compound of interest is present in data
compound = "everolimus"
ordered_inst_info.loc[ordered_inst_info.pert_iname == compound].pert_id.value_counts()

Before we can dive in to actually comparing things, we have some decisions to make. In particular, what cell type and what dose should we take. We'll
also have to clean up some data inconsistencies.

In [None]:
# combine non-exact matches in cell line naming, and save set for selecting cell lines
# with available knockouts
ordered_inst_info["cell_line"] = np.array([re.sub('\..*', '', x) for x in ordered_inst_info.cell_id])
cell_lines_2_keep = set(ordered_inst_info[ordered_inst_info.pert_type == "trt_xpr"].cell_line.unique())
cell_lines_2_keep

In [None]:
# select all instances of compound perturbation on potential comparable cell lines
mask = (ordered_inst_info.pert_iname == compound) & ordered_inst_info.cell_id.isin(cell_lines_2_keep)
comp_dat = bc_vae_embedding.loc[mask]
comp_meta = ordered_inst_info.loc[mask]
# equalize rounding so dosages match
comp_meta.loc[:, 'pert_dose_rounded'] = np.round(comp_meta.pert_dose, 2)

In [None]:
# calcute the cosine similarity between all replicates within a given
# cell line and dosage.
# the idea here, is if replicates don't correlate, then there's likely
# no consistent effect for compound <-> gene comparisons
# in biological terms, these might be cases where the compound is doesed
# too low to have an effect or the protein target is not expressed
flat = comp_dat.groupby([comp_meta.cell_id, comp_meta.pert_dose_rounded]).apply(
    mean_cosine_similarity)
flat = pd.DataFrame(flat).reset_index()
# reshape to cell line X dose
pivoted = pd.pivot(flat, columns='pert_dose_rounded', index='cell_id')
pivoted.head()

In [None]:
sns.heatmap(pivoted, vmin=-1, vmax=1, cmap="seismic")

This compound has a fairly consistent effect across cell
lines, nevertheless we will focus on a couple of the strongest.

#### Find the most similar genetic knock outs


In [None]:
def query_gene_knockouts(compound, target_cell_line, min_dose, max_dose,
                         embedding, metadata):
    """finds genes in L1000 dataset that are most similar to query comound
    under specified cell line and dosage constraints"""

    # prep binary masks to later select...
    # ... cell line
    cell_mask = metadata.cell_line == target_cell_line
    # ... all genes
    genes_mask = metadata.pert_type == "trt_xpr"
    # ... compound of interest
    comp_mask = metadata.pert_iname == compound
    # ... doses of interest
    dose_mask = (min_dose <= metadata.pert_dose) & (metadata.pert_dose <=  max_dose)

    # get data subsets
    comp_dat = embedding.loc[cell_mask & comp_mask & dose_mask]

    sub_info = metadata.loc[cell_mask & genes_mask]
    sub_dat = embedding.loc[cell_mask & genes_mask]

    # aggregate all genes
    g_references = sub_dat.groupby(sub_info.pert_iname).mean()
    # get query compound
    comp_query = comp_dat.mean()

    # calculate cosine similarities
    ready = pd.concat([g_references, pd.DataFrame(comp_query).T])
    return pd.Series(cosine_similarity(ready)[-1], index=ready.index)

Now we'll query for the gene knockouts that show the most similar perturbation
effect to everolimus, in some of the cell lines and dosage that showed the
most reproducibility for everolimus above. Adjust cell line and dosages
as you see fit. Other compounds with less global effects are likely more
sensitive to such choices.

### Result time!

In [None]:
# top hits in cell line A549
query_gene_knockouts(compound, "A549", min_dose=0, max_dose=2,
                     embedding=bc_vae_embedding, metadata=ordered_inst_info).sort_values(ascending=False)[1:11]  # 0th index is always the compound itself, skip it

In [None]:
# top hits in A375
query_gene_knockouts(compound, "A375", min_dose=.11, max_dose=2,
                     embedding=bc_vae_embedding, metadata=ordered_inst_info).sort_values(ascending=False)[1:11]

In [None]:
# top hits in YAPC
query_gene_knockouts(compound, "YAPC", min_dose=0, max_dose=0.15,
                     embedding=bc_vae_embedding, metadata=ordered_inst_info).sort_values(ascending=False)[1:11]

#### Thoughts
1. It's good to see MTOR robustly at or near the top of the list.

2. Other consistent hits should be checked if they could be indirect effects
(e.g. in the pathway with MTOR); or whether they might be off target effects.
It is likely helpful to consult a database such as [signor](https://signor.uniroma2.it/), [reactome](https://reactome.org/) or [corum](https://mips.helmholtz-muenchen.de/corum/) to find expected indirect interactions.

**Congratulations! You made it!**

-----------------


## Additional exercises
If you have extra time, we'd encourage you to do one of the following


1) Optimization.
This lab was designed to be a minimalistic, but nevertheless an end to end, target deconvolution analysis. At each step, we consistently chose a simple
and "good enough" option, but did not try to optimize any of them.

The exercize: pick your favorite step of
this lab (e.g. normalization, batch correction, embedding, similarity measurement) and try and optimize it.

2) Extend functionality.
This lab was minimalistic, there are surely many things that could
be added to strengthen the analysis.

The exercize: make a prioritized and justified list of what you would
add next.

3) Follow up experiments on potential off target effects.
We ended with a list of candidate drug targets and didn't make final
decisions on what was a chance correlation, an indirect effect, or an
actual target.

The exercise: previous labs have given you tools to asses drug-protein
interactions. Use these, or at least consider what you would best use,
to check some of the candidate interactions further.
