# Fig 2. Benchmark of host identification performance

####  Dependencies:
    - kraken2==2.1.3
    - metamaps==0.1
    - megan==6.25.9
    - argo==0.1.0
    - pandas
    - taxonkit
    - biopython

#### Inputs:
    - fig2/data/*-arg.fa (ARG-containing reads of HQ and LQ)
    - fig2/data/reference (reference genomes and metadata of the top 25 pathgens)

In [1]:
# %%bash
# mkdir -p fig2/tmp/kraken fig2/tmp/centrifuger fig2/tmp/metamaps/ fig2/tmp/megan fig2/tmp/minimap
# for file in fig2/data/*.fa
# do
#     filename=${file%.fa}
#     filename=${filename##*/}

#     ## kraken
#     kraken2 $file --output fig2/tmp/kraken/$filename.output --db db/kraken --threads 48

#     ## centrifuger
#     centrifuger -u $file -x db/centrifuger -t 48 > fig2/tmp/centrifuger/$filename.tsv

#     ## metamaps
#     metamaps mapDirectly --all -r db/metamaps/DB.fa -q $file -o fig2/tmp/metamaps/$filename -t 48 --maxmemory 20
#     metamaps classify --mappings fig2/tmp/metamaps/$filename --DB db/metamaps -t 48

#     ## megan
#     minimap2 -ax map-ont --split-prefix=$filename --sam-hit-only -I 8G -t 48 db/refseq.fna.gz $file > fig2/tmp/megan/$filename.sam
#     sam2rma -i fig2/tmp/megan/$filename.sam -mdb db/megan -r $file -o fig2/tmp/megan/$filename.rma -c false -lg -alg longReads -ram readCount -t 48
#     rma2info -i fig2/tmp/megan/$filename.rma -o fig2/tmp/megan/$filename.txt -r2c Taxonomy

#     ## minimap
#     minimap2 -cx map-ont -t 48 db/refseq.fna.gz $file > fig2/tmp/minimap/$filename.paf
#     minimap2 -x ava-ont -t 48 $file $file > fig2/tmp/minimap/$filename-self.paf
# done

In [2]:
import pandas as pd
import glob
import subprocess

from collections import defaultdict, Counter
from scipy.sparse import csr_matrix
from Bio import SeqIO
from argo.utils import *

## create an accession to species (ground truth) mapping
metadata = pd.read_table('fig2/data/reference/data_summary.tsv')
metadata['species'] = metadata['Organism Scientific Name'].str.split(' ').str.get(0) + ' ' + metadata['Organism Scientific Name'].str.split(' ').str.get(1)

row = []
for file in glob.glob('fig2/data/reference/*/*.fna'):
    with open(file) as handle:
        for record in SeqIO.parse(handle, 'fasta'):
            row.append([record.id, 'GCF_' + file.split('/')[-1].split('_')[1]])

accession2species = pd.merge(pd.DataFrame(row, columns = ['accession', 'Assembly Accession']), metadata).set_index('accession').species.to_dict()

## record all read ids
ids = defaultdict(list)
for file in glob.glob('fig2/data/*.fa'):
    filename = file.split('.fa')[0].split('/')[-1]
    with open(file) as handle:
        for record in SeqIO.parse(handle, 'fasta'):
            ids[filename].append(record.id)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def get_taxonomy(taxid):
    output = subprocess.run([
        'taxonkit', 'reformat',
        '--taxid-field', '1',
        '--show-lineage-taxids',
        '--fill-miss-rank',
        '--miss-taxid-repl', '0',
        '--miss-rank-repl', 'unclassified',
        '--trim',
        '-f', '{k}\t{p}\t{c}\t{o}\t{f}\t{g}\t{s}'],
        input='\n'.join(taxid)+'\n', text=True, capture_output=True, check=True)

    taxonomy = {}
    for line in output.stdout.rstrip().split('\n'):
        ls = line.rstrip().split('\t')
        taxonomy[int(ls[0])] = ';'.join([ls[i+7] + '|' + ls[i] for i in range(1, len(ls)-7)])

    return taxonomy

