In [25]:
import numpy as np
from Bio import SeqIO
import os
import pandas as pd
import torch
import seaborn as sns
import math
import scipy.stats as stats
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
from numpy import trapz
from sklearn.metrics import auc
from pathlib import Path
from functools import reduce
sns.set(font_scale=1.5, style="ticks")

In [2]:
runs_folder = Path("../../../runs")
data_folder = Path("../../../data/gisaid/cov")

In [3]:
from scipy.stats import shapiro
def check_normality(data):
    stat, p_value = shapiro(data)
    print(f"Shapiro-Wilk Test Statistic: {stat}")
    print(f"P-value: {p_value}")
    # Interpret the result
    if p_value > 0.05:
        print("The data appears to be normally distributed (fail to reject H0).")
    else:
        print("The data does not appear to be normally distributed (reject H0).")

In [4]:
def get_coverage(path, seq2freq, topk=100):
    df = pd.read_csv(path)
    preds = Counter(df["prediction"]).most_common()
    topks = []
    sum_of_freqs = []
    sum_of_freq = 0.0
    for topk in range(len(preds)):
        sum_of_freq += seq2freq.get(preds[topk][0], 0.0)
        topks.append(topk)
        sum_of_freqs.append(sum_of_freq)
    return topks, sum_of_freqs


def get_evescape_coverage(preds, seq2freq, topk=100):
    topks = []
    sum_of_freqs = []
    sum_of_freq = 0.0
    
    for _topk in range(len(preds)):
        sum_of_freq += seq2freq.get(preds[_topk], 0.0)
        topks.append(_topk)
        sum_of_freqs.append(sum_of_freq)
    if len(topks) < topk:
        sum_of_freqs.extend([sum_of_freqs[-1]] * (topk - len(topks)))
    return list(range(topk)), sum_of_freqs

In [5]:
def get_precision(path, dominant_seqs, topk=100):
    df = pd.read_csv(path)
    preds = Counter(df["prediction"]).most_common(topk)
    pred_seqs = set([x[0] for x in preds])
    assert len(pred_seqs) == topk
    
    precision = 0
    for seq in pred_seqs:
        if seq in dominant_seqs:
            precision += 1
    
    return precision / topk

In [6]:
def get_summary(res, val_key="coverage"):
    new_res = defaultdict(list)
    models = list(set(res["method"]))
    for model in models:
        _res = res[res["method"] == model]
        median_res = np.median(_res[val_key])
        std_res = np.std(_res[val_key])
        new_res["method"].append(model)
        new_res["median"].append(median_res)
        new_res["std"].append(std_res)

    return pd.DataFrame(new_res)

In [26]:
testing_window=3
temperature = 1.0

res = defaultdict(list)

res_precision = defaultdict(list)

generated_num=500

