In [None]:
import pandas as pd
from sklearn.metrics import average_precision_score, roc_auc_score
import numpy as np
import os
import sys
from sklearn.linear_model import LogisticRegressionCV
from gimmemotifs.motif import read_motifs
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.preprocessing import scale, MinMaxScaler, minmax_scale
from glob import glob
import random
import networkx as nx
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import qnorm
import re

from loguru import logger
logger.remove()
logger.add(sys.stderr, level="DEBUG")

%matplotlib inline

# Train models

Warning: takes a lot of memory! It is not optimized for performance right now.

Required layout:

```
train_dir/
    peaks/
        TF1.cell_type1.narrowPeak
        TF2.cell_type1.narrowPeak
        TF1.cell_type2.narrowPeak
    remap_overlap/
    {date}_trained/
```

# Config: files and such

In [None]:
BASEDIR = "/bheuts/ANANSE/hg19/model_training/ANANSE-CAGE"

# Directory for training data
train_dir = f"{BASEDIR}"  # for layout see above


# Output directory for models
out_dir = f"{BASEDIR}/2022-05-04_trained"

# Output directory for models
out_overlap = f"{BASEDIR}/remap_overlap"

    
# CAGE cell line data
ref_bed = f"{BASEDIR}/CAGE.enhancers.bed"

# Results of motif scan, code does currently not include generating this:
# !gimme scan {ref_bed} -g "/ceph/rimlsfnwi/data/molbio/martens/bheuts/ANANSE/genome/hg19/hg19.fa" -T > {BASEDIR}/CAGE.enhancers.motifs.txt
motif_scan_file = f"{BASEDIR}/CAGE.enhancers.motifs.txt" 

# Coverage file
coverage_bw = f"{BASEDIR}/remap2022.hg19.w50.bw"

#Cell type name mapping, to match cell types of CAGE to cell type from ReMap
ct_map = {
    "HepG2":"Hep-G2",
    "HELA":"HeLa-S3", 
#     "LNCAP":"LNCaP",
    "H1-hESC":"hESC", 
    "K562":"K-562",
    "H9-hESC":"hESC", 
    "HEPG2": "Hep-G2", 
#     "MCF7": "MCF-7"
}

cell_types = ["GM12878"]  # extra cell types
cell_types = set(cell_types + list(ct_map.values()))

force_rerun = True
 
# Remove old files
if os.path.exists(f"{out_dir}"):
    ! rm -r {out_dir}
    ! mkdir {out_dir}
    
if os.path.exists(f"{out_overlap}"):
    ! rm -r {out_overlap}
    ! mkdir {out_overlap}
    
if not os.path.exists(out_overlap):
    os.makedirs(out_overlap)
    
if os.path.exists(f"{train_dir}/all_tfs_y_true.feather"):
    ! rm {train_dir}/all_tfs_y_true.feather
    
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# Config: models and such

In [None]:
# Base model
# model = RandomForestClassifier(n_jobs=-1, n_estimators=500, class_weight="balanced")
model = LogisticRegressionCV(class_weight="balanced", n_jobs=24, )

# Evaluation
scores = {"pr_auc":average_precision_score, "roc_auc":roc_auc_score}

# Specify test chromosome by regex
test_chrom_regex = "chr(1|8|21)[^\d]"

# CAGE data pre-processing

In [None]:
# CAGEfightR bidirectional sites QC

df = pd.read_table(f"bidirectional.merged_headers.txt")
# print(df.head())

df2 = df.copy()
df2 = df2.set_index("Id")
for col in df2.columns.values:
    df2[col] = df2[col] + 1
    df2[col] = np.log10(df2[col])

plt.hist(df2.sum(1), range=[0, 2]);

In [None]:
# centre bidirectional regions and set a window of 200 bp

regions = df["Id"].str.split('[:-]', expand=True)
regions.columns = ["chrom", "start", "end"]
regions["start"] = regions["start"].astype(int)
regions["end"] = regions["end"].astype(int)

