# P(s) for separate chromosomal arms

In [None]:
import pprint
import os
import sys
import collections

import bioframe
import click
import cooler
import cooltools
import cooltools.expected
import matplotlib.pyplot as plt
import matplotlib.gridspec
from matplotlib.lines import Line2D
import numpy as np
from numpy.lib.function_base import average
from pairlib.scalings import norm_scaling
import pandas as pd
import pairlib
import pairlib.scalings
import pairtools
from diskcache import Cache
from itertools import combinations

from pandas.io.pytables import IndexCol

In [None]:
pairs_path = ''
out_path = ''
labels = ['938_cdc20td_cdc20']
assembly = 'sacCer3'
exclude_chroms = ['chrXII', 'chrM']
centromeres_path = ''
show_average_trans = True
title = 'cdc20-td: separate arms'
normalized = True
plot_slope = True
no_cache = True

In [None]:
chromsizes = bioframe.fetch_chromsizes(assembly, filter_chroms=False, as_bed=True)
chromsizes = chromsizes[~chromsizes.chrom.isin(exclude_chroms)]

if centromeres_path:
    centromeres = {}
    with open(centromeres_path) as file:
        for line in file:
            cols = line.split(' ')
            if cols[0] not in exclude_chroms:
                # calculate middle position
                centromeres[cols[0]] = (int(cols[1]) + int(cols[2])) // 2
else:
    centromeres = bioframe.fetch_centromeres(assembly)
    centromeres.set_index('chrom', inplace=True)
    centromeres = centromeres.mid.to_dict()

# use chromosomal arms as separate regions if no regions are specified
arms = bioframe.split(chromsizes, centromeres)
# remove 40kb from each side (80kb total) of an arm to remove centromere and telomere regions
arms = bioframe.ops.expand(arms, -int(8e4))
# remove arms arms with a length of < 0 after removing side regions
regions = arms[arms.start < arms.end].reset_index()

scalings = {}
avg_trans_levels = {}

regions

In [None]:
arms = bioframe.split(chromsizes, centromeres).groupby('chrom')
left = arms.nth(0).reset_index()
right = arms.nth(1).reset_index()

In [None]:
cis_scalings, trans_levels = pairlib.scalings.compute_scaling(
                pairs_path,
                regions,
                chromsizes,
                dist_range=(int(1e1), int(1e9)),
                n_dist_bins=128,
                chunksize=int(1e7)
            )

In [None]:
# remove unassigned pairs with start/end positions < 0
cis_scalings = cis_scalings[(cis_scalings.start1 > 0) & (cis_scalings.end1 > 0) & (cis_scalings.start2 > 0) & (cis_scalings.end2 > 0)]

In [None]:
sc_agg = (cis_scalings
            .groupby(['chrom1', 'start1', 'min_dist', 'max_dist'])
            .agg({'n_pairs': 'sum', 'n_bp2': 'sum'})
            .reset_index()
        )
avail_chroms = set(sc_agg.chrom1)

chrom_scalings = {chrom: {'left': None, 'right': None} for chrom in avail_chroms}
chrom_scalings

all_scalings = []
all_avg_trans_levels = []
labels = []

In [None]:
def calc_pair_freqs(scalings, trans_levels, calc_avg_trans, normalized):
    dist_bin_mids = np.sqrt(scalings.min_dist * scalings.max_dist)
    pair_frequencies = scalings.n_pairs / scalings.n_bp2
    mask = pair_frequencies > 0

    avg_trans = None
    if calc_avg_trans:
        avg_trans = (
                trans_levels.n_pairs.astype('float64').sum() /
                trans_levels.np_bp2.astype('float64').sum()
        )

    if normalized:
        norm_fact = pairlib.scalings.norm_scaling_factor(dist_bin_mids, pair_frequencies, anchor=int(1e3))
        pair_frequencies = pair_frequencies / norm_fact
        avg_trans = avg_trans / norm_fact

    return (dist_bin_mids[mask], pair_frequencies[mask]), avg_trans

In [None]:
def plot_scalings(scalings, avg_trans_levels, plot_slope, labels, title, out_path):
    """
    Plot scaling curves from a list of (bin, pair frequencies) tuples.
    """
    fig = plt.figure(figsize=(6, 10))
    gs = matplotlib.gridspec.GridSpec(2, 1, height_ratios=[3, 1.5])
    scale_ax = fig.add_subplot(gs[0, 0])
    slope_ax = fig.add_subplot(gs[1, 0]) if plot_slope else None

    for idx, scalings in enumerate(scalings):
        dist_bin_mids, pair_frequencies = scalings
        scale_ax.loglog(
            dist_bin_mids,
            pair_frequencies,
            label=labels[idx],
            lw=1,
            alpha=0.5
        )

        if avg_trans_levels is not None:
            scale_ax.axhline(
                avg_trans_levels[idx],
                ls='dashed',
                c=scale_ax.get_lines()[-1].get_color(),
                lw=1,
                alpha=0.5
            )

        if slope_ax is not None:
            slope_ax.semilogx(
                np.sqrt(dist_bin_mids.values[1:] * dist_bin_mids.values[:-1]),
                np.diff(np.log10(pair_frequencies.values)) / np.diff(np.log10(dist_bin_mids.values)),
                label=labels[idx],
                lw=1,
                alpha=0.5
            )

    plt.sca(scale_ax)
    plt.grid(lw=0.5,color='gray')
    plt.gca().set_aspect(1.0)
    plt.xlim(1e3, 1e6)
    plt.xlabel('genomic separation (bp)')
    plt.ylabel('contact frequency')

    handles, labels = plt.gca().get_legend_handles_labels()
    if avg_trans_levels is not None:
        handles.append(Line2D([0], [0], color='black', lw=1, ls='dashed'))
        labels.append('average trans')
    plt.legend(handles, labels, loc=(1.025, 0.5), frameon=False)

    if slope_ax is not None:
        plt.sca(slope_ax)
        plt.grid(lw=0.5,color='gray')
        plt.xlim(1e3, 1e6)
        plt.ylim(-3.0, 0.0)
        plt.gca().set_aspect(1.0)
        plt.xlabel('distance (bp)')
        plt.ylabel('log-log slope')

    fig.suptitle(title)
    fig.tight_layout()
    fig.subplots_adjust(top=0.95)

    plt.savefig(out_path, dpi=300)

In [None]:
np.seterr(divide='ignore', invalid='ignore')
all_scalings = []
all_avg_trans_levels = []

for chrom in avail_chroms:

    sc_left, avg_trans_left = calc_pair_freqs(sc_agg[(sc_agg.chrom1 == chrom) & (sc_agg.start1 == 80000)], trans_levels, show_average_trans, normalized)
    sc_right, avg_trans_right = calc_pair_freqs(sc_agg[(sc_agg.chrom1 == chrom) & (sc_agg.start1 != 80000)], trans_levels, show_average_trans, normalized)

    all_scalings.append(sc_left)
    all_avg_trans_levels.append(avg_trans_left)
    labels.append(f'{chrom}: left')

    all_scalings.append(sc_right)
    all_avg_trans_levels.append(avg_trans_right)
    labels.append(f'{chrom}: right')

    #display(sc_left)

    #path = os.path.join(os.path.dirname(out_path), '_'.join((chrom, os.path.basename(out_path))))
    #plot_scalings(all_scalings, all_avg_trans_levels, plot_slope, ['left', 'right'], f'{title}: {chrom}', path)

    #all_scalings.clear()
    #all_avg_trans_levels.clear()

plot_scalings(all_scalings, all_avg_trans_levels, plot_slope, labels, title, out_path)