Copyright 2023 Recursion

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Building and Benchmarking a Biological Map with JUMP-Cell Painting Consortium Cell Profiler Features

In this notebook, we create and benchmark maps based on the JUMP-CP imaging dataset featurized with Cell Profiler. The focus will be on creating simple EFAAR pipelines and benchmarking using the framework to see how changes in the pipeline effect the metrics.

The primary components of an EFAAR Pipeline are 

- **Embed:** A high dimensional featurization of a biological perturbation
- **Filter:** A preprocessing step where embeddings are discarded based on statistical, machine learning thresholds or external metadata or features 
- **Align:** Process to combine the embeddings into the same high dimensional space while increasing the the signal and decreasing noise in the data 
- **Aggregate:** Combining any replicates of perturbations into a single representation for comparison 
- **Relate:** Any metric that allows pairwise comparison between perturbations

Additionally, we will examine the presence of chromosomal proximity bias in the data and alter the pipeline to decrease its effect on relationships. 

### Imports

In [None]:
# Uncomment the following lines to run this notebook in google Colab

# !git clone https://github.com/recursionpharma/EFAAR_benchmarking.git
# import sys
# sys.path.insert(0, "/content/EFAAR_benchmarking")
# import os
# os.chdir("EFAAR_benchmarking/notebooks")
# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_0.parquet data/cpg_0016_crispr_0.parquet
# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_1.parquet data/cpg_0016_crispr_1.parquet
# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_2.parquet data/cpg_0016_crispr_2.parquet
# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_3.parquet data/cpg_0016_crispr_3.parquet
# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_4.parquet data/cpg_0016_crispr_4.parquet
# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_5.parquet data/cpg_0016_crispr_5.parquet
# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_6.parquet data/cpg_0016_crispr_6.parquet
# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_7.parquet data/cpg_0016_crispr_7.parquet
# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_8.parquet data/cpg_0016_crispr_8.parquet
# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_9.parquet data/cpg_0016_crispr_9.parquet

In [None]:
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple, Dict, Any, Optional

import pandas as pd
pd.options.mode.chained_assignment = None
import numpy as np
import scipy as sp
import sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.covariance import EllipticEnvelope
import matplotlib.pyplot as plt
from scipy.cluster import hierarchy
from scipy.spatial.distance import pdist
from skimage.measure import block_reduce

import seaborn as sns
import plotly.graph_objects as go
from sklearn.utils import Bunch
import scipy.linalg as linalg

from efaar_benchmarking.benchmarking import benchmark as bm
from efaar_benchmarking.utils import get_benchmark_metrics

import matplotlib as mpl
import matplotlib.pyplot as plt

sns.set_theme()

### Loading
Load the cpg0016 CRISPR data. The metadata is not available via an API so it has been copied from the Cell Painting gallery for ease of use. But the features can be easily pulled from the S3 bucket provided by the JUMP consortium. We pickle the raw data after loading to avoid repeatedly pulling from the S3 bucket.

In [None]:
CP_FEATURE_FORMATTER = (
    "s3://cellpainting-gallery/cpg0016-jump/"
    "{Metadata_Source}/workspace/profiles/"
    "{Metadata_Batch}/{Metadata_Plate}/{Metadata_Plate}.parquet"
)

CPG_METADATA_COLS = [
    "Metadata_Source",
    "Metadata_Plate",
    "Metadata_Well",
    "Metadata_JCP2022",
    "Metadata_Batch",
    "Metadata_PlateType",
    "Metadata_NCBI_Gene_ID",
    "Metadata_Symbol",
]


def load_cpg_crispr_well_metadata():
    """Load well metadata for CRISPR plates from Cell Painting Gallery."""
    plates = pd.read_csv("data/plate.csv.gz")
    crispr_plates = plates.query("Metadata_PlateType=='CRISPR'")
    wells = pd.read_csv("data/well.csv.gz")
    crispr = pd.read_csv("data/crispr.csv.gz")

    well_plate = wells.merge(crispr_plates, on=["Metadata_Source", "Metadata_Plate"])
    crispr_well_metadata = well_plate.merge(crispr, on="Metadata_JCP2022")
    return crispr_well_metadata


def _load_plate_features(path: str):
    try:
        df = pd.read_parquet("data/" + path.split("/")[-1])
    except FileNotFoundError:
        df = pd.read_parquet(path, storage_options={"anon": True})
        df.to_parquet("data/" + path.split("/")[-1])
    return df


def load_feature_data(metadata_df: pd.DataFrame, max_workers=4) -> pd.DataFrame:
    """Load feature data from Cell Painting Gallery from metadata dataframe.

    Parameters
    ----------
    metadata_df : pd.DataFrame
        Well metadata dataframe

    Returns
    -------
    pd.DataFrame
        Well features
    """
    cripsr_plates = metadata_df[
        ["Metadata_Source", "Metadata_Batch", "Metadata_Plate", "Metadata_PlateType"]
    ].drop_duplicates()
    data = []
    with ThreadPoolExecutor(max_workers=max_workers) as executer:
        future_to_plate = {
            executer.submit(
                _load_plate_features, CP_FEATURE_FORMATTER.format(**row.to_dict())
            ): CP_FEATURE_FORMATTER.format(**row.to_dict())
            for _, row in cripsr_plates.iterrows()
        }
        for future in concurrent.futures.as_completed(future_to_plate):
            data.append(future.result())
    return pd.concat(data)


def build_combined_data(metadata: pd.DataFrame, features: pd.DataFrame) -> pd.DataFrame:
    """Join well metadata and well features

    Parameters
    ----------
    metadata : pd.DataFrame
        Well metadata
    features : pd.DataFrame
        Well features

    Returns
    -------
    pd.DataFrame
        Combined dataframe
    """
    return metadata.merge(features, on=["Metadata_Source", "Metadata_Plate", "Metadata_Well"])


