In [2]:
import os
import pandas as pd
from tqdm.notebook import tqdm
from sentence_transformers import InputExample
import numpy as np
import random
seed = 42
random.seed(seed)
np.random.seed(seed)

In [2]:
VOC_NAMES = ["Alpha", "Beta", "Delta", "Gamma", "Omicron"]

df = pd.DataFrame()

for i, voc_name in enumerate(VOC_NAMES):
    voc_df = pd.read_csv(f"data/unique_{voc_name}_2k.csv")
    voc_df['label'] = i
    df = pd.concat([df, voc_df])

df = df.drop(['accession_id', 'date'], axis=1)
df = df.reset_index(drop=True)
display(df)

Unnamed: 0,sequence,label
0,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,0
1,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,0
2,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,0
3,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,0
4,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,0
...,...,...
9995,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVXX...,4
9996,TQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHA...,4
9997,TQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHA...,4
9998,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,4


In [3]:
seq = {}

for voc in VOC_NAMES:
    seq[voc] = []

for i in range(len(df)):
    seq[VOC_NAMES[df["label"][i]]].append(df["sequence"][i])
    
print(seq["Omicron"][0])

MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHVIHVSGTNGTKRFDNPVLPFNDGVYFASIEKSNIIRGWIFGTTLDSKTQSLLIVNNATNVVIKVCEFQFCNDPFLGVYYHKNNKSWMESEFRVYSSANNCTFEYVSQPFLMDLEGKQGNFKNLREFVFKNIDGYFKIYSKHTPIIIVRDLPQGFSALEPLVDLPIGINITRFQTLLALHRSYLTPGDSSSGWTAGAAAYYVGYLQPRTFLLKYNENGTITDAVDCALDPLSETKCTLKSFTVEKGIYQTSNFRVQPTESIVRFPNITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLAPFFTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGNIADYNYKLPDDFTGCVIAWNSNKLDSKVSGNYNYLYRLFRKSNLKPFERDISTEIYQAGNKPCNGVAGFNCYFPLRSYSFRPTYGVGHQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFNFNGLKGTGVLTESNKKFLPFQQFGRDIADTTDAVRDPQTLEILDITPCSFGGVSVITPGTNTSNQVAVLYQGVNCTEVPVAIHADQLTPTWRVYSTGSNVFQTRAGCLIGAEYVNNSYECDIPIGAGICASYQTQTKSHRRARSVASQSIIAYTMSLGAENSVAYSNNSIAIPTNFTISVTTEILPVSMTKTSVDCTMYICGDSTECSNLLLQYGSFCTQLKRALTGIAVEQDKNTQEVFAQVKQIYKTPPIKYFGGFNFSQILPDPSKPSKRSFIEDLLFNKVTLADAGFIKQYGDCLGDIAARDLICAQKFKGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNHNAQALNTLVKQLSSKFGAISSVLNDIFSRLDKVEAEVQIDRLITGR

# Construct Contrastive Dataset

In [9]:
pos_seq1, pos_seq2, pos_labels = [], [], []
total_iterations = sum(len(seqlist) * (len(seqlist) - 1) // 2 for seqlist in seq.values())

with tqdm(total=total_iterations, desc="Processing sequences") as pbar:
    for seqlist in seq.values():
        for s1 in range(len(seqlist)):
            for s2 in range(s1+1, len(seqlist)):
                pos_seq1.append(seqlist[s1])
                pos_seq2.append(seqlist[s2])
                pos_labels.append(1)
                pbar.update(1)
    
pos_examples = pd.DataFrame({'seq1': pos_seq1, 'seq2': pos_seq2, 'label': pos_labels})
display(pos_examples)

Processing sequences:   0%|          | 0/9995000 [00:00<?, ?it/s]

Unnamed: 0,seq1,seq2,label
0,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,1
1,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,1
2,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,1
3,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,1
4,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,1
...,...,...,...
9994995,TQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHA...,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,1
9994996,TQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHA...,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,1
9994997,TQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHA...,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,1
9994998,TQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHA...,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,1


In [10]:
neg_seq1, neg_seq2, neg_labels = [], [], []
total_iterations = sum(len(seq[VOC_NAMES[voc1]]) * len(seq[VOC_NAMES[voc2]]) // 4
                      for voc1 in range(len(VOC_NAMES))
                      for voc2 in range(voc1 + 1, len(VOC_NAMES)))

with tqdm(total=total_iterations, desc="Processing") as pbar:
    for voc1 in range(len(VOC_NAMES)):
        for voc2 in range(voc1+1, len(VOC_NAMES)):
            for s1 in seq[VOC_NAMES[voc1]]:
                for i, s2 in enumerate(seq[VOC_NAMES[voc2]]):
                    if i % 4 == 0:
                        neg_seq1.append(s1)
                        neg_seq2.append(s2)
                        neg_labels.append(0)
                        pbar.update(1)
neg_examples = pd.DataFrame({'seq1': neg_seq1, 'seq2': neg_seq2, 'label': neg_labels})
display(neg_examples)

Processing:   0%|          | 0/10000000 [00:00<?, ?it/s]

Unnamed: 0,seq1,seq2,label
0,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,0
1,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNFTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,0
2,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNFTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,0
3,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,0
4,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNFTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,0
...,...,...,...
9999995,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,0
9999996,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLITRIQSYTNSFTRGVYYPDKVFRSSVLH...,0
9999997,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLITRTQSYTNSFTRGVYYPDKVFRSSVLH...,0
9999998,MFVFLVLLPLVSSQCVNFTNRTQLPSAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,0


# Train Eval Test Split

In [18]:
from sklearn.model_selection import train_test_split

merged_df = pd.concat([pos_examples, neg_examples], ignore_index=True)

# Split the data into features (X) and labels (y)
X = merged_df[['seq1', 'seq2']]
y = merged_df['label']

# Apply train, dev, test split using 70:20:10 ratio with shuffling and stratification
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, stratify=y, shuffle=True, random_state=42)
X_dev, X_test, y_dev, y_test = train_test_split(X_temp, y_temp, test_size=0.33, stratify=y_temp, shuffle=True, random_state=42)

