In [1]:
import json
import csv

In [2]:
database_viruses = [
    "chickenpox",
    "dengue",
    "ebola",
    "herpes",
    "kyasanur",
    "marburg",
    "measles",
    "sars-cov-2"    
]

other_viruses = [
    "crimea-congo",
    "hantavirus",
    "influenza",
    "junin",
    "lassa",
    "machupo",
    "papiloma",
    "rotavirus",
]

In [3]:
def read_raw_results_from_file(filename):
    with open(filename, "r") as f:
        return json.load(f)

In [4]:
raw_results = read_raw_results_from_file("raw_results.json")

In [5]:
platforms = ["pacbio0", "pacbio5", "pacbio10", "pacbio15"]

In [6]:
# intialize processed results
processed_results = {}
for platform in platforms:
    processed_results[platform] = {}
    for threshold in range(0,22):
        processed_results[platform].update({
            threshold: {
                "positives": {},
                "negatives": {}
            }
        })
        for virus in database_viruses:
            processed_results[platform][threshold]["positives"][virus] = {
                "tp": 0, "fp": 0, "fn": 0
            }
        for virus in other_viruses:
            processed_results[platform][threshold]["negatives"][virus] = {
                "tn": 0, "fp": 0
            }


In [13]:
virus_to_rank = {
    "Multiple": 0,
    "None": 2
}
for platform in platforms:
    for threshold in range(0, 22):
        if not str(threshold) in raw_results[platform]:
            continue
        for virus in database_viruses:
            if not virus in raw_results[platform][str(threshold)]:
                continue
            read_results, read_rc_results = raw_results[platform][str(threshold)][virus].values()
            tp, fn, fp = 0, 0, 0
            for read_res, rc_res in zip(read_results, read_rc_results):
                res_virus = [read_res, rc_res]
                res_virus.sort(key=lambda x: virus_to_rank.get(x, 1))
                if res_virus[0] !=  "None" and res_virus[1] != "None" and (res_virus[0] != res_virus[1]):
                    res_virus = "Multiple"
                else:
                    res_virus = res_virus[0]
                if res_virus == virus:
                    processed_results[platform][threshold]["positives"][virus]["tp"] += 1
                elif res_virus == "None":
                    processed_results[platform][threshold]["positives"][virus]["fn"] += 1
                elif res_virus == "Multiple":
                    processed_results[platform][threshold]["positives"][virus]["tp"] += 1
                    processed_results[platform][threshold]["positives"][virus]["fp"] += 1
                else:
                    processed_results[platform][threshold]["positives"][virus]["fp"] += 1
        for virus in other_viruses:
            if not virus in raw_results[platform][str(threshold)]:
                continue
            read_results, read_rc_results = raw_results[platform][str(threshold)][virus].values()
            tn, fp = 0, 0
            for read_res, rc_res in zip(read_results, read_rc_results):
                res_virus = [read_res, rc_res]
                res_virus.sort(key=lambda x: virus_to_rank.get(x, 1))
                res_virus = res_virus[0]
                if res_virus == "None":
                    processed_results[platform][threshold]["negatives"][virus]["tn"] += 1
                elif res_virus == "Multiple":
                    processed_results[platform][threshold]["negatives"][virus]["fp"] += 1

In [8]:
for platform, curr_dict in processed_results.items():
    for threshold, threshold_dict in curr_dict.items():
        for virus, virus_dict in threshold_dict["positives"].items():
            tp, fp, fn = virus_dict.values()
            if tp + fp == 0:
                precision = 0
            else:
                precision = tp / (tp + fp)
            if tp + fn == 0:
                recall = 0
            else:
                recall = tp / (tp + fn)
            if precision + recall == 0:
                f1 = 0
            else:
                f1 = 2 * (precision * recall) / (precision + recall)
            processed_results[platform][threshold]["positives"][virus].update({
                "precision": precision,
                "recall": recall,
                "f1": f1
            })
        for virus, virus_dict in threshold_dict["negatives"].items():
            tn, fp = virus_dict.values()
            if tn + fp == 0:
                specificity = 0
            else:
                specificity = tn / (tn + fp)
            processed_results[platform][threshold]["negatives"][virus].update({
                "specificity": specificity
            })

In [9]:
# Write results to a csv file, s.t. the rows will be the different thresholds and the columns will be the different viruses
for platform in platforms:
    with open(f"data/processed_results_{platform}.csv", "w") as f:
        writer = csv.writer(f)
        writer.writerow(["Threshold"] + database_viruses + database_viruses + other_viruses)
        for threshold in range(0, 22):
            row = [threshold]
            for virus in database_viruses:
                row += [processed_results[platform][threshold]["positives"][virus]["recall"]]
            for virus in database_viruses:
                row += [processed_results[platform][threshold]["positives"][virus]["precision"]]
            for virus in other_viruses:
                row += [processed_results[platform][threshold]["negatives"][virus]["specificity"]]
            writer.writerow(row)