def read_parquets_from_gcs():
    data = []
    for i in range(10):
        try:
            data.append(pd.read_parquet(f"data/cpg_0016_crispr_{i}.parquet"))
        except FileNotFoundError:
            subset = pd.read_parquet(f"gs://rxrx-cytodata2023-public/cpg_0016_crispr_{i}.parquet")
            subset.to_parquet(f"data/cpg_0016_crispr_{i}.parquet")
            data.append(subset)
        print(f"Read in parquet {i+1} of 10")
    return pd.concat(data)

In [None]:
raw_well_data = read_parquets_from_gcs()

# Creating and benchmarking an initial map
Let's create and benchmark an initial map with a simple EFAAR pipeline to serve as a baseline

- **Embed:** Cell Profiler Embeddings from JUMP cpg0016
- **Filter:** Filtering out rows with Nan values
- **Align:** Centering and scaling the individual features followed by PCA keeping a fraction of the variance 
- **Aggregate:** The mean over the well replicates of each gene KO
- **Relate:** Cosine similarity between the perturbations



In [None]:
def preprocess_data(
    data: pd.DataFrame, metadata_cols: list = CPG_METADATA_COLS, drop_image_cols: bool = True
) -> pd.DataFrame:
    """Preprocess data by dropping feature columns with nan values,
    and optionaly dropping the Image columns

    Parameters
    ----------
    data : pd.DataFrame
        Data to preprocess
    metadata_cols : List[str], optional
        Metadata columns, by default CPG_METADATA_COLS
    drop_image_cols : bool, optional
        Whether to drop Image columns, by default True

    Returns
    -------
    pd.DataFrame
        Processed dataframe
    """
    metadata = data[metadata_cols]
    features = data[[col for col in data.columns if col not in metadata_cols]]
    features = features.dropna(axis=1)
    if drop_image_cols:
        image_cols = [col for col in features.columns if col.startswith("Image_")]
        features = features.drop(columns=image_cols)
    return metadata.join(features)


def pca_transform_data(data, 
    metadata_cols: list = CPG_METADATA_COLS,
    variance=0.98,
    pca_fit_fraction = 1.0,) -> pd.DataFrame:
    """Align data by centerscaling data and then transforming by PCA

    Parameters
    ----------
    data : pd.DataFrame
        Data to preprocess
    metadata_cols : List[str], optional
        Metadata columns, by default CPG_METADATA_COLS
    pca_fit_fraction : bool, optional
        Fraction of wells to sample for PCA, by default 1.0

    Returns
    -------
    pd.DataFrame
        Processed dataframe
    """

    metadata = data[metadata_cols]
    features = data[[col for col in data.columns if col not in metadata_cols]]

    scaler = StandardScaler()
    features.loc[:,:] = scaler.fit_transform(
        features
    )
    
    pca = PCA(variance)
    pca.fit(features.sample(frac=pca_fit_fraction, random_state=42))
    features = pd.DataFrame(pca.transform(features), index=metadata.index)

    return metadata.join(features)

In [None]:
transformed_well_data = pca_transform_data(preprocess_data(raw_well_data), pca_fit_fraction=0.05)

In [None]:
features = transformed_well_data[[col for col in transformed_well_data.columns if col not in CPG_METADATA_COLS]]
aggregated_data = features.groupby(transformed_well_data["Metadata_Symbol"], as_index=True).mean()
aggregated_data.index.name = "gene"

#### Relationships
Visualizing relationships of some well known gene sets that we expect to cluster together

In [None]:
def make_pairwise_cos(
    df: pd.DataFrame,
    convert: bool = True,
    dtype: type = np.float16,
) -> pd.DataFrame:
    """
    Converts a dataframe of samples X features into a square dataframe of samples X samples
    of cosine similarities between rows.

    Inputs
    ------
    - df = pd.DataFrame
    - convert = bool. Whether to convert the results to a smaller data type
    - dtype = type. Data type to convert to
    """
    mat = (1 - sp.spatial.distance.cdist(df.values, df.values, metric="cosine")).clip(-1, 1)
    if convert:
        mat = mat.astype(dtype)
    return pd.DataFrame(mat, index=df.index, columns=df.index)

In [None]:
def cluster_perturbations(data: pd.DataFrame) -> pd.DataFrame:
    """
    Takes a square data frame, and uses euclidian distance to reorder the rows and columns
    into clusters. The index and column values for the data frame must be identical to eachother.
    """
    if list(data.index) != list(data.columns):
        raise ValueError("index and columns for pldata must be equal")

    order = hierarchy.dendrogram(
            hierarchy.linkage(
                pdist(data)
            ),
            no_plot=True,
        )['ivl']
    data = data.iloc[order]
    data = data[data.index]
    return data


def plot_cosine_rectangle_data(pldata: pd.DataFrame, title: str='plot', h=600, w=600):
    """
    pldata is expected to be a rectangular dataframe, where all the values are to be plotted
    """
    pldata = cluster_perturbations(data=pldata)

    fig = go.Figure(data=go.Heatmap(
        z=pldata.values,
        x=pldata.columns,
        y=pldata.index,
        colorscale='RdBu_r',
        zmin=-1,
        zmax=1,
    ))
    fig.update_layout(height=h, width=w, title_text=f"{title}")
    fig.show()

