# TF binding perturbation (CIS-BP)

In [None]:
from typing import Dict, List, Optional, Tuple, Union
from MOODS import parsers, tools
from mpl_toolkits.axes_grid1 import make_axes_locatable
from pybedtools import BedTool
from scipy.stats import zscore
from glob import glob

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

import subprocess
import warnings
import os

warnings.filterwarnings("ignore")  # ignore warning messages

In [None]:
DATADIR = "data"
variants_file = os.path.join(DATADIR, "LDL_variants/LDLvar_credset.xlsx")
# read variants XLSX file
variants = pd.read_excel(variants_file, na_values=['*', 'NA', 'N/A', 'nan'])
variants.head()

In [None]:
# subset the report to keep columns used in this analysis
cols = ["rsid", "CHR", "HG19 BP", "A1", "A2", "fdr_adj", "MotifRaptor analysis"] 
variants = variants[cols].dropna(subset=[ "CHR", "HG19 BP", "A1", "A2"])  
variants.columns = ["ID", "CHR", "POS", "REF", "ALT", "SCORE", "TEST"]
variants.reset_index(drop=True, inplace=True)
variants

In [None]:
# cast CHR and POS columns data to right data type (int)
def correct_chrom(chrom: Union[float, str]) -> Union[int, str]:
    if isinstance(chrom, float) and (chrom > 0 and chrom < 23):
        return int(chrom)
    if isinstance(chrom, str) and chrom.upper() in ["X", "Y"]:
        return chrom.upper()
    raise ValueError(f"Unknown chromosome {chrom}")
variants["CHR"] = variants.apply(lambda x: correct_chrom(x[1]), axis=1)

def correct_position(pos: float) -> int:
    if pos < 0: 
        raise ValueError(f"Forbidden position: {pos}")
    return int(pos)

variants["POS"] = variants.apply(lambda x: correct_position(x[2]), axis=1)
variants

In [None]:
# remove indels from variants report (artifacts)
variants = variants[
    (variants.apply(lambda x: len(x[3]) == 1, axis=1)) & 
    (variants.apply(lambda x: len(x[4]) == 1, axis=1))
]
variants.reset_index(drop=True, inplace=True)
variants

In [None]:
# assign IDs to SNPs with NaN on ID column
def assign_snpid(snpid: str, chrom: str, pos: int, ref: str, alt: str) -> str:
    if str(snpid) == "nan":
        return f"{chrom}_{pos}_{ref}_{alt}"
    return snpid

variants["ID"] = variants.apply(
    lambda x: assign_snpid(x[0], x[1], x[2], x[3], x[4]), axis=1
)
variants

In [None]:
# subset the report to the marked variants
variants = variants[variants.TEST == "V"]
variants.reset_index(drop=True, inplace=True)
variants

In [None]:
# recover 61 bp genomic sequences centred around the SNP positon
GENOME = os.path.join(DATADIR, "hg19/hg19.fa")
snpids = variants.ID.tolist()
chroms = variants.CHR.tolist()
positions = variants.POS.tolist()
coordinates = {
    snpid: f"chr{chroms[i]}\t{positions[i] - 31}\t{positions[i] + 30}" 
    for i, snpid in enumerate(snpids)
}
bed = BedTool("\n".join(list(coordinates.values())), from_string=True)  # create BED object
bed = bed.sequence(fi=GENOME)  # recover sequences
sequences = {
    snpids[i]: sequence 
    for i, sequence in enumerate(
        [
            line.strip().upper() 
            for line in open(bed.seqfn).readlines() 
            if not line.startswith(">")
        ]
    )
}
sequences

In [None]:
# create reference and alternative sequence reports
def compute_alt_sequences(
    sequences: Dict[str, str], ref_alleles: List[str], alt_alleles: List[str]
) -> Dict[str, str]:
    sequences_alt = {}
    for i, snpid in enumerate(sequences):
        assert sequences[snpid][30] == ref_alleles[i]
        seqalt = sequences[snpid][:30] + alt_alleles[i] + sequences[snpid][31:]
        assert len(seqalt) == len(sequences[snpid])
        sequences_alt[snpid] = seqalt
    return sequences_alt

def write_fasta(sequences: Dict[str, str], outfasta: str) -> None:
    with open(outfasta, mode="w") as outfile:
        for snpid in sequences:
            outfile.write(f">{snpid}\n{sequences[snpid]}\n")
    assert os.path.isfile(outfasta) and os.stat(outfasta).st_size > 0

