## Step 1: Load the MIMIC-III Dataset

We'll load the MIMIC-III dataset using PyHealth 2's new `MIMIC3Dataset` class. We need to specify:
- `root`: Path to the MIMIC-III data directory
- `tables`: Clinical tables to load (diagnoses, procedures, prescriptions)
- `dev`: Set to `True` for development/testing with a small subset of data

In [1]:
from pyhealth.datasets import MIMIC3Dataset

# Load MIMIC-III dataset
dataset = MIMIC3Dataset(
    root=r"F:\coding_projects\pyhealth\downloads\mimic-iii-demo",
    tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
    dev=True,  # Set to False for full dataset
)

dataset.stats()

No config path provided, using default config
Initializing mimic3 dataset from F:\coding_projects\pyhealth\downloads\mimic-iii-demo (dev mode: True)
Initializing mimic3 dataset from F:\coding_projects\pyhealth\downloads\mimic-iii-demo (dev mode: True)
Scanning table: patients from F:\coding_projects\pyhealth\downloads\mimic-iii-demo\PATIENTS.csv.gz
Original path does not exist. Using alternative: F:\coding_projects\pyhealth\downloads\mimic-iii-demo\PATIENTS.csv
Scanning table: admissions from F:\coding_projects\pyhealth\downloads\mimic-iii-demo\ADMISSIONS.csv.gz
Original path does not exist. Using alternative: F:\coding_projects\pyhealth\downloads\mimic-iii-demo\ADMISSIONS.csv
Scanning table: icustays from F:\coding_projects\pyhealth\downloads\mimic-iii-demo\ICUSTAYS.csv.gz
Original path does not exist. Using alternative: F:\coding_projects\pyhealth\downloads\mimic-iii-demo\ICUSTAYS.csv
Scanning table: patients from F:\coding_projects\pyhealth\downloads\mimic-iii-demo\PATIENTS.csv.gz
O



## Step 2: Define the Mortality Prediction Task

PyHealth 2 uses task classes to define how to extract samples from the raw EHR data. The `MortalityPredictionMIMIC3` task:
- Extracts diagnosis codes (ICD-9), procedure codes, and drug information from each visit
- Creates binary labels based on mortality in the next visit
- Filters out visits without sufficient clinical codes

You can optionally specify a `cache_dir` to save processed samples for faster future loading.

In [2]:
from pyhealth.tasks import MortalityPredictionMIMIC3

# Define the mortality prediction task
task = MortalityPredictionMIMIC3()

# Apply the task to generate samples
samples = dataset.set_task(
    task=task,
    cache_dir="./cache_mortality_mimic3"  # Cache processed samples
)

print(f"Generated {len(samples)} samples")
print(f"\nInput schema: {samples.input_schema}")
print(f"Output schema: {samples.output_schema}")

Setting task MortalityPredictionMIMIC3 for mimic3 base dataset...
Loading cached samples from cache_mortality_mimic3\MortalityPredictionMIMIC3.parquet
Loading cached samples from cache_mortality_mimic3\MortalityPredictionMIMIC3.parquet
Loaded 26 cached samples
Label mortality vocab: {0: 0, 1: 1}
Loaded 26 cached samples
Label mortality vocab: {0: 0, 1: 1}


Processing samples: 100%|██████████| 26/26 [00:00<00:00, 5198.89it/s]

Generated 26 samples for task MortalityPredictionMIMIC3
Generated 26 samples

Input schema: {'conditions': 'sequence', 'procedures': 'sequence', 'drugs': 'sequence'}
Output schema: {'mortality': 'binary'}
Generated 26 samples

Input schema: {'conditions': 'sequence', 'procedures': 'sequence', 'drugs': 'sequence'}
Output schema: {'mortality': 'binary'}





## Step 3: Explore a Sample

Let's examine what a single sample looks like. Each sample represents one hospital visit with:
- **conditions**: List of ICD-9 diagnosis codes
- **procedures**: List of ICD-9 procedure codes  
- **drugs**: List of drug names
- **mortality**: Binary label (0 = survived, 1 = deceased in next visit)

