In [None]:
from google.colab import drive
drive.mount('/content/drive')

file_path = '/content/drive/MyDrive/Health Project/'

# **Load embeddings, build cohort, save data for training**

## Load libraries and setup environment

In [None]:
# Import libraries

!pip install tslearn
!pip install minisom
!pip install dtw-python
!pip install Levenshtein
!pip install optuna


from tslearn.metrics import cdist_dtw
from sklearn.cluster import AgglomerativeClustering
from minisom import MiniSom
from dtw import dtw

from datetime import timedelta
import os

import numpy as np
import pandas as pd
import seaborn as sns
import os

from sklearn.cluster import KMeans
from sklearn.cluster import DBSCAN
from sklearn import metrics
from sklearn.decomposition import PCA #Principal Component Analysis
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors

import re
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

from IPython.display import display, HTML, Image
%matplotlib inline

plt.style.use('ggplot')
plt.rcParams.update({'font.size': 20})

# Access data using Google BigQuery.
from google.colab import auth
from google.cloud import bigquery

import bigframes.pandas as bf
import matplotlib.pyplot as plt
import plotly as py
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.io as pio
pio.renderers.default = "colab"

from gensim.models import Word2Vec

from IPython.display import clear_output

import torch

import copy
import datetime
import sys




In [None]:
bf.options.bigquery.location = "US"
bf.options.bigquery.project = 'loyal-mason-431106-n3' #'hellobigquery-431508'

# authenticate
auth.authenticate_user()

# Set up environment variables
project_id = 'loyal-mason-431106-n3' #'hellobigquery-431508'
os.environ["GOOGLE_CLOUD_PROJECT"] = project_id
dataset = 'mimiciv'


## **Load embedding mapping from itemid to vectors**

In [None]:
# load embedding vectors
df_itemid_to_vector = pd.read_csv(file_path + 'itemid_to_vector.csv')
itemid_to_vector = {str(key): value for key, value in df_itemid_to_vector.set_index('itemid').T.to_dict('list').items()}
print(len(itemid_to_vector))


## **Find patients with AD related ICD codes.**

Load all diagnoses icd table, and filter with our event list:

In [None]:
query = """
  SELECT
    d.*,
    a.dischtime AS discharge_time
  FROM
    `physionet-data.mimiciv_hosp.diagnoses_icd` AS d
  INNER JOIN
    `physionet-data.mimiciv_hosp.admissions` AS a
  ON
    d.hadm_id = a.hadm_id
  WHERE
    d.icd_code IN ('G300', 'G301', 'G308', 'G309', 'F0280', 'F0281', 'F0290', 'F0391', 'F04',
      'F060', 'F068', 'G3101', 'G3109', 'G311', 'G3183', 'G3185', 'G3189', 'G319',
      'G454', 'G937', 'G94', 'G910', 'G911', 'G912', 'F0150', 'F0151', 'I675',
      'I671', 'I672', 'I674', 'I676', 'I677', 'I6781', 'I6782', 'I6789', 'I679')

"""
df_ad_patients_with_discharge_time = bf.read_gbq(query)
print(len(df_ad_patients_with_discharge_time))
df_ad_patients_with_discharge_time.head(10)

# WHERE
    # d.icd_code IN ('G300', 'G301', 'G308', 'G309', 'F0280', 'F0281', 'F0290', 'F0391', 'F04',
    # 'F060', 'F068', 'G3101', 'G3109', 'G311', 'G3183', 'G3185', 'G3189', 'G319',
    # 'G454', 'G937', 'G94', 'G910', 'G911', 'G912', 'F0150', 'F0151', 'I675',
    # 'I671', 'I672', 'I674', 'I676', 'I677', 'I6781', 'I6782', 'I6789', 'I679',
    # '29012', '331', '2904', '29041', '29042', '29043', '294', '2948', '2949',
    # '2941', '29411', '2942', '29421', '3312', '3316', '3317', '33111', '33119',
    # '33181', '33182', '33189', '4377', '3313', '3314', '3315', '437', '4371',
    # '4372', '4373', '4374', '4375', '4376', '4378', '4379', '33183', '331',
    # '3311' )

In [None]:
# df_ad_patients_with_discharge_time.head(10)

In [None]:
ad_patients_list = list(df_ad_patients_with_discharge_time.drop_duplicates(subset=['subject_id'])['subject_id'])
len(ad_patients_list)
# df_ad_patients