OUTDIR = "output"
fastafolder = os.path.join(OUTDIR, "fasta")
if not os.path.exists(fastafolder):
    os.mkdir(fastafolder)
ref_alleles = variants.REF.tolist()
alt_alleles = variants.ALT.tolist()
# write FASTA files
write_fasta(sequences, os.path.join(fastafolder, "ref.fa"))
write_fasta(
    compute_alt_sequences(sequences, ref_alleles, alt_alleles), 
    os.path.join(fastafolder, "alt.fa")
)

In [None]:
# read Cis-BP motifs and compute the corresponding PFMs
def read_motif_table(motif_table_file: str) -> Dict[str, str]:
    with open(motif_table_file, mode="r") as infile:
        motif_table = {
            fields[3]: fields[6]
            for line in infile
            for fields in [line.strip().split()]
        }
    return motif_table

def cisbp2pfm(motif: str) -> np.ndarray:
    with open(motif, mode="r") as infile:
        infile.readline()  # skip header
        pfm = [
            list(map(float, line.strip().split()[1:])) for line in infile
        ]
    motiflen = len(pfm)
    pfm = np.matrix(pfm).T
    return pfm, motiflen

CISBPMOTIFS = os.path.join(DATADIR, "TFs/cisbp")
cisbpmotifs = glob(os.path.join(CISBPMOTIFS, "pwms_all_motifs/*.txt"))
cisbp_table = read_motif_table(
    os.path.join(CISBPMOTIFS, "TF_Information_all_motifs_plus.txt")
)
pfms = {}
for motif in cisbpmotifs:
    basename = os.path.basename(motif).replace("txt", "pfm")
    motifname = "_".join(
        [basename.replace(".pfm", ""), cisbp_table[basename.replace(".pfm", "")]]
    )
    pfm, motiflen = cisbp2pfm(motif)
    if pfm.size == 0:
        continue
    pfmfile = os.path.join(CISBPMOTIFS, "pfms", basename)
    np.savetxt(pfmfile, pfm, fmt="%.6f", delimiter="\t")
    pfms[os.path.basename(motif).replace(".txt", "")] = (motifname, pfmfile, motiflen)
pfms

In [None]:
# compute PWMs from PFMs
P = 0.0001  # pseudocount
BG = tools.flat_bg(4)  # uniform background distribution

def minmax_vals(pwm: np.ndarray) -> Tuple[float, float]:
    return np.sum(pwm.min(axis=0)), np.sum(pwm.max(axis=0))

def compute_pwm(pfmfile: str, outfolder: str) -> np.ndarray:
    basename = os.path.basename(pfmfile).replace("pfm", "pwm")
    assert basename.endswith(".pwm")
    pwm = np.matrix(parsers.pfm_to_log_odds(pfmfile, BG, P))
    minval, maxval = minmax_vals(pwm)
    pwmfile = os.path.join(outfolder, basename)
    np.savetxt(pwmfile, pwm, fmt="%.6f", delimiter="\t")
    return pwmfile, minval, maxval

pwms = {
    motifname: compute_pwm(pfms[motifname][1], os.path.join(CISBPMOTIFS, "pwms")) 
    for motifname in pfms
}
pwms

In [None]:
# scan reference and alternative sequences using MOODS
MOODSSCAN = "moods-dna.py"
def moods_scan(
    fastafile: str, pwmfiles: str, outfile: str, pvalue: Optional[float] = 1
) -> None:
    code = subprocess.call(
        f"{MOODSSCAN} -S {pwmfiles} -s {fastafile} -p {pvalue} -o {outfile} --batch",
        shell=True,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.STDOUT,
    )
    assert code == 0

scandir = os.path.join(OUTDIR, "moods_scan_results")
if not os.path.exists(scandir):
    os.mkdir(scandir)
pwmfiles = os.path.join(CISBPMOTIFS, "pwms/*.pwm")
for fastafile in glob(os.path.join(fastafolder, "*.fa")):
    outfile = os.path.join(scandir, os.path.basename(fastafile).replace("fa", "txt"))
    assert outfile.endswith(".txt")
    moods_scan(fastafile, pwmfiles, outfile)

In [None]:
# scan reference and alternative sequences using MOODS and recover significant hits
P = 0.0001

scandir = os.path.join(OUTDIR, "moods_scan_results_hits")
if not os.path.exists(scandir):
    os.mkdir(scandir)
pwmfiles = os.path.join(CISBPMOTIFS, "pwms/*.pwm")
for fastafile in glob(os.path.join(fastafolder, "*.fa")):
    outfile = os.path.join(scandir, os.path.basename(fastafile).replace("fa", "txt"))
    assert outfile.endswith(".txt")
    moods_scan(fastafile, pwmfiles, outfile, pvalue=P)

