# STRAT - Short Tandem Repeat Analysis Tool

## 3. Collect statistics on on-target reads - gaps

### 3.1 Imports

In [1]:
from csv import QUOTE_NONE
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw
from string2string.alignment import NeedlemanWunsch

### 3.2 Arguments

In [2]:
motif = 'CAG'

# pcr2persons guppy
# input_path = '/opt/data/pcr2persons/output/guppy/guppy.ontarget.tsv'
# output_path = '/opt/data/pcr2persons/output/guppy/'

# pcr2persons dorado
# input_path = '/opt/data/pcr2persons/output/dorado/dorado.ontarget.tsv'
# output_path = '/opt/data/pcr2persons/output/dorado/'

# jovan guppy
# input_path = '/opt/data/jovan/output/guppy/guppy.ontarget.tsv'
# output_path = '/opt/data/jovan/output/guppy/'

# jovan dorado
# input_path = '/opt/data/jovan/output/dorado/dorado.ontarget.tsv'
# output_path = '/opt/data/jovan/output/dorado/'

# dm108 guppy
# input_path = '/opt/data/dm108/output/guppy/guppy.ontarget.tsv'
# output_path = '/opt/data/dm108/output/guppy/'

# bc3_1 guppy
input_path = '/opt/data/bc3_1/output/guppy/guppy.ontarget.tsv'
output_path = '/opt/data/bc3_1/output/guppy/'

# bc3_1 dorado
# input_path = '/opt/data/bc3_1/output/dorado/dorado.ontarget.tsv'
# output_path = '/opt/data/bc3_1/output/dorado/'

# bc3_2 guppy
# input_path = '/opt/data/bc3_2/output/guppy/guppy.ontarget.tsv'
# output_path = '/opt/data/bc3_2/output/guppy/'

# bc3_2 dorado
# input_path = '/opt/data/bc3_2/output/dorado/dorado.ontarget.tsv'
# output_path = '/opt/data/bc3_2/output/dorado/'

# bc3_3 guppy
# input_path = '/opt/data/bc3_3/output/guppy/guppy.ontarget.tsv'
# output_path = '/opt/data/bc3_3/output/guppy/'

# bc3_3 dorado
# input_path = '/opt/data/bc3_3/output/dorado/dorado.ontarget.tsv'
# output_path = '/opt/data/bc3_3/output/dorado/'

### 3.3 Constants

In [3]:
def rev_comp(seq, comps):
    return ''.join(comps.get(n, n) for n in reversed(seq))


DIRECTIONS = ['fwd', 'rev']

COMPLEMENT = {
    'A': 'T',
    'C': 'G',
    'G': 'C',
    'T': 'A'
}

MOTIFS = {
    'fwd': motif,
    'rev': rev_comp(motif, COMPLEMENT)
}

COLUMNS = [
    'direction',
    'id',
    'prefix_flank',
    'ins',
    'suffix_flank',
    'prefix_flank_q',
    'ins_q',
    'suffix_flank_q',
]

COLUMNS_SEQ_EXT = ['ins_ext']
COLUMNS_LEN_EXT = ['len_ins_ext']

MATCH_WEIGHT = 10  # weight for a match
MISMATCH_WEIGHT = -8  # weight for a mismatch
GAP_WEIGHT = -9  # weight for a gap
NW = NeedlemanWunsch(
    match_weight=MATCH_WEIGHT,  # weight for a match
    mismatch_weight=MISMATCH_WEIGHT,  # weight for a mismatch
    gap_weight=GAP_WEIGHT,  # weight for a gap
    gap_char=''  # character to use for a gap
)

COLORS = {
    'A': '#3DA853',  # green
    'C': '#4285F4',  # blue
    'G': '#F8BC07',  # yellow
    'T': '#EA4334',  # red
    ' ': 'white'
}

### 3.4 Functions

