In [None]:
import pandas as pd

splitted_data_path = "../../data/splitted_data/paired"

df_train = pd.read_csv(f"{splitted_data_path}/train.tsv", sep="\t")
df_test = pd.read_csv(f"{splitted_data_path}/test.tsv", sep="\t")

In [None]:
# Create sets for quick lookup
tcr_key = "tcr_key"
df_train[tcr_key] = df_train['TRA_CDR3'].astype(str) + '_' + df_train['TRB_CDR3'].astype(str)
df_test[tcr_key] = df_test['TRA_CDR3'].astype(str) + '_' + df_test['TRB_CDR3'].astype(str)
epitopes_in_train = set(df_train['Epitope'])
cdr3_in_train = set(df_train[tcr_key])

In [None]:
# Function to verify each row
def verify_task(row):
    epitope_exists = row['Epitope'] in epitopes_in_train
    cdr3_exists = row[tcr_key] in cdr3_in_train
    
    if epitope_exists and cdr3_exists:
        return 'TPP1' == row['task']
    elif epitope_exists and not cdr3_exists:
        return 'TPP2' == row['task']
    elif not epitope_exists and not cdr3_exists:
        return 'TPP3' == row['task']
    return False  # This handles unexpected cases

# Apply the verification function
df_test['is_correct'] = df_test.apply(verify_task, axis=1)

In [None]:
# Check overall correctness
correctness_summary = df_test['is_correct'].value_counts()

# Optionally, identify rows with incorrect task settings
incorrect_rows = df_test[df_test['is_correct'] == False]
if len(incorrect_rows) > 0:
  print("Incorrectly set tasks:")
  print(incorrect_rows[['Epitope', 'TRA_CDR3', 'TRB_CDR3', 'task']])
else:
  print("Task property seems to be right")

print("Correctness summary:")
print(correctness_summary)