In [18]:
import os
import glob
import pandas as pd
import numpy as np
import collections as col

"""
What does this do?
Processes the output of "process_bng_hybrid.py" after aligning scaffolds
back to reference assembly. This notebook computes the approximate scaffold
placement and assigns the confidence score for the scaffolded contig alignments.

This process is a prerequisite for computing the "coverage mask" (accessibility)
based on the scaffolded contig alignments (as opposed to unscaffolded contig
alignments, after MAPQ thresholding).

Prints stats about pct. of high/medium confidence scaffolded contig alignments.
"""

scaffolds_path = '/home/local/work/data/hgsvc/figSX_panels/bng_scaffolds/v12'
scfalign_path = '/home/local/work/data/hgsvc/figSX_panels/bng_scfaling/v12'
cache_file = '/home/local/work/data/hgsvc/figSX_panels/bng_scfaling/cache.h5'

print_stats = True

align_columns = [
    'chrom',
    'start',
    'end',
    'name',
    'mapq',
    'orientation'
]

def load_scaffolds(assembly):
    
    pattern = os.path.join(scaffolds_path, '{}*.tsv'.format(assembly))
    scf_file = glob.glob(pattern)
    assert len(scf_file) == 1, 'Mulit scaffold match'
    df = pd.read_csv(scf_file[0], sep='\t', header=0, index_col=None)
    return df


def extract_contig_length(name_entry):
    
    ctg_info = name_entry.split('@')[-1]
    positions = ctg_info.split(':')[-1]
    start, end = positions.split('-')
    return int(end) - int(start)


def load_alignments(align_file):
    
    align_path = os.path.join(scfalign_path, align_file)
    df = pd.read_csv(align_path, sep='\t', names=align_columns, header=None)
    df = df.loc[df['name'] != 'dummy', :].copy()
    if df.empty:
        return df

    df['scaffold'] = df['name'].apply(lambda x: x.split('@')[0])
    df['order_number'] = df['name'].apply(lambda x: int(x.split('@')[2]))
    df['aligned_length'] = (df['end'] - df['start']).astype('int32')
    df['midpoint'] = df['start'] + df['aligned_length'] // 2
    df['contig_length'] = df['name'].apply(extract_contig_length)
    df.sort_values(['chrom', 'start', 'end'], axis=0, ascending=True, inplace=True)
    return df


def compute_scaffold_midpoint_distances(scaffold, scaffold_layouts):
    
    select_scaffold = scaffold_layouts['object'] == scaffold
    select_sequence = scaffold_layouts['component'] == 'sequence'
    # scaffold length includes also assembly gaps
    scaffold_length = int(scaffold_layouts.loc[scaffold_layouts['name'] == scaffold, 'length'])
    
    this_scf = scaffold_layouts.loc[select_scaffold & select_sequence, :].copy()
    this_scf['midpoint'] = this_scf['start'] + this_scf['length'] // 2
    weighted_midpoint = int(np.average(this_scf['midpoint'], weights=this_scf['length']))
    this_scf['midpoint_dist'] = this_scf['midpoint'] - weighted_midpoint
    
    sequence_lengths = dict((row['name'], row['length']) for idx, row in this_scf.iterrows())
    sequence_lengths['scaffold'] = scaffold_length
    sequence_lengths['scaffold_no_gaps'] = sum([v for k, v in sequence_lengths.items() if k != 'scaffold'])

    return this_scf, weighted_midpoint, sequence_lengths


def compute_covered_sequence(alignments, midpoint, left_margin, right_margin):
    
    lower_bound = midpoint - left_margin
    upper_bound = midpoint + right_margin
    
    select_left = alignments['end'] > lower_bound
    select_right = alignments['start'] < upper_bound
    
    total_overlap = 0
    indices = []
    for idx, row in alignments.loc[select_left & select_right, :].iterrows():
        indices.append(idx)
        overlap = min(upper_bound, row['end']) - max(lower_bound, row['start'])
        assert overlap > 0, 'Something is wrong'
        total_overlap += overlap
    if total_overlap == 0:
        raise ValueError('No alignments selected: {} to {}\n\n{}'.format(lower_bound, upper_bound, alignments))
    
    return total_overlap, indices


def compute_alignment_midpoint_distances(contig_alignments):
    
    try:
        aln_wt_midpoint = int(np.average(
            contig_alignments['midpoint'],
            weights=contig_alignments['aligned_length']
        ))
    except ZeroDivisionError:
        print(contig_alignments)
        raise
    
    aln_midpoints = col.defaultdict(dict)
    for idx, rows in contig_alignments.groupby('name'):
        wt_midpoint = int(np.average(
            rows['midpoint'],
            weights=rows['aligned_length']
        ))
        # Super-Scaffold_363@chr17@1@frw@cluster12_scaffold_95@ctg:0-524872
        contig_name = idx.split('@')[-2]
        aln_midpoints[contig_name]['midpoint'] = wt_midpoint
        aln_midpoints[contig_name]['midpoint_dist'] = wt_midpoint - aln_wt_midpoint

    return aln_midpoints