In [None]:
# create the reports from scanning results
SNPPOS = 31  # snps occur at position 31 

def recover_motifname(motif: str) -> str:
    return motif.replace(".pwm", "")

def filter_motif_position(scanres: pd.DataFrame, pfms: Dict[str, Tuple]) -> pd.DataFrame:
    positions = {
        motifname: list(range(SNPPOS - pfms[motifname][2], SNPPOS)) 
        for motifname in pfms
    }
    return scanres[scanres.apply(lambda x: x[2] in positions[x[1]], axis=1)]

def relative_score(score: float, minscore: float, maxscore: float) -> float:
    return (score - minscore) / (maxscore - minscore)

def construct_report(
    report: str, pfms: Dict[str, Tuple], pwms: Dict[str, Tuple]
) -> pd.DataFrame:
    scanres = pd.read_csv(report, header=None).drop([6], axis=1)
    scanres.columns = ["SNPID", "MOTIF", "POS", "STRAND", "SCORE", "SEQUENCE"]
    scanres["MOTIF"] = scanres.apply(lambda x: recover_motifname(x[1]), axis=1)
    # filter motif hits not overlapping the SNP
    scanres = filter_motif_position(scanres, pfms)
    # compute relative scores
    scanres["RELATIVESCORE"] = scanres.apply(
        lambda x: relative_score(x[4], pwms[x[1]][1], pwms[x[1]][2]), axis=1
    )
    return scanres

scandir = os.path.join(OUTDIR, "moods_scan_results")
scanreport_ref = construct_report(os.path.join(scandir, "ref.txt"), pfms, pwms)
scanreport_alt = construct_report(os.path.join(scandir, "alt.txt"), pfms, pwms)
scanreport_ref.head(), scanreport_alt.head()

In [None]:
# merge the two reports and compute the disruption score
def disruption_score(score_ref: float, score_alt: float) -> float:
    return score_alt - score_ref

def label_disruption(disruption: float) -> str:
    if disruption < 0:
        return "decrease"
    elif disruption > 0:
        return "increase"
    return "equal"

scanreport = scanreport_ref.merge(
    scanreport_alt, on=["SNPID", "MOTIF", "POS", "STRAND"]
)
scanreport.columns = [
    "SNPID", "MOTIF", "POS", "STRAND", "SCORE_REF", "SEQUENCE_REF", "RELATIVESCORE_REF", "SCORE_ALT", "SEQUENCE_ALT", "RELATIVESCORE_ALT"
]
scanreport["DISRUPTION"] = scanreport.apply(
    lambda x: disruption_score(x[6], x[9]), axis=1
)  # compute disruption score
# label the SNP effect on TF binding according to the disruption score
scanreport["LABEL"] = scanreport.apply(lambda x: label_disruption(x[-1]), axis=1)
scanreport.head()

In [None]:
# check wheter motif matches are hits in the reference, alternative or both sequences
def recover_hits(hits: str) -> List[str]:
    hits = pd.read_csv(hits, header=None).drop([6], axis=1)
    hits.columns = ["SNPID", "MOTIF", "POS", "STRAND", "SCORE", "SEQUENCE"]
    hits["MOTIF"] = hits.apply(lambda x: recover_motifname(x[1]), axis=1)
    # filter motif hits not overlapping the SNP
    hits = filter_motif_position(hits, pfms)
    return list(hits.apply(lambda x: f"{x[0]}_{x[1]}_{x[2]}_{x[3]}", axis=1))

def merge_hits(hits_ref: List[str], hits_alt: List[str]) -> Dict[str, str]:
    hits = {h: "R" for h in set(hits_ref).difference(set(hits_alt))}
    for h in set(hits_alt).difference(set(hits_ref)):
        hits[h] = "A"
    for h in set(hits_ref).intersection(set(hits_alt)):
        hits[h] = "B"
    return hits

def assign_hit(snpid: str, motif: str, pos: int, strand: str, hits: Dict[str, str]) -> str:
    try:
        return hits[f"{snpid}_{motif}_{pos}_{strand}"]
    except KeyError:
        return "N"

hits = merge_hits(
    recover_hits(os.path.join(OUTDIR, "moods_scan_results_hits/ref.txt")),
    recover_hits(os.path.join(OUTDIR, "moods_scan_results_hits/alt.txt"))
)
scanreport["HIT"] = scanreport.apply(
    lambda x: assign_hit(x[0], x[1], x[2], x[3], hits), axis=1
)
assert (
    (
        scanreport[scanreport.HIT == "R"].shape[0] + 
        scanreport[scanreport.HIT == "A"].shape[0] +
        scanreport[scanreport.HIT == "B"].shape[0]
    ) == len(hits)
)
scanreport

