# TCN Model Training on MIMIC-III Dataset

Train the TCN (Temporal Convolutional Networks) model for mortality prediction using the MIMIC-III dataset.

In [None]:
from pyhealth.datasets import MIMIC3Dataset

dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    dev=True,
)
dataset.stats()

## Set Mortality Prediction Task

We use the in-hospital mortality prediction task which predicts patient mortality based on diagnosis and procedure codes.

In [None]:
from pyhealth.tasks import MortalityPredictionMIMIC3

task = MortalityPredictionMIMIC3()
samples = dataset.set_task(task)

## Split Dataset

Split the dataset into train, validation, and test sets using patient-level splitting.

In [None]:
from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(
    samples, ratios=[0.7, 0.15, 0.15]
)

train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False)

## Initialize TCN Model

Create the TCN model with specified hyperparameters.

In [None]:
from pyhealth.models import TCN

model = TCN(
    dataset=samples,
    embedding_dim=128,
    num_channels=128,
    kernel_size=2,
    dropout=0.5,
)

## Train Model

Train the model using the PyHealth Trainer with relevant metrics for mortality prediction.

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["pr_auc", "roc_auc", "f1", "accuracy"],
)

trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=10,
    monitor="roc_auc",
)

## Evaluate on Test Set

Evaluate the trained model on the test set and print the results.

In [None]:
results = trainer.evaluate(test_loader)

print("Test Set Results:")
print(f"  ROC-AUC: {results['roc_auc']:.4f}")
print(f"  PR-AUC: {results['pr_auc']:.4f}")
print(f"  F1 Score: {results['f1']:.4f}")
print(f"  Accuracy: {results['accuracy']:.4f}")

## Custom TCN Configuration

You can customize the TCN architecture by specifying different parameters:

In [None]:
# Create TCN with custom architecture
custom_model = TCN(
    dataset=samples,
    embedding_dim=64,
    num_channels=[64, 128, 256],  # List for manual layer specification
    kernel_size=3,
    dropout=0.3,
)

print("Custom TCN architecture:")
print(f"Embedding dim: {custom_model.embedding_dim}")
print(f"Output channels: {custom_model.num_channels}")
print(f"Number of features: {len(custom_model.feature_keys)}")