In [None]:
import pandas as pd

if not 'precision' in locals():
  precision = 'gene'

if not 'splitted_data_path' in locals():
  splitted_data_path = f"../../data/splitted_data/{precision}/beta"

if not 'train_file_name' in locals():
  train_file_name = 'train.tsv'

if not 'test_file_name' in locals():
  test_file_name = 'test.tsv'

if not 'validation_file_name' in locals():
  validation_file_name = 'validation.tsv'

df_train = pd.read_csv(f"{splitted_data_path}/{train_file_name}", sep="\t")
df_test = pd.read_csv(f"{splitted_data_path}/{test_file_name}", sep="\t")
df_validate = pd.read_csv(f"{splitted_data_path}/{validation_file_name}", sep="\t")
df_train = pd.concat([df_train, df_validate]) # epitopes and cdr3 are seen if in validate or train. naming not perfect

In [None]:
# Create sets for quick lookup
epitopes_in_train = set(df_train['Epitope'])
trb_cdr3_in_train = set(df_train['TRB_CDR3'])

In [None]:
# Function to verify each row
def verify_task(row):
    epitope_exists = row['Epitope'] in epitopes_in_train
    trb_cdr3_exists = row['TRB_CDR3'] in trb_cdr3_in_train
    
    if epitope_exists and trb_cdr3_exists:
        return 'TPP1' == row['task']
    elif epitope_exists and not trb_cdr3_exists:
        return 'TPP2' == row['task']
    elif not epitope_exists and not trb_cdr3_exists:
        return 'TPP3' == row['task']
    elif not epitope_exists and trb_cdr3_exists:
        return 'TPP4' == row['task']
    return False  # This handles unexpected cases

# Apply the verification function
df_test['is_correct'] = df_test.apply(verify_task, axis=1)

In [None]:
number_of_TPP1 = (df_test['task'] == 'TPP1').sum()
number_of_TPP2 = (df_test['task'] == 'TPP2').sum()
number_of_TPP3 = (df_test['task'] == 'TPP3').sum()
number_of_TPP4 = (df_test['task'] == 'TPP4').sum()
test_ratio = len(df_test)/(len(df_test) + len(df_train))

print(f"train data has {len(df_train)-len(df_validate)} entries")
print(f"test data has {len(df_test)} entries")
print(f"test data has {number_of_TPP1} TPP1 tasks (seen tcr & seen epitopes).")
print(f"test data has {number_of_TPP2} TPP2 tasks (unseen tcr & seen epitopes).")
print(f"test data has {number_of_TPP3} TPP3 tasks (unseen tcr & unseen epitope).")
print(f"test data has {number_of_TPP4} TPP4 tasks (seen tcr & unseen epitope).")
print(f"the train/test ratio is {(1-test_ratio)}/{test_ratio}")

In [None]:
# Check overall correctness
correctness_summary = df_test['is_correct'].value_counts()

# Optionally, identify rows with incorrect task settings
incorrect_rows = df_test[df_test['is_correct'] == False]
if len(incorrect_rows) > 0:
  print("Incorrectly set tasks:")
  print(incorrect_rows[['Epitope', 'TRB_CDR3', 'task']])
else:
  print("Classification is correct.")

print("Correctness summary:")
print(correctness_summary)