for year_and_month in ("2021-07", "2021-10", "2022-01", "2022-04", ): # 
    year, month = year_and_month.split("-")
    year = int(year)
    index = (year - 2020) * 4 + (int(month) - 1) // 3  # e.g. 2020-01~03 == 0, 2021-04~06 == 1

    print(">>>", year_and_month)

    for location in ("europe", "north_america", "asia", "south_america", "oceania", "africa"):
        training_path = data_folder / f"spike_rbd_processed_continents/2019-12_to_{year_and_month}_1M/all/continents/human_minBinSize100_minLen223_maxLen223_{location}_location.fasta"

        total_sample_size = 0
        total_seq_num = 0
        for record in SeqIO.parse(training_path, "fasta"):
            desc = record.description.split()[1].strip().split("|")
            desc = {x.split("=")[0]: x.split("=")[1] for x in desc}
            cnt = round(float(desc["freq"]) * float(desc["bin_size"]))
            total_sample_size += cnt
            total_seq_num += 1
        
        gt_path = data_folder / f"spike_rbd_processed_continents/2020-01_to_2023-12_3M/all/human_minBinSize100_minLen223_maxLen223_location_region1_bins/{index}.fasta"
        seq2freq = dict()
        seq2count = dict()
        for record in SeqIO.parse(gt_path, "fasta"):
            desc = record.description.split()[1].strip().split("|")
            desc = {x.split("=")[0]: x.split("=")[1] for x in desc}
            if desc["location"] != location:
                continue
            seq2freq[str(record.seq)] = float(desc["freq"])
            seq2count[str(record.seq)] = round(float(desc["freq"]) * float(desc["bin_size"]))
        seq2freq = {k: v / sum(seq2count.values()) for k, v in seq2count.items()}

        name2rev_nll_path = {
                    "Transmission": runs_folder / f"cov_continents/2019-12_to_{year_and_month}_1M/transmission/generations_beam_search_500_3M/{location}/temp_1.0/lightning_logs/version_0/predictions.csv",
                    "Transmission(eig_topk_3)": runs_folder / f"cov_continents/2019-12_to_{year_and_month}_1M/transmission_ablation/no_reg_top3_eig/generations_beam_search_500_3M/{location}/temp_1.0/lightning_logs/version_0/predictions.csv",
                    "Add_Embed": runs_folder / f"cov_continents/2019-12_to_{year_and_month}_1M/concat/generations_beam_search_500_3M/{location}/temp_1.0/lightning_logs/version_0/predictions.csv",
                    "Prepend": runs_folder / f"cov_continents/2019-12_to_{year_and_month}_1M/prepend/generations_beam_search_500_3M/{location}/temp_1.0/lightning_logs/version_0/predictions.csv",
                    "Global": runs_folder / f"cov_continents/2019-12_to_{year_and_month}_1M/global/generations_beam_search_500_3M/temp_1.0/lightning_logs/version_0/predictions.csv",
                    "LoRA": runs_folder / f"cov_continents/2019-12_to_{year_and_month}_1M/finetune_lora/generations_beam_search_500_3M/{location}/temp_1.0/lightning_logs/version_0/predictions.csv",
                    "Finetune": runs_folder / f"cov_continents/2019-12_to_{year_and_month}_1M/finetune/generations_beam_search_500_3M/{location}/temp_1.0/lightning_logs/version_0/predictions.csv",
                    "Param_share": runs_folder / f"cov_continents/2019-12_to_{year_and_month}_1M/param_share/generations_beam_search_500_3M/{location}/temp_1.0/lightning_logs/version_0/predictions.csv",
                }
        
        # Last
        seq2count_last = defaultdict(int)
        for last_index in range(index - 4, index): # 
            last_path = data_folder / f"spike_rbd_processed_continents/2020-01_to_2023-12_3M/all/human_minBinSize100_minLen223_maxLen223_location_region1_bins/{last_index}.fasta"
            for record in SeqIO.parse(last_path, "fasta"):
                desc = record.description.split()[1].strip().split("|")
                desc = {x.split("=")[0]: x.split("=")[1] for x in desc}
                # assert desc["location"] == location
                if desc["location"] != location:
                    continue
                seq2count_last[str(record.seq)] += round(float(desc["freq"]) * float(desc["bin_size"]))
        seq2count_last_sorted = sorted(seq2count_last.items(), key=lambda x: x[1], reverse=True)
        seq2count_last_sorted = [x[0] for x in seq2count_last_sorted]
        topks, sum_of_freqs = get_evescape_coverage(seq2count_last_sorted[:generated_num], seq2freq, topk=generated_num)
        assert len(sum_of_freqs) == len(topks), "len(sum_of_freqs)=%d, len(topks)=%d" % (len(sum_of_freqs), len(topks))
        res["coverage"].extend(sum_of_freqs)
        res["location"].extend([location] * len(topks))
        res["method"].extend(["Last"] * len(topks))
        res["date"].extend([year_and_month] * len(topks))
        res["topk"].extend(topks)
        res["total_sample_size"].extend([total_sample_size] * len(topks))
        res["total_seq_num"].extend([total_seq_num] * len(topks))
        # Last

        for model in name2rev_nll_path:
            path = name2rev_nll_path[model]
            topks, sum_of_freqs = get_coverage(path, seq2freq)
            # print(model, set(pd.read_csv(path)["src_time"]))
            assert set(pd.read_csv(path)["src_time"]) == set(pd.read_csv(name2rev_nll_path["Transmission"])["src_time"])
            assert len(sum_of_freqs) == len(topks)
            res["coverage"].extend(sum_of_freqs)
            res["location"].extend([location] * len(topks))
            res["method"].extend([model] * len(topks))
            res["date"].extend([year_and_month] * len(topks))
            res["topk"].extend(topks)
            res["total_sample_size"].extend([total_sample_size] * len(topks))
            res["total_seq_num"].extend([total_seq_num] * len(topks))
        
        # evescape:
        evescape_score_path = runs_folder / f"cov_continents/2019-12_to_{year_and_month}_1M/evescape/all_seqs_EVEscape_scores.csv"
        evescape_seq_path = runs_folder / f"cov_continents/2019-12_to_{year_and_month}_1M/evescape/human_minBinSize100_lenQuantile0.2.fasta"
        id2seq = {}
        for record in SeqIO.parse(evescape_seq_path, "fasta"):
            id2seq[record.id] = str(record.seq)
        df = pd.read_csv(evescape_score_path)

        seqs = [id2seq[x] for x in df["src_id"]]
        topks, sum_of_freqs = get_evescape_coverage(seqs[:generated_num], seq2freq, topk=generated_num)
        assert len(sum_of_freqs) == len(topks)
        res["coverage"].extend(sum_of_freqs)
        res["location"].extend([location] * len(topks))
        res["method"].extend(["Evescape"] * len(topks))
        res["date"].extend([year_and_month] * len(topks))
        res["topk"].extend(topks)
        res["total_sample_size"].extend([total_sample_size] * len(topks))
        res["total_seq_num"].extend([total_seq_num] * len(topks))
        
        # ========== evescape (rerank-last):
        df["seqs"] = seqs
        in_last = np.asarray([True if seq2count_last.get(seq, 0) >= 1 else False for seq in seqs ])
        _df = df[in_last]
        _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
        topks, sum_of_freqs = get_evescape_coverage(_df["seqs"][:generated_num].tolist(), seq2freq, topk=generated_num)
        assert len(sum_of_freqs) == len(topks)
        res["coverage"].extend(sum_of_freqs)
        res["location"].extend([location] * len(topks))
        res["method"].extend(["Evescape (rerank last)"] * len(topks))
        res["date"].extend([year_and_month] * len(topks))
        res["topk"].extend(topks)
        res["total_sample_size"].extend([total_sample_size] * len(topks))
        res["total_seq_num"].extend([total_seq_num] * len(topks))
        # ========== evescape (rerank-last):