In [4]:
def load(input_path, columns):
    df = pd.read_csv(input_path, sep='\t', header=None, dtype=str, quoting=QUOTE_NONE)
    df.columns = columns

    return df


def extend_ins(row):
    prefix = row['prefix_flank']
    ins = row['ins']
    suffix = row['suffix_flank']
    motif = MOTIFS[row['direction']]
    target = 2 * motif
    window = len(target)
    for s in range(window - 1):
        # i in (window-1)..1
        i = window - 1 - s
        if prefix[-i:] + ins[:window - i] == target:
            ins = prefix[-i:] + ins
            break
    
    for s in range(window - 1):
        # i in 1..(window-1)
        i = s + 1
        if ins[-i:] + suffix[:window - i] == target:
            ins = ins + suffix[:window - i]
            break

    return ins


def lengths(df, columns_seq, columns_len):
    for s, l in zip(columns_seq, columns_len):
        df[l] = df[s].str.len()
    return df


def orient_inserts(row):
    seq = row['ins_ext']
    seq = rev_comp(seq, COMPLEMENT) if row['direction'] == 'rev' else seq
    return seq


def subsample(df, sample_size=10, threshold=1):
    countss = set(df['len_ins_ext'])

    df_subs = []
    for c in countss:
        cond = df['len_ins_ext'] == c
        if len(df[cond]) >= threshold:
            df_subs.append(df[cond].sample(n=min(len(df[cond]), sample_size)))

    return pd.concat(df_subs)


def group(df, cutoff=300):
    dfg = df.groupby(['ins_oriented', 'len_ins_ext'])['id'].count().reset_index()
    dfg.columns = ['ins_oriented', 'len_ins_ext', 'count']
    cond = dfg['len_ins_ext'] <= cutoff
    dfg = dfg[cond]
    dfg = dfg.sort_values(['len_ins_ext', 'ins_oriented'])
    return dfg


def align(row, nw=NW):
    source = int(np.ceil((row['len_ins_ext'] / 3))) * motif
    target = row['ins_oriented']
    aligned_source, aligned_target = nw.get_alignment(source, target, return_score_matrix=False)
    aligned_source = aligned_source.split(' | ')
    aligned_target = aligned_target.split(' | ')
    res = ''.join(t for s, t in zip(aligned_source, aligned_target) if s != ' ')
    return res


def fit_target(t, nw=NW):
    if t is not None and len(t) > 0:
        source = int(np.round((len(t) / 3))) * motif
        aligned_source, aligned_target = nw.get_alignment(source, t, return_score_matrix=False)
        aligned_source = aligned_source.split(' | ')
        aligned_target = aligned_target.split(' | ')
        return ''.join(t for s, t in zip(aligned_source, aligned_target) if s != ' ')

    return t


def fit(row):
    target = row['ins_oriented']
    targets = target.split(motif)
    if targets:
        return motif.join(fit_target(t) for t in targets)
    else:
        return target


def ungroup(df):
    results = []
    for i in df.index:
        row = df.loc[i]
        for j in range(row['count']):
            results.append({'ins_aligned': row['ins_aligned'], 'len_ins_aligned': row['len_ins_aligned']})
    dfa = pd.DataFrame(results)
    return dfa


def prepare_for_plotting(df):
    results = []
    for i in range(df['ins_aligned'].str.len().max()):
        cond = df['len_ins_aligned'] >= i + 1
        row = dict(df[cond]['ins_aligned'].str[i].value_counts())
        cond = df['len_ins_aligned'] == i + 1
        row['insert_count'] = sum(cond)
        cond = df['len_ins_aligned'] >= i + 1
        row['coverage_count'] = sum(cond)
        results.append(row)
    
    result_df = pd.DataFrame(results).fillna(0).astype(int)
    return result_df