## Possible Tests

## **See what tests are taken by these patients**

In [None]:
# See what labtests are taken for these patients

query = """
  SELECT
    d.subject_id,
    d.hadm_id,
    d.icd_code,
    d.icd_version,
    l.itemid,
    l.valuenum,
    l.valueuom,
    l.labevent_id,
    l.charttime,
    l.flag,
    dlab.label,
    dlab.fluid,
    dlab.category
  FROM
    (
      SELECT subject_id, hadm_id, icd_code, icd_version
      FROM `physionet-data.mimiciv_hosp.diagnoses_icd`
      WHERE icd_code IN ('G300', 'G301', 'G308', 'G309', 'F0280', 'F0281', 'F0290', 'F0391', 'F04',
            'F060', 'F068', 'G3101', 'G3109', 'G311', 'G3183', 'G3185', 'G3189', 'G319',
            'G454', 'G937', 'G94', 'G910', 'G911', 'G912', 'F0150', 'F0151', 'I675',
            'I671', 'I672', 'I674', 'I676', 'I677', 'I6781', 'I6782', 'I6789', 'I679')
    ) AS d
  INNER JOIN
    (
      SELECT subject_id, hadm_id, labevent_id, itemid, valuenum, valueuom, charttime, flag
      FROM `physionet-data.mimiciv_hosp.labevents`
      WHERE lower(flag) LIKE 'abnormal%'
        OR flag IS NULL
    ) AS l
  ON
    d.subject_id = l.subject_id
    AND d.hadm_id = l.hadm_id
  INNER JOIN
    `physionet-data.mimiciv_hosp.d_labitems` AS dlab
  ON
    l.itemid = dlab.itemid
    WHERE l.valuenum IS NOT NULL  -- remove NULL values
  # WHERE (LOWER(dlab.label) LIKE '%csf' OR
  #  LOWER(dlab.label) LIKE'%b12%')
"""
df_ad_patients_lab_results = bf.read_gbq(query).sort_values(by=['subject_id', 'charttime', 'labevent_id']).reset_index(drop=True)
print(len(df_ad_patients_lab_results))
df_ad_patients_lab_results.head(10)


For each patient, only keep the test data with charttime before their diagnoses' dischargetime

In [None]:
df_temp = df_ad_patients_with_discharge_time[['subject_id', 'hadm_id', 'discharge_time']].drop_duplicates(subset=['subject_id', 'hadm_id']).to_pandas()
df_ad_patients_lab_results_pd = df_ad_patients_lab_results.to_pandas()

df_ad_patients_lab_results_pd = pd.merge(df_ad_patients_lab_results_pd, df_temp, on=['subject_id', 'hadm_id'], how='left')

df_ad_patients_lab_results_pd = df_ad_patients_lab_results_pd[df_ad_patients_lab_results_pd['charttime'] <= df_ad_patients_lab_results_pd['discharge_time']].reset_index(drop=True)

print(len(df_ad_patients_lab_results_pd))
df_ad_patients_lab_results_pd.head(10)


In [None]:
df_ad_patients_lab_results_pd = df_ad_patients_lab_results.to_pandas()
df_ad_patients_lab_results_pd['itemid'] = df_ad_patients_lab_results_pd['itemid'].astype(str)
df_ad_patients_lab_results_pd.dtypes

In [None]:
# see how many unique patients:
print(len(df_ad_patients_lab_results_pd['subject_id'].unique()))

In [None]:
possible_tests = bf.read_gbq("""
  SELECT *
  FROM `physionet-data.mimiciv_hosp.d_labitems`
""")
ad_tests = list(set(df_ad_patients_lab_results['itemid']))
ad_test_names = possible_tests[possible_tests['itemid'].isin(ad_tests)].to_pandas()
ad_test_names['itemid'] = ad_test_names['itemid'].astype(str)

ad_test_names.head(1)

In [None]:
itemid_counts_df = df_ad_patients_lab_results_pd.groupby('itemid')['subject_id'].nunique().reset_index()
itemid_counts_df.columns = ['itemid', 'count']
itemid_counts_df['itemid'] = itemid_counts_df['itemid'].astype(str)


merged_df = pd.merge(itemid_counts_df, ad_test_names, on='itemid', how='inner')

merged_df_sorted = merged_df.sort_values(by='count', ascending=False)