>>> 2021-07


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_gu

>>> 2021-10


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_gu

>>> 2022-01


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_gu

>>> 2022-04


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_gu

In [27]:
res = pd.DataFrame(res)

In [28]:
res

Unnamed: 0,coverage,location,method,date,topk,total_sample_size,total_seq_num
0,0.012437,europe,Last,2021-07,0,1445792,6240
1,0.038220,europe,Last,2021-07,1,1445792,6240
2,0.944364,europe,Last,2021-07,2,1445792,6240
3,0.944372,europe,Last,2021-07,3,1445792,6240
4,0.944405,europe,Last,2021-07,4,1445792,6240
...,...,...,...,...,...,...,...
144207,0.806375,africa,Evescape (rerank last),2022-04,495,120091,4653
144208,0.806375,africa,Evescape (rerank last),2022-04,496,120091,4653
144209,0.806375,africa,Evescape (rerank last),2022-04,497,120091,4653
144210,0.806375,africa,Evescape (rerank last),2022-04,498,120091,4653


In [29]:
_summary_all = []
for k in (99, 299, 499):
    _summary = get_summary(res[res["topk"] == k])
    _summary = _summary.rename(columns={"median": f"top-{k}"})
    _summary = _summary.drop(columns=['std'])
    _summary_all.append(_summary)

