# Train-/Validation split creation
In this notebook we go through our process of creating the train dataset as well as the validation dataset. 
In the context of this project a dataset means a set of `(protein sequence, thermostability)` pairs. 
For a given protein sequence there may be multiple thermostability value measurements as given by the FLIP dataset.

## Read Meltome dataset
The source of the dataset file `full_dataset_sequences.fasta` is [here](https://github.com/J-SNACKKB/FLIP/tree/main/splits/meltome). Among other information it contains protein sequences and corresponding thermostability (melting point) measurements.

In [1]:
def read_fasta(filepath="data/full_dataset_sequences.fasta"):
    first = True
    max = 0
    dataset = []
    with open(filepath) as fasta:
        for line in fasta:
            if line[0] == ">":
                if first:
                    first = False
                else:
                    dataset.append(entry)
                entry = {}
                header_tokens = line.split(" ")
                entry["id"] = header_tokens[0].replace(">", "").split("_")[0]
                entry["header"] = line.replace("\n", "")
                entry["temp"] = float(header_tokens[1].split("=")[1].replace("\n", ""))
                entry["sequence"] = ""
            else:
                entry["sequence"] = entry["sequence"] + line.replace("\n", "")
                max = len(entry["sequence"]) if len(entry["sequence"]) > max else max

    return dataset


flip_dataset = read_fasta()

## Read ESM validation protein clusters
ESM validation protein clusters are those Uniref100 clusters that were held out during training of ESM2 and ESMFold. 
As we are basically doing transfer learning on top of the ESMFold outputs, we also use these as a validation/test set, avoiding any potential data leakage due to ESMFold having seen proteins during training that we are using during validation/testing. 
More info and a download link can be found [here](https://github.com/facebookresearch/esm#pre-training-dataset-split--) .


In [2]:
esm_eval_clusters = dict()
esm_eval_ids = set()
with open("./data/uniref201803_ur100_valid_headers.txt") as txt_file:
    for line in txt_file:
        parts = line.split(" ")
        id = parts[0].split("_")[1]
        cluster = parts[1].split("_")[1].replace("\n", "")
        esm_eval_ids.add(id)
        if cluster not in esm_eval_clusters:
            esm_eval_clusters[cluster] = []
        esm_eval_clusters[cluster].append(id)

## Create held out dataset (test/val) and train ids set

In [3]:
held_out_dataset = []
train_dataset = []
train_ids = set()
all_ids = set()
dataset = read_fasta()
for entry in dataset:
    seq = entry["sequence"]
    id = entry["id"]
    all_ids.add(id)
    if id in esm_eval_ids:
        held_out_dataset.append(entry)
    else:
        train_dataset.append(entry)
        train_ids.add(id)

## Split held out dataset by cluster
To avoid optimizing the hyperparameters of our model for the given validation set, we also introduce a test set. The test set is then only used for final evaluation of the model. However hyperparameters won't be optimized for best performance on the test set but on the validation set. 

To avoid any similar protein sequences being present in our test and validation set, we construct these sets by randomly splitting the held out set (i.e. the non-training set) by their UniRef100 cluster. While doing this each cluster has a 2/3rd chance to be added to the validation set and a 1/3rd chance of being added to the test set respectively.

### Execute filtering

In [4]:
test_ids = set()
val_ids = set()
import random

random.seed(42)

for cluster_id, protein_ids in esm_eval_clusters.items():
    is_test = random.random() <= 1 / 3
    for protein_id in protein_ids:
        (test_ids if is_test else val_ids).add(protein_id)

test_dataset = [item for item in held_out_dataset if item["id"] in test_ids]
val_dataset = [item for item in held_out_dataset if item["id"] in val_ids]

## Store val.csv and train.csv

In [8]:
def storeMetadata(ds, name: str):
    with open(f"./data/{name}.csv", "w") as f:
        f.write("sequence, melting point\n")
        for entry in ds:
            f.write(f'{entry["sequence"]}, {entry["temp"]}\n')


def print_ds_infos(name: str, ds: list):
    unique_seqs = set([entry["sequence"] for entry in ds])
    print(f'{"-"*5}Info for {name} set{"-"*5}')
    print("Num sequences:", len(ds))
    print("Num sequences (len < 3000):", len([entry for entry in ds if len(entry["sequence"]) < 3000]))
    print("Num unique sequences: ", len(unique_seqs))
    print(
        "Num unique sequences (len < 3000): ",
        len([seq for seq in unique_seqs if len(seq) < 3000]),
    )




# print_ds_infos("train", train_dataset)
# print_ds_infos("val", val_dataset)
# print_ds_infos("test", test_dataset)

storeMetadata(train_dataset, "train")
storeMetadata(val_dataset, "val")
storeMetadata(test_dataset, "test")

# Create our dataset with median measurements
As our dataset contains mutliple melting point measurements per protein, there is no single true value for a given protein. Another approach than showing the model all measurements for a given protein would be to take the median, so that we don't optimize back and forth

In [9]:
import numpy as np
def poolMeasurements(ds):
    """
    Pool measurements of a protein via median
    """
    measurementsPerProtein = dict([(sample["sequence"], []) for sample in ds])
    for sample in ds:
        measurementsPerProtein[sample["sequence"]].append(sample["temp"])
    pooledDs = [{"temp": np.median(measurements), "sequence": sequence} for sequence, measurements in measurementsPerProtein.items()]
    return pooledDs
    
train_median_ds = poolMeasurements(train_dataset)
val_median_ds = poolMeasurements(val_dataset)
test_median_ds = poolMeasurements(test_dataset)

storeMetadata(train_median_ds, "train_EPA")
storeMetadata(val_median_ds, "val_EPA")
storeMetadata(test_median_ds, "test_EPA")
print_ds_infos("train_EPA", train_median_ds)
print_ds_infos("val_EPA", val_median_ds)
print_ds_infos("test_EPA", test_median_ds)

-----Info for train_median set-----
Num sequences: 30988
Num sequences (len < 3000): 30844
Num unique sequences:  30988
Num unique sequences (len < 3000):  30844
-----Info for val_median set-----
Num sequences: 2129
Num sequences (len < 3000): 2113
Num unique sequences:  2129
Num unique sequences (len < 3000):  2113
-----Info for test_median set-----
Num sequences: 1091
Num sequences (len < 3000): 1082
Num unique sequences:  1091
Num unique sequences (len < 3000):  1082


# Convert FLIP dataset split to our format

In [10]:
with open("data/mixed_split.csv", "r") as f:
    # columns sequence,target,set,validation
    train = []
    val = []
    test = []

    for i, line in enumerate(f):
        if i == 0:
            continue
        sequence, target, split, validation = line.strip().split(",")
       
       
        if split == "train":
            if validation=="True":
                val.append({"sequence": sequence, "temp": target})
            else: 
                train.append({"sequence": sequence, "temp": target})
        elif split == "test":
            test.append({"sequence": sequence, "temp": target})
        else: 
            raise Exception("Invalid set")
    
    storeMetadata(train, "train_FLIP")
    storeMetadata(val, "val_FLIP")
    storeMetadata(test, "test_FLIP")
    print_ds_infos("train_FLIP", train)
    print_ds_infos("val_FLIP", val)
    print_ds_infos("test_FLIP", test)

-----Info for train_FLIP set-----
Num sequences: 22335
Num sequences (len < 3000): 22225
Num unique sequences:  18435
Num unique sequences (len < 3000):  18349
-----Info for val_FLIP set-----
Num sequences: 2482
Num sequences (len < 3000): 2466
Num unique sequences:  2376
Num unique sequences (len < 3000):  2362
-----Info for test_FLIP set-----
Num sequences: 3134
Num sequences (len < 3000): 3115
Num unique sequences:  3134
Num unique sequences (len < 3000):  3115