merged_df_sorted

### **Apply Word2vec embedding, treating each patient's test history as a "sentence" and each labtest as a "word".**

Prepare the sentences:

In [None]:
query = """
  SELECT
    subject_id,
    STRING_AGG(CAST(itemid AS STRING) ORDER BY charttime ASC) AS itemid_sequence,  -- ASC for time order
    ARRAY_AGG(CAST(valuenum AS FLOAT64) ORDER BY charttime ASC) AS test_value_sequence,  -- ASC for time order
    ARRAY_LENGTH(ARRAY_AGG(itemid ORDER BY charttime ASC)) AS sequence_length,
    STRING_AGG(DISTINCT icd_code ORDER BY icd_code ASC) AS icd_codes,
    'mimic_iv' AS data_source,  -- New column with constant value 'mimic_iv'
    CASE
        WHEN REGEXP_CONTAINS(STRING_AGG(DISTINCT filtered.icd_code ORDER BY filtered.icd_code ASC), r'G30.*')
        OR REGEXP_CONTAINS(STRING_AGG(DISTINCT filtered.icd_code ORDER BY filtered.icd_code ASC), r'331.*')
        THEN 1
        ELSE 0
    END AS label_ad  -- New column that marks Alzheimer's Disease based on G30 ICD codes
  FROM (
    SELECT
        d.subject_id,
        l.itemid,
        l.valuenum,
        l.charttime,
        d.icd_code,
        ROW_NUMBER() OVER (PARTITION BY d.subject_id, l.itemid ORDER BY l.charttime ASC) AS rn
    FROM
        physionet-data.mimiciv_hosp.diagnoses_icd AS d
    JOIN
        physionet-data.mimiciv_hosp.labevents AS l
    ON
        d.subject_id = l.subject_id
    JOIN
        physionet-data.mimiciv_hosp.admissions AS a
    ON
        l.subject_id = a.subject_id
    JOIN
      physionet-data.mimiciv_hosp.d_labitems AS dlab
    ON
      l.itemid = dlab.itemid
    WHERE
      (
        d.icd_code LIKE 'G30%' OR
        d.icd_code LIKE 'F01%' OR
        d.icd_code LIKE 'F03%' OR
        d.icd_code LIKE 'F02%' OR
        d.icd_code LIKE 'R54%' OR

        d.icd_code IN ('G318', 'G310', 'G311', 'G318', 'G319', '3310', '3311', '3312', '3319', '2904', '2900', '2901', '2902', '2903', '2908', '2909', '797')
      )
      AND l.valuenum IS NOT NULL
      # AND l.charttime <= a.dischtime
      AND l.hadm_id < a.hadm_id  -- Only include hadm_id before first AD diagnosis
      AND (LOWER(l.flag) LIKE 'abnormal%' OR l.flag IS NULL)
      # AND (LOWER(l.flag) LIKE 'abnormal%')
      AND lower(dlab.fluid) LIKE '%blood%'

  ) AS filtered
  WHERE
    rn <= 2  -- only keep the first two entries for each itemid per patient
  GROUP BY
    subject_id
  ORDER BY
    subject_id;
"""




df_itemid_sequences_4 = bf.read_gbq(query)
print(len(df_itemid_sequences_4))
df_itemid_sequences_4.head(10)


In [None]:
# prompt: plot distribution of df_itemid_sequences['sequence_length']
df_labtest_sequences_pd_4 = df_itemid_sequences_4.to_pandas()
sns.displot(df_labtest_sequences_pd_4['sequence_length'])


In [None]:
print(df_labtest_sequences_pd_4['sequence_length'].max())
print(df_labtest_sequences_pd_4['sequence_length'].min())

In [None]:
print(df_labtest_sequences_pd_4.dtypes)

### **Similarly for MIMIC_III**