merged_df = reduce(lambda left, right: left.merge(right, on='method'), _summary_all)
merged_df = merged_df.round(3)
print(merged_df)

                      method  top-99  top-299  top-499
0                  Add_Embed   0.828    0.875    0.889
1                Param_share   0.816    0.870    0.889
2   Transmission(eig_topk_3)   0.852    0.894    0.909
3     Evescape (rerank last)   0.118    0.408    0.664
4                       LoRA   0.828    0.855    0.863
5               Transmission   0.841    0.888    0.901
6                       Last   0.819    0.874    0.906
7                    Prepend   0.832    0.867    0.891
8                     Global   0.828    0.855    0.862
9                   Finetune   0.839    0.875    0.882
10                  Evescape   0.001    0.043    0.050


# Country_level

In [None]:
testing_window=3
temperature = 1.0
k=500
min_count=1000

res = defaultdict(list)
res_precision = defaultdict(list)

for year_and_month in ("2021-07", "2021-10", "2022-01", "2022-04"):
    year, month = year_and_month.split("-")
    year = int(year)
    index = (year - 2020) * 4 + (int(month) - 1) // 3

    print(">>>", year_and_month)
    training_path = data_folder / f"spike_rbd_processed_countries/2020-01_to_{year_and_month}_1M/all/human_minBinSize100_minLen223_maxLen223_location_region2.fasta"
    location2total_sample_size = defaultdict(int)
    location2total_seq_num = defaultdict(int)
    for record in SeqIO.parse(training_path, "fasta"):
        desc = record.description.split()[1].strip().split("|")
        desc = {x.split("=")[0]: x.split("=")[1] for x in desc}
        cnt = round(float(desc["freq"]) * float(desc["bin_size"]))
        location = desc["location"]
        location2total_sample_size[location] += cnt
        location2total_seq_num[location] += 1

    for location in "africa/south_africa asia/china asia/india asia/indonesia asia/israel asia/japan asia/south_korea europe/austria europe/belgium europe/czech_republic europe/denmark europe/france europe/germany europe/ireland europe/italy europe/luxembourg europe/netherlands europe/norway europe/poland europe/russia europe/slovenia europe/spain europe/sweden europe/switzerland europe/turkey europe/united_kingdom north_america/canada north_america/mexico north_america/usa oceania/australia south_america/brazil south_america/peru".split():
        gt_path = data_folder / f"spike_rbd_processed_countries/2020-01_to_2023-12_3M/all/human_minBinSize100_minLen223_maxLen223_location_region2_bins/{int(index)}.fasta"
        seq2freq = dict()
        seq2count = dict()
        for record in SeqIO.parse(gt_path, "fasta"):
            desc = record.description.split()[1].strip().split("|")
            desc = {x.split("=")[0]: x.split("=")[1] for x in desc}
            if desc["location"] != location:
                continue
            seq2freq[str(record.seq)] = float(desc["freq"])
            seq2count[str(record.seq)] = round(float(desc["freq"]) * float(desc["bin_size"]))

        total_sample_size = sum(seq2count.values())
        if total_sample_size < min_count:
            continue
        seq2freq = {k: v / total_sample_size for k, v in seq2count.items()}

        name2rev_nll_path = {
            # , temperature
                    "Add_Embed": runs_folder / f"cov_countries/2020-01_to_{year_and_month}_1M/concat/generations_beam_search_500_{testing_window}M/{location.split('/')[0]}_{location.split('/')[1]}/temp_{temperature}/lightning_logs/version_0/predictions.csv",
                    "Prepend": runs_folder / f"cov_countries/2020-01_to_{year_and_month}_1M/prepend/generations_beam_search_500_{testing_window}M/{location.split('/')[0]}_{location.split('/')[1]}/temp_{temperature}/lightning_logs/version_1/predictions.csv",
                    "Global": runs_folder / f"cov_countries/2020-01_to_{year_and_month}_1M/global/generations_beam_search_500_{testing_window}M/temp_{temperature}/lightning_logs/version_0/predictions.csv",
                    "LoRA": runs_folder / f"cov_countries/2020-01_to_{year_and_month}_1M/finetune_lora/generations_beam_search_500_{testing_window}M/{location.split('/')[0]}_{location.split('/')[1]}/temp_{temperature}/lightning_logs/version_0/predictions.csv",
                    "Finetune": runs_folder / f"cov_countries/2020-01_to_{year_and_month}_1M/finetune/generations_beam_search_500_{testing_window}M/{location.split('/')[0]}_{location.split('/')[1]}/temp_{temperature}/lightning_logs/version_1/predictions.csv",
                    "Param_share": runs_folder / f"cov_countries/2020-01_to_{year_and_month}_1M/param_share/generations_beam_search_500_{testing_window}M/{location.split('/')[0]}_{location.split('/')[1]}/temp_{temperature}/lightning_logs/version_0/predictions.csv",
                    "Transmission(G2L)": runs_folder / f"cov_countries/2020-01_to_{year_and_month}_1M/transmission_hierachy/country_to_continent/agg_complete_3/generations_beam_search_500_{testing_window}M/{location.split('/')[0]}_{location.split('/')[1]}/temp_{temperature}/lightning_logs/version_1/predictions.csv",
                    "Transmission(G2G)": runs_folder / f"cov_countries/2020-01_to_{year_and_month}_1M/transmission_hierachy/continent_to_continent/agg_complete_3/generations_beam_search_500_{testing_window}M/{location.split('/')[0]}_{location.split('/')[1]}/temp_{temperature}/lightning_logs/version_1/predictions.csv",
                }
        
        # Last
        seq2count_last = defaultdict(int)
        for last_index in range(index - 4, index): # 4 * 3M=12M
            last_path = data_folder / f"spike_rbd_processed_countries/2020-01_to_2023-12_3M/all/human_minBinSize100_minLen223_maxLen223_location_region2_bins/{int(last_index)}.fasta"
            if not os.path.exists(last_path):
                print(last_path)
                continue
            for record in SeqIO.parse(last_path, "fasta"):
                desc = record.description.split()[1].strip().split("|")
                desc = {x.split("=")[0]: x.split("=")[1] for x in desc}
                if desc["location"] != location:
                    continue
                seq2count_last[str(record.seq)] += round(float(desc["freq"]) * float(desc["bin_size"]))
        seq2count_last_sorted = sorted(seq2count_last.items(), key=lambda x: x[1], reverse=True)
        seq2count_last_sorted = [x[0] for x in seq2count_last_sorted]
        topks, sum_of_freqs = get_evescape_coverage(seq2count_last_sorted[:k], seq2freq, topk=k)
        assert len(sum_of_freqs) == len(topks)
        res["coverage"].extend(sum_of_freqs)
        res["location"].extend([location] * len(topks))
        res["method"].extend(["Last"] * len(topks))
        res["date"].extend([year_and_month] * len(topks))
        res["topk"].extend(topks)
        res["total_sample_size"].extend([location2total_sample_size[location]] * len(topks))
        res["total_seq_num"].extend([location2total_seq_num[location]] * len(topks))
        # Last
        
        for model in name2rev_nll_path:
            path = name2rev_nll_path[model]

            if not os.path.exists(path):
                print(path)
                continue

            # assert set(pd.read_csv(path)["src_time"]) == set(pd.read_csv(name2rev_nll_path["Transmission(G2L)"])["src_time"])
            
            topks, sum_of_freqs = get_coverage(path, seq2freq)
            assert len(sum_of_freqs) == len(topks), "len(sum_of_freqs)=%d, len(topks)=%d" % (len(sum_of_freqs), len(topks))

            res["coverage"].extend(sum_of_freqs)
            res["location"].extend([location] * len(topks))
            res["method"].extend([model] * len(topks))
            res["date"].extend([year_and_month] * len(topks))
            res["topk"].extend(topks)
            res["total_sample_size"].extend([location2total_sample_size[location]] * len(topks))
            res["total_seq_num"].extend([location2total_seq_num[location]] * len(topks))
        
        
        # evescape:
        evescape_score_path = runs_folder / f"cov_countries/2020-01_to_{year_and_month}_1M/evescape" / "all_seqs_EVEscape_scores.csv"
        evescape_seq_path = runs_folder / f"cov_countries/2020-01_to_{year_and_month}_1M/evescape" / "human_minBinSize100_lenQuantile0.2.fasta"
        id2seq = {}
        for record in SeqIO.parse(evescape_seq_path, "fasta"):
            id2seq[record.id] = str(record.seq)
        
        df = pd.read_csv(evescape_score_path)

        seqs = [id2seq[x] for x in df["src_id"]]
        topks, sum_of_freqs = get_evescape_coverage(seqs[:k], seq2freq, topk=k)
        assert len(sum_of_freqs) == len(topks)
        res["coverage"].extend(sum_of_freqs)
        res["location"].extend([location] * len(topks))
        res["method"].extend(["Evescape"] * len(topks))
        res["date"].extend([year_and_month] * len(topks))
        res["topk"].extend(topks)
        res["total_sample_size"].extend([location2total_sample_size[location]] * len(topks))
        res["total_seq_num"].extend([location2total_seq_num[location]] * len(topks))
        
        # ========== evescape (rerank-last):
        df["seqs"] = seqs
        in_last = np.asarray([True if seq2count_last.get(seq, 0) >= 1 else False for seq in seqs ])
        _df = df[in_last]
        _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
        topks, sum_of_freqs = get_evescape_coverage(_df["seqs"][:k].tolist(), seq2freq, topk=k)
        assert len(sum_of_freqs) == len(topks)
