In [1]:
import pandas as pd

# Read split files
flip_test = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/test_FLIP.csv')
flip_val = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/val_FLIP.csv')
flip_train = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/train_FLIP.csv')
epa_test = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/val_median.csv')
epa_val = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/test_median.csv')
epa_train = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/train_median.csv')

# Extract sequences from split files
flip_test_sequences = set(flip_test['sequence'].to_numpy())
flip_val_sequences = set(flip_val['sequence'].to_numpy())
epa_test_sequences = set(epa_test['sequence'].to_numpy())
epa_val_sequences = set(epa_val['sequence'].to_numpy())

In [2]:
# Read ESM2 pretraining validation clusters
esm_eval_clusters = dict()
esm_eval_ids = set()
with open("/hpi/fs00/home/leon.hermann/hotprot/uniref201803_ur100_valid_headers.txt") as txt_file:
    for line in txt_file:
        parts = line.split(" ")
        id = parts[0].split("_")[1]
        cluster = parts[1].split("_")[1].replace("\n", "")
        esm_eval_ids.add(id)
        if cluster not in esm_eval_clusters:
            esm_eval_clusters[cluster] = []
        esm_eval_clusters[cluster].append(id)

In [3]:
# Complete dataset from FLIP with all measurements
def read_fasta(filepath="/hpi/fs00/home/leon.hermann/hotprot/full_dataset_sequences.fasta"):
    first = True
    max = 0
    dataset = []
    with open(filepath) as fasta:
        for line in fasta:
            if line[0] == ">":
                if first:
                    first = False
                else:
                    dataset.append(entry)
                entry = {}
                header_tokens = line.split(" ")
                entry["id"] = header_tokens[0].replace(">", "").split("_")[0]
                entry["header"] = line.replace("\n", "")
                entry["temp"] = float(header_tokens[1].split("=")[1].replace("\n", ""))
                entry["sequence"] = ""
            else:
                entry["sequence"] = entry["sequence"] + line.replace("\n", "")
                max = len(entry["sequence"]) if len(entry["sequence"]) > max else max

    return dataset

flip_dataset = read_fasta()

In [4]:
# Build list of validation proteins by linking the sequences to the ids
eval_proteins = set()
for entry in flip_dataset:
    if entry["id"] in esm_eval_ids:
        eval_proteins.add(entry["sequence"])

In [5]:
def data_leakage(sequences, esm_eval_proteins):
    # Calculate intersection of test/val sequences and validation proteins of ESM2 validation partition
    # --> intersection proteins will not show data leakage in the evaluation
    intersection = sequences.intersection(esm_eval_proteins)
    return len(intersection) / len(sequences)

In [6]:
# Calculate the amount of proteins without data leakage in validation and test partitions of both EPA and FLIP splits
epa_test__data_leakage = data_leakage(epa_test_sequences, eval_proteins)
epa_val__data_leakage = data_leakage(epa_val_sequences, eval_proteins)
flip_test__data_leakage = data_leakage(flip_test_sequences, eval_proteins)
flip_val__data_leakage = data_leakage(flip_val_sequences, eval_proteins)
print("Leakage free proteins in EPA split (val / test): ", epa_val__data_leakage, " / ", epa_test__data_leakage)
print("Leakage free proteins in FLIP split (val / test): ", flip_val__data_leakage, " / ", flip_test__data_leakage)

Leakage free proteins in EPA split (val / test):  1.0  /  1.0
Leakage free proteins in FLIP split (val / test):  0.09006734006734007  /  0.09285258455647734