def check_overlaps(contig_alignments, aln_wt_midpoint, left_margin, right_margin):

    # since we don't know the relative orientation of the scaffold,
    # check which selection covers more aligned sequence

    is_reverse = False
    
    forward_overlaps, forward_indices = compute_covered_sequence(
        contig_alignments,
        aln_wt_midpoint,
        left_margin,
        right_margin
    )
    
    reverse_overlaps, reverse_indices = compute_covered_sequence(
        contig_alignments,
        aln_wt_midpoint,
        right_margin,
        left_margin
    )

    if forward_overlaps == 0 and reverse_overlaps == 0:
        raise ValueError('No overlaps')

    if forward_overlaps > reverse_overlaps:
        selected_alignments = forward_indices
    elif forward_overlaps < forward_overlaps:
        selected_alignments = reverse_indices
        is_reverse = True
    else:
        # seems not to matter
        selected_alignments = forward_indices
    return selected_alignments, is_reverse


def exclude_isolated_blocks(contig_alignments):
    
    # compute all vs all midpoint distances
    avg_distances = []
    for idx, row in contig_alignments.iterrows():
        
        not_self = contig_alignments.index != idx
        avg_dist = (contig_alignments.loc[not_self, 'midpoint'] - row['midpoint']).abs().sum() / not_self.sum()
        avg_distances.append((avg_dist, idx))
        
    avg_distances = sorted(avg_distances, reverse=True)
    exclude_indices = set()
    for pos, (d, i) in enumerate(avg_distances):
        try:
            if d > avg_distances[pos+1][0]:
                exclude_indices.add(i)
                break
        except IndexError:
            break
        exclude_indices.add(i)

    scattered = contig_alignments.index.isin(exclude_indices)
        
    aln_midpoint = int(np.average(
        contig_alignments.loc[~scattered, 'midpoint'],
        weights=contig_alignments.loc[~scattered, 'aligned_length']
    ))

    return aln_midpoint