In [None]:
gene_sets = {
    'PROTEASOME': 'PSMB2,PSMB7,PSMB4,PSMA7,PSMA4,PSMB6,PSMA5,PSMB3,PSMA6,PSMA1,PSMB1,PSMA3'.split(','),
    'EXOSOME': 'DIS3,EXOSC4,EXOSC8,EXOSC7,EXOSC9,EXOSC5'.split(','),
    'VATPASE': 'ATP6V1B2,ATP6V1H,ATP6V1D,ATP6V1A,ATP6V1F,ATP6V1E1'.split(','),
    'AUTOPHAGY': 'ATG12,ATG5'.split(','),
    'REPLICATION_FACTOR_COMPLEX': 'RFC3,RFC4,RFC2,RFC5'.split(','),  
    'DYNEIN': 'DYNC1I2,DYNC1H1,DYNC1LI1,DYNC1LI2'.split(','),
    'RNA': 'POLA2,POLA1,POLR2L,POLR2B,POLR2I,POLR2G,POLR2C'.split(','), 
    'EGFR': 'PRKCE,BRAF,HRAS,SHC1,RAF1,EGFR,MAPK1'.split(','),
    'TGFB/ACTIVIN': 'ACVR1B,TGFBR2,TGFBR1'.split(','),
    }

cluster_genes = [g for gene_set in gene_sets.values() for g in gene_set]

In [None]:
plot_cosine_rectangle_data(make_pairwise_cos(aggregated_data.loc[cluster_genes]), title = "JUMP Gene Sets - Initial EFAAR Pipeline")

### Benchmarking
Plots provide a nice visualization but we can quantify how well the map is recapitulating biology with these metrics

#### Multivariate Benchmarks
The multivariate benchmarks compare annotated sets of known relationships against a distribution of randomly selected embeddings and computes recall based on the percentiles of the random distribution. In this notebook we will use the 5th and 95th percentil. Hence, embeddings that don't have biological signal are expected to have around 

In [None]:
metadata = aggregated_data.reset_index()[['gene']]
features = aggregated_data.reset_index(drop=True)
results_dict = bm(Bunch(metadata=metadata, features=features), pert_label_col='gene', run_count=1)

In [None]:
results_dict['Seeds_1608637542_3421126067']['HuMAP']

In [None]:
benchmark_results_df = get_benchmark_metrics(results_dict)
benchmark_results_df['map_version'] = 'JUMP_initial_EFAAR'
sns.barplot(data=benchmark_results_df, x='source' ,y='recall')
plt.axhline(y=0.1, color='red', linestyle='--') 
plt.title('Recall at 5th and 95th Percentiles')
plt.show()

#### Univariate Benchmarks
For the univariate benchmarks we can compare a metric that measures the consistency of the well level replicates of a perturbation against an empirical null where the metrics is computed on random sets of replicates. This computation can be very sensitive to the null distribution. Incorporating the experiment design in the null is import but can become computationally expensive very quickly. The following is an example of the idea but a production implementation needs more refinement and comptue. We are not matching plates or well address for example which can make replicates look more similar to each other. 

In [None]:
def avg_angle(df):
    cosine_sim = sklearn.metrics.pairwise.cosine_similarity(df.values)
    if cosine_sim.shape[0]==1:
        print("Whatt???")
    return np.arccos(cosine_sim[np.tril_indices(cosine_sim.shape[0], k=-1)]).mean()

def generate_null_distribution(feature_data, n_samples = 10000, cardinality = 5):
    """Generate null distribution for univariate avg cosine sim metric"""
    null_metrics = []
    for i in range(n_samples):
        null_metrics.append(avg_angle(feature_data.sample(n=cardinality)))
    return null_metrics


def compute_cosine_sim_metric(data: pd.DataFrame, metadata_cols=CPG_METADATA_COLS, cardinality=5):
    """Compute univariate cosine similarity metrics on well level data

    Parameters
    ----------
    data : pd.DataFrame
        Data to process
    metadata_cols : List[str], optional
        Metadata columns, by default CPG_METADATA_COLS
    cardinality : bool, optional
        Cardinatity of the wells to match when generating the null distribution

    Returns
    -------
    pd.DataFrame
        Processed dataframe
    """
    data = data.query('Metadata_Symbol != "no-guide" and Metadata_Symbol != "non-targeting" and Metadata_Symbol != "PLK1"')
    gene_count = data.groupby('Metadata_Symbol').Metadata_Well.count()
    genes = list(gene_count[gene_count==cardinality].index)
    data = data.query('Metadata_Symbol.isin(@genes)')
    
    perts = data[['Metadata_Symbol']]
    features = data[[col for col in data.columns if col not in metadata_cols]]

    df = perts.join(features)
    query_metrics = df.drop(columns = "Metadata_Symbol").groupby(df['Metadata_Symbol']).apply(avg_angle)
    query_metrics.name = "avg_cossim"
    query_metrics = query_metrics.reset_index()

    null = []
    for i in range(5):
        df['Metadata_Symbol'] = df['Metadata_Symbol'].sample(frac=1).values
        null.extend(df.drop(columns = "Metadata_Symbol").groupby(df['Metadata_Symbol']).apply(avg_angle).values)
    sorted_null = np.sort(null)
    query_metrics['avg_cossim_pval'] = np.searchsorted(sorted_null, query_metrics.avg_cossim)/len(sorted_null)
    return query_metrics


In [None]:
initial_univariate_metrics = compute_cosine_sim_metric(transformed_well_data)

In [None]:
num01genes = len(initial_univariate_metrics.query("avg_cossim_pval<=0.01"))
num05genes = len(initial_univariate_metrics.query("avg_cossim_pval<=0.05"))
print(f"{num01genes} significant genes at a 0.01 significance threshold")
print(f"{num05genes} significant genes at a 0.05 significance threshold")

## Improving the Map
Let's iterate on the EFAAR pipeline to try to improve our both the univariate and multivariate benchmarks

#### Filtering
With CellProfiler features, our features are identifiable so we can filter directly based on them. With other types of embeddings, this could be based on statistical or machine learning models predicing whether a well should be removed.

In [None]:
cols = raw_well_data.columns