In [3]:
# Display a sample
print("Sample structure:")
print(samples[0])

# Show statistics
print("\n" + "="*50)
print("Dataset Statistics:")
print("="*50)

# Count unique codes
all_conditions = set()
all_procedures = set()
all_drugs = set()
mortality_count = 0
for sample in samples:
    all_conditions.update(sample['conditions'])
    all_procedures.update(sample['procedures'])
    all_drugs.update(sample['drugs'])
    mortality_count += float(sample['mortality'])

print(f"Unique diagnosis codes: {len(all_conditions)}")
print(f"Unique procedure codes: {len(all_procedures)}")
print(f"Unique drugs: {len(all_drugs)}")
print(f"\nMortality rate: {mortality_count/len(samples)*100:.2f}%")
print(f"Positive samples: {mortality_count}")
print(f"Negative samples: {len(samples) - mortality_count}")

Sample structure:
{'hadm_id': '102203', 'patient_id': '42135', 'conditions': tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]), 'procedures': tensor([1, 2]), 'drugs': tensor([ 1,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,  3, 14, 15, 16,
        17, 16, 17,  1, 17, 18,  4, 19, 18, 19, 20,  1, 18, 21, 22, 18, 18, 21,
        14,  1,  1, 21, 18, 18, 14, 23, 23,  1, 24, 25, 19,  3,  3, 26,  2,  3,
        23,  4, 13, 12, 11,  9,  8,  6, 23, 10, 14, 27, 28, 28, 28, 29,  1, 30,
        31, 32, 33,  3,  3,  5,  7, 23,  9, 11,  4,  6,  8, 12, 13, 15, 18, 18,
        14, 34, 18, 18, 20,  3, 14]), 'mortality': tensor([1.])}

Dataset Statistics:
Unique diagnosis codes: 404
Unique procedure codes: 101
Unique drugs: 2647

Mortality rate: 19.23%
Positive samples: 5.0
Negative samples: 21.0


## Step 4: Split the Dataset

We split the data into training, validation, and test sets using a 70-10-20 split.

**Note:** We use `split_by_sample` which randomly splits samples. For time-series tasks, you might want to use temporal splits to avoid data leakage.

In [4]:
from pyhealth.datasets import split_by_sample

# Split dataset: 70% train, 10% validation, 20% test
train_dataset, val_dataset, test_dataset = split_by_sample(
    dataset=samples, 
    ratios=[0.7, 0.1, 0.2]
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Training samples: 18
Validation samples: 2
Test samples: 6


## Step 5: Create Data Loaders

Data loaders batch the samples and handle data feeding during training and evaluation.

In [5]:
from pyhealth.datasets import get_dataloader

# Create data loaders
train_dataloader = get_dataloader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=2, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=2, shuffle=False)

print(f"Training batches: {len(train_dataloader)}")
print(f"Validation batches: {len(val_dataloader)}")
print(f"Test batches: {len(test_dataloader)}")

Training batches: 9
Validation batches: 1
Test batches: 3

Validation batches: 1
Test batches: 3


## Step 6: Initialize the AdaCare Model

The AdaCare model in PyHealth 2 automatically handles different feature types:
- **Sequence features** (like diagnosis/procedure/drug codes) are embedded using learned embeddings
- **Multiple feature keys** are processed by separate AdaCare layers
- The model provides interpretability through attention weights

### Key Parameters:
- `embedding_dim`: Dimension of code embeddings (default: 128)
- `hidden_dim`: Hidden dimension of GRU layers (default: 128)
- `kernel_size`: Kernel size for causal convolution (default: 2)
- `kernel_num`: Number of convolution kernels (default: 64)
- `dropout`: Dropout rate for regularization (default: 0.5)

In [6]:
from pyhealth.models import AdaCare

# Initialize AdaCare model
model = AdaCare(
    dataset=samples,
    embedding_dim=128,
    hidden_dim=128,
)

print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")
print(f"\nModel architecture:")
print(model)

  from .autonotebook import tqdm as notebook_tqdm


Model initialized with 816369 parameters

