# Benchmark tr-solve on a set of segmented sequences

## Setup

In [1]:
import portion as P
import itertools
from pathlib import Path

In [2]:
trsolve = "/Users/edolzhenko/projects/2023/Q3/tr-solve/target/release/tr-solve"

## Segment sequences with tr-solve

In [3]:
%%bash -s $trsolve

trsolve=$1
awk '{print $1, $2}' expected_segs.txt | $trsolve > observed_segs.txt

## Load expected and observed segmentations

In [4]:
def merge_intervals(intervals):
    union = P.empty()
    for interval in intervals:
        union = union | interval
    return union


def parse_segmentation(encoding):
    if encoding == None:
        return {}
    labels = [rec.split("_") for rec in encoding.split(",")]
    labels = [(m, P.closedopen(int(s), int(e))) for m, s, e in labels]
    labels.sort(key=lambda rec: rec[0])
    segmentation = {}
    for motif, group in itertools.groupby(labels, key=lambda rec:rec[0]):
        span = merge_intervals([rec[1] for rec in group])
        segmentation[motif] = span
    return segmentation


def parse_rec(line):
    sl = line.split()
    if len(sl) == 2:
        return (sl[0], sl[1], None)
    seq, motifs, seg = sl
    motifs = motifs.split(",")

    return seq, motifs, seg


def load_recs(path):
    recs = {}
    with open(path, "r") as file:
        for line in file:
            seq, motifs, seg = parse_rec(line)
            recs[seq] = seg
    return recs


In [5]:
expected_segs = load_recs("expected_segs.txt")
observed_segs = load_recs("observed_segs.txt")

## Compare segmentations

In [6]:
def get_len(intervals):
    return sum(i.upper - i.lower for i in intervals)


def compare_segs(seg_a, seg_b):
    seg_a = parse_segmentation(seg_a)
    seg_b = parse_segmentation(seg_b)
    motifs = set(seg_a).union(set(seg_b))
    intersect_lens = []
    union_lens = []
    
    for motif in motifs:
        spans_a = seg_a.get(motif, P.empty())
        spans_b = seg_b.get(motif, P.empty())
        intersect_lens.append(get_len(spans_a & spans_b))
        union_lens.append(get_len(spans_a | spans_b))
    return sum(intersect_lens) / sum(union_lens)

In [7]:
for seq, obs_seg in observed_segs.items():
    exp_seg = expected_segs[seq]
    jaccard = compare_segs(obs_seg, exp_seg)
    print(seq, obs_seg, exp_seg, jaccard)

CAGCAGCAGCAGCAGAAAAAA CAG_0_15,A_15_21 CAG_0_15,A_15_21 1.0
TCTATGCAACCAACTTTCTGTTAGTCATAGTACCCCAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAATAGAAATGTGTTTAAGAATTCCTCAATAAG AGA_37_97 AGA_37_99 0.967741935483871
TCTATGCAACCAACTTTCTGTTAGTCATAGTACCAGAAGAAGAAGAAGAAGAAGAAGAAGAAAGAAGAGGAAGAGGAAGAGGAAGAGGAAGAGAAGAGGAAGAGGAAGAGGAAGAGGAAGAGGAAGAGGAAGAGGAAGAGGAAGAGAAGAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAGAAGAAGAATAGAAATGTGTTTAAGAATTCCTCAATAAG AGA_34_63,GAAGAG_63_146,AGA_146_211 AGA_34_63,GAAGAG_63_146,AGA_146_213 0.9888268156424581
TCTATGCAACCAACTTTCTGTGAAGAAGAAGAAGAAGAAGAAGAAGAAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGGAGAAGGAGAAGGAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAATAGAAATGTGTTTAAGAATTCCTCAATAAG AGA_20_47,AGG_47_149,AGA_149_221 AGA_23_47,AGG_47_149,AGA_149_223 0.9753694581280788