In [None]:
[col for col in cols if "ImageQuality_MeanIntensity" in col]

In [None]:
[col for col in cols if "Object_Number" in col]

In [None]:
raw_well_data['Image_ImageQuality_MeanIntensity_OrigDNA'].plot(kind='kde')
plt.title('Mean Intensity')
plt.show()

In [None]:
raw_well_data['Cytoplasm_Number_Object_Number'].plot(kind='kde')
plt.title('Object Number')
plt.show()

In [None]:
def filter_by_image_intensity(df,contamination=0.01):
    """Filter dataframe by image intensity threshold."""

    intensity_cols = [col for col in cols if "ImageQuality_MeanIntensity" in col]
    image_data = df[intensity_cols]
    envelope = EllipticEnvelope(contamination=contamination, random_state=42)
    lables = envelope.fit_predict(image_data)
    return df.loc[lables == 1]

In [None]:
def filter_by_cell_count(df, lower_threshold, upper_threshold):
    """Filter dataframe by cell count"""
    mask = (df['Cytoplasm_Number_Object_Number'] >= lower_threshold) & (df['Nuclei_Number_Object_Number'] >= lower_threshold)
    mask = mask & (df['Cytoplasm_Number_Object_Number'] <= upper_threshold) & (df['Nuclei_Number_Object_Number'] <= upper_threshold)
    return df.loc[mask]

In [None]:
def filter_and_preprocess_data(
    data: pd.DataFrame, metadata_cols: list = CPG_METADATA_COLS, drop_image_cols: bool = True
) -> pd.DataFrame:
    """Preprocess data by dropping feature columns with nan values, dropping the
    Number_Object columns, and optionaly dropping the Image columns

    Parameters
    ----------
    data : pd.DataFrame
        Data to preprocess
    metadata_cols : List[str], optional
        Metadata columns, by default CPG_METADATA_COLS
    drop_image_cols : bool, optional
        Whether to drop Image columns, by default True

    Returns
    -------
    pd.DataFrame
        Processed dataframe
    """

    metadata = data[metadata_cols]
    features = data[[col for col in data.columns if col not in metadata_cols]]
    features = features.dropna(axis=1)
    data = metadata.join(features)
    data = filter_by_image_intensity(data)
    data = filter_by_cell_count(data, lower_threshold=50, upper_threshold=350)    

    data = data.drop(columns=[col for col in data.columns if col.endswith("Object_Number")])

    if drop_image_cols:
        image_cols = [col for col in data.columns if col.startswith("Image_")]
        data = data.drop(columns=image_cols)
    return data

In [None]:
# Run preprocessing on the raw data
print(len(raw_well_data), "wells before filtering")
processed_data = filter_and_preprocess_data(raw_well_data)
print(len(processed_data), "wells after filterng")

### Alignment

In [None]:
def _compute_modified_cov(data: pd.DataFrame):
        return np.cov(data, rowvar=False, ddof=1) + 0.5 * np.eye(
            np.shape(data)[1], dtype=np.float32
        )


def tvn_transform(
    data: pd.DataFrame,
    metadata_cols: list = CPG_METADATA_COLS,
    variance=0.98,
    pca_fit_fraction = 1.0
) -> pd.DataFrame:
    """Transform data by scaling, applying PCA followed by by coral by batch. Data is scaled by plate
    before and after PCA is applied. 

    Parameters
    ----------
    data : pd.DataFrame
        Data to transform
    metadata_cols : list, optional
        Metadata columns, by default CPG_METADATA_COLS
    variance : float, optional
        Variance to keep after PCA, by default 0.98
    pca_fit_fraction : bool, optional
        Fraction of wells to sample for PCA, by default 1.0

    Returns
    -------
    pd.DataFrame
        Transformed data
    """
    metadata = data[metadata_cols]
    features = data[[col for col in data.columns if col not in metadata_cols]]

    for plate in metadata.Metadata_Plate.unique():
        scaler = StandardScaler()
        features.loc[metadata.Metadata_Plate == plate, :] = scaler.fit_transform(
            features.loc[metadata.Metadata_Plate == plate, :]
        )

    pca = PCA(variance)
    pca.fit(features.sample(frac=pca_fit_fraction, random_state=42))
    features = pd.DataFrame(pca.transform(features), index= metadata.index)

    for plate in metadata.Metadata_Plate.unique():
        scaler = StandardScaler()
        features.loc[metadata.Metadata_Plate == plate, :] = scaler.fit_transform(
            features.loc[metadata.Metadata_Plate == plate, :]
        )

    for batch in metadata.Metadata_Batch.unique():
        source_cov = _compute_modified_cov(features.loc[metadata.query(f"Metadata_Batch == '{batch}'").index])
        source_cov_half_inv = linalg.fractional_matrix_power(source_cov, -0.5)
                                 
        features.loc[metadata.Metadata_Batch == batch, :] = np.matmul(
            features.loc[metadata.Metadata_Batch == batch, :],
            source_cov_half_inv,
        )                                   
    return metadata.join(features)

In [None]:
# Run alignment on the processed data
# This takes a while, so we again only use a fraction of the data to fit the PCA
tvn_transformed_well_data = tvn_transform(processed_data, pca_fit_fraction= .1)

### Aggregation
Combine representations by taking the median across the replicates

In [None]:
features = tvn_transformed_well_data[[col for col in tvn_transformed_well_data.columns if col not in CPG_METADATA_COLS]]
aggregated_tvn_data = features.groupby(tvn_transformed_well_data["Metadata_Symbol"], as_index=True).median()
aggregated_tvn_data.index.name = "gene"

In [None]:
# inspect the data
aggregated_tvn_data.head()

### Relationships