Model architecture:
AdaCare(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(191, 128, padding_idx=0)
    (procedures): Embedding(46, 128, padding_idx=0)
    (drugs): Embedding(298, 128, padding_idx=0)
  ))
  (adacare): ModuleDict(
    (conditions): AdaCareLayer(
      (nn_conv1): CausalConv1d(128, 64, kernel_size=(2,), stride=(1,), padding=(1,))
      (nn_conv3): CausalConv1d(128, 64, kernel_size=(2,), stride=(1,), padding=(3,), dilation=(3,))
      (nn_conv5): CausalConv1d(128, 64, kernel_size=(2,), stride=(1,), padding=(5,), dilation=(5,))
      (nn_convse): Recalibration(
        (avg_pool): AdaptiveAvgPool1d(output_size=1)
        (nn_c): Linear(in_features=192, out_features=48, bias=True)
        (nn_rescale): Linear(in_features=48, out_features=192, bias=True)
        (sparsemax): Sparsemax()
        (softmax): Softmax(dim=1)
      )
      (nn_inputse): Recalibration(
        (avg_pool): Adapt

## Step 7: Train the Model

We use PyHealth's `Trainer` class which handles:
- Training loop with automatic batching
- Validation during training
- Model checkpointing based on validation metrics
- Early stopping

We monitor the **ROC-AUC** score on the validation set.

In [7]:
from pyhealth.trainer import Trainer

# Initialize trainer
trainer = Trainer(
    model=model,
    metrics=["roc_auc", "pr_auc", "accuracy", "f1"]  # Track multiple metrics
)

# Train the model
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=50,
    monitor="roc_auc",  # Use ROC-AUC for model selection
    optimizer_params={"lr": 1e-3},  # Learning rate
)

AdaCare(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(
    (conditions): Embedding(191, 128, padding_idx=0)
    (procedures): Embedding(46, 128, padding_idx=0)
    (drugs): Embedding(298, 128, padding_idx=0)
  ))
  (adacare): ModuleDict(
    (conditions): AdaCareLayer(
      (nn_conv1): CausalConv1d(128, 64, kernel_size=(2,), stride=(1,), padding=(1,))
      (nn_conv3): CausalConv1d(128, 64, kernel_size=(2,), stride=(1,), padding=(3,), dilation=(3,))
      (nn_conv5): CausalConv1d(128, 64, kernel_size=(2,), stride=(1,), padding=(5,), dilation=(5,))
      (nn_convse): Recalibration(
        (avg_pool): AdaptiveAvgPool1d(output_size=1)
        (nn_c): Linear(in_features=192, out_features=48, bias=True)
        (nn_rescale): Linear(in_features=48, out_features=192, bias=True)
        (sparsemax): Sparsemax()
        (softmax): Softmax(dim=1)
      )
      (nn_inputse): Recalibration(
        (avg_pool): AdaptiveAvgPool1d(output_size=1)
        (nn_c): Linear(in_features

Epoch 0 / 50: 100%|██████████| 9/9 [00:00<00:00, 20.61it/s]

--- Train epoch-0, step-9 ---
loss: 0.4599
loss: 0.4599



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 99.97it/s]

--- Eval epoch-0, step-9 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 0.7692
New best roc_auc score (1.0000) at epoch-0, step-9
f1: 0.0000
loss: 0.7692
New best roc_auc score (1.0000) at epoch-0, step-9





Epoch 1 / 50: 100%|██████████| 9/9 [00:00<00:00, 24.32it/s]

--- Train epoch-1, step-18 ---
loss: 0.1897
loss: 0.1897



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 83.32it/s]

--- Eval epoch-1, step-18 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 0.9792

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 0.9792




Epoch 2 / 50: 100%|██████████| 9/9 [00:00<00:00, 26.94it/s]

--- Train epoch-2, step-27 ---
loss: 0.0799
loss: 0.0799



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 99.97it/s]

--- Eval epoch-2, step-27 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.1449

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.1449




Epoch 3 / 50: 100%|██████████| 9/9 [00:00<00:00, 27.81it/s]

--- Train epoch-3, step-36 ---
loss: 0.0616
loss: 0.0616



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 73.95it/s]

--- Eval epoch-3, step-36 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.2328

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.2328




Epoch 4 / 50: 100%|██████████| 9/9 [00:00<00:00, 25.78it/s]

--- Train epoch-4, step-45 ---
loss: 0.0229
loss: 0.0229



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 111.11it/s]

--- Eval epoch-4, step-45 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.2952

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.2952




Epoch 5 / 50: 100%|██████████| 9/9 [00:00<00:00, 34.61it/s]

--- Train epoch-5, step-54 ---
loss: 0.0136
loss: 0.0136



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 90.88it/s]
Evaluation: 100%|██████████| 1/1 [00:00<00:00, 90.88it/s]