In [None]:
# interpolate FPKM data in the report
def recover_fpkm(counts_file: str) -> Dict[str, int]:
    with open(counts_file, mode="r") as infile:
        infile.readline()  # skip header
        fpkm = {
            fields[0]: float(fields[6]) 
            for line in infile 
            for fields in [line.strip().split()]
        }
    return fpkm

def interpolate_fpkm(motifname: str, fpkm: Dict[str, int]) -> pd.DataFrame:
    try:
        value = fpkm[motifname.split("_")[-1].upper()]
        if str(value) == "nan":
            return 0
        return value
    except KeyError:
        return 0

fpkm = recover_fpkm(os.path.join(DATADIR, "HepG2_rnaseq/ENCFF103FSL_genesymbols.tsv"))
scanreport["FPKM"] = scanreport.apply(
    lambda x: interpolate_fpkm(pfms[x[1]][0], fpkm), axis=1
)
scanreport["FPKM_ZSCORE"] = zscore(scanreport["FPKM"])
scanreport.head()

In [None]:
# plot SNP impact on TF binding
def select_snp_data(report: pd.DataFrame, snpid: str) -> pd.DataFrame:
    report = report[report.SNPID == snpid]
    # report["DISRUPTION_ABS"] = np.abs(report["DISRUPTION"])
    # values_table = report.groupby("MOTIF")["DISRUPTION_ABS"].max().to_dict()
    # report = report[
    #     report.apply(lambda x: x["DISRUPTION_ABS"] == values_table[x[1]], axis=1)
    # ]
    # report = report.loc[report.groupby("MOTIF")["DISRUPTION_ABS"].idxmax()]
    # report = report.drop(["DISRUPTION_ABS"], axis=1)
    return report

def plot_tf_disruption(
    snpreport: pd.DataFrame, 
    pfms: Dict[str, Tuple],
    outfile: str,
    binding_threshold: Optional[float] = 0.9, 
) -> None:
    f, ax = plt.subplots(
        1, 1, figsize=(10, 8), dpi=300, facecolor="w", edgecolor="w"
    )
    norm = plt.Normalize(vmin=0, vmax=max(snpreport["FPKM"]))
    snpreport = snpreport.sort_values(["FPKM"], ascending=True)
    markers = {"R": "v", "A": "^", "B": "*", "N": "s"}
    for hl in ["N", "R", "A", "B"]:
        snpreport_hit = snpreport[snpreport.HIT == hl]
        alpha = 0.25 if hl == "N" else 1
        s = ax.scatter(
            snpreport_hit["DISRUPTION"],
            snpreport_hit["RELATIVESCORE_REF"],
            s=50,  # 50 pixel
            c=snpreport_hit["FPKM"],
            cmap=plt.get_cmap("coolwarm"),  # reds colormap
            marker=markers[hl],  # squares,
            alpha=alpha,
            norm=norm,
        )
        if hl in ["R", "A", "B"]:
            hits = snpreport_hit[
                (snpreport_hit.RELATIVESCORE_REF > binding_threshold) |
                (snpreport_hit.RELATIVESCORE_ALT > binding_threshold)
            ]
            hits.apply(
                lambda x: ax.text(x[10] + 0.05, x[6], pfms[x[1]][0].split("_")[-1], fontsize=10),
                axis=1
            )
    ax.set_xlabel("Disruption", size=14)
    ax.set_xlim([-1.1, 1.1])
    ax.set_ylabel("Binding affinity", size=14)
    ax.set_ylim([-0.1, 1.1])
    ax.set_title(f"TFs affected by SNP {snpid}", size=16)
    ax.grid(True)
    ax.axhline(.5, ls="--", color="gray")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.5)
    cax.set_title("Expression FPKM", size=14)
    f.colorbar(s, cax=cax, orientation="vertical")
    plt.savefig(outfile, format="PNG", dpi=300)

figuresdir = os.path.join(OUTDIR, "figures")
if not os.path.exists(figuresdir):
    os.mkdir(figuresdir)
for snpid in snpids:
    snpreport = select_snp_data(scanreport.copy(), snpid)
    plot_tf_disruption(
        snpreport, pfms, os.path.join(figuresdir, f"TF_binding_impact_{snpid}.png")
    )