center = ((regions["start"] + regions["end"]) / 2).astype(int)
regions["start"] = center - 100
regions["end"] = center + 100

regions.to_csv(ref_bed, header=False, index=False, sep="\t")

loc = regions["chrom"] + ":" + regions["start"].astype(str) + "-"  + regions["end"].astype(str)

df.index = loc
df = df.drop(columns="Id")
data = df
# print(data)

# Read motifs and factors

Read a GimmeMotifs databases and load associated factors. Only factors that are TFs according to [Lovering et al. 2020](https://www.biorxiv.org/content/10.1101/2020.10.28.359232v2.full) are used.


In [None]:
valid_factors = pd.read_excel(
    "https://www.biorxiv.org/content/biorxiv/early/2020/12/07/2020.10.28.359232/DC1/embed/media-1.xlsx",
    engine='openpyxl', sheet_name=1)
valid_factors = valid_factors.loc[valid_factors["Pseudogene"].isnull(), "HGNC approved gene symbol"].values
valid_factors = [f for f in valid_factors if f not in ["EP300", "EZH2"]]
print(f"{len(valid_factors)} TFs")

# Create overlap with TF (CAGE) peaks and ReMap enhancers

In [None]:
cell_types

In [None]:
for cell_type in cell_types:
    for fname in glob(f"{train_dir}/peaks/*.{cell_type}.narrowPeak"):
        tf, cell_type = fname.split("/")[-1].split(".")[:2]
        if not (ct_map.get(cell_type, cell_type) in cell_types and tf in valid_factors):
            logger.debug(f"skipping {tf} {cell_type}")
            continue
  
        if not os.path.exists(f"{out_overlap}/{tf}.{cell_type}.enhancers.txt"):
            logger.debug(f"converting {tf} {cell_type}")
            !bedtools intersect -a {ref_bed} -b {train_dir}/peaks/{tf}.{cell_type}.narrowPeak  -c > {out_overlap}/{tf}.{cell_type}.enhancers.bed
            !cat {out_overlap}/{tf}.{cell_type}.enhancers.bed | sed 's/\t/:/' | sed 's/\t/-/' > {out_overlap}/{tf}.{cell_type}.enhancers.txt

In [None]:
motifs = read_motifs(as_dict=True)
indirect = True
f2m = {}
for name, motif in motifs.items():
    for k, factors in motif.factors.items():
        if k != "direct" and not indirect:
            print("skip")
            continue
        for factor in factors:
            f2m.setdefault(factor.upper(), []).append(name)

# Filter for valid TFs
f2m = {k:v for k,v in f2m.items() if k in valid_factors}

# Only use TFs for which we have data
factors = list(set([x.split(".")[0].split("/")[-1] for x in glob(f"{out_overlap}/*enhancers.txt")]))
valid_factors = [f for f in factors if f in valid_factors]

print(len(valid_factors), "factors")

In [None]:
# Create one big dataframe with all TF peak overlap with reference enhancer set

fnames = glob(f"{train_dir}/peaks/*narrowPeak")
fnames = [fname for fname in fnames if re.search("|".join([x + '[.]' for x in cell_types]), fname)]

y_true_file = f"{train_dir}/all_tfs_y_true.feather"

if force_rerun or not os.path.exists(y_true_file):
    y_true = pd.DataFrame()
    for fname in tqdm(fnames):
        m = re.search(r"(\w+)\.([\w-]+)\..*", fname)
        factor = m.group(1)
        cell_type = m.group(2)
        if factor not in factors:
            continue        
        try:
            y_true[f"{factor}.{cell_type}"] = pd.read_table(f"{out_overlap}/{factor}.{cell_type}.enhancers.txt", 
                                                                index_col=0, names=[f"{factor}.{cell_type}"]).iloc[:,0]
        except Exception as e:
            print(fname, e)
    y_true.reset_index().to_feather(y_true_file)
else:
    logger.debug("reading y_true")
    y_true = pd.read_feather(f"{train_dir}/all_tfs_y_true.feather")
    y_true = y_true.set_index(y_true.columns[0])
    
y_true[y_true > 1] = 1  # peak, yes or no

for column in y_true: # Remove TFs that have 0 true peaks
    if (y_true[column] == 1).sum() == 0:
        print(column)
        print(y_true[column].sum())
        y_true.drop(column, axis=1, inplace=True)

# Remap coverage

This represents the average binding of TFs across cell types.

In [None]:
remap_file = f"{out_dir}/reference.coverage.txt"
if force_rerun or not os.path.exists(remap_file):
    !coverage_table -p {ref_bed} -d {coverage_bw} > {remap_file}
remap_cov = pd.read_table(remap_file, sep="\t", comment="#", index_col=0)
remap_cov.rename(columns={remap_cov.columns[0]:"average"}, inplace=True)
remap_cov["average"] = remap_cov["average"] / remap_cov["average"].max()

# Enhancer data

Transform and normalize CAGE enhancer expression data (TPMs)

In [None]:
data = np.log1p(data)
data = qnorm.quantile_normalize(data)
data.loc[:,:] = minmax_scale(data)

data.columns = data.columns + ".CAGE"
tables = {
    "CAGE":data
}

# Create base files for X

In [None]:
X_base = remap_cov

# Big motif scan table
gimme = pd.read_table(motif_scan_file, index_col=0, comment="#")

# First pass for benchmark

In [None]:
## print(sorted(cell_types))
print(sorted(tables["CAGE"].columns))

In [None]:
params = [
    ("motif", "CAGE"),
    ("average", "CAGE"),
    ("average", "motif", "CAGE"),
]

In [None]:
ct_benchmark = []

In [None]:
factors = y_true.columns.str.replace("\..*", "").value_counts()
factors = factors[factors > 1].index

test_idx = y_true.index[y_true.index.str.contains(test_chrom_regex)]
train_idx = y_true.index[~y_true.index.str.contains(test_chrom_regex)]

marks = list(tables.keys())
meanref = {mark:tables[mark].mean(1) for mark in marks}
common_f = [f for f in factors if f in valid_factors]
for factor in common_f:
    
    if factor not in valid_factors or factor not in f2m:
        logger.debug(f"Skipping {factor}, not a TF or no motif known")
        continue
    if factor in set([x[0] for x in ct_benchmark]):
        logger.debug(f"Skipping {factor}, already done")
        continue
    logger.info(f"Model benchmark: {factor}")
    
    cols = y_true.columns[y_true.columns.str.contains(f"^{factor}\.")]
    cell_types = [col.split(".")[-1] for col in cols]
    
    cell_types = [ct for ct in cell_types if ct != "LNCaP"]
    print(cell_types)
    if len(cell_types) == 1:
        # Can't check in other cell-types
        continue
    motif_frame = gimme[f2m[factor]].mean(1).to_frame("motif")
    X = X_base.join(motif_frame)
    #X = X_base
    
    tmp = pd.DataFrame()
    for ct in cell_types:
        a = [tables[mark][[f"{ct}.{mark}"]] for mark in marks] 
        #b = [tables[mark][[f"{ct}.{mark}"]].sub(meanref[mark], axis=0).rename(columns={f"{ct}.{mark}":f"{ct}.{mark}.relative"}) for mark in ["ATAC"]]
        ct_frame = pd.concat(a  + [y_true[f"{factor}.{ct}"].rename("y_true")], axis=1)
        ct_frame.columns = ct_frame.columns.str.replace(f"{ct}.", "")
        ct_frame["cell_type"] = ct
        tmp = pd.concat([tmp, X.join(ct_frame)])
    
    X = tmp
    for test_cell in cell_types:
        print(test_cell)
        train_cells = [c for c in cell_types if c != test_cell]
        
        train = X.loc[X.index.intersection(train_idx),:]
        test = X.loc[X.index.intersection(test_idx),:]

        X_train = train[train["cell_type"] != test_cell].drop(columns=["cell_type"])
        X_test = test[test["cell_type"] == test_cell].drop(columns=["cell_type"])
        logger.debug(f"X_train: {str(X_train.shape)}")
        
        X_train = X_train.reset_index(drop=True)
        if X_train[X_train["y_true"] == 0].shape[0] >= 100000:
            X_train = pd.concat((X_train[X_train["y_true"] == 1], X_train[X_train["y_true"] == 0].sample(100000)))
        y_train = X_train[["y_true"]]
        y_test = X_test[["y_true"]]
        

#         print(y_train)
#         print(y_test)
        
        if y_train.sum()[0] < 50 or y_test.sum()[0] < 50:
            print("skipping, not enough peaks")
            continue
        
        X_train = X_train.drop(columns=["y_true"])
        X_test = X_test.drop(columns=["y_true"])
        

#         print(X_train.sample(10))
        #print(X_test.sample(10))
        

        print(f"{factor}\t{test_cell}\tFitting models...")
        print(f"{factor}\t{test_cell}\tbaseline\tpr_auc\t{y_test.mean()[0]:.3f}")
        ct_benchmark.append([factor, test_cell, "baseline", "pr_auc", y_test.mean()[0]])
        for param_set in params:
            param_columns = sorted(param_set)
            model_name = "_".join(param_columns)
            model.fit(X_train[param_columns], y_train)
            y_pred = model.predict_proba(X_test[param_columns])[:,1]
            for name, func in scores.items():
                score = func(y_test, y_pred)
                ct_benchmark.append([factor, test_cell, model_name, name, score])
                print(f"{factor}\t{test_cell}\t{model_name}\t{name}\t{score:.3f}")

        # score baselines
        for base in ["CAGE", "average", "motif"]:
            for name, func in scores.items():
                score = func(y_test, X_test[base])
                ct_benchmark.append([factor, test_cell, f"{base}.baseline", name, score])
                if base in ["average", "CAGE"]:
                    # Relevant baseline: average binding across cell types
                    print(f"{factor}\t{test_cell}\t{base}.baseline\t{name}\t{score:.3f}")        
    
    ct_benchmark_lr = pd.DataFrame(ct_benchmark, columns=["factor", "test_cell_type", "model", "score", "value"])

    ct_benchmark_lr.to_csv(f"{out_dir}/benchmark.txt", sep="\t")

In [None]:
ct_benchmark_lr = pd.DataFrame(ct_benchmark, columns=["factor", "test_cell_type", "model", "score", "value"])
ct_benchmark_lr[ct_benchmark_lr["score"] == "pr_auc"].median()
ct_benchmark_lr.to_csv(f"{out_dir}/benchmark_pr_fp.txt", sep="\t")

order = ct_benchmark_lr[ct_benchmark_lr["score"] == "pr_auc"].groupby(["model", "score"]).median().reset_index().sort_values("value")["model"].values
sns.boxplot(data=ct_benchmark_lr[ct_benchmark_lr["score"]=="pr_auc"], x="value", y="model", order=order)
plt.xlabel("Precision-Recall AUC")
plt.xlim(0,1)
plt.tight_layout()
plt.savefig("PR_AUC_first_pass_Benchmark.pdf")

In [None]:
ct_benchmark_lr = pd.DataFrame(ct_benchmark, columns=["factor", "test_cell_type", "model", "score", "value"])
ct_benchmark_lr[ct_benchmark_lr["score"] == "roc_auc"].median()
ct_benchmark_lr.to_csv(f"{out_dir}/benchmark_roc_fp.txt", sep="\t")

order = ct_benchmark_lr[ct_benchmark_lr["score"] == "roc_auc"].groupby(["model", "score"]).median().reset_index().sort_values("value")["model"].values
sns.boxplot(data=ct_benchmark_lr[ct_benchmark_lr["score"]=="roc_auc"], x="value", y="model", order=order)
plt.xlabel("Receiver-operator characteristic (ROC) AUC")
plt.xlim(0,1)
plt.tight_layout()
plt.savefig("ROC_AUC_first_pass_Benchmark.pdf")

# Second pass: training full models

Trained models are saved as pickles using joblib. Not sure if this is the most safe / optimal approach.

In [None]:
import joblib
factors = y_true.columns.str.replace("\..*", "").value_counts()
factors = factors.index


X_general = pd.DataFrame()

for factor in factors:
    if factor not in valid_factors or factor not in f2m:
        logger.debug(f"Skipping {factor}, not a TF or no motif known")
        continue
    print(factor)
    cols = y_true.columns[y_true.columns.str.contains(f"^{factor}\.")]
    cell_types = [col.split(".")[-1] for col in cols]
    cell_types = [ct for ct in cell_types if ct != "LNCaP"]
    #cell_types = [ct for ct in cell_types if ct != "MCF-7"]
    
    motif_frame = gimme[f2m[factor]].mean(1).to_frame("motif")
    X = X_base.join(motif_frame)
    
    tmp = pd.DataFrame()
    for ct in cell_types:         
#         a = [tables[mark][[f"{ct}.{mark}"]].sub(meanref[mark], axis=0).apply(scale).rename(columns={f"{ct}.{mark}":f"{ct}.{mark}.relative"}) for mark in ["ATAC"]]
        b = [tables[mark][[f"{ct}.{mark}"]] for mark in marks] 

        ct_frame = pd.concat(b + [y_true[f"{factor}.{ct}"].rename("y_true")], axis=1)
        ct_frame.columns = ct_frame.columns.str.replace(f"{ct}.", "")
        ct_frame["cell_type"] = ct
        tmp = pd.concat([tmp, X.join(ct_frame)])
        #print(tmp.head())
    
    X = tmp
    if not "y_true" in X.columns:
        continue
    # Use all positive regions, and randomly sample negative regions
    # Initially 100,000 random negative regions were sampled, but I changed it to randomly sample 
    # total rows - all true positives 
    X = pd.concat((X[X["y_true"] == 1], X[X["y_true"] == 0].sample((len(X) - (X.y_true == 1).sum()))))
    
    X_general = pd.concat((X_general, X.sample(5000))) # Use sample for general model
    y = X[["y_true"]]
    X = X.drop(columns=["y_true"])
    X = X.rename(columns={"remap.w50":"average"})
    
    for param_set in params:
        print(f"{factor}\tFitting model...")
        
        model.fit(X[sorted(param_set)], y)
        dirname = os.path.join(out_dir, "_".join(sorted(param_set)))
        fname = os.path.join(dirname, f"{factor}.pkl"  )
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        joblib.dump(model, fname)

y = X_general[["y_true"]]
X = X_general.drop(columns=["y_true"])
X = X.rename(columns={"remap.w50":"average"})


# Make a general model, that is not TF specific
for param_set in params:
    print(f"Fitting general model...")
    #print(X.head())
    model.fit(X[sorted(param_set)], y)
    
    
    
    dirname = "_".join(sorted(param_set))
    fname = os.path.join(out_dir, dirname, "general.pkl"  )
#     if not os.path.exists(dirname):
#         os.makedirs(dirname)
    joblib.dump(model, fname)


# Third part: benchmarking the trained general model

In [None]:
factors = y_true.columns.str.replace("\..*", "").value_counts()
factors = factors[factors > 1].index

test_idx = y_true.index[y_true.index.str.contains(test_chrom_regex)]
train_idx = y_true.index[~y_true.index.str.contains(test_chrom_regex)]

marks = list(tables.keys())
meanref = {mark:tables[mark].mean(1) for mark in marks}
common_f = [f for f in factors if f in valid_factors]
for factor in common_f:
    
    if factor not in valid_factors or factor not in f2m:
        logger.debug(f"Skipping {factor}, not a TF or no motif known")
        continue
#     if factor in set([x[0] for x in ct_benchmark]):
#         logger.debug(f"Skipping {factor}, already done")
#         continue
    logger.info(f"Model benchmark: {factor}")
    
    cols = y_true.columns[y_true.columns.str.contains(f"^{factor}\.")]
    cell_types = [col.split(".")[-1] for col in cols]
    
    cell_types = [ct for ct in cell_types if ct != "LNCaP"]
    print(cell_types)
    if len(cell_types) == 1:
        # Can't check in other cell-types
        continue
    motif_frame = gimme[f2m[factor]].mean(1).to_frame("motif")
    X = X_base.join(motif_frame)
    
    tmp = pd.DataFrame()
    for ct in cell_types:
        a = [tables[mark][[f"{ct}.{mark}"]] for mark in marks] 
#         b = [tables[mark][[f"{ct}.{mark}"]].sub(meanref[mark], axis=0).rename(columns={f"{ct}.{mark}":f"{ct}.{mark}.relative"}) for mark in ["ATAC"]]
        ct_frame = pd.concat(a + [y_true[f"{factor}.{ct}"].rename("y_true")], axis=1)
        ct_frame.columns = ct_frame.columns.str.replace(f"{ct}.", "")
        ct_frame["cell_type"] = ct
        tmp = pd.concat([tmp, X.join(ct_frame)])
    
    X = tmp
    for test_cell in cell_types:
        test = X.loc[test_idx]
        X_test = test[test["cell_type"] == test_cell].drop(columns=["cell_type"])
        y_test = X_test[["y_true"]]
        
        if y_test.sum()[0] < 50:
            print("skipping, not enough peaks")
            continue
        
        X_test = X_test.drop(columns=["y_true"])
        
        for param_set in params:
            param_columns = sorted(param_set)
            model_name = "_".join(param_columns)
            
            dirname = "_".join(sorted(param_set))
            fname = os.path.join(out_dir, dirname, "general.pkl"  )
            model = joblib.load(fname)
            
            y_pred = model.predict_proba(X_test[param_columns])[:,1]
            for name, func in scores.items():
                score = func(y_test, y_pred)
                ct_benchmark.append([factor, test_cell, f"{model_name}.general", name, score])
                print(f"{factor}\t{test_cell}\t{model_name}.general\t{name}\t{score:.3f}")
    ct_benchmark_lr = pd.DataFrame(ct_benchmark, columns=["factor", "test_cell_type", "model", "score", "value"])
    ct_benchmark_lr.to_csv(f"{out_dir}/benchmark.with_general.txt", sep="\t")

In [None]:
ct_benchmark_lr = pd.DataFrame(ct_benchmark, columns=["factor", "test_cell_type", "model", "score", "value"])
ct_benchmark_lr[ct_benchmark_lr["score"] == "pr_auc"].median()
ct_benchmark_lr.to_csv(f"{out_dir}/benchmark_pr_gm.txt", sep="\t")

order = ct_benchmark_lr[ct_benchmark_lr["score"] == "pr_auc"].groupby(["model", "score"]).median().reset_index().sort_values("value")["model"].values
sns.boxplot(data=ct_benchmark_lr[ct_benchmark_lr["score"]=="pr_auc"], x="value", y="model", order=order)
plt.xlabel("Precision-Recall AUC")
plt.xlim(0,1)
plt.tight_layout()
plt.savefig("PR_AUC_full_Benchmark.pdf")

In [None]:
ct_benchmark_lr = pd.DataFrame(ct_benchmark, columns=["factor", "test_cell_type", "model", "score", "value"])
ct_benchmark_lr[ct_benchmark_lr["score"] == "roc_auc"].median()
ct_benchmark_lr.to_csv(f"{out_dir}/benchmark_roc_gm.txt", sep="\t")

order = ct_benchmark_lr[ct_benchmark_lr["score"] == "roc_auc"].groupby(["model", "score"]).median().reset_index().sort_values("value")["model"].values
sns.boxplot(data=ct_benchmark_lr[ct_benchmark_lr["score"]=="roc_auc"], x="value", y="model", order=order)
plt.xlabel("Receiver-operator characteristic (ROC) AUC")
plt.xlim(0,1)
plt.tight_layout()
plt.savefig("ROC_AUC_full_Benchmark.pdf")

#### 