def draw(df, colors, output_image):
    width = 6 * len(df)
    height = 1000
    # height = 100 + int((np.log2(df['insert_count'].max()) * 100).round())
    maxx = np.log(df['insert_count'].max())
    # print(maxx)
    print(height, width)
    image = Image.new('RGB', (width, height), 'white')
    draw = ImageDraw.Draw(image)
    for i in sorted(df.index):
        row = df.iloc[i]
        bottom = int(height / 2 + 1)
        insert_count = 0 if row['insert_count'] == 0 else np.log(row['insert_count'])
        # print(insert_count, maxx)
        size = (insert_count / maxx) if (insert_count / maxx > 0.5) else 0.0
        top = int(np.round(500 * size)) + bottom
        # print(bottom, top)
        for m in range(bottom, top):
            draw.point((6*i, m), fill='black')
            draw.point((6*i+1, m), fill='black')
            draw.point((6*i+2, m), fill='black')
        bottom = 0
        for n in ['A', 'C', 'G', 'T', ' ']:
            cnt = row[n]
            freq = int((500 * cnt / row['coverage_count']).round())
            if n == 'CAG'[i%3]:
                color = 'black'
            else:
                color = colors[n]
            for y in range(freq):
                draw.point((6*i,   bottom + y), fill=color)
                draw.point((6*i+1, bottom + y), fill=color)
                draw.point((6*i+2, bottom + y), fill=color)
                draw.point((6*i+3, bottom + y), fill=color)
                draw.point((6*i+4, bottom + y), fill=color)
                draw.point((6*i+5, bottom + y), fill=color)
            bottom += freq
        if i % 3 == 0:
            draw.line([(6*i, 0), (6*i, 500)], fill='grey')
    image.save(output_image)

### 3.5 Main

In [5]:
# Load
df = load(input_path, COLUMNS)
print(f'Loaded: {len(df)}')

# Extend
df['ins_ext'] = df.apply(extend_ins, axis=1)
print(f"Extended: {sum(df['ins'] != df['ins_ext'])}")

# Add length columns
df = lengths(df, COLUMNS_SEQ_EXT, COLUMNS_LEN_EXT)

# Orient inserts
df['ins_oriented'] = df.apply(orient_inserts, axis=1)

# Subsample the dataframe, keep only lengths with 10 or more inserts, sample 10 inserts
df_sub = subsample(df, 1000000)
print(f'Subsampled to: {len(df_sub)}')

# Group by insert extended; size and sequence
dfg = group(df_sub, 600)
print(f'Grouped: {len(dfg)}')

Loaded: 69287
Extended: 63618
Subsampled to: 69287
Grouped: 12068


In [6]:
%%time
# Align
dfg['ins_aligned'] = dfg.apply(fit, axis=1)

CPU times: user 20.8 s, sys: 0 ns, total: 20.8 s
Wall time: 20.8 s


In [7]:
# Add lengths for aligned inserts
dfg['len_ins_aligned'] = dfg['ins_aligned'].str.len()

# Sort grouped inserts by size and sequence
dfg = dfg.sort_values(['len_ins_aligned', 'ins_aligned'])

# Ungroup
dfa = ungroup(dfg)
print(f'Ungrouped: {len(dfa)}')

# Prepare for plotting
dfp = prepare_for_plotting(dfa)
print(f'Plot: {len(dfp)}')

Ungrouped: 69249
Plot: 597


In [8]:
# MATCH_WEIGHT = 10  # weight for a match
# MISMATCH_WEIGHT = -8  # weight for a mismatch
# GAP_WEIGHT = -9  # weight for a gap
# NW = NeedlemanWunsch(
#     match_weight=MATCH_WEIGHT,  # weight for a match
#     mismatch_weight=MISMATCH_WEIGHT,  # weight for a mismatch
#     gap_weight=GAP_WEIGHT,  # weight for a gap
#     gap_char=''  # character to use for a gap
# )
# row = {
#     'len_ins_ext': 8,
#     'ins_oriented': 'CAGTTTC'
# }
# align(row, NW)

In [9]:
draw(dfp, COLORS, f'{output_path}test1.png')

1000 3582
