In [1]:
from pathlib import Path
import sys
import json

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import torch

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
sns.set_theme(font_scale=1.5, palette='Set2')
sns.set_style('whitegrid')
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

sys.path.insert(0, '..')
from dsdna_mpra import config, plots


ENCODE data annotation and preprocessing are performed in [`cre_classifier_dataset_preparation.py`](../scripts/cre_classifier_dataset_preparation.py).

The CRE classifier model is trained and saved using [`cre_classifier_training.py`](../scripts/cre_classifier_training.py).

Predictions of CRE classes for the ENCODE test set and viral tiles are generated using [`cre_classifier_inference.py`](../scripts/cre_classifier_inference.py).


### Performance on the ENCODE CRE Test Set


In [2]:
test_predictions = pd.read_csv(config.RESULTS_DIR / 'encode_validation_classification.csv')

class_pred = [config.ENCODE_CRE_TYPES.index(pred_class) for pred_class in test_predictions.predicted_class]
class_gt = [config.ENCODE_CRE_TYPES.index(pred_class) for pred_class in test_predictions.real_class]


In [3]:
for normalization_type, figure_name in [
    [None, 'fig_S3_confusion_matrix'],
    ['pred', 'fig_S3B_cre_class_precision'],
    ['true', 'fig_S3C_cre_class_recall'],
]:
    pred_matrix = confusion_matrix(class_gt, class_pred, normalize=normalization_type)
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_title(f"CRE Classifier ('{normalization_type}'-normalization)", fontsize=17)
    if normalization_type is not None:
        img = ax.imshow(pred_matrix, cmap="Reds", vmax=1, vmin=0)
    else:
        img = ax.imshow(pred_matrix, cmap="Reds")
    ax.set_xticks(np.arange(8), config.ENCODE_CRE_TYPES, rotation=90, fontsize=12)
    ax.set_yticks(np.arange(8), config.ENCODE_CRE_TYPES, fontsize=12)
    ax.set_ylabel('True', fontsize=15)
    ax.set_xlabel('Predicted', fontsize=15)
    ax.grid(False)
    cbar_ax = fig.add_axes([.92, 0.2, 0.02, 0.6])
    cbar = fig.colorbar(img, cax=cbar_ax)
    cbar.set_label('Number / Fraction of ENCODE CREs', fontsize=15)
    plt.savefig(config.FIGURES_DIR / f"{figure_name}.pdf", format="pdf", bbox_inches="tight")
    pred_matrix = pd.DataFrame(pred_matrix, columns=config.ENCODE_CRE_TYPES)
    pred_matrix.insert(0, 'CRE-type', config.ENCODE_CRE_TYPES)
    pred_matrix.to_csv(config.RESULTS_DIR / f"{figure_name}.csv", index=False)
    plt.close()


### Classifier predictions for viral tiles active in K562 cells


Proportion of each predicted class in viral CREs and in the ENCODE test set of real CRE sequences.


In [4]:
test_predictions['predicted_class_short'] = np.where(
    test_predictions.predicted_class.str.contains('shuffled'),
    'Undetermined', test_predictions.predicted_class,
)
test_predictions.predicted_class_short = test_predictions.predicted_class_short.astype('category').cat.set_categories(config.ENCODE_CRE_TYPES_SHORT)
real_test_predictions = test_predictions[~test_predictions['real_class'].str.contains('shuffled')]

virus_predictions = pd.read_csv(config.RESULTS_DIR / 'k562_active_tiles_classification.csv')
virus_predictions['virus'] = (virus_predictions.family + ', ' + virus_predictions.strain).astype('category').cat.set_categories(config.VIRUSES)
virus_predictions['predicted_class_short'] = np.where(
    virus_predictions.predicted_class.str.contains('shuffled'),
    'Undetermined', virus_predictions.predicted_class,
)
virus_predictions.predicted_class_short = virus_predictions.predicted_class_short.astype('category').cat.set_categories(config.ENCODE_CRE_TYPES_SHORT)