--- Eval epoch-5, step-54 ---
roc_auc: 1.0000
pr_auc: 1.0000
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.3583

accuracy: 0.5000
f1: 0.0000
loss: 1.3583



Epoch 6 / 50: 100%|██████████| 9/9 [00:00<00:00, 20.59it/s]

--- Train epoch-6, step-63 ---
loss: 0.0061
loss: 0.0061



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 76.90it/s]

--- Eval epoch-6, step-63 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.4303

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.4303




Epoch 7 / 50: 100%|██████████| 9/9 [00:00<00:00, 29.12it/s]

--- Train epoch-7, step-72 ---
loss: 0.0052
loss: 0.0052



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 90.88it/s]

--- Eval epoch-7, step-72 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.4894

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.4894




Epoch 8 / 50: 100%|██████████| 9/9 [00:00<00:00, 27.31it/s]

--- Train epoch-8, step-81 ---
loss: 0.0035
loss: 0.0035



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 73.95it/s]

--- Eval epoch-8, step-81 ---
roc_auc: 1.0000
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.5430

pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.5430




Epoch 9 / 50: 100%|██████████| 9/9 [00:00<00:00, 22.38it/s]

--- Train epoch-9, step-90 ---
loss: 0.0016
loss: 0.0016



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 99.99it/s]

--- Eval epoch-9, step-90 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.5915

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.5915




Epoch 10 / 50: 100%|██████████| 9/9 [00:00<00:00, 26.26it/s]

--- Train epoch-10, step-99 ---
loss: 0.0036
loss: 0.0036



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 62.50it/s]

--- Eval epoch-10, step-99 ---
roc_auc: 1.0000
pr_auc: 1.0000
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.6252

accuracy: 0.5000
f1: 0.0000
loss: 1.6252




Epoch 11 / 50: 100%|██████████| 9/9 [00:00<00:00, 30.76it/s]

--- Train epoch-11, step-108 ---
loss: 0.0018
loss: 0.0018



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 142.86it/s]

--- Eval epoch-11, step-108 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.6449

f1: 0.0000
loss: 1.6449




Epoch 12 / 50: 100%|██████████| 9/9 [00:00<00:00, 21.60it/s]

--- Train epoch-12, step-117 ---
loss: 0.0011
loss: 0.0011



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 90.91it/s]

--- Eval epoch-12, step-117 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.6705

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.6705




Epoch 13 / 50: 100%|██████████| 9/9 [00:00<00:00, 29.55it/s]

--- Train epoch-13, step-126 ---
loss: 0.0011
loss: 0.0011



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 99.88it/s]

--- Eval epoch-13, step-126 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.6970

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.6970




Epoch 14 / 50: 100%|██████████| 9/9 [00:00<00:00, 26.50it/s]

--- Train epoch-14, step-135 ---
loss: 0.0008
loss: 0.0008



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 83.35it/s]

--- Eval epoch-14, step-135 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.7247

loss: 1.7247




Epoch 15 / 50: 100%|██████████| 9/9 [00:00<00:00, 21.35it/s]

--- Train epoch-15, step-144 ---
loss: 0.0011
loss: 0.0011



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 100.00it/s]

--- Eval epoch-15, step-144 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.7483

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.7483




Epoch 16 / 50: 100%|██████████| 9/9 [00:00<00:00, 30.92it/s]

--- Train epoch-16, step-153 ---
loss: 0.0007
loss: 0.0007



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 125.02it/s]

--- Eval epoch-16, step-153 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.7723

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.7723




