In [31]:
%matplotlib inline

import os
import pickle
import matplotlib.pyplot as plt
import numpy as np

input_paths = [
    '/MMCI/TM/scratch/pebert/diploid_assembly/output/fastq_validation',
    '/MMCI/TM/scratch/pebert/diploid_assembly/side_tracks/read_stats'
]
input_files = []
for p in input_paths:
    files_under_path = [os.path.join(p, f) for f in os.listdir(p) if f.endswith('.pck')]
    input_files.extend(files_under_path)

output_root = '/MMCI/TM/scratch/pebert/diploid_assembly/output/plots/fastq_stats'

def plot_gc_content(title, binned_values):
    
    fig, ax = plt.subplots(figsize=(6,4))
    bins = np.array(list(binned_values.keys()), dtype=np.int8)
    bins.sort()
    counts = np.array([binned_values[b] for b in bins], dtype=np.int32)
    total_read_count = counts.sum()
    
    ax.bar(bins, counts, width=1)
    ax.set_xlabel('G+C content (% ; binned)', fontsize=12)
    ax.set_ylabel('Read count', fontsize=12)
    ax.text(0.7, 0.8, s='Total reads:\n{}'.format(total_read_count),
            transform=ax.transAxes, fontdict={'fontsize': 12})
    tt = ax.set_title(title, fontsize=12)
    tt.set_position([0.5, 1.2])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    return fig, [tt]
    

def plot_read_length_distribution(title, read_lengths):
    gsize = 3 * 1e9
    length_values = np.array(list(read_lengths.keys()), dtype=np.int32)
    length_values.sort()
    read_counts = np.array([read_lengths[l] for l in length_values])
    base_counts = np.array([read_lengths[l] * l for l in length_values])
    
#     total_read_length = 0
#     ul_read_length = 0
#     for length, count in read_lengths.items():
#         total_read_length += length * count
#         if length >= 100000:
#             ul_read_length += length * count
    
    base_count_bins = []
    read_count_bins = []
    lower = [0, 1000] + list(range(5000, 105000, 5000))
    higher = lower[1:] + [10000000]
    
    for l, h in zip(lower, higher):
        select_low = np.array(length_values >= l, dtype=np.bool)
        select_high = np.array(length_values < h, dtype=np.bool)
        selector = np.logical_and(select_low, select_high)
        base_count_bins.append(base_counts[selector].sum())
        read_count_bins.append(read_counts[selector].sum())
    
    fig, ax = plt.subplots(figsize=(4,3))
    x_pos = list(range(len(lower)))
    x_labels = list(map(str, [x//1000 for x in lower[1:]]))
    x_labels = [x if i % 3 == 0 else ' ' for i, x in enumerate(x_labels)]
    x_labels.append('UL')
    
    ax.bar(x_pos, base_count_bins, width=0.8, align='center')
    ax.set_xlabel('Read length < X kbp', fontsize=12)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(x_labels, fontsize=10)
    ax.set_ylabel('Base count', fontsize=12)
    
    total_read_count = sum(read_count_bins)
    total_base_pair = sum(base_count_bins)
    
    ultralong_read_count = read_count_bins[-1]
    ultralong_base_pair = base_count_bins[-1]
    
    #assert ultralong_base_pair == ul_read_length
    #assert total_base_pair == total_read_length
    
    text = 'Reads: {}\n'.format(total_read_count)
    text += ' ~ {} Gbp\n'.format(round(total_base_pair / 1e9, 1))
    text += ' ~ {}x\n'.format(round(total_base_pair / gsize, 2))
    text += 'UL: {}\n'.format(ultralong_read_count)
    text += ' ~ {} Gbp\n'.format(round(ultralong_base_pair / 1e9, 1))
    text += ' ~ {}x'.format(round(ultralong_base_pair / gsize, 2))
    

    txt = ax.text(1, 0.4, s=text, transform=ax.transAxes, fontdict={'fontsize': 12})
    tt = ax.set_title(title, fontsize=12)
    tt.set_position([0.5, 1.2])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    return fig, [txt, tt]
        

def plot_sequenced_bases(title, base_count):
    
    fig, ax = plt.subplots(figsize=(4,3))
    
    count = [base_count[b] for b in 'ACGT']
    xpos = list(range(4))
    ax.bar(xpos, count, width=0.8, align='center')
    ax.set_xticks(xpos)
    ax.set_xticklabels(['A', 'C', 'G', 'T'], fontsize=10)
    ax.set_xlabel('Nucleotide', fontsize=12)
    ax.set_ylabel('# sequenced', fontsize=12)
    tt = ax.set_title(title, fontsize=12)
    tt.set_position([0.5, 1.2])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    return fig, [tt]


title_filetype = {'cmp': 'Raw',
                 'tag-h1': 'Hap1',
                 'tag-h2': 'Hap2',
                 'tag-un': 'Untag',
                 'tag-h1-un': 'Hap1-Un',
                 'tag-h2-un': 'Hap2-Un'}

prettify_sample = {
    'HG00514.sequel2.pb-ccs': 'HG.514 PB-Sq2 CCS',
    'HG00514.sequel2.pb-clr': 'HG.514 PB-Sq2 CLR',
    'HG00514.sequel2.pb-clr-25': 'HG.514 PB-Sq2 CLR (0.25)'
}

plot_gc_cont = True
plot_read_length = True
plot_seq_bases = True

save_figures = True

for infile in input_files:
    filename = os.path.basename(infile)
    if 'HG00514' not in filename:
        continue
    if 'EDEVI' in infile:
        filetype = filename.rsplit('.', 1)[0].rsplit('_', 2)[-3:]
        filetype = '_'.join(filetype)
        filetype = filetype.replace('.', '_')
        if 'subreads' in filename or filename.endswith('.ccs.pck'):
            plot_filetype = 'HG.514 PB-Sq2 subreads'
            sample = filename.split('.')[2]
            plot_sample = filename.split('.')[2]
    else:
        filetype = infile.rsplit('.', 3)[-3]
        plot_filetype = title_filetype[filetype]
        sample = os.path.basename(infile).split('.')[:3]
        sample = '.'.join(sample)
        
    plot_sample = prettify_sample.get(sample, sample)
        
    dump = pickle.load(open(infile, 'rb'))
    
    outpath = os.path.join(output_root, sample, '{}_{}_'.format(sample, filetype.lower()))
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    if plot_gc_cont:
        fig, exart = plot_gc_content('{}\n{}'.format(plot_filetype, plot_sample), dump['gc_bins'])
        if save_figures:
            fig.savefig(outpath + 'gc-cont.png', dpi=300, bbox_inches='tight', extra_artists=exart)
        plt.close(fig)
    
    if plot_read_length:
        fig, exart = plot_read_length_distribution('{}\n{}'.format(plot_filetype, plot_sample), dump['len_stats'])
        if save_figures:
            fig.savefig(outpath + 'read-len.png', dpi=300, bbox_inches='tight', extra_artists=exart)
        plt.close(fig)
    
    if plot_seq_bases:
        fig, exart = plot_sequenced_bases('{}\n{}'.format(plot_filetype, plot_sample), dump['nuc_stats'])
        if save_figures:
            fig.savefig(outpath + 'seq-bases.png', dpi=300, bbox_inches='tight', extra_artists=exart)
        plt.close(fig)