# AFF Disease Prediction Using BEHRT - Preprocessing for MLM

This notebook demonstrates preprocessing steps for AFF disease prediction using BEHRT, including handling patient information, disease history, and creating training datasets for MLM.

In [None]:

import pandas as pd
import numpy as np
import pickle

# Load the disease codes
csv_dir = '/path/to/csv/'
Dis_codeF = pd.read_csv(csv_dir + "DS.csv")

# Convert the DataFrame to a pickle file
with open('DS.pkl', 'wb') as f:
    pickle.dump(Dis_codeF, f)

# Load disease history
T20 = pd.read_csv(csv_dir + "T20.csv")
with open('T20.pkl', 'wb') as f:
    pickle.dump(T20, f)

# Load patient information
BFC = pd.read_csv(csv_dir + "BFC.csv")
with open('BFC.pkl', 'wb') as f:
    pickle.dump(BFC, f)

# Merge disease history and patient information
T20_BFC = pd.merge(T20, BFC, how='left', on='ID')

# Add AGE2 column
T20_BFC['AGE2'] = T20_BFC['YEAR'].astype('int64') - 2002 + T20_BFC['AGE'].astype('int64')


## Step 1: Count Disease Occurrences Per Year

In [None]:

# Count unique diseases by year and patient
year_cnt = T20_BFC[['d', 'ID', 'YEAR', 'AGE']].groupby(['YEAR', 'ID', 'AGE'], as_index=False).agg({"d": "nunique"})
year_cnt = pd.DataFrame(year_cnt).reset_index()

# Pivot the data for analysis
year_cnt2 = year_cnt.pivot(index="ID", columns=["YEAR"], values=["d", "AGE"])
year_cnt2 = year_cnt2.apply(lambda x: x.fillna(0), axis=0).reset_index()
print(year_cnt2.head())

# Define the maximum sequence length
max_len = max(year_cnt.d)


## Step 2: Padding and Vocabulary Creation

In [None]:

# Padding function
def padding(x):
    pad = np.zeros(max_len - len(x))
    x.extend(pad)
    return x

# Vocabulary creation function
def voc(data):
    data_cnt = data.groupby(['YEAR', 'ID']).agg({'d': lambda x: x.tolist()}).reset_index()
    data_cnt['d2'] = data_cnt['d']

    codes_unique = [item for items in data_cnt['d2'] for item in items]
    vocab = dict(zip(set(codes_unique), range(1, len(set(codes_unique)) + 1)))  # 0 for padding
    print('Vocabulary size:', len(vocab))
    return vocab

# Create the vocabulary
vocab2 = voc(T20_BFC)


## Step 3: Group Data for MLM and NSP

In [None]:

# Group data by patient and year
T20_BFC['AGE2'] = T20_BFC['AGE2'].astype(str)
T20_BFC['d'] = T20_BFC['d'].astype(str)
grouped_diagnoses = T20_BFC.groupby(['ID', 'YEAR'])[['d', 'AGE2']].agg(list).reset_index()

# Add 'SEP' to separate sequences
grouped_diagnoses['d2'] = grouped_diagnoses['d'].apply(lambda x: x + ['SEP'])
grouped_diagnoses['AGE3'] = grouped_diagnoses['AGE2'].apply(lambda x: x + [x[0]])

grouped = grouped_diagnoses[['ID', 'AGE3', 'd2']].groupby('ID').apply(
    lambda x: pd.Series({'d2': x['d2'].sum(), 'AGE': x['AGE3'].sum()})
).reset_index()

# Save the grouped data
with open('T20_BFC_BEHRT_grouped_df.pkl', 'wb') as f:
    pickle.dump(grouped, f)


## Step 4: Create MLM and NSP Datasets

In [None]:

# Load the grouped data
with open('T20_BFC_BEHRT_grouped_df.pkl', 'rb') as f:
    grouped_data = pickle.load(f)

grouped_data = pd.DataFrame(grouped_data)

# Add fold2 column for MLM and NSP splits
grouped_data['fold2'] = 0
test_index = (grouped_data['ID'].astype(str).str[-1].isin(['1', '3', '5', '7', '9']))
grouped_data.loc[test_index, 'fold2'] = 1

# Split data for MLM and NSP
group_data_mlm1 = grouped_data[grouped_data['fold2'] == 0].reset_index()
group_data_mlm2 = grouped_data[grouped_data['fold2'] == 1].reset_index()
group_data_nsp1 = grouped_data[grouped_data['fold2'] == 1].reset_index()
group_data_nsp2 = grouped_data[grouped_data['fold2'] == 0].reset_index()

# Save datasets
with open('T20_BFC_BEHRT_group_data_mlm_op1.pkl', 'wb') as f:
    pickle.dump(group_data_mlm1, f)

with open('T20_BFC_BEHRT_group_data_mlm_op2.pkl', 'wb') as f:
    pickle.dump(group_data_mlm2, f)

with open('T20_BFC_BEHRT_group_data_nsp_op1.pkl', 'wb') as f:
    pickle.dump(group_data_nsp1, f)

with open('T20_BFC_BEHRT_group_data_nsp_op2.pkl', 'wb') as f:
    pickle.dump(group_data_nsp2, f)


## Step 5: Verify Data Splits

In [None]:

# Verify the splits
print("MLM Dataset Option 1 Size:", len(group_data_mlm1))
print("MLM Dataset Option 2 Size:", len(group_data_mlm2))
print("NSP Dataset Option 1 Size:", len(group_data_nsp1))
print("NSP Dataset Option 2 Size:", len(group_data_nsp2))