Epoch 17 / 50: 100%|██████████| 9/9 [00:00<00:00, 24.45it/s]

--- Train epoch-17, step-162 ---
loss: 0.0007
loss: 0.0007



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 111.14it/s]

--- Eval epoch-17, step-162 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.7988

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.7988




Epoch 18 / 50: 100%|██████████| 9/9 [00:00<00:00, 31.41it/s]

--- Train epoch-18, step-171 ---
loss: 0.0008
loss: 0.0008



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 111.10it/s]

--- Eval epoch-18, step-171 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.8302

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.8302




Epoch 19 / 50: 100%|██████████| 9/9 [00:00<00:00, 22.84it/s]

--- Train epoch-19, step-180 ---
loss: 0.0007
loss: 0.0007



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 111.08it/s]

--- Eval epoch-19, step-180 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.8568

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.8568




Epoch 20 / 50: 100%|██████████| 9/9 [00:00<00:00, 29.55it/s]

--- Train epoch-20, step-189 ---
loss: 0.0006
loss: 0.0006



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 76.87it/s]

--- Eval epoch-20, step-189 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.8795

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.8795




Epoch 21 / 50: 100%|██████████| 9/9 [00:00<00:00, 23.40it/s]

--- Train epoch-21, step-198 ---
loss: 0.0008
loss: 0.0008



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 111.10it/s]

--- Eval epoch-21, step-198 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.8982

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.8982




Epoch 22 / 50: 100%|██████████| 9/9 [00:00<00:00, 25.74it/s]

--- Train epoch-22, step-207 ---
loss: 0.0005
loss: 0.0005



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 106.27it/s]

--- Eval epoch-22, step-207 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.9166

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.9166




Epoch 23 / 50: 100%|██████████| 9/9 [00:00<00:00, 22.00it/s]

--- Train epoch-23, step-216 ---
loss: 0.0004
loss: 0.0004



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 99.97it/s]

--- Eval epoch-23, step-216 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.9363

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.9363




Epoch 24 / 50: 100%|██████████| 9/9 [00:00<00:00, 20.83it/s]

--- Train epoch-24, step-225 ---
loss: 0.0005
loss: 0.0005



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 90.91it/s]

--- Eval epoch-24, step-225 ---
roc_auc: 1.0000
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.9573

pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.9573




Epoch 25 / 50: 100%|██████████| 9/9 [00:00<00:00, 21.02it/s]

--- Train epoch-25, step-234 ---
loss: 0.0006
loss: 0.0006



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 124.89it/s]

--- Eval epoch-25, step-234 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.9779

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.9779




Epoch 26 / 50: 100%|██████████| 9/9 [00:00<00:00, 22.72it/s]

--- Train epoch-26, step-243 ---
loss: 0.0003
loss: 0.0003



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 111.12it/s]

--- Eval epoch-26, step-243 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.9958

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 1.9958




Epoch 27 / 50: 100%|██████████| 9/9 [00:00<00:00, 23.13it/s]

--- Train epoch-27, step-252 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 100.02it/s]

--- Eval epoch-27, step-252 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.0142

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.0142




Epoch 28 / 50: 100%|██████████| 9/9 [00:00<00:00, 22.00it/s]

--- Train epoch-28, step-261 ---
loss: 0.0003
loss: 0.0003



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 111.02it/s]

--- Eval epoch-28, step-261 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.0343

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.0343




Epoch 29 / 50: 100%|██████████| 9/9 [00:00<00:00, 19.10it/s]

--- Train epoch-29, step-270 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 90.91it/s]

--- Eval epoch-29, step-270 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.0511

loss: 2.0511




Epoch 30 / 50: 100%|██████████| 9/9 [00:00<00:00, 19.60it/s]

--- Train epoch-30, step-279 ---
loss: 0.0003
loss: 0.0003



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 100.01it/s]

--- Eval epoch-30, step-279 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.0675

loss: 2.0675




Epoch 31 / 50: 100%|██████████| 9/9 [00:00<00:00, 23.19it/s]

--- Train epoch-31, step-288 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 76.90it/s]

