# AFF Disease Prediction Using BEHRT for Disease Prediction

This notebook demonstrates the prediction process for AFF disease using BEHRT with pre-trained MLM.

In [None]:

# Import necessary libraries
import pickle
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split


## Step 1: Load Preprocessed Data

In [None]:

# Load NSP data for AFF prediction (Option 1)
with open('../task/T20_BFC_BEHRT_group_data_sickFinal_NSP_op1.pkl', 'rb') as f:
    data = pickle.load(f)
group_data_sickFinal_NSP1 = pd.DataFrame(data)

# Select AFF disease with code '138'
NDP_T20_BFC_ArtrialF, NDP_T20_BFC_ArtrialF_no = group_data_sickFinal_NSP1[group_data_sickFinal_NSP1['d'] == '138'],                                                  group_data_sickFinal_NSP1[group_data_sickFinal_NSP1['d'] != '138']


## Step 2: Filter and Prepare Data

In [None]:

# Prepare AFF disease group
df_138 = NDP_T20_BFC_ArtrialF.copy()
df_138['YEAR'] = pd.to_datetime(df_138['YEAR'], format='%Y')

# Calculate the first and specific disease years
first_starting_year = df_138.groupby('ID')['YEAR'].min()
specific_disease_year = df_138.groupby('ID')['YEAR'].transform('min')

# Merge into result dataframe
result_df = pd.DataFrame({'ID': df_138['ID'].unique()})
result_df['first_starting_year'] = result_df['ID'].map(first_starting_year)
result_df['specific_disease_year'] = result_df['ID'].map(specific_disease_year)


## Step 3: Define Sequence Calculation Functions

In [None]:

# Function to calculate disease sequence with 'SEP' token
def get_disease_sequence(row):
    id_data = df_138[(df_138['ID'] == row['ID']) & 
                     (df_138['YEAR'] >= row['first_starting_year']) & 
                     (df_138['YEAR'] < row['specific_disease_year'])]
    seq = []
    for index, year in enumerate(id_data['YEAR']):
        if index > 0 and id_data['YEAR'].iloc[index - 1] != year:
            seq.append('SEP')
        seq.append(str(id_data['d'].iloc[index]))
    return seq

# Apply the function to calculate sequences
result_df['disease_sequenceF'] = result_df.apply(get_disease_sequence, axis=1)


## Step 4: Sampling Data

In [None]:

# Prepare data for sampling
result_df['length'] = result_df['disease_sequenceF'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))
result_df = result_df[result_df['length'] >= 2].reset_index(drop=True)

# Perform stratified sampling based on 'first_starting_year'
sampled_df = result_df.groupby('first_starting_year', group_keys=False).apply(lambda x: x.sample(frac=0.5, random_state=42))
sampled_df['age_B'] = sampled_df['disease_sequenceF'].apply(lambda x: x[0] if len(x) > 0 else None)


## Step 5: Prepare Train, Validation, and Test Splits

In [None]:

# Split the data into train, validation, and test sets
X = sampled_df.drop(['disease_sequenceF'], axis=1)
Y = sampled_df['disease_sequenceF']

X_train, X_temp, y_train, y_temp = train_test_split(X, Y, test_size=0.3, random_state=42, stratify=sampled_df['age_B'])
X_valid, X_test, y_valid, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=X_temp['age_B'])

# Combine data for saving
Train = pd.concat([X_train, y_train], axis=1)
Valid = pd.concat([X_valid, y_valid], axis=1)
Test = pd.concat([X_test, y_test], axis=1)


## Step 6: Save Train, Validation, and Test Sets

In [None]:

# Save the train, validation, and test datasets
with open('../task/d138_5y_onset_Train2.pkl', 'wb') as f:
    pickle.dump(Train, f)
with open('../task/d138_5y_onset_Valid2.pkl', 'wb') as f:
    pickle.dump(Valid, f)
with open('../task/d138_5y_onset_Test2.pkl', 'wb') as f:
    pickle.dump(Test, f)


## Step 7: Merge with BFC Data and Analyze Age Groups

In [None]:

# Load and merge BFC data
with open('./BFC.pkl', 'rb') as f:
    BFC = pickle.load(f)
merged_df = Train.merge(BFC, on='ID')

# Analyze age groups
merged_df['age_B'] = pd.to_numeric(merged_df['age_B'], errors='coerce')
bins = [0, 20, 40, 65, 110]
labels = ['below 20', '20s-30s', '40s-64s', 'over 65s']
merged_df['AgeGroup'] = pd.cut(merged_df['age_B'], bins=bins, labels=labels, right=False)
print(merged_df['AgeGroup'].value_counts())