def parser(df, method, file):
    species2unclassified = df[df.est=='unclassified'].groupby('species').size().to_dict()
    species2misclassified = df[df.est!=df.species].groupby('species').size().to_dict()
    species2tp = df[(df.est!='unclassified') & (df.est!=df.species) & (df.est.isin(set(df.species)))].groupby('species').size().to_dict()
    species2fp = df[(df.est!='unclassified') & (df.est!=df.species) & (~df.est.isin(set(df.species)))].groupby('species').size().to_dict()
    sp = df.groupby('species', as_index=False).size()
    
    sp['unclassified'] = sp.species.map(species2unclassified).fillna(0)
    sp['misclassified'] = sp.species.map(species2misclassified).fillna(0)
    sp = sp.assign(method=method, file=file.split('-')[0])

    st = pd.DataFrame([[
        sum(species2unclassified.values()) / len(df),
        sum(species2misclassified.values()) / len(df),
        sum(species2tp.values()) / len(df),
        sum(species2fp.values()) / len(df)
    ]], columns = ['unclassified', 'misclassified', 'tp', 'fp']).assign(method=method, file=file.split('-')[0])

    species.append(sp)
    statistics.append(st)

In [4]:
files = ['HQ-arg', 'LQ-arg']
species, statistics = [], []

In [5]:
## kraken
for filename in files:
    row = []
    with open(f'fig2/tmp/kraken/{filename}.output') as f:
        for line in f:
            ls = line.rstrip().split('\t')
            row.append([ls[1], int(ls[2]), int(ls[3])])
            
    df = pd.merge(pd.DataFrame(ids[filename]), pd.DataFrame(row), how='left').fillna(0)
    df['est'] = df[1].map(get_taxonomy(df[1].astype(str).unique())).str.split('|').str.get(-1)
    df['species'] = df[0].str.split('-').str.get(0).map(accession2species)

    parser(df, 'Kraken2', filename)

In [6]:
## centrifuger
for filename in files:
    df = pd.read_table(f'fig2/tmp/centrifuger/{filename}.tsv')
    df = pd.merge(pd.DataFrame(ids[filename]), df, how='left', left_on=0, right_on='readID').fillna(0)
    df['est'] = df['taxID'].map(get_taxonomy(df['taxID'].astype(str).unique())).str.split('|').str.get(-1)
    df['species'] = df[0].str.split('-').str.get(0).map(accession2species)

    parser(df, 'Centrifuger', filename)

In [7]:
## metamaps
for filename in files:
    df = pd.read_table(f'fig2/tmp/metamaps/{filename}.EM.reads2Taxon', header=None)
    df = pd.merge(pd.DataFrame(ids[filename]), df, how='left').fillna(0)
    df[1] = df[1].astype(int)
    df['est'] = df[1].map(get_taxonomy(df[1].astype(str).unique())).str.split('|').str.get(-1)
    df['species'] = df[0].str.split('-').str.get(0).map(accession2species)

    parser(df, 'MetaMaps', filename)

In [8]:
## megan
for filename in files:
    df = pd.read_table(f'fig2/tmp/megan/{filename}.txt', header=None)
    df = pd.merge(pd.DataFrame(ids[filename]), df, how='left').fillna(0)
    df[1] = df[1].astype(int)
    df['est'] = df[1].map(get_taxonomy(df[1].astype(str).unique())).str.split('|').str.get(-1)
    df['species'] = df[0].str.split('-').str.get(0).map(accession2species)

    parser(df, 'MEGAN-LR', filename)

In [9]:
## minimap2+bh
assembly2species = pd.read_table('db/refseq.assembly2species.tsv').set_index('assembly').species.to_dict()
accession2assembly = pd.read_table('db/refseq.accession2assembly.tsv').set_index('accession').assembly.to_dict()

