In [None]:
import matplotlib, os
os.environ["QT_QPA_PLATFORM"] = "offscreen"
matplotlib.use("agg")

import collections, sys
from matplotlib import pyplot as plt
from matplotlib import gridspec as gs

In [None]:
WORKSPACE = "/data/gordanlab/samkim/dna-repair-tf"
RUN_ID = "proximal-DHS_mergedbg"  # run ID
WHICH_DATA = "small_dhs"  # data group name
SUFFIX = "pro"
# suffix to filename (currently only for merged and merged-bg)

In [None]:
### Get mutation counts per position, and per position per TF ###
def get_dists(mut_dataset_name=None, suffix=None):
    if mut_dataset_name is None:
        return {}, {}
    
    mut_list = []
    with open(
        "{}/data/ssm.open.{}_{}{}_centered.bed".format(
            WORKSPACE, RUN_ID, mut_dataset_name, "_" + suffix if suffix is not None else ""
        )
    ) as f:
        for line in f:
            _, dist, _, mut, tf = line.strip().split()
            mut_list.append((int(dist), mut, tf))

    counts = collections.defaultdict(int)
    for dist, _, _ in mut_list:
        counts[dist] += 1

    counts_by_tf = {}
    for dist, mut, tf in mut_list:
        if tf not in counts_by_tf:
            counts_by_tf[tf] = collections.defaultdict(int)  # initialize
        counts_by_tf[tf][dist] += 1

    return counts, counts_by_tf

In [None]:
### Create mutation profile plots ###
def plot_dists(
    counts,
    counts_by_tf,
    counts_2=None,
    counts_2_by_tf=None,
    name=None,
    figsize=(10, 14),
    h1=0.2,
    h2=0.6,
    w2=0.5,
    save_fig=False,
):
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(hspace=h1)
    outer = gs.GridSpec(2, 1, height_ratios=[2, 5])
    gs1 = gs.GridSpecFromSubplotSpec(1, 6, subplot_spec=outer[0])
    gs2 = gs.GridSpecFromSubplotSpec(5, 6, subplot_spec=outer[1], hspace=h2, wspace=w2)

    # Create large graph of all TFs
    ax1 = plt.subplot(gs1[:, 1:-1])
    
    # Default, or bound DHS
    X1 = sorted([int(dist) for dist in counts.keys()])
    y1 = [counts[dist] for dist in X1]
    ax1.plot(X1, y1)
    
    # Unbound DHS
    if counts_2 is not None and counts_2_by_tf is not None:
        X1_2 = sorted([int(dist) for dist in counts_2.keys()])
        y1_2 = [counts_2[dist] for dist in X1_2]
        ax1.plot(X1_2, y1_2)
    
    # Style large graph
    ax1.set_xlim(-1000, 1000)
    ax1.set_ylim(0, None)
    ax1.set_xlabel("Distance from TFBS center (bp)")
    ax1.set_ylabel("Number of mutations")
    if name is None:
        ax1.set_title("Mutation profile for all TFs")
    else:
        ax1.set_title("{} mutation profile for all TFs".format(name))
    
    # Create small graphs of every TF
    ordered_counts = sorted(
        counts_by_tf.items(), key=lambda t: sum(t[1].values()), reverse=True
    )
    row, col = 0, 0
    for tf, tf_counts in ordered_counts:
        ax2 = plt.subplot(gs2[row, col])
        
        # Default, or bound DHS
        X2 = sorted([int(dist) for dist in tf_counts.keys()])
        y2 = [tf_counts[dist] for dist in X2]
        ax2.plot(X2, y2)
        
        # Unbound DHS
        if counts_2 is not None and counts_2_by_tf is not None and tf in counts_2_by_tf:
            X2_2 = sorted([int(dist) for dist in counts_2_by_tf[tf].keys()])
            y2_2 = [counts_2_by_tf[tf][dist] for dist in X2_2]
            ax2.plot(X2_2, y2_2)
        
        # Style small graph
        ax2.set_xlim(-1000, 1000)
        ax2.set_ylim(0)
        ax2.set_title(tf)
        if not ax2.is_last_row():
            plt.setp(ax2.get_xticklabels(), visible=False)
        if not ax2.is_first_col():
            plt.setp(ax2.get_yticklabels(), visible=False)

        col += 1
        if col >= 6:
            row += 1
            col = 0
        if row >= 5:
            break
    
    # Save figure
    if save_fig:
        fig.savefig(
            "{}/figures/temp/{}_{}{}.png".format(WORKSPACE, RUN_ID, name, SUFFIX)
            .replace("proximal", "prox")
            .replace("distal", "dist")
            .replace("DHS_", "DHS-"),
            dpi="figure",
            transparent=True,
            bbox_inches="tight",
            pad_inches=0,
        )
    else:
        plt.show()

In [None]:
### Actually run the datasets ###
if WHICH_DATA == "all":
    # All cancers
    all_names = ["BLCA", "BRCA", "COAD", "COCA", "HNSC", "LUAD", "LUSC", "MELA", "READ", "SKCA", "SKCM"]
elif WHICH_DATA == "dhs":
    # Only cancers with DHS data
    all_names = ["BRCA", "COAD", "COCA", "LUAD", "LUSC", "MELA", "READ", "SKCA", "SKCM"]
elif WHICH_DATA == "skcm":
    # Only skin cancers
    all_names = ["MELA", "SKCA", "SKCM"]
elif WHICH_DATA == "small":
    # Only small (<1 GB) cancers
    all_names = ["BLCA", "COAD", "HNSC", "LUAD", "READ"]
elif WHICH_DATA == "small_dhs":
    # Only small (<1 GB) cancers with DHS data
    all_names = ["COAD", "LUAD", "READ"]
else:
    # Individual cancer types
    all_names = [WHICH_DATA]

In [None]:
all_counts = {}
all_counts_by_tf = {}
all_counts_2 = {}
all_counts_2_by_tf = {}
name_index = 0

def make_next_plot(name_index):
    if name_index >= len(all_names):
        return
    
    name = all_names[name_index]
    all_counts[name], all_counts_by_tf[name] = get_dists(
        name, SUFFIX + "_bound" if RUN_ID.endswith("mergedbg") else None
    )
    all_counts_2[name], all_counts_2_by_tf[name] = get_dists(
        name if RUN_ID.endswith("mergedbg") else None, SUFFIX + "_unbound"
    )
    plot_dists(all_counts[name], all_counts_by_tf[name], all_counts_2[name], all_counts_2_by_tf[name], name)

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1

In [None]:
%matplotlib inline

make_next_plot(name_index)
name_index += 1