In [None]:
query = """
  SELECT
    subject_id,
    STRING_AGG(CAST(itemid AS STRING) ORDER BY charttime ASC) AS itemid_sequence,  -- ASC for time order
    ARRAY_AGG(CAST(valuenum AS FLOAT64) ORDER BY charttime ASC) AS test_value_sequence,  -- ASC for time order
    ARRAY_LENGTH(ARRAY_AGG(itemid ORDER BY charttime ASC)) AS sequence_length,
    STRING_AGG(DISTINCT icd9_code ORDER BY icd9_code ASC) AS icd_codes,
    'mimic_iii' AS data_source,
    CASE WHEN STRING_AGG(DISTINCT icd9_code ORDER BY icd9_code ASC) LIKE '331%' THEN 1 ELSE 0 END AS label_ad
  FROM (
    SELECT
      d.subject_id,
      l.itemid,
      l.valuenum,
      l.charttime,
      d.icd9_code,
      ROW_NUMBER() OVER (PARTITION BY d.subject_id, l.itemid ORDER BY l.charttime ASC) AS rn
    FROM
        `physionet-data.mimiciii_clinical.diagnoses_icd` AS d
    JOIN
        `physionet-data.mimiciii_clinical.labevents` AS l
    ON
        d.subject_id = l.subject_id
    JOIN
        `physionet-data.mimiciii_clinical.admissions` AS a
    ON
        l.subject_id = a.subject_id
    INNER JOIN
      `physionet-data.mimiciii_clinical.d_labitems` AS dlab
    ON
      l.itemid = dlab.itemid
    WHERE
      d.icd9_code LIKE '331%'  -- ICD-9 codes for Alzheimer's and related diseases
      AND l.valuenum IS NOT NULL
      # AND l.charttime <= a.dischtime
      AND l.hadm_id < a.hadm_id  -- Only include hadm_id before first AD diagnosis
      AND (LOWER(l.flag) LIKE 'abnormal%' OR l.flag IS NULL)
      # AND (LOWER(l.flag) LIKE 'abnormal%')
      AND lower(dlab.fluid) LIKE '%blood%'


  ) AS filtered
  WHERE
    rn <= 2  -- only keep the first two entries for each itemid per patient
  GROUP BY
    subject_id
  ORDER BY
    subject_id;
"""
df_itemid_sequences_3 = bf.read_gbq(query)
print(len(df_itemid_sequences_3))
df_itemid_sequences_3.head(10)

In [None]:
df_labtest_sequences_pd_3 = df_itemid_sequences_3.to_pandas()
sns.displot(df_labtest_sequences_pd_3['sequence_length'])

In [None]:
df_labtest_sequences_pd_3.head()

**Combine data fromo mimic iii and iv**

In [None]:
# df_labtest_sequences_pd = df_itemid_sequences.to_pandas()
df_labtest_sequences_pd = pd.concat([df_labtest_sequences_pd_4, df_labtest_sequences_pd_3], ignore_index=True)

import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats

plt.rcParams.update({'font.size': 8})

sns.displot(df_labtest_sequences_pd['sequence_length'], kde=True, color='#00BFFF')

skewness = df_labtest_sequences_pd['sequence_length'].skew()
print(f'Skewness: {skewness}')

plt.title(f'Sequence Length Distribution with Skewness = {skewness:.2f}')
plt.xlabel('Sequence Length')
plt.ylabel('Frequency')
plt.legend()
plt.show()


In [None]:
for i in range(len(df_labtest_sequences_pd)):
  seq = df_labtest_sequences_pd.loc[i, 'itemid_sequence']
  if '52285' in seq:
    print(seq)

**Split itemid sequences, save dataframe**

In [None]:
import ast


df_labtest_sequences_pd['itemid_sequence'] = df_labtest_sequences_pd['itemid_sequence'].apply(lambda x: x.split(',') if isinstance(x, str) else [])
# df_labtest_sequences_pd['test_value_sequence'] = df_labtest_sequences_pd['test_value_sequence'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
# df_labtest_sequences_pd.to_csv(file_path + 'df_labtest_sequences_pd.csv', index=False)

# patients_itemid_seqs = df_labtest_sequences_pd['itemid_sequence']
# patients_itemid_seqs[:1]
df_labtest_sequences_pd

In [None]:
# df_labtest_sequences_pd = pd.read_csv('df_labtest_sequences_pd.csv')
print(len(df_labtest_sequences_pd))
df_labtest_sequences_pd.head(1)

In [None]:
df_labtest_sequences_pd.dtypes

# **Apply scalers for each labtest**

**Train scalers for each labtest itemid based on global data**

In [None]:
# from sklearn.preprocessing import RobustScaler

# labtest_scalers = {}
# outlier_ranges = []

# unique_itemids = df_ad_patients_lab_results_pd['itemid'].unique()
# num_itemids = len(unique_itemids)

# for i in range(num_itemids):
#   itemid = unique_itemids[i]