In [None]:
plot_cosine_rectangle_data(make_pairwise_cos(aggregated_tvn_data.loc[list(set(cluster_genes).intersection(aggregated_tvn_data.index))]), 
                           title = "JUMP Gene Sets - Improved EFAAR Pipeline")

#### Multivariate Benchmarks

In [None]:
data_dict={}

In [None]:
metadata = aggregated_tvn_data.reset_index()[['gene']]
features = aggregated_tvn_data.reset_index(drop=True)
data_dict['improved_EFAAR'] = Bunch(metadata=metadata, features=features)
results = bm(data_dict['improved_EFAAR'], pert_label_col='gene', run_count=1)

In [None]:
new_results_df = get_benchmark_metrics(results)
new_results_df['map_version'] = 'JUMP_improved_EFAAR'
benchmark_results_df = pd.concat([benchmark_results_df, new_results_df])
sns.barplot(data=benchmark_results_df, x='source' ,y='recall', hue='map_version')
plt.axhline(y=0.1, color='red', linestyle='--') 
plt.title('Recall at 5th and 95th Percentiles')
plt.show()

# Proximity Bias in CRISPR Knockout Data

### Visualizing Proximity Bias in JUMP Data
We create a genome wide heatmap ordered by gene chromosome position

In [None]:
def _get_data_path(filename):
    return "data/"+filename


def _chr_to_int(chr):
    if chr == "x":
        return 24
    elif chr == "y":
        return 25
    return int(chr)


def _chrom_int(chrom):
    return chrom.copy().str.split("chr").str[1].str.lower().map(_chr_to_int)

VALID_CHROMS = [f"chr{i}" for i in range(1, 23)] + ["chrX", "chrY"]

def _load_centromeres() -> pd.DataFrame:
    centros = pd.read_csv(_get_data_path("centromeres_hg38.tsv"), sep="\t", usecols=["chrom", "chromStart", "chromEnd"])
    centros[["chromStart", "chromEnd"]] = centros[["chromStart", "chromEnd"]].astype(int)
    centros = centros.groupby("chrom", as_index=False).agg(
        centromere_start=("chromStart", "min"),
        centromere_end=("chromEnd", "max"),
    )
    return centros


def _load_chromosomes(centromeres: Optional[pd.DataFrame] = None) -> pd.DataFrame:
    chroms = pd.read_csv(_get_data_path("hg38_scaffolds.tsv"), sep="\t", usecols=["chrom", "chromStart", "chromEnd"])
    chroms[["chromStart", "chromEnd"]] = chroms[["chromStart", "chromEnd"]].astype(int)
    chroms = chroms.loc[chroms.chrom.isin(VALID_CHROMS)].rename(columns={"chromStart": "start", "chromEnd": "end"})
    chroms["chrom_int"] = _chrom_int(chroms.chrom)

    # Merge in centromere data if available
    if isinstance(centromeres, pd.DataFrame) and not centromeres.empty:
        chroms = chroms.merge(centromeres, on="chrom", how="left")

    chroms = chroms.set_index("chrom").sort_values("chrom_int", ascending=True)
    return chroms


def _load_bands() -> pd.DataFrame:
    bands = pd.read_csv(
        _get_data_path("hg38_cytoband.tsv.gz"), sep="\t", usecols=["name", "#chrom", "chromStart", "chromEnd"]
    )
    bands = bands.rename(columns={"#chrom": "chrom"})
    bands = bands.groupby(["chrom", "name"], as_index=False).agg(
        band_start=("chromStart", "min"),
        band_end=("chromEnd", "max"),
        band_chrom_arm=("name", lambda x: x.str[:1].min()),
    )
    bands["chrom_int"] = _chrom_int(bands.chrom)
    return bands


def _load_genes(chromosomes: Optional[pd.DataFrame] = None) -> pd.DataFrame:
    genes = pd.read_csv(
        _get_data_path("ncbirefseq_hg38.tsv.gz"), sep="\t", usecols=["name2", "chrom", "txStart", "txEnd"]
    ).rename(columns={"name2": "gene"})

    genes = genes.loc[genes.chrom.isin(VALID_CHROMS)]
    genes["chrom_int"] = _chrom_int(genes.chrom)
    genes = genes.groupby("gene", as_index=False).agg(
        start=("txStart", "min"),
        end=("txEnd", "max"),
        chrom_int=("chrom_int", "min"),
        chrom_count=("chrom", "nunique"),
        chrom=("chrom", "first"),
    )

    # Filter out (psuedo-)genes of unknown function and genes that show up on multiple chromosomes
    genes = genes.loc[~genes.gene.str.contains("^LOC*", regex=True)]
    genes = genes.loc[genes.chrom_count == 1].drop(columns="chrom_count").set_index("gene")
    genes = genes.sort_values(["chrom_int", "start", "end"], ascending=True)

    # Use the middle of the centromere as the way to determine if a gene is on the 0/1 chromosome
    if isinstance(chromosomes, pd.DataFrame) and not chromosomes.empty:
        chroms_centromere_mid = chromosomes.copy().set_index("chrom_int")
        chrom_centromere_mid = (
            (chroms_centromere_mid.centromere_start + chroms_centromere_mid.centromere_end) / 2
        ).to_dict()
        genes["chrom_arm_int"] = genes.apply(lambda x: x.end > chrom_centromere_mid[x.chrom_int], axis=1).astype(int)

        # NOTE: Assumes that p is the first chromosome
        genes["chrom_arm"] = genes["chrom_arm_int"].apply(lambda x: "p" if x == 0 else "q")
        genes["chrom_arm_name"] = genes["chrom"] + genes["chrom_arm"]
    return genes


def get_chromosome_info_as_dfs() -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Get structured information about the chromosomes that genes lie on as three dataframes:
        - Genes, including start and end and which chromosome arm they are on
        - Chromosomes, including the centromere start and end genomic coordinates
        - Cytogenic bands, including the name and start and end genomic

    Returns
    -------
    Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
        Genes, chromosomes, cytogenic bands
    """

    # TODO: refactor into more composable functions
    bands = _load_bands()
    chroms = _load_chromosomes(centromeres=_load_centromeres())
    genes = _load_genes(chromosomes=chroms)

    return genes, chroms, bands

def get_chromosome_info_as_dicts(
    legacy_bands: bool = False,
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
    """
    Convert the output of `get_chromosome_info_as_dfs` to dictionary form for compatibility
    with legacy notebooks.

    Returns
    -------
    Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]
        Dict corresponding to genes, chromosomes, and cytogenic bands
    """
    gene_df, chrom_df, band_df = get_chromosome_info_as_dfs()

    # Extra composite key for convenience
    gene_df["arm"] = gene_df.chrom + gene_df.chrom_arm
    gene_dict = gene_df.to_dict(orient="index")

    chrom_dict = chrom_df.to_dict(orient="index")

    # Extra composite key for convenience
    band_df["region"] = band_df.chrom + band_df.name
    if legacy_bands:
        band_df = band_df[["region", "chrom", "band_start", "band_end"]].set_index("region")
        band_dict = {str(k): tuple(v) for k, v in band_df.iterrows()}
    else:
        band_dict = band_df.to_dict(orient="index")
    return gene_dict, chrom_dict, band_dict


In [None]:
gene_dict, chrom_dict, band_dict = get_chromosome_info_as_dicts()
genes = list(gene_dict.keys())

In [None]:
n_feats = aggregated_tvn_data.shape[1]
aggregated_tvn_data =aggregated_tvn_data.reset_index()
idx = aggregated_tvn_data.query(f"gene.isin({genes})").index
print(f'Full data has {aggregated_tvn_data.shape[0]} genes, {len(idx)} of which are in hg38 annotations')
data_t = aggregated_tvn_data.rename({'Metadata_Symbol': 'gene'}, axis=1)
data_t = data_t.loc[idx].reset_index(drop=True)

# Add in chromomsome information
data_t['chromosome'] = data_t.gene.apply(lambda x: gene_dict[x]['chrom'] if x in gene_dict else "no info" )
data_t['chr_idx'] = data_t.gene.apply(lambda x: gene_dict[x]['chrom_int'] if x in gene_dict else "no info" )
data_t['chromosome_arm'] = data_t.gene.apply(lambda x: gene_dict[x]['arm'] if x in gene_dict else "no info" )
data_t['gene_bp'] = data_t.gene.apply(lambda x: gene_dict[x]["start"] if x in gene_dict else "no info" )

In [None]:
data_t.head()

In [None]:
cols = ['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp',] + list(range(0,n_feats))
data_t = data_t.loc[:, cols]
data_t = data_t.sort_values(['chr_idx', 'gene_bp']).reset_index(drop=True)
data_t = data_t.set_index(['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp',])

In [None]:
def crunch_square_df(
    sims: pd.DataFrame,
    crunch_factor: int,
) -> pd.DataFrame:
    """
    Compress `sims` dataframe by `crunch_factor` to make visualizations reasonable. This takes averages of squares of
    size `crunch_factor` X `crunch_factor`. Indices are replaced by the first value in the crunch block.

    Inputs:
    -------
    - sims: pd.DataFrame() with matching row and column indices
    - crunch_fctor: int to compress the data by.
    """

    idx = [(i % crunch_factor == 0) for i, x in enumerate(sims.index)]
    new_index = sims.index[idx]  # type: ignore
    crunched = block_reduce(sims.values, (crunch_factor, crunch_factor), np.mean)

    return pd.DataFrame(crunched, index=new_index, columns=new_index)


def plot_heatmap(
    sims: pd.DataFrame,
    f_name: Optional[str] = None,
    format: str = "png",
    crunch_factor: int = 1,
    show_chr_lines: bool = True,
    show_cent_lines: bool = True,
    show_chroms: bool = True,
    show_chrom_arms: bool = False,
    figsize: tuple = (20, 20),
    title: Optional[str] = None,
    label_locy: Optional[float] = None,
    lab_s: int = 12,
    drop_chry: bool = True,
    lw: float = 0.5,
    lab_rot: int = 0,
):
    """
    Plotting function for heatmaps (full-genome or subsets) can be sorted/clustered or ordered by chromosome

    Inputs:
    - sims: Square dataframe with matching row and column indices and similarity values. This can be "split" along
            the diagonal to show different datasets. Index should include `chromosome` and `chromosome_arm` if
            ordering by genomic position. Each row/column should represent one gene ordered by genomic position
    - f_name: file name
    - format: file format for saving a file
    - crunch_factor: if > 1 will apply a average smoothing to reduce the size of the output file
    - show_chr_lines: whether to show lines at chromosome boundaries
    - show_cent_lines: whether to show lines at centromeres
    - show_chroms: Whether to label chromosomes on top and right
    - show_chrom_arms: whether to label chromosome arms on top and right
    - figsize: size of the resulting figure
    - title: plot title
    - label_locy: location of labels in y
    - lab_s: Font size of labels
    - drop_chry: Whether to remove Chromosome Y values
    - lw: line width
    - lab_rot: rotation of labels
    """
    color_norm = mpl.colors.Normalize(vmin=-1, vmax=1)
    cmap = mpl.colormaps["RdBu_r"]
    # This sets nan values to white
    # cmap.set_bad(color='white')

    if crunch_factor > 1:
        # Downsample the data to make the file size less crazy
        # Every `sample_factor`th row/column will be kept
        sims = crunch_square_df(sims, crunch_factor=crunch_factor)

    image_data = cmap(color_norm(sims.values))

    plt.figure(figsize=figsize)
    plt.imshow(image_data)

    if drop_chry and "chromosome" in sims.index.names:
        noy_idx = sims.index.get_level_values("chromosome") != "chrY"
        sims = sims.copy().loc[noy_idx, noy_idx]  # type: ignore

    if show_chr_lines or show_chroms or show_cent_lines or show_chrom_arms:
        # Get the position of all the chromosomes and centromeres
        index_df = sims.index.to_frame(index=False).reset_index().rename({"index": "pos"}, axis=1)
        chr_pos = index_df.groupby("chromosome").pos.max().sort_values()
        cent_pos = index_df.groupby("chromosome_arm").pos.max().sort_values()
        # Filter to just p-arms
        cent_pos_p = cent_pos[[x[-1] == "p" for x in cent_pos.index]]
        # Get midpoints for annotations
        chr_mids = pd.DataFrame(
            (np.insert(chr_pos.values[:-1], 0, 0) + chr_pos.values) / 2, index=chr_pos.index  # type: ignore
        ).to_dict()[
            0  # type: ignore
        ]
        cent_mids = pd.DataFrame(
            (np.insert(cent_pos.values[:-1], 0, 0) + cent_pos.values) / 2, index=cent_pos.index  # type: ignore
        ).to_dict()[
            0  # type: ignore
        ]

    # Hide X and Y axes label marks
    ax = plt.gca()
    ax.xaxis.set_tick_params(labelbottom=False)
    ax.yaxis.set_tick_params(labelleft=False)
    # Hide X and Y axes tick marks
    ax.set_xticks([])
    ax.set_yticks([])

    xm, xM = ax.get_xlim()
    ym, yM = ax.get_ylim()

    if show_chr_lines:
        for x in chr_pos.values:
            plt.plot([x + 0.5, x + 0.5], [ym, yM], color="k", ls="-", lw=lw)
            plt.plot([xm, xM], [x + 0.5, x + 0.5], color="k", ls="-", lw=lw)
    if show_cent_lines:
        for x in cent_pos_p.values:
            plt.plot([x + 0.5, x + 0.5], [ym, yM], color="k", ls=":", lw=lw)
            plt.plot([xm, xM], [x + 0.5, x + 0.5], color="k", ls=":", lw=lw)

    if show_chroms:
        # Label chromosomes on top/right to not clash with coords
        ax = plt.gca()
        s = sims.shape[0]
        for ch in chr_mids:
            # Labels across the top
            if label_locy is None:
                label_locy = -0.008 * s
            # Labels on top
            ax.text(
                chr_mids[ch], label_locy, ch.replace("chr", ""), ha="center", va="bottom", rotation=lab_rot, size=lab_s
            )
            # Labels on the right
            ax.text(sims.shape[0] + 0.008 * s, chr_mids[ch], ch.replace("chr", ""), ha="left", va="center", size=lab_s)

    if show_chrom_arms:
        # Label chromosome arms on top/right
        ax = plt.gca()
        s = sims.shape[0]
        for cent in cent_mids:
            # Labels across the top
            if label_locy is None:
                label_locy = -0.008 * s
            ax.text(cent_mids[cent], label_locy, cent, ha="left", rotation=lab_rot, size=lab_s)
            # Labels on the right
            ax.text(sims.shape[0] + 0.001 * s, cent_mids[cent], cent, ha="left", size=lab_s)

    plt.title(title, size="xx-large", y=1.1)
    plt.gcf().set_facecolor("white")

In [None]:
plot_heatmap(make_pairwise_cos(data_t), f_name=('data/cpg0016_split_prenorm.svg'),
             format='svg', crunch_factor=10, title='JUMP Whole Genome Heatmap')

### Using Arm Subtraction to Correct for Proximity Bias
We add an additional alignment step to our pipeline where we subtract off the mean of the corresponding chromosome arm from each transformed embedding.

In [None]:
# Load U2OS expression data
u2os_exp = pd.read_csv('data/u2os.csv')
u2os_exp = u2os_exp.groupby('gene',as_index=False).zfpkm.median()

In [None]:
def build_arm_centering_df(
    data: pd.DataFrame,
    metadata_cols: List[str],
    arm_column="chromosome_arm",
    subset_query="zfpkm <-3",
    min_num_gene=20,
) -> pd.DataFrame:
    """Build a dataframe with the mean feature values for each chromosome arm

    Parameters
    ----------
    data : pd.DataFrame
        DataFrame with metadata and features
    metadata_cols : List[str]
        Metadata columns
    arm_column : str, optional
        Metadata column with arm identifier, by default "chromosome_arm"
    subset_query : str, optional
        Query to subset genes, by default "zfpkm <-3"
    min_num_gene : int, optional
        Minimum number of genes required. Dataframe returned will only
        include arms that meet this threshold, by default 20

    Returns
    -------
    pd.DataFrame
        DataFrame with mean feature values for each chromosome arm
    """
    subset = data.query(subset_query)
    if arm_column not in metadata_cols:
        metadata_cols = metadata_cols + [arm_column]
    features = subset.drop(metadata_cols, axis="columns")
    return features.groupby(subset[arm_column]).mean()[
        subset.groupby(arm_column)[metadata_cols[0]].size() > min_num_gene
    ]


def perform_arm_centering(
    data: pd.DataFrame,
    metadata_cols: List[str],
    arm_centering_df: pd.DataFrame,
    arm_column: str = "chromosome_arm",
) -> pd.DataFrame:
    """Apply arm centering to data

    Parameters
    ----------
    data : pd.DataFrame
        Data DataFrame
    metadata_cols : List[str]
        List of metadata columns
    arm_centering_df : pd.DataFrame
        Arm centering dataframe
    arm_column : str, optional
        Column that identifies chromosome arm, by default "chromosome_arm"

    Returns
    -------
    pd.DataFrame
        Arm centered data
    """
    metadata = data[metadata_cols]
    features = data.drop(metadata_cols, axis="columns")
    for chromosome_arm in arm_centering_df.index:
        arm_features = features[metadata[arm_column] == chromosome_arm]
        arm_features = arm_features - arm_centering_df.loc[chromosome_arm]
        features[metadata[arm_column] == chromosome_arm] = arm_features
    return metadata.join(features)

In [None]:
data = data_t.reset_index()
data= data.merge(u2os_exp, how='left', left_on='gene', right_on = 'gene')

In [None]:
arm_centering_df = build_arm_centering_df(data, metadata_cols=['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp', 'zfpkm']) 

In [None]:
arm_centered_data = perform_arm_centering(data, metadata_cols=['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp', 'zfpkm'], 
                                          arm_centering_df=arm_centering_df)

In [None]:
plot_heatmap(make_pairwise_cos(arm_centered_data.set_index(['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp', 'zfpkm'])), 
            f_name=('data/cpg0016_arm_centered.svg'), format='svg', crunch_factor=10, title='JUMP Whole Genome Arm Centered Heatmap')

### Recomputing Metrics

In [None]:
metadata = arm_centered_data[['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp', 'zfpkm']]
features = arm_centered_data.drop(columns=['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp', 'zfpkm'])
data_dict['arm_centered_EFAAR'] = Bunch(metadata=metadata, features=features)
arm_centered_results = bm(data_dict['arm_centered_EFAAR'], pert_label_col='gene', run_count=1)

In [None]:
arm_results_df = get_benchmark_metrics(arm_centered_results)
arm_results_df['map_version'] = 'JUMP_arm_centered_EFAAR'
benchmark_results_df = pd.concat([benchmark_results_df, arm_results_df])
sns.barplot(data=benchmark_results_df, x='source' ,y='recall', hue='map_version')
plt.axhline(y=0.1, color='red', linestyle='--') 
plt.title('Recall at 5th and 95th Percentiles')
plt.show()

### Arm Stratified Metrics
We also compute the multiperturbation metrics stratified by chromosome arm. This is because there may have been some relationships where they were only recalled because the genes happened to be on the same chromosome arm. 

In [None]:
from efaar_benchmarking.constants import RANDOM_SEED, BENCHMARK_SOURCES, N_NULL_SAMPLES
from efaar_benchmarking.utils import (
    generate_null_cossims,
    generate_query_cossims,
    get_benchmark_data,
    get_feats_w_indices,
)

def compute_within_cross_arm_pairwise_metrics(
    data: Bunch,
    pert_label_col: str = "gene",
    pct_thresholds: list = [0.05, 0.95],
) -> tuple:
    """Compute known biology benchmarks stratified by whether the pairs of genes
    are on the same chromosome arm or not.

    Parameters
    ----------
    data : Bunch
        Metadata-features bunch
    pert_label_col : str, optional
        Column in the metadata that defines the perturbation, by default "gene"
    pct_thresholds : list, optional
        Percentile thresholds for the recall computation, by default [0.05, 0.95]

    Returns
    -------
    tuple
        Results for within-arm and cross-arm pairs, respectively.
    """
    np.random.seed(RANDOM_SEED)
    within = {}
    between = {}
    for source in BENCHMARK_SOURCES:
        random_seed_pair = np.random.randint(2**32, size=2)
        gt_data = get_benchmark_data(source)
        gene_dict, _, _ = get_chromosome_info_as_dicts()

        feats = get_feats_w_indices(data, pert_label_col)

        gt_data["entity1_chrom"] = gt_data.entity1.apply(lambda x: gene_dict[x]["arm"] if x in gene_dict else "no info")
        gt_data["entity2_chrom"] = gt_data.entity2.apply(lambda x: gene_dict[x]["arm"] if x in gene_dict else "no info")
        gt_data = gt_data.query("entity1_chrom != 'no info' and entity2_chrom != 'no info'")
        df_gg_null = generate_null_cossims(
            feats,
            feats,
            rseed_entity1=random_seed_pair[0],
            rseed_entity2=random_seed_pair[1],
            n_entity1=N_NULL_SAMPLES,
            n_entity2=N_NULL_SAMPLES,
        )

        within_gt_subset = gt_data.query("entity1_chrom == entity2_chrom")
        between_gt_subset = gt_data.query("entity1_chrom != entity2_chrom")

        df_gg_within = generate_query_cossims(feats, feats, within_gt_subset)
        df_gg_between = generate_query_cossims(feats, feats, between_gt_subset)

        within[source] = _compute_recall(df_gg_null, df_gg_within, pct_thresholds)

        between[source] = _compute_recall(df_gg_null, df_gg_between, pct_thresholds)
    return within, between

def _compute_recall(null_cossims, query_cossims, pct_thresholds) -> dict:
    null_sorted = np.sort(null_cossims)
    percentiles = np.searchsorted(null_sorted, query_cossims) / len(null_sorted)
    return sum((percentiles <= np.min(pct_thresholds)) | (percentiles >= np.max(pct_thresholds))) / len(percentiles)

In [None]:
arm_stratified_results = {}
for k, v in data_dict.items():
    arm_stratified_results[k] = compute_within_cross_arm_pairwise_metrics(v)


result_records = []
for map_label, v in arm_stratified_results.items():
    for name, result in zip(("within arm", "between arms"), v):
        for source, recall in result.items():
            result_records.append((map_label, name, source, recall))

stratified_results_df = pd.DataFrame.from_records(result_records, columns=["Map Name", "arms", "source", "recall"])


In [None]:
sns.catplot(data = stratified_results_df, x = "arms", y="recall", hue = "Map Name", col = "source", kind = "bar")
plt.title("Arm Stratified Benchmarks")
plt.show()