# Create dataframes for training, dev, and test sets
train_df = pd.concat([X_train, y_train], axis=1).reset_index(drop=True)
dev_df = pd.concat([X_dev, y_dev], axis=1).reset_index(drop=True)
test_df = pd.concat([X_test, y_test], axis=1).reset_index(drop=True)

# Display the shapes of the resulting datasets
print("Train set shape:", train_df.shape)
print("Dev set shape:", dev_df.shape)
print("Test set shape:", test_df.shape)

Train set shape: (13996500, 3)
Dev set shape: (4018995, 3)
Test set shape: (1979505, 3)


# Save Contrastive Dataset

In [29]:
import pickle
with open('contrastive_train.pkl', 'wb') as file:
    pickle.dump(train_df, file)
with open('contrastive_dev.pkl', 'wb') as file:
    pickle.dump(dev_df, file)
with open('contrastive_test.pkl', 'wb') as file:
    pickle.dump(test_df, file)

In [83]:


VOC_NAMES = ["Alpha", "Beta", "Delta", "Gamma", "Omicron"]

sequences = []
for voc_name in VOC_NAMES:
    sequences.append(pd.read_csv(f"data/unique_{voc_name}_2k.csv")["sequence"].tolist())

examples = []
# sequences = [["00" + str(i) for i in range(10, 40)],
#              ["10" + str(i) for i in range(10, 40)],
#              ["20" + str(i) for i in range(10, 40)],
#              ["30" + str(i) for i in range(10, 40)],
#              ["40" + str(i) for i in range(10, 40)]]
r = 1
while len(sequences[0]) >= 12: # 12 sequences from each VoC per round
    print("round", r)
    for a_p_list_id in range(len(sequences)):
        n_list_ids = list(range(len(sequences))).copy()
        n_list_ids.remove(a_p_list_id)
        #print("n_list_ids: ", n_list_ids)
        for n_list_id in n_list_ids:
            anchor_positive_list = sequences[a_p_list_id]
            negative_list = sequences[n_list_id]
            anchor_index = random.randint(0, len(anchor_positive_list) - 1)
            anchor = anchor_positive_list[anchor_index]
            anchor_positive_list.pop(anchor_index)
            positive_index = random.randint(0, len(anchor_positive_list) - 1)
            positive = anchor_positive_list[positive_index]
            anchor_positive_list.pop(positive_index)
            negative_index = random.randint(0, len(negative_list) - 1)
            negative = negative_list[negative_index]
            negative_list.pop(negative_index)
            triplet = [anchor, positive, negative]
            examples.append(InputExample(texts=triplet))
            #print(triplet)
    r += 1



round 1
round 2
round 3
round 4
round 5
round 6
round 7
round 8
round 9
round 10
round 11
round 12
round 13
round 14
round 15
round 16
round 17
round 18
round 19
round 20
round 21
round 22
round 23
round 24
round 25
round 26
round 27
round 28
round 29
round 30
round 31
round 32
round 33
round 34
round 35
round 36
round 37
round 38
round 39
round 40
round 41
round 42
round 43
round 44
round 45
round 46
round 47
round 48
round 49
round 50
round 51
round 52
round 53
round 54
round 55
round 56
round 57
round 58
round 59
round 60
round 61
round 62
round 63
round 64
round 65
round 66
round 67
round 68
round 69
round 70
round 71
round 72
round 73
round 74
round 75
round 76
round 77
round 78
round 79
round 80
round 81
round 82
round 83
round 84
round 85
round 86
round 87
round 88
round 89
round 90
round 91
round 92
round 93
round 94
round 95
round 96
round 97
round 98
round 99
round 100
round 101
round 102
round 103
round 104
round 105
round 106
round 107
round 108
round 109
round 110
round 11

In [84]:
print(len(examples))

3320