#         print(len(topks), len(sum_of_freqs))
        res["coverage"].extend(sum_of_freqs)
        res["location"].extend([location] * len(topks))
        res["method"].extend(["Evescape (rerank last)"] * len(topks))
        res["date"].extend([year_and_month] * len(topks))
        res["topk"].extend(topks)
        res["total_sample_size"].extend([location2total_sample_size[location]] * len(topks))
        res["total_seq_num"].extend([location2total_seq_num[location]] * len(topks))
        # ========== evescape (rerank-last)

>>> 2021-07


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_gu

>>> 2021-10


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df.sort_values("EVEscape score_sigmoid", ascending=False, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_gu

In [14]:
res = pd.DataFrame(res)

In [24]:
_summary_all = []
for k in (99, 299, 499):
    _summary = get_summary(res[res["topk"] == k])
    _summary = _summary.rename(columns={"median": f"top-{k}"})
    _summary = _summary.drop(columns=['std'])
    _summary_all.append(_summary)

merged_df = reduce(lambda left, right: left.merge(right, on='method'), _summary_all)
merged_df = merged_df.round(3)
print(merged_df)

                    method  top-99  top-299  top-499
0                Add_Embed   0.853    0.879    0.894
1              Param_share   0.818    0.870    0.884
2   Evescape (rerank last)   0.461    0.842    0.887
3                     LoRA   0.854    0.883    0.885
4        Transmission(G2L)   0.872    0.925    0.934
5                     Last   0.841    0.882    0.898
6                  Prepend   0.849    0.879    0.891
7        Transmission(G2G)   0.859    0.898    0.923
8                   Global   0.853    0.882    0.884
9                 Finetune   0.855    0.882    0.889
10                Evescape   0.000    0.003    0.021