#   item_values = df_ad_patients_lab_results_pd[df_ad_patients_lab_results_pd['itemid'] == itemid]['valuenum'].dropna()
#   # print(f"Processing Itemid: {itemid}, values: {item_values}")

#   if item_values.empty:
#     print(f"Itemid: {itemid} has no valid data, skipping.")
#     continue


#   if item_values.size == 0:
#     print(f"Itemid: {itemid} has no valid data, skipping.")
#     continue

#   sys.stdout.write(f"\rProcessing itemid {itemid} , {i + 1}/{num_itemids} ({(i + 1) / num_itemids * 100:.2f}%)")
#   sys.stdout.flush()

#   item_values_array = item_values.values.reshape(-1, 1)


#   scaler = RobustScaler()
#   scaler.fit(item_values_array)

#   labtest_scalers[itemid] = scaler



In [None]:
# print(len(labtest_scalers))
# print(labtest_scalers)

**save scaler**

In [None]:
# import pickle

# with open(file_path + 'labtest_scalers.pkl', 'wb') as file:
#     pickle.dump(labtest_scalers, file)

**Concatenate vectors of patients to matrices**


In [None]:
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from collections import defaultdict
import sys

df_all = copy.deepcopy(df_labtest_sequences_pd)

# df_all['itemid_sequence'] = df_all['itemid_sequence'].apply(lambda x: list(x) if isinstance(x, list) else x)
df_all['test_value_sequence'] = df_all['test_value_sequence'].apply(lambda x: list(x) if isinstance(x, list) else x)


df_all['matrix'] = None
df_all['scaled_value_sequence'] = None

missing_scaler_count = defaultdict(int)
missing_pca_vector_count = defaultdict(int)
outlier_count = defaultdict(int)

num_patients = len(df_all)

for i in range(len(df_all)):

  patient_matrix = []
  updated_itemid_sequence = []
  updated_test_value_sequence = []
  scaled_values = []

  sys.stdout.write(f"\rProcessing patient {i + 1}/{num_patients} ({(i + 1) / num_patients * 100:.2f}%)")
  sys.stdout.flush()

  itemid_sequence = df_all.loc[i, 'itemid_sequence']
  test_value_sequence = df_all.loc[i, 'test_value_sequence']
  # print(itemid_sequence)
  # print(test_value_sequence)

  # Iterate through itemid and test_value
  for itemid, test_value in zip(itemid_sequence, test_value_sequence):
    if itemid == '52285':
      print(itemid, test_value)
    #
    # print(itemid, test_value)
    # scaler = labtest_scalers.get(itemid)
    # if scaler is None:
    #     missing_scaler_count[itemid] += 1
    #     continue  # Skip if no scaler for the itemid

    # item_range = df_outlier_ranges[df_outlier_ranges['itemid'] == itemid]
    # # print(item_range)
    # lower_bound = item_range['lower_bound'].values[0] if not item_range.empty else None
    # upper_bound = item_range['upper_bound'].values[0] if not item_range.empty else None

    # if lower_bound is not None and upper_bound is not None:
    #     if test_value < lower_bound or test_value > upper_bound:
    #         outlier_count[itemid] += 1
    #         continue  # Skip if the test value is outside the outlier range
    # else:
    #     print(f"No outlier range found for itemid: {itemid}")
    #     continue


    pca_vector = itemid_to_vector.get(str(itemid))
    if pca_vector is None:
        missing_pca_vector_count[itemid] += 1
        continue  # Skip if no PCA vector for the itemid

    # Combine the 20-dim PCA vector with the test value to form a 21-dim vector
    test_value_array = np.array(test_value).reshape(-1, 1)

    # test_value_scaled = scaler.transform(test_value_array)
    # print(test_value, test_value_scaled)

    # combined_vector = np.append(pca_vector, test_value_scaled)
    combined_vector = np.append(pca_vector, test_value)

    patient_matrix.append(combined_vector)
    # print(combined_vector)

    # Update sequences
    updated_itemid_sequence.append(itemid)
    updated_test_value_sequence.append(test_value)
    # scaled_values.append(test_value_scaled.item())

  # Update the patient's matrix and sequences in the DataFrame
  df_all.at[i, 'matrix'] = patient_matrix
  df_all.at[i, 'itemid_sequence'] = updated_itemid_sequence
  df_all.at[i, 'test_value_sequence'] = updated_test_value_sequence
  df_all.at[i, 'sequence_length'] = len(updated_itemid_sequence)
  # df_all.at[i, 'scaled_value_sequence'] = scaled_values