--- Eval epoch-31, step-288 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.0865

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.0865




Epoch 32 / 50: 100%|██████████| 9/9 [00:00<00:00, 23.07it/s]

--- Train epoch-32, step-297 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 111.12it/s]

--- Eval epoch-32, step-297 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1026
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1026





Epoch 33 / 50: 100%|██████████| 9/9 [00:00<00:00, 15.74it/s]

--- Train epoch-33, step-306 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 86.83it/s]

--- Eval epoch-33, step-306 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1172

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1172




Epoch 34 / 50: 100%|██████████| 9/9 [00:00<00:00, 21.68it/s]

--- Train epoch-34, step-315 ---
loss: 0.0001
loss: 0.0001



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 90.91it/s]

--- Eval epoch-34, step-315 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1304

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1304




Epoch 35 / 50: 100%|██████████| 9/9 [00:00<00:00, 22.72it/s]

--- Train epoch-35, step-324 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 100.02it/s]

--- Eval epoch-35, step-324 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1443

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1443




Epoch 36 / 50: 100%|██████████| 9/9 [00:00<00:00, 28.98it/s]

--- Train epoch-36, step-333 ---
loss: 0.0001
loss: 0.0001



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 90.92it/s]

--- Eval epoch-36, step-333 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1572

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1572




Epoch 37 / 50: 100%|██████████| 9/9 [00:00<00:00, 20.42it/s]

--- Train epoch-37, step-342 ---
loss: 0.0004
loss: 0.0004



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 76.92it/s]

--- Eval epoch-37, step-342 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1685

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1685




Epoch 38 / 50: 100%|██████████| 9/9 [00:00<00:00, 24.45it/s]

--- Train epoch-38, step-351 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 90.87it/s]

--- Eval epoch-38, step-351 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1819

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1819




Epoch 39 / 50: 100%|██████████| 9/9 [00:00<00:00, 22.38it/s]

--- Train epoch-39, step-360 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 76.82it/s]
Evaluation: 100%|██████████| 1/1 [00:00<00:00, 76.82it/s]


--- Eval epoch-39, step-360 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1938

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.1938



Epoch 40 / 50: 100%|██████████| 9/9 [00:00<00:00, 29.85it/s]

--- Train epoch-40, step-369 ---
loss: 0.0001
loss: 0.0001



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 90.89it/s]

--- Eval epoch-40, step-369 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2094

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2094




Epoch 41 / 50: 100%|██████████| 9/9 [00:00<00:00, 25.85it/s]

--- Train epoch-41, step-378 ---
loss: 0.0001
loss: 0.0001



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 16.25it/s]

--- Eval epoch-41, step-378 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2243

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2243




Epoch 42 / 50: 100%|██████████| 9/9 [00:00<00:00, 22.95it/s]

--- Train epoch-42, step-387 ---
loss: 0.0001
loss: 0.0001



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 125.04it/s]

--- Eval epoch-42, step-387 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2387

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2387




Epoch 43 / 50: 100%|██████████| 9/9 [00:00<00:00, 26.74it/s]

--- Train epoch-43, step-396 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 111.05it/s]

--- Eval epoch-43, step-396 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2484

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2484




Epoch 44 / 50: 100%|██████████| 9/9 [00:00<00:00, 21.53it/s]

--- Train epoch-44, step-405 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 99.98it/s]

--- Eval epoch-44, step-405 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2582

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2582




Epoch 45 / 50: 100%|██████████| 9/9 [00:00<00:00, 18.65it/s]

--- Train epoch-45, step-414 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 28.96it/s]

--- Eval epoch-45, step-414 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2729

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2729




Epoch 46 / 50: 100%|██████████| 9/9 [00:00<00:00, 21.32it/s]

--- Train epoch-46, step-423 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 71.39it/s]

--- Eval epoch-46, step-423 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2867

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2867




Epoch 47 / 50: 100%|██████████| 9/9 [00:00<00:00, 26.19it/s]

--- Train epoch-47, step-432 ---
loss: 0.0002
loss: 0.0002



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 90.90it/s]

--- Eval epoch-47, step-432 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2971

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.2971