def select_alignments_for_placement(
    contig_alignments,
    seq_lens,
    scf_left_margin, 
    scf_right_margin, 
    is_single_split):
    
    # compute weighted midpoint for all alignments
    aln_midpoint = int(np.average(
        contig_alignments['midpoint'],
        weights=contig_alignments['aligned_length']
    ))

    # allow some slack for potential variants
    left_margin = int(scf_left_margin + scf_left_margin * 0.05)
    right_margin = int(scf_right_margin + scf_right_margin * 0.05)
        
    try:
        selected_alignments, is_reverse = check_overlaps(
            contig_alignments,
            aln_midpoint,
            left_margin,
            right_margin
        )
    except ValueError:
        if is_single_split:
            # reset midpoint to best aligned block
            best_alignment = contig_alignments['aligned_length'] * contig_alignments['mapq']
            best_index = best_alignment.index[best_alignment == best_alignment.max()]
            aln_midpoint = int(contig_alignments.loc[best_index, 'start']) + int(contig_alignments.loc[best_index, 'aligned_length'] // 2)
            selected_alignments, is_reverse = check_overlaps(
                contig_alignments,
                aln_midpoint,
                left_margin,
                right_margin
            )
        elif contig_alignments.shape[0] < 3:
            # not enough aligned contigs to make an attempt to save this
            # scaffold
            raise
        else:
            # make one attempt to drop the one alignment (or multiple if they all
            # have the same midpoint; may happen for low-scoring fragments) that
            # has largest average distance to all other alignments;
            # if alignments are too scattered,
            # this won't change anything and still result in low confidence
            aln_midpoint = exclude_isolated_blocks(contig_alignments)
            selected_alignments, is_reverse = check_overlaps(
                contig_alignments,
                aln_midpoint,
                left_margin,
                right_margin
            )
    
    # compute alignment midpoint distances
    aln_midpoint_dists = compute_alignment_midpoint_distances(
        contig_alignments.loc[selected_alignments, :].copy()
    )
    
    # set the boundaries of the scaffold to the maximum extension
    min_start = contig_alignments.loc[selected_alignments, 'start'].min()
    max_end = contig_alignments.loc[selected_alignments, 'end'].max()
    
    aligned_length = contig_alignments.loc[selected_alignments, 'aligned_length'].sum()
    # check that the contig is not placed based on a marginal amount of sequence
    # unlikely to happen, but better safe than sorry...
    scf_length_nogap = seq_lens['scaffold_no_gaps']
    if aligned_length < scf_length_nogap * 0.5:
        raise ValueError('Insufficient aligned length for placing scaffold')
    
    if is_reverse:
        scf_left_margin, scf_right_margin = scf_right_margin, scf_left_margin
    
    scf_midpoint = aln_midpoint
    
    scf_left_margin = min(min_start, scf_midpoint - scf_left_margin)
    scf_right_margin = max(max_end, scf_midpoint - scf_right_margin)
    
    return selected_alignments, aln_midpoint_dists, scf_midpoint, scf_left_margin, scf_right_margin
 

def compute_displacement_stats(scaffold, scaffold_length, aln_midpoint_dists):
    # since the scaffold placement does not take structural variation into
    # account, omit this for now - can be quite off
    raise RuntimeError('Do not use this funtion')
    contig_displacements = []
    for idx, contig in scaffold.iterrows():
        ctg = contig['name']
        if ctg not in aln_midpoint_dists:
            continue
        delta = abs(abs(contig['midpoint_dist']) - abs(aln_midpoint_dists[ctg]['midpoint_dist']))
        contig_displacements.append(delta)

    contig_displacements = np.array(contig_displacements, dtype=np.int32)
    total_displacement = contig_displacements.sum()
    rel_displacement = (total_displacement / scf_length * 100).round(2) 
    return total_displacement, rel_displacement


def determine_scattered_or_single_split(alignments, select_scf_aln):

    # if scattered over multiple chromosomes, assign all low confidence
    is_scattered = alignments.loc[select_scf_aln, 'chrom'].nunique() > 1
    if is_scattered:
        alignments.loc[select_scf_aln, 'loc_confidence'] = 'low'
        alignments.loc[select_scf_aln, 'scaffold_strand'] = '.'

    # if single contig scaffold, opt for assigning high confidence
    is_single_contig = alignments.loc[select_scf_aln, 'chrom'].shape[0] == 1
    if is_single_contig:
        alignments.loc[select_scf_aln, 'loc_confidence'] = 'high'
        orientation = alignments.loc[select_scf_aln, 'orientation'].values[0]
        alignments.loc[select_scf_aln, 'scaffold_strand'] = orientation

    # single split contig
    # example case: HG00096 / H1 / Super-Scaffold_101491
    # cen or peri-cen, alignment scattered, HET region on chr1q12
    num_contigs = alignments.loc[select_scf_aln, 'name'].nunique()
    num_chroms = alignments.loc[select_scf_aln, 'chrom'].nunique()
    is_single_split_contig = num_contigs == 1 and num_chroms == 1

    return is_scattered, is_single_contig, is_single_split_contig
    
    
def determine_scaffold_placement_confidence(alignments, scaffolds):
    
    alignments['loc_confidence'] = 'undetermined'
    alignments['scaffold_strand'] = 'undetermined'
    alignments['scaffold_start'] = -1
    alignments['scaffold_end'] = -1
        
    for scf in set(alignments['scaffold'].values):
        
        select_scf_aln = alignments['scaffold'] == scf
        
        is_scattered, is_single, is_single_split_contig = determine_scattered_or_single_split(
            alignments,
            select_scf_aln
        )
        if is_scattered:
            # side effect: alignments set to low in call above
            continue
        if is_single and not is_single_split_contig:
            # side effect: alignment set to high in call above
            continue
        
        # complicated part: heuristic to determine approximate placement
        # of scaffold relative to hg38
        scaffold, scf_midpoint, seq_lengths = compute_scaffold_midpoint_distances(scf, scaffolds)
        
        # left margin: 0 ... MID
        left_margin = scf_midpoint - 1
        # right margin: MID ... END
        right_margin = seq_lengths['scaffold'] - scf_midpoint
        
        contig_aln = alignments.loc[select_scf_aln, :].copy()
        try:
            selected_aln, aln_midpoint_dists, scf_midpoint, left_margin, right_margin = select_alignments_for_placement(
                contig_aln,
                seq_lengths,
                left_margin,
                right_margin,
                is_single_split_contig
            )
        except ValueError:
            # impossible to find good placement, set everything to low confidence
            # example: Super-Scaffold_101620 / HG00096 H1 scattered on chr4
            alignments.loc[select_scf_aln, 'loc_confidence'] = 'low'
            alignments.loc[select_scf_aln, 'scaffold_strand'] = '.'
            continue
        
        aligned_order = contig_aln.loc[selected_aln, 'order_number'].drop_duplicates(keep='first', inplace=False)
                
        # note that this implies that all contigs
        # have been aligned/used, although not necessarily
        # to 100% of their size
        if np.array_equal(aligned_order.values, scaffold['order'].values):
            # scaffold orientation -> forward
            alignments.loc[selected_aln, 'loc_confidence'] = 'high'
            alignments.loc[selected_aln, 'scaffold_strand'] = '+'
        elif np.array_equal(aligned_order.values, np.flip(scaffold['order'].values)):
            # scaffold orientation -> reverse
            alignments.loc[selected_aln, 'loc_confidence'] = 'high'
            alignments.loc[selected_aln, 'scaffold_strand'] = '-'
        else:
            majority_orientation = contig_aln.loc[selected_aln, 'orientation'].value_counts().sort_values(inplace=False, ascending=False)
            majority_orientation = majority_orientation.index[0]
            alignments.loc[selected_aln, 'loc_confidence'] = 'medium'
            alignments.loc[selected_aln, 'scaffold_strand'] = majority_orientation
        
        alignments.loc[selected_aln, 'scaffold_start'] = scf_midpoint - left_margin
        alignments.loc[selected_aln, 'scaffold_end'] = scf_midpoint + right_margin
        
        # in case contigs have not been used for scaffold placement
        unused_contigs = contig_aln.index[~contig_aln.index.isin(selected_aln)]
        if not unused_contigs.empty:
            alignments.loc[unused_contigs, 'loc_confidence'] = 'low'
            alignments.loc[unused_contigs, 'scaffold_strand'] = '.'

    if (alignments['loc_confidence'] == 'undetermined').any():
        raise RuntimeError('Undetermined alignment confidence')
            
    return alignments



if not os.path.isfile(cache_file):

    collector = col.defaultdict(list)

    for scfalign in sorted(os.listdir(scfalign_path)):
        if not scfalign.endswith('.bed'):
            continue
        assembly = scfalign.split('_map-to_')[0]
        scaffolds = load_scaffolds(assembly)
        alignments = load_alignments(scfalign)
        if alignments.empty:
            continue
        alignments = determine_scaffold_placement_confidence(alignments, scaffolds)

        collector[assembly].append(alignments)

    with pd.HDFStore(cache_file, 'w', complevel=5) as hdf:
        for k, v in collector.items():
            merged = pd.concat(v, axis=0, ignore_index=False)
            merged.sort_values(['chrom', 'start', 'end'], axis=0, inplace=True)
            hdf.put(k.replace('.', '_').replace('-', ''), merged, format='fixed')


out_path = '/home/local/work/data/hgsvc/figSX_panels/bng_scfaling/processed'
            
aligned_length = col.defaultdict(col.Counter)
score_map = {
    'high': 1000,
    'medium': 500,
    'low': 0
}
with pd.HDFStore(cache_file, 'r') as hdf:
    for k in hdf.keys():
        sample = k.split('_')[0].strip('/')
        if 'h1un' in k:
            hap = 'H1'
        else:
            hap = 'H2'
        if 'pbsq2ccs' in k:
            tech = 'HiFi'
        else:
            tech = 'CLR'
        df = hdf[k]
        df['score'] = df['loc_confidence'].apply(lambda x: score_map[x])
        bed = df[['chrom', 'start', 'end', 'name', 'score', 'orientation']].copy()
        bed.sort_values(['chrom', 'start', 'end'], axis=0, inplace=True)
        if 'h1un' in k:
            dump_file = k.strip('/').replace('_h1un_', '.h1-un.').replace('pbsq2', 'pbsq2-')
        else:
            dump_file = k.strip('/').replace('_h2un_', '.h2-un.').replace('pbsq2', 'pbsq2-')
        dump_file += '_map-to_GRCh38.scf-conf.bed'
        dump_path = os.path.join(out_path, dump_file)

        with open(dump_path, 'w') as dump:
            _ = dump.write('#')
            bed.to_csv(dump, header=True, index=False, sep='\t')
        
        if print_stats:
            grouped = df.groupby('loc_confidence')['aligned_length'].sum()
            for k, v in grouped.items():
                aligned_length[(sample, tech, hap)][k] += v
    
collector = col.defaultdict(list)
total_aligned = col.defaultdict(list)
# compute per sample stats
for (s, t, h), counts in aligned_length.items():
    total = sum(counts.values())
    total_aligned[t].append(total)
    total_aligned['any'].append(total)
    for loc_conf, count in counts.items():
        rel = round(count / total * 100, 2)
        collector[(t, loc_conf)].append(rel)
        collector[('any', loc_conf)].append(rel)

if print_stats:
    for (tech, loc_conf), stats in collector.items():
        print(tech, loc_conf, ' ', np.median(stats).round(2))
    for tech, total in total_aligned.items():
        print(tech, ' ', np.median(total).round(2))


CLR high   65.69
any high   69.0
CLR low   0.62
any low   0.39
CLR medium   31.74
any medium   28.48
HiFi high   79.03
HiFi low   0.3
HiFi medium   20.37
CLR   2778884550.5
any   2785558664.0
HiFi   2790407215.5