print('\n')
print(df_all.head(1))
print(f"Number of patients: {len(df_all)}")
# print(f"Number of values with missing scaler: {missing_scaler_count}")
print(f"Number of values with missing pca vector: {missing_pca_vector_count}")
print(f"Number of values outside outlier range: {outlier_count}")

df_all = df_all[df_all['sequence_length'] >= 4].reset_index(drop=True)
df_all.head(5)


# **save matrix df**

In [None]:
# df_all['matrix'] = df_all['matrix'].apply(lambda x: str(x))
# df_all.to_csv(file_path + 'df_labtest_sequences_pd_pre_blood.csv', index=False)
# np.save(file_path + 'matrix_data_pre_blood.npy', df_all['matrix'].values)

In [None]:
# df_all = pd.read_csv(file_path + 'df_labtest_sequences_pd_pre_blood.csv')
# df_all['matrix'] = df_all['matrix'].apply(lambda x: np.array(eval(x, {'array': np.array})))


In [None]:

# df_labtest_sequences_pd = copy.deepcopy(df_all)
# df_all.head(1)

# **GRU**

In [None]:
df_labtest_GRU = copy.deepcopy(df_all)
df_labtest_GRU.head()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import copy
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score,
    roc_curve,
    auc
)
import matplotlib.pyplot as plt

class EarlyStopping:
    def __init__(self, patience=7):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.best_model = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.best_model = copy.deepcopy(model.state_dict())
        elif score < self.best_score:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model = copy.deepcopy(model.state_dict())
            self.counter = 0

class BalancedFocalLoss(nn.Module):
    def __init__(self, alpha=0.5, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        batch_size = targets.size(0)
        class_weights = torch.bincount(targets).float()
        class_weights = batch_size / (2 * class_weights)
        sample_weights = class_weights[targets]
        focal_loss = sample_weights * self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()

class ResidualBlock(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, hidden_size)
        )

    def forward(self, x):
        return x + self.block(x)

class EnhancedMedicalGRUDelta(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout_prob):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.feature_extraction = nn.Sequential(
            nn.Linear(input_size, hidden_size * 2),
            nn.LayerNorm(hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_size * 2, hidden_size),
            ResidualBlock(hidden_size)
        )

        self.gru = nn.GRU(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout_prob,
            bidirectional=True
        )

        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_size * 2,
            num_heads=16,
            dropout=dropout_prob
        )

        self.residual_conn = nn.Linear(input_size, hidden_size * 2)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_size, 2)
        )

    def forward(self, x):
        x_lengths = [len(seq) for seq in x]
        x_padded = nn.utils.rnn.pad_sequence(x, batch_first=True)
        batch_size = x_padded.size(0)

        x_reshaped = x_padded.view(-1, x_padded.size(-1))
        x_features = self.feature_extraction(x_reshaped)
        x_features = x_features.view(batch_size, -1, self.hidden_size)

        packed_input = nn.utils.rnn.pack_padded_sequence(
            x_features,
            x_lengths,
            batch_first=True,
            enforce_sorted=False
        )

        packed_output, _ = self.gru(packed_input)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)

        residual = self.residual_conn(x_padded)

        attention_output, _ = self.attention(
            output.transpose(0, 1),
            output.transpose(0, 1),
            output.transpose(0, 1)
        )
        attention_output = attention_output.transpose(0, 1)
        attention_output = attention_output + residual

        mask = torch.arange(output.size(1))[None, :] < torch.tensor(x_lengths)[:, None]
        mask = mask.to(output.device)
        masked_output = attention_output * mask.unsqueeze(-1)
        pooled = masked_output.sum(dim=1) / mask.sum(dim=1, keepdim=True)

        return self.classifier(pooled)