Epoch 48 / 50: 100%|██████████| 9/9 [00:00<00:00, 26.00it/s]

--- Train epoch-48, step-441 ---
loss: 0.0001
loss: 0.0001



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 76.92it/s]

--- Eval epoch-48, step-441 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.3073

roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.3073




Epoch 49 / 50: 100%|██████████| 9/9 [00:00<00:00, 26.86it/s]

--- Train epoch-49, step-450 ---
loss: 0.0001
loss: 0.0001



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 111.12it/s]

--- Eval epoch-49, step-450 ---
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.3182
Loaded best model
roc_auc: 1.0000
pr_auc: 1.0000
accuracy: 0.5000
f1: 0.0000
loss: 2.3182
Loaded best model





## Step 8: Evaluate on Test Set

After training, we evaluate the model on the held-out test set to measure its generalization performance.

In [8]:
# Evaluate on test set
test_results = trainer.evaluate(test_dataloader)

print("\n" + "="*50)
print("Test Set Performance")
print("="*50)
for metric, value in test_results.items():
    print(f"{metric}: {value:.4f}")

Evaluation: 100%|██████████| 3/3 [00:00<00:00, 50.40it/s]


Test Set Performance
roc_auc: 0.1250
pr_auc: 0.2917
accuracy: 0.6667
f1: 0.0000
loss: 0.7619





## Step 9: Model Interpretability (Optional)

One of AdaCare's key features is interpretability. The model provides attention weights that indicate which features are most important for predictions.

Let's examine the feature importance for a few test samples.

In [9]:
import torch
import numpy as np

# Get a batch from test set
test_batch = next(iter(test_dataloader))

# Run model in evaluation mode
model.eval()
with torch.no_grad():
    output = model(**test_batch)

# Extract interpretability information
if 'feature_importance' in output:
    print("Feature importance available!")
    print(f"Shape: {output['feature_importance']}")
    
    # Display importance for first sample
    print("\nFeature importance for first sample:")
    print("This shows which clinical features the model focuses on.")
else:
    print("Feature importance not available in model output.")

# Display predictions
print("\n" + "="*50)
print("Sample Predictions:")
print("="*50)
predictions = output['y_prob'].cpu().numpy()
true_labels = output['y_true'].cpu().numpy()

for i in range(min(5, len(predictions))):
    pred = predictions[i][0]
    true = int(true_labels[i][0])
    print(f"Sample {i+1}: Predicted={pred:.3f}, True={true}, Prediction={'Mortality' if pred > 0.5 else 'Survival'}")

Feature importance available!
Shape: [tensor([[[0.3808, 0.5472, 0.4767,  ..., 0.5502, 0.5235, 0.5840],
         [0.5611, 0.5604, 0.4398,  ..., 0.5486, 0.5007, 0.5026],
         [0.4841, 0.5169, 0.4600,  ..., 0.4207, 0.4374, 0.5361],
         ...,
         [0.4293, 0.5795, 0.3993,  ..., 0.5154, 0.5159, 0.5424],
         [0.4928, 0.4697, 0.4777,  ..., 0.5893, 0.4837, 0.5156],
         [0.5197, 0.5583, 0.4666,  ..., 0.4400, 0.4263, 0.5351]],

        [[0.5389, 0.5437, 0.4693,  ..., 0.5332, 0.4713, 0.5942],
         [0.4667, 0.5731, 0.5536,  ..., 0.4723, 0.4179, 0.6412],
         [0.4448, 0.5027, 0.4800,  ..., 0.4831, 0.4847, 0.5509],
         ...,
         [0.5170, 0.4971, 0.4677,  ..., 0.5119, 0.4715, 0.5198],
         [0.5170, 0.4971, 0.4677,  ..., 0.5119, 0.4715, 0.5198],
         [0.5170, 0.4971, 0.4677,  ..., 0.5119, 0.4715, 0.5198]]]), tensor([[[0.4630, 0.5156, 0.5087,  ..., 0.4908, 0.4307, 0.5002],
         [0.5030, 0.5333, 0.5260,  ..., 0.4907, 0.4975, 0.5351],
         [0.5030, 0