for filename in files:
    row = []
    with open(f'fig2/tmp/minimap/{filename}.paf') as f:
        for line in f:
            ls = line.rstrip().split()
            row.append([ls[0], assembly2species.get(accession2assembly.get(ls[5])), int(ls[14].split('AS:i:')[-1])])

    df = pd.DataFrame(row).sort_values(2, ascending=False)
    df = df.groupby(0, as_index=False).first()
    df = pd.merge(pd.DataFrame(ids[filename]), df, how='left').fillna(0)

    df['est'] = df[1].str.split('|').str.get(-1)
    df['species'] = df[0].str.split('-').str.get(0).map(accession2species)

    parser(df, 'minimap2+BH', filename)

In [10]:
## minimap2+em and minimap2+bh
assembly2species = pd.read_table('db/refseq.assembly2species.tsv').set_index('assembly').species.to_dict()
accession2assembly = pd.read_table('db/refseq.accession2assembly.tsv').set_index('accession').assembly.to_dict()

for filename in files:
    alignments = []
    scores = defaultdict(lambda: defaultdict(lambda: {'AS': 0, 'DE': 0, 'ID': 0}))

    with open(f'fig2/tmp/minimap/{filename}.paf') as f:
        for line in f:
            ls = line.rstrip().split('\t')
            qstart, qend, qseqid, sseqid = int(ls[2]), int(ls[3]), ls[0], ls[5]
            lineage = assembly2species.get(accession2assembly.get(sseqid)).split('|')[-1]
    
            AS_MAX, AS = scores[qseqid][lineage].get('AS', 0), int(ls[14].split('AS:i:')[-1])
            DE_MAX, DE = scores[qseqid][lineage].get('DE', 0), 1 - float((ls[19] if ls[16] in {'tp:A:S', 'tp:A:i'} else ls[20]).split('de:f:')[-1])
            ID_MAX, ID = scores[qseqid][lineage].get('ID', 0), int(ls[9]) / int(ls[10])
    
            ## filter out non-overlapping alignments
            if AS > AS_MAX or DE > DE_MAX or ID > ID_MAX:
                scores[qseqid][lineage]['AS'] = max(AS_MAX, AS)
                scores[qseqid][lineage]['DE'] = max(DE_MAX, DE)
                scores[qseqid][lineage]['ID'] = max(ID_MAX, ID)
                alignments.append([qseqid, sseqid, AS, DE, ID, lineage])

    ## filter out low-score alignments
    duplicates = set()
    max_scores = defaultdict(lambda: {'AS': 0, 'DE': 0, 'ID': 0})
    
    for alignment in alignments:
        max_scores[alignment[0]]['AS'] = max(max_scores[alignment[0]]['AS'], alignment[2])
        max_scores[alignment[0]]['DE'] = max(max_scores[alignment[0]]['DE'], alignment[3])
        max_scores[alignment[0]]['ID'] = max(max_scores[alignment[0]]['ID'], alignment[4])
    
    sa = []
    for alignment in sorted(alignments, key=lambda alignment: (alignment[0], alignment[2], alignment[3], alignment[4]), reverse=True):
        if (
            max(alignment[2] / 0.995, alignment[2] + 50) > max_scores[alignment[0]]['AS']
        ):
            if (alignment[0], alignment[-1]) not in duplicates:
                sa.append(alignment)
                duplicates.add((alignment[0], alignment[-1]))

    ## EM
    max_iteration = 1000
    epsilon = 1e-10
    em = {}
    ## create a matrix then fill
    qseqids, lineages = np.unique([alignment[0] for alignment in sa]), np.unique([alignment[-1] for alignment in sa])
    qseqid2index = {qseqid: index for index, qseqid in enumerate(qseqids)}
    lineage2index = {lineage: index for index, lineage in enumerate(lineages)}
    
    rows = [qseqid2index[alignment[0]] for alignment in sa]
    cols = [lineage2index[alignment[-1]] for alignment in sa]
    matrix = csr_matrix((np.ones(len(rows)), (rows, cols)), shape=(len(qseqids), len(lineages)), dtype=int)
    
    ## run EM using the count matrix as input
    n_reads, n_mappings = matrix.shape
    
    ## init
    p_mappings = np.ones((1, n_mappings)) / n_mappings
    p_mappings_hist = p_mappings.copy()
    
    iteration = 0
    while iteration < max_iteration:
        iteration += 1
    
        ## e-step
        p_reads = matrix.multiply(p_mappings) / matrix.dot(p_mappings.T)
    
        ## m-step
        p_mappings = np.sum(p_reads, axis=0) / n_reads
    
        ## check convergence
        if np.sum(np.abs(p_mappings - p_mappings_hist)) < epsilon:
            break
    
        ## update hist
        np.copyto(p_mappings_hist, p_mappings)
    
    ## return assignments
    assignments = []
    for p_read in p_reads.tocsr():
        p_read = p_read.toarray().squeeze()
        assignments.append(np.where(p_read == p_read.max())[0].tolist())
    
    ties = defaultdict(set)
    for qseqid, lineage in enumerate(assignments):
        if len(assignment := lineages[lineage]) > 1:
            ties[tuple(assignment)].add(qseqids[qseqid])
        else:
            em[qseqids[qseqid]] = assignment[0]
    
    ## resolve ties for equal probability cases using AS, MS and ID
    if ties:
        qset = set.union(*(set(qseqid) for qseqid in ties.values()))
        alignments = [alignment for alignment in sa if alignment[0] in qset]
    
        for lineages, qseqids in ties.items():
            targets = [alignment for alignment in alignments if alignment[0] in qseqids and alignment[-1] in lineages]
    
            scores = defaultdict(lambda: defaultdict(list))
            for target in targets:
                scores[target[-1]]['AS'].append(target[2])
                scores[target[-1]]['DE'].append(target[3])
                scores[target[-1]]['ID'].append(target[4])
    
            ## if all the same, choose the one with known species name
            target = sorted([
                [
                    np.mean(score['AS']),
                    np.mean(score['DE']),
                    np.mean(score['ID']),
                    not bool(re.search(r' sp\.$| sp\. | sp[0-9]+', lineage.split(';')[-1])),
                    lineage
                ] for lineage, score in scores.items()
            ], reverse=True)[0][-1]
    
            for qseqid in qseqids:
                em[qseqid] = target

    df = pd.DataFrame(ids[filename])
    df['est'] = df[0].map(em)
    df['species'] = df[0].str.split('-').str.get(0).map(accession2species)

    parser(df, 'minimap2+EM', filename)

    ## SC
    overlaps = filter_overlap(file=f"fig2/tmp/minimap/{filename}-self.paf")
    DV = np.median([overlap[-1] for overlap in overlaps])
    overlaps = [overlap for overlap in overlaps if overlap[-1] <= 2.5 * DV]

    nodes = np.unique(ids[filename])
    node2index = {node: index for index, node in enumerate(nodes)}
    identities = defaultdict(lambda: 0)
    
    for overlap in overlaps:
        if (row := node2index.get(overlap[0])) and (col := node2index.get(overlap[1])):
            identities[(row, col)] = identities[(col, row)] = max(1 - overlap[-1], identities.get((row, col), 0), identities.get((col, row), 0))
    
    matrix = dok_matrix((len(nodes), len(nodes)))
    rows, cols = zip(*identities.keys())
    matrix[rows, cols] = list(identities.values())
    clusters = mcl(matrix, max_iterations=1000, inflation=2, expansion=2)
    index2label = {index: label for label, indexes in enumerate(clusters) for index in indexes}
    labels = np.array([index2label.get(index) for index in range(len(nodes))])
    clusters = [nodes[labels==label] for label in np.unique(labels)[::-1]]

    sc = dict()
    for index, cluster in enumerate(clusters):
        elements = set(cluster)
    
        ## append scores of covered reads
        subsets = defaultdict(set)
        scores = defaultdict(lambda: defaultdict(dict))
        alignments = [alignment for alignment in sa if alignment[0] in elements]
        for alignment in alignments:
            subsets[alignment[-1]].add(alignment[0])
            scores[alignment[-1]][alignment[0]] = alignment[2]
    
        ## get covers
        covers = set_cover(elements, subsets, scores)
        if len(covers) >= 2:
            qseqid2lineage = dict()
            score = defaultdict(lambda: 0)
            for qseqid in set.union(*[subsets[cover] for cover in covers]):
                for cover in covers:
                    AS = scores[cover].get(qseqid, 0)
                    if AS > score.get(qseqid, 0):
                        qseqid2lineage[qseqid] = cover
                        score[qseqid] = AS
            sc.update(qseqid2lineage)
    
        if len(covers)==1:
            sc.update({qseqid: covers[0] for qseqid in subsets[covers[0]]})

    df = pd.DataFrame(ids[filename])
    df['est'] = df[0].map(sc)
    df['species'] = df[0].str.split('-').str.get(0).map(accession2species)

    parser(df, 'minimap2+RO', filename)

    ## record cluster size
    r = []
    for cluster in clusters:
        tmp = [assembly2species.get(accession2assembly.get(x.split('-')[0])) for x in cluster]
        r.append([Counter(tmp).most_common(1)[0][1] / len(tmp), len(tmp)])
    pd.DataFrame(r, columns = ['purity', 'size']).to_csv(f'fig2/{filename}.size.tsv', index=False, sep='\t')

    ## record cluster scov
    arg = pd.read_table(f'sarg/{filename}.sarg.tsv', header=None, names=['qseqid', 'qlen', 'sseqid', 'sstart', 'send', 'slen'])
    arg['scov'] = (arg['send'] - arg['sstart']) / arg.slen
    sscovs = defaultdict(lambda: defaultdict(set))
    for _, i in arg.iterrows():
        sscovs[i.qseqid][(i.sseqid, i.slen)].add((i.sstart, i.send))
    
    discarded_hits = defaultdict(set)
    r = []
    for elements in clusters:
        scovs = dict()
        merged_scovs = defaultdict(set)
        for scov in [sscovs.get(element) for element in elements]:
            for sseqid, coordinates in scov.items():
                merged_scovs[sseqid].update(coordinates)
    
        for sseqid, coordinates in merged_scovs.items():
            coordinates = list(sorted(coordinates))
            merged_coordinates = [list(coordinates[0])] 
    
            for coordinate in coordinates[1:]:
                if coordinate[0] <= merged_coordinates[-1][1]:
                    merged_coordinates[-1][1] = max(merged_coordinates[-1][1], coordinate[1])
                else:
                    merged_coordinates.append(list(coordinate))
    
            for element in elements:
                r.append([sseqid[0], element, sum([coordinate[1] - coordinate[0] for coordinate in merged_coordinates]) / sseqid[1]])
    
    pd.merge(
        pd.DataFrame(r, columns = ['sseqid', 'qseqid', 'scov-c']).assign(mode = 'clustered'),
        arg.assign(mode = 'raw')[['qseqid', 'sseqid', 'scov', 'mode']], 
        on = ['qseqid', 'sseqid'], how='right').to_csv(f'fig2/{filename}.scov.tsv', index=False, sep='\t')

In [11]:
pd.concat(species).to_csv('fig2/species.tsv', sep='\t', index=False)
pd.concat(statistics).to_csv('fig2/statistics.tsv', sep='\t', index=False)