def prepare_medical_data(df):
    train_df, eval_df = train_test_split(
        df,
        test_size=0.1,
        random_state=42,
        stratify=df['label_ad']
    )

    true_eval_df = eval_df[eval_df['label_ad'] == 1]
    false_eval_df = eval_df[eval_df['label_ad'] == 0]
    downsampled_false_eval_df = false_eval_df.sample(n=len(true_eval_df), random_state=42)
    eval_df = pd.concat([true_eval_df, downsampled_false_eval_df])

    def process_sequence(matrix):
        if not isinstance(matrix, np.ndarray):
            matrix = np.array(matrix)
        return torch.tensor(matrix, dtype=torch.float32)

    X_train_tensor = [process_sequence(row['matrix']) for _, row in train_df.iterrows()]
    y_train_tensor = torch.tensor(train_df['label_ad'].values, dtype=torch.long)

    X_eval_tensor = [process_sequence(row['matrix']) for _, row in eval_df.iterrows()]
    y_eval_tensor = torch.tensor(eval_df['label_ad'].values, dtype=torch.long)

    return X_train_tensor, y_train_tensor, X_eval_tensor, y_eval_tensor

def train_model(model, X_train, y_train, X_eval, y_eval, config):
    criterion = BalancedFocalLoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=3,
        min_lr=1e-6
    )
    early_stopping = EarlyStopping(patience=config['patience'])

    train_losses, eval_losses = [], []
    train_accs, eval_accs = [], []

    for epoch in range(config['num_epochs']):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        permutation = torch.randperm(len(X_train))
        for i in range(0, len(X_train), config['batch_size']):
            optimizer.zero_grad()
            indices = permutation[i:i + config['batch_size']]
            batch_X = [X_train[idx] for idx in indices]
            batch_y = y_train[indices]

            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += batch_y.size(0)
            correct += predicted.eq(batch_y).sum().item()

        train_loss = train_loss / (len(X_train) // config['batch_size'])
        train_acc = correct / total

        model.eval()
        with torch.no_grad():
            eval_outputs = model(X_eval)
            eval_loss = criterion(eval_outputs, y_eval).item()
            _, predicted = eval_outputs.max(1)
            eval_acc = predicted.eq(y_eval).sum().item() / len(y_eval)

        train_losses.append(train_loss)
        eval_losses.append(eval_loss)
        train_accs.append(train_acc)
        eval_accs.append(eval_acc)

        current_lr = optimizer.param_groups[0]['lr']
        print(f'Epoch {epoch+1}/{config["num_epochs"]} - '
              f'Train Loss: {train_loss:.4f}, '
              f'Train Acc: {train_acc:.4f}, '
              f'Val Loss: {eval_loss:.4f}, '
              f'Val Acc: {eval_acc:.4f}, '
              f'LR: {current_lr:.6f}')

        scheduler.step(eval_loss)
        if early_stopping(eval_loss, model):
            print("Early stopping triggered")
            model.load_state_dict(early_stopping.best_model)
            break

    return model, train_losses, eval_losses, train_accs, eval_accs

config = {
    'hidden_size': 256,
    'num_layers': 3,
    'dropout_prob': 0.5,
    'learning_rate': 0.001,
    'weight_decay': 0.01,
    'batch_size': 32,
    'num_epochs': 10,
    'patience': 5
}

X_train, y_train, X_eval, y_eval = prepare_medical_data(df_labtest_GRU)
config['input_size'] = X_train[0].shape[1]

model = EnhancedMedicalGRUDelta(**{k: v for k, v in config.items()
                                  if k in ['input_size', 'hidden_size',
                                         'num_layers', 'dropout_prob']})

model, train_losses, eval_losses, train_accs, eval_accs = train_model(
    model, X_train, y_train, X_eval, y_eval, config
)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(eval_losses, label='Val Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(eval_accs, label='Val Acc')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.show()

model.eval()
with torch.no_grad():
    outputs = model(X_eval)
    _, predicted = torch.max(outputs, 1)

    accuracy = (predicted == y_eval).sum().item() / len(y_eval)
    print(f"Final Accuracy: {accuracy:.4f}")
    print(f"Confusion Matrix:\n{confusion_matrix(y_eval, predicted)}")
    print(f"Precision: {precision_score(y_eval, predicted):.4f}")
    print(f"Recall: {recall_score(y_eval, predicted):.4f}")
    print(f"F1 Score: {f1_score(y_eval, predicted):.4f}")

    probs = F.softmax(outputs, dim=1)[:, 1]
    fpr, tpr, _ = roc_curve(y_eval, probs)
    roc_auc = auc(fpr, tpr)
    print(f"AUC: {roc_auc:.4f}")

    plt.figure()
    plt.plot(fpr, tpr, label=f'ROC (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend()
    plt.show()