# Drug Recommendation using MICRON Model on MIMIC-III Dataset

This notebook demonstrates how to use the MICRON model for drug recommendation using the MIMIC-III dataset. The model is implemented using PyHealth 2.0 framework.

MICRON (Medication reCommendation using Recurrent ResIdual Networks) is designed to predict medications based on patient diagnoses and procedures.

## 1. Setup Google Drive and Environment

First, we'll mount Google Drive to access and save our data. We'll also install PyHealth from the forked repository and its dependencies. The notebook uses the latest version of PyHealth from https://github.com/naveenkcb/PyHealth.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install PyHealth from your forked repository
!pip install git+https://github.com/naveenkcb/PyHealth.git
# Install other required packages
!pip install torch scikit-learn pandas numpy tqdm

## 2. Import Required Libraries and Setup Configuration

Now we'll import the necessary libraries and set up our configuration for the MICRON model.

In [None]:
import os
import numpy as np
import pandas as pd
import torch
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.models import MICRON
from pyhealth.trainer import Trainer
from pyhealth.metrics import multilabel_metrics

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Configuration
MIMIC3_PATH = "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III\"  # Update this path to your MIMIC-III data location
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

## 3. Load and Process MIMIC-III Dataset

We'll load the MIMIC-III dataset using PyHealth's built-in dataset loader and prepare it for training. The dataset will be processed to include patient diagnoses, procedures, and medications.

In [None]:
# Load MIMIC-III dataset
dataset = MIMIC3Dataset(
    root=MIMIC3_PATH,
    tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
    code_mapping={"ICD9CM": "CCSCM", "ATC": "ATC"},
    refresh_cache=False,
)

# Define the dataset schema
input_schema = {
    "conditions": "sequence",
    "procedures": "sequence",
}
output_schema = {
    "drugs": "multilabel"
}

# Split dataset
train_dataset, val_dataset, test_dataset = dataset.split(["train", "val", "test"])

## 4. Initialize and Configure MICRON Model

Now we'll set up the MICRON model with appropriate hyperparameters for drug recommendation.

In [None]:
# Model hyperparameters
model_params = {
    "embedding_dim": 128,
    "hidden_dim": 128,
    "lam": 0.1  # Regularization parameter for reconstruction loss
}

# Initialize MICRON model
model = MICRON(
    dataset=train_dataset,
    **model_params
).to(DEVICE)

# Configure trainer
trainer = Trainer(
    model=model,
    device=DEVICE,
    metrics=[multilabel_metrics],
    train_loader_params={"batch_size": 32, "shuffle": True},
    val_loader_params={"batch_size": 32, "shuffle": False},
    test_loader_params={"batch_size": 32, "shuffle": False}
)

## 5. Train the Model

Let's train the MICRON model on our processed MIMIC-III dataset.

In [None]:
# Train the model
history = trainer.train(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    epochs=10,
    monitor="val_jaccard_macro"
)

# Save the trained model
torch.save(model.state_dict(), "/content/drive/MyDrive/micron_model.pt")

## 6. Evaluate Model Performance

Finally, let's evaluate our trained model on the test set and visualize the results.

In [None]:
# Set model to evaluation mode
model.eval()

# Initialize lists to store predictions and actual values
all_preds = []
all_labels = []

# Evaluate on test set
with torch.no_grad():
    for batch in test_dataloader:
        # Forward pass
        output = model(**batch)
        
        # Get predictions
        preds = output.logits.argmax(dim=-1)
        
        # Store predictions and labels
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch['labels'].cpu().numpy())

# Convert to numpy arrays
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Calculate metrics
accuracy = (all_preds == all_labels).mean()
print(f"Test Accuracy: {accuracy:.4f}")

# Calculate additional metrics
from sklearn.metrics import precision_recall_fscore_support
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

In [None]:
# Visualize results using a confusion matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# Create confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Plot training history
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

## Conclusion

We have successfully implemented and trained a MICRON model for drug recommendation using the MIMIC-III dataset. The model's performance can be evaluated using the metrics above:

1. Accuracy: Shows the overall correct prediction rate
2. Precision: Indicates how many of the predicted drugs were actually correct
3. Recall: Shows how many of the actual drugs were correctly predicted
4. F1 Score: The harmonic mean of precision and recall

The confusion matrix visualization helps us understand where the model performs well and where it might need improvement. The training loss plot shows how the model learned over time.

Next steps could include:
- Hyperparameter tuning to improve performance
- Testing with different model architectures
- Analyzing specific cases where the model performs well or poorly
- Incorporating additional patient features