class_counts = virus_predictions.value_counts(['virus', 'predicted_class_short']).to_frame().reset_index().pivot(columns='virus', index='predicted_class_short', values='count')
class_counts = class_counts.replace(np.nan, 0)
class_counts['ENCODE test set'] = real_test_predictions.value_counts('predicted_class_short', sort=False)
class_counts.div(class_counts.sum()).to_csv(config.RESULTS_DIR / "fig_2D_cre_classifier_class_fractions.csv")


MPRA activity levels for CREs grouped by predicted class.


In [5]:
paired_tiles = pd.read_csv(config.RESULTS_DIR / "virus_paired_tiles_cds_overlap.csv")
paired_tiles = paired_tiles[paired_tiles.tile_id.isin(virus_predictions.tile_id)].reset_index(drop=True)
paired_tiles = paired_tiles.merge(virus_predictions[['tile_id', 'predicted_class_short']], on='tile_id', how='left')


In [6]:
thresholds = pd.read_csv(config.RESULTS_DIR / 'thresholds_log2_1p.csv')

fig, axes = plt.subplots(figsize=(20, 14), nrows=2, ncols=3, layout="tight")
for cell_index, cell in enumerate(config.CELL_LINES[::-1]):
    ax = axes[cell_index // 3, cell_index % 3]
    cell_tiles = paired_tiles[paired_tiles.cell == cell]
    for pred_class, class_df in cell_tiles.groupby('predicted_class_short', observed=False):
        plots.violin(
            ax, class_df[['fwd_lfc', 'rev_lfc']].max(1),
            config.ENCODE_CRE_TYPES_SHORT.index(pred_class) * 7,
            width_factor=3, box_width=.4, text=False
        )
    ax.axhline(thresholds[thresholds.cell == cell].threshold.iloc[0], linestyle='--', color='red')
    ax.set_ylim([0, 8])
    ax.grid(False)
    ax.set_title(cell)
    if cell_index % 3 == 0:
        ax.set_ylabel(r'$\log_2 (\frac{RNA}{DNA} + 1)$')
    if cell_index // 3 == 1:
        ax.set_xticklabels(config.ENCODE_CRE_TYPES_SHORT, rotation=90)
plt.savefig(config.FIGURES_DIR / 'fig_2E_cre_classes_activity_by_cell_line.pdf', format="pdf", bbox_inches="tight")
plt.close(fig)


MPRA activity of CREs in each class across different viruses.


In [7]:
n_families = len(config.DSDNA_FAMILIES)
n_strains_per_family = paired_tiles.drop_duplicates(['family', 'strain']).value_counts('family')
class_activity = paired_tiles.groupby(['family', 'strain', 'cell', 'predicted_class_short'], observed=True).median('cell_rank').reset_index()[['family', 'strain', 'cell', 'predicted_class_short', 'cell_rank']]
class_activity['virus'] = (class_activity.family + ', ' + class_activity.strain).astype('category').cat.set_categories(config.VIRUSES)
class_activity.family = class_activity.family.astype('category').cat.set_categories(config.DSDNA_FAMILIES[::-1])

orig_cmap = plt.get_cmap('Reds')
mild_cmap = orig_cmap(np.linspace(0, 1, 256))
mild_cmap = 0.9 * mild_cmap + 0.1  # blend with white to desaturate
mild_cmap = np.clip(mild_cmap, 0, 1)
mild_cmap = ListedColormap(mild_cmap)
mild_cmap.set_bad(color='lightgrey')

fig, axes = plt.subplots(figsize=(23, 15), nrows=n_families, ncols=len(config.CELL_LINES), height_ratios=n_strains_per_family)
for cell_idx, cell in enumerate(config.CELL_LINES):
    cell_activity = class_activity[class_activity.cell == cell].pivot_table(values='cell_rank', index=['family', 'virus', 'strain'], columns='predicted_class_short', observed=True).reset_index()
    for fam_idx, (family, family_activity) in enumerate(cell_activity.groupby('family', observed=False)):
        ax = axes[fam_idx, cell_idx]
        title_args={'label': f"{cell}", 'fontsize': 15} if fam_idx == 0 else None
        img = plots.heatmap_with_stats(
            ax, family_activity[['virus'] + config.ENCODE_CRE_TYPES_SHORT].set_index('virus'),
            imshow_args={'cmap': mild_cmap, 'vmin': .75, 'vmax': 1, 'norm': None},
            title_args=title_args, text_fontsize=12
        )
        if fam_idx != 0:
            ax.tick_params(axis='x', which='both', top=False, labeltop=False)
        if cell_idx != 0:
            ax.tick_params(axis='y', which='both', left=False, labelleft=False)
fig.subplots_adjust(right=1.2, wspace=.1)
cbar_ax = fig.add_axes([1.25, 0.15, 0.02, 0.7])
cbar = fig.colorbar(img, cax=cbar_ax)
cbar.set_label('Cell activity rank')
plt.savefig(config.FIGURES_DIR / 'fig_S3G_cre_classes_activity_by_cell_line_by_virus.pdf', format="pdf", bbox_inches="tight")
plt.close(fig)


Proportions of predicted classes within coding and non-coding regions.


In [8]:
class_cds_counts = virus_predictions.value_counts(['is_cds', 'predicted_class_short']).to_frame().reset_index().pivot(columns='is_cds', index='predicted_class_short', values='count')
class_cds_counts.to_csv(config.RESULTS_DIR / "fig_3H_cre_classifier_in_cds_counts.csv")
class_cds_counts.div(class_cds_counts.sum()).to_csv(config.RESULTS_DIR / "fig_3H_cre_classifier_in_cds_fractions.csv")


Proportions of CREs located in coding sequences for each CRE class across different viruses.


In [9]:
class_cds_counts = virus_predictions.value_counts(
    ['virus', 'predicted_class_short', 'is_cds']
).to_frame(name='count').reset_index()
pivoted_counts = class_cds_counts.pivot_table(index=['virus', 'is_cds'], columns='predicted_class_short', values='count', fill_value=0, observed=False)
# normalize within each (virus, is_cds) group to get fractions of each class
class_fractions_within_cds = pivoted_counts.div(pivoted_counts.sum(axis=1), axis=0)
fractions_table = class_fractions_within_cds.unstack(level='is_cds')
cols = fractions_table.columns
new_cols = [
    (cls, is_cds)
    for cls in config.ENCODE_CRE_TYPES_SHORT
    for is_cds in [True, False]
    if (cls, is_cds) in cols
]
fractions_table = fractions_table[new_cols]
fractions_table.to_csv(config.RESULTS_DIR / "fig_3I_fractions_of_classes_within_cds.csv")


Distribution of the number of transcription factor (TF) motif instances in each CRE class.


In [10]:
with open(config.RESULTS_DIR / "malinois_K562_tf_motif_map.json", 'r', encoding='utf-8') as f:
    tile_motif_map = {
        tile_map['tile_id']: tile_map
        for tile_map in json.load(f)
    }


In [11]:
npeaks_per_class = dict()
for pred_class, class_df in virus_predictions.groupby('predicted_class_short', observed=False):
    npeaks_per_class[pred_class] = np.array([
        len(tile_motif_map[tile_id]['peak_positions'])
        for tile_id in class_df.tile_id.values
    ])

cre_colors = {
    'Promoter-like': 'firebrick',
    'Proximal': 'orange',
    'Distal': 'cornflowerblue',
    'CTCF-only': 'forestgreen',
    'Undetermined': 'grey'
}
fig, ax = plt.subplots(figsize=(10, 7))
npeaks_grid = np.arange(9)
for pred_class, npeaks in npeaks_per_class.items():
    cum_density = np.less_equal.outer(npeaks_grid, npeaks).mean(axis=1)
    ax.plot(npeaks_grid, cum_density, color=cre_colors[pred_class], linewidth=5, label=pred_class)
ax.set_ylabel('Cumulative density')
ax.set_xlabel('Number of motifs per tile')
ax.legend(loc='upper right')
ax.grid(False)
fig.savefig(config.FIGURES_DIR / 'fig_3J_number_motifs_per_cre_class_tile.pdf', format='pdf', bbox_inches='tight')
plt.close()


Proportion of tiles containing transcription start sites (TSS) in different CRE classes.


In [12]:
cage_tiles = pd.read_csv(config.RESULTS_DIR / "virus_paired_tiles_cage_peaks_overlap.csv")
cage_tiles = cage_tiles[cage_tiles.cell == 'K562'].reset_index(drop=True)
cage_tiles['virus'] = (cage_tiles.family + ', ' + cage_tiles.strain).astype('category').cat.set_categories(config.VIRUSES)
cage_tiles = cage_tiles.merge(virus_predictions.drop_duplicates('tile_id')[['tile_id', 'predicted_class_short']].astype(str), on='tile_id', how='left')
cage_tiles.loc[cage_tiles.predicted_class_short.isna(), 'predicted_class_short'] = 'Not CRE'
cage_tiles.predicted_class_short = cage_tiles.predicted_class_short.astype('category').cat.set_categories(config.ENCODE_CRE_TYPES_SHORT + ['Not CRE'])
class_cage_overlap = cage_tiles.value_counts(['virus', 'is_cage_peak', 'predicted_class_short'], sort=False).to_frame().reset_index()
class_cage_overlap.to_csv(config.RESULTS_DIR / 'fig_3K_cre_classes_cage_overlap.csv', index=False)


### TFBS Frequency in CRE Classes

Malinois K562 contribution scores and TF motif annotation for ENCODE CRE sequences are performed using [`tf_motif_annotation_encode_cres.py`](../scripts/tf_motif_annotation_encode_cres.py).


In [13]:
thresholds_df = pd.read_csv(config.RESULTS_DIR / 'thresholds_malinois_log2_1p.csv')[['cell', 'threshold']]
K562_THRESHOLD = thresholds_df[thresholds_df.cell == 'k562'].threshold.iloc[0]

tfbs_counts_tiles = pd.read_csv(config.RESULTS_DIR / "malinois_K562_tfbs_counts_virus_tiles.csv")
tfbs_counts_tiles = tfbs_counts_tiles[
    (tfbs_counts_tiles.malinois_k562_lfc >= K562_THRESHOLD) &
    (tfbs_counts_tiles.tile_id.isin(virus_predictions.tile_id))
].merge(virus_predictions.drop_duplicates('tile_id')[['tile_id', 'predicted_class_short']], on='tile_id', how='left')
tfbs_counts_tiles.virus = tfbs_counts_tiles.virus.astype('category').cat.set_categories(config.VIRUSES)

tfbs_counts_encode = pd.read_csv(config.RESULTS_DIR / "malinois_K562_tfbs_counts_encode_cres.csv")
tfbs_counts_encode = tfbs_counts_encode[tfbs_counts_encode.malinois_k562_lfc >= K562_THRESHOLD]


In [14]:
virus_means = tfbs_counts_tiles.groupby(
    ["virus", "predicted_class_short"], observed=False
)[config.TF_GENES_K562].mean().reset_index().set_index('virus')

encode_means = (
    tfbs_counts_encode.groupby('encode_region_type')[config.TF_GENES_K562].mean()
    .pipe(lambda df: pd.concat([
        df, pd.DataFrame([tfbs_counts_encode[config.TF_GENES_K562].mean()], index=['all CREs'])
    ]))
)

pseudocount = 1e-2
tfbs_enrichment = list()
for cre_type, virus_cres in virus_means.groupby('predicted_class_short', observed=True):
    if cre_type == 'Undetermined':
        continue
    encode_means.loc[cre_type]
    ratio = np.log2((virus_cres[config.TF_GENES_K562] + pseudocount).div(encode_means.loc[cre_type] + pseudocount)).reset_index()
    ratio.insert(0, 'cre_type', cre_type)
    tfbs_enrichment.append(ratio)
pd.concat(tfbs_enrichment, ignore_index=True).to_csv(config.RESULTS_DIR / "fig_S4B_log2_virus_over_encode_ratio_by_cre_types_ps1e-2.csv", index=False)


Genome-wide MPRA activity, including predicted CRE types, is visualized using the [`plot_mpra_activity_genomewide.py`](../scripts/plot_mpra_activity_genomewide.py) script.
