In [None]:
import collections
from matplotlib import pyplot as plt
from matplotlib import gridspec as gs

In [None]:
WORKSPACE = "/data/gordanlab/samkim/dna-repair-tf"

In [None]:
# Run on cluster instead.
# import subprocess
# subprocess.call(["{}/all-mut-profile.sh".format(WORKSPACE)])

In [None]:
def get_dists(mut_dataset_name, filter_by="", dhs=True):
    mut_list = []
    with open("{}/data/ssm.open.{}{}DHS_{}_centered.bed".format(
              WORKSPACE, filter_by + ("_" if filter_by else ""), ("" if dhs else "no"), mut_dataset_name)) as f:
        for line in f:
            _, dist, _, mut, tf = line.strip().split()
            mut_list.append((int(dist), mut, tf))

    counts = collections.defaultdict(int)
    for dist, _, _ in mut_list:
        counts[dist] += 1

    counts_by_tf = {}
    for dist, mut, tf in mut_list:
        if tf not in counts_by_tf:
            counts_by_tf[tf] = collections.defaultdict(int) # initialize
        counts_by_tf[tf][dist] += 1
    
    return counts, counts_by_tf

In [None]:
def plot_dists_all_tfs(counts, mut_dataset_name=None):
    X = sorted([int(dist) for dist in counts.keys()])
    y = [counts[dist] for dist in X]
    
    if len(X) == 0:
        print("List is empty")
        return
    
    plt.plot(X, y)
    plt.xlim(-1000, 1000)
    plt.ylim(0, None)
    plt.xlabel("Distance from TFBS center (bp)")
    plt.ylabel("Number of mutations")
    if mut_dataset_name is None:
        plt.title("Mutation profile for all TFs")
    else:
        plt.title("{} mutation profile for all TFs".format(mut_dataset_name))
    plt.show()

In [None]:
def plot_dists_by_tf(counts_by_tf, mut_dataset_name=None, num_rows=5, num_cols=6):
    ordered_counts = sorted(counts_by_tf.items(), key=lambda t: sum(t[1].values()), reverse=True)
    
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols*1.5, num_rows*1.5), sharex='col', sharey='row')
    row, col = 0, 0
    for tf, counts in ordered_counts:
        X = sorted([int(dist) for dist in counts.keys()])
        y = [counts[dist] for dist in X]
        
        axs[row, col].plot(X, y)
        axs[row, col].set_xlim(-1000, 1000)
        axs[row, col].set_ylim(0)
        axs[row, col].set_title(tf)
        
        col += 1
        if col >= num_cols:
            row += 1
            col = 0
        if row >= num_rows:
            break
    
    plt.tight_layout()
    plt.show()

In [None]:
def plot_dists(counts, counts_by_tf, mut_dataset_name=None, figsize=(10, 14), h1=0.2, h2=0.6, w2=0.5):
    plt.figure(figsize=figsize)
    plt.subplots_adjust(hspace=h1)
    outer = gs.GridSpec(2, 1, height_ratios=[2, 5])
    gs1 = gs.GridSpecFromSubplotSpec(1, 6, subplot_spec=outer[0])
    gs2 = gs.GridSpecFromSubplotSpec(5, 6, subplot_spec=outer[1], hspace=h2, wspace=w2)
    
    X1 = sorted([int(dist) for dist in counts.keys()])
    y1 = [counts[dist] for dist in X1]
    
    ax1 = plt.subplot(gs1[:, 1:-1])
    ax1.plot(X1, y1)
    ax1.set_xlim(-1000, 1000)
    ax1.set_ylim(0, None)
    ax1.set_xlabel("Distance from TFBS center (bp)")
    ax1.set_ylabel("Number of mutations")
    if mut_dataset_name is None:
        ax1.set_title("Mutation profile for all TFs")
    else:
        ax1.set_title("{} mutation profile for all TFs".format(mut_dataset_name))
    
    ordered_counts = sorted(counts_by_tf.items(), key=lambda t: sum(t[1].values()), reverse=True)
    row, col = 0, 0
    for tf, counts in ordered_counts:
        X2 = sorted([int(dist) for dist in counts.keys()])
        y2 = [counts[dist] for dist in X2]
        
        ax2 = plt.subplot(gs2[row, col])
        ax2.plot(X2, y2)
        ax2.set_xlim(-1000, 1000)
        ax2.set_ylim(0)
        ax2.set_title(tf)
        
        if not ax2.is_last_row():
            plt.setp(ax2.get_xticklabels(), visible=False)
        if not ax2.is_first_col():
            plt.setp(ax2.get_yticklabels(), visible=False)
        
        col += 1
        if col >= 6:
            row += 1
            col = 0
        if row >= 5:
            break
    
    plt.show()

In [None]:
def make_plots(all_counts, all_counts_by_tf, all_names, name):
    if isinstance(name, int) or name.isnumeric():
        if int(name) < len(all_names):
            name = all_names[int(name)]
        else:
            print("There are no more datasets.")
            return
#     plot_dists_all_tfs(all_counts[name], name)
#     plot_dists_by_tf(all_counts_by_tf[name], name)
    plot_dists(all_counts[name], all_counts_by_tf[name], name)

In [None]:
all_names = ["BLCA","BRCA","COAD","COCA","HNSC","LUAD","LUSC","MELA","SKCA","SKCM"]

all_counts = {}
all_counts_by_tf = {}
for name in all_names:
    all_counts[name], all_counts_by_tf[name] = get_dists(name, "NC2")

name_index = 0

In [None]:
%matplotlib inline

make_plots(all_counts, all_counts_by_tf, all_names, name_index)
name_index += 1

In [None]:
%matplotlib inline

make_plots(all_counts, all_counts_by_tf, all_names, name_index)
name_index += 1

In [None]:
%matplotlib inline

make_plots(all_counts, all_counts_by_tf, all_names, name_index)
name_index += 1

In [None]:
%matplotlib inline

make_plots(all_counts, all_counts_by_tf, all_names, name_index)
name_index += 1

In [None]:
%matplotlib inline

make_plots(all_counts, all_counts_by_tf, all_names, name_index)
name_index += 1

In [None]:
%matplotlib inline

make_plots(all_counts, all_counts_by_tf, all_names, name_index)
name_index += 1

In [None]:
%matplotlib inline

make_plots(all_counts, all_counts_by_tf, all_names, name_index)
name_index += 1

In [None]:
%matplotlib inline

make_plots(all_counts, all_counts_by_tf, all_names, "MELA")
name_index += 1

In [None]:
%matplotlib inline

make_plots(all_counts, all_counts_by_tf, all_names, "SKCA")
name_index += 1

In [None]:
%matplotlib inline

make_plots(all_counts, all_counts_by_tf, all_names, name_index)
name_index += 1