# Train-/Validation split creation
In this notebook we go through our process of creating the train dataset as well as the validation dataset. 
In the context of this project a dataset means a set of `(protein sequence, thermostability)` pairs. 
For a given protein sequence there may be multiple thermostability value measurements as given by the FLIP dataset.

## Read FLIP dataset
The source of the dataset file `full_dataset_sequences.fasta` is [here](https://github.com/J-SNACKKB/FLIP/tree/main/splits/meltome). Among other information it contains protein sequences and corresponding thermostability (melting point) measurements.

In [None]:
def read_fasta(filepath='../data/full_dataset_sequences.fasta'):
    first = True
    max =0
    dataset = []
    with open(filepath) as fasta:
        for line in fasta:
            if line[0] == '>':
                if first:
                    first = False
                else:
                    dataset.append(entry)
                entry = {}
                header_tokens = line.split(' ')
                entry['id'] = header_tokens[0].replace('>','').split('_')[0]
                entry['header'] = line.replace('\n', '')
                entry['temp'] = float(header_tokens[1].split('=')[1].replace('\n',''))
                entry['sequence'] = ''
            else:
                entry['sequence'] = entry['sequence'] + line.replace('\n','')
                max = len(entry['sequence']) if len(entry['sequence'])> max else max
    
    return dataset

flip_dataset = read_fasta()

## Read ESM validation protein ids
ESM validation protein ids are those Uniref50 cluster representative protein ids that were held out during training of ESM2 and ESMFold. 
As we are basically doing transfer learning on top of the ESMFold outputs, we also use these as a validation set, avoiding any potential data leakage due to ESMFold having seen proteins during training that we are using during validation. 
More info and a download link can be found [here](https://github.com/facebookresearch/esm#pre-training-dataset-split--) .


In [None]:
esm_eval_ids = set()
with open('../data/uniref201803_ur50_valid_headers.txt') as txt_file:
    for line in txt_file:
        id = line.split('_')[1].replace('\n','')
        esm_eval_ids.add(id)

## Create validation dataset and unfiltered train ids set

In [None]:
eval_dataset = []
train_unfiltered_dataset = []
train_ids = set()
all_ids = set()
dataset = read_fasta()
for entry in dataset:
    seq = entry['sequence']
    id = entry["id"]
    all_ids.add(id)
    if id in esm_eval_ids:
        eval_dataset.append(entry)
    else: 
        train_unfiltered_dataset.append(entry)
        train_ids.add(id)

## Filter train dataset by cluster
To avoid any similar protein sequences being present in our train and validation, we filter our train set to only contain proteins that are not part of the same cluster as proteins in our validation set, based on the [FLIP clustering](https://github.com/J-SNACKKB/FLIP/blob/main/splits/meltome/splits.zip)

### Read clusters and associated protein ids

In [None]:
clusters = {}
with open("../data/meltome_PIDE20_clusters.tsv", "r") as f:
    firstLine = True
    for line in f:
        if firstLine:
          firstLine = False
          continue   
        cluster_id, protein_id = line.replace("\n", "").split("\t")
        protein_id = protein_id.split('_')[0]
        
        if protein_id in all_ids:
            if cluster_id in clusters:
                clusters[cluster_id].add(protein_id)
            else: 
                clusters[cluster_id] = set([protein_id])

### Execute filtering

In [None]:
for cluster_id, protein_ids in clusters.items():
    num_train = 0
    num_eval = 0
    for protein_id in protein_ids:
        if protein_id in esm_eval_ids:
            num_eval += 1
        else: 
            num_train += 1
    
    if num_eval > 0 and num_train>0:
        for protein_id in protein_ids:
            if protein_id in train_ids:
                train_ids.remove(protein_id)

train_dataset = [item for item in train_unfiltered_dataset if item["id"] in train_ids]

## Store val.csv and train.csv

In [None]:
def storeMetadata(ds,name: str):
    with open(f"../data/s_s/{name}.csv", "w") as f:
        f.write("sequence, melting point\n")
        for entry in ds:
            f.write(f'{entry["sequence"]}, {entry["temp"]}\n')

storeMetadata(train_dataset, "train")
storeMetadata(eval_dataset, "val")