In [2]:
#Import external libraries
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

In [None]:
# Add project root to Python path
# Assuming the notebook is in a subdirectory of the project root (e.g., notebooks/)
try:
    # Get the absolute path of the current notebook (if running in an environment that supports it)
    notebook_path = os.path.abspath(__file__) # Fails in some interactive environments
except NameError:
    # Fallback for interactive environments like Jupyter
    notebook_path = os.path.abspath('.')

project_root = os.path.dirname(notebook_path) # If notebook is in root
if os.path.basename(project_root) == 'notebooks': # Check if we are in the 'notebooks' subdir
    project_root = os.path.dirname(project_root) # Go one level up to the actual project root

if project_root not in sys.path:
    sys.path.insert(0, project_root)
    print(f"Project root added to path: {project_root}")
else:
    print(f"Project root already in path: {project_root}")

In [30]:
# Import from our modules
from data.ClinicalNoteDataset import ClinicalNoteDataset
from model_training.DiagnosisDateRelationModel import DiagnosisDateRelationModel
from data.synthetic_data_generator import generate_dataset #this is used in train.py
from utils.extraction_utils import extract_entities
from utils.training_utils import train_model, evaluate_model, plot_training_curves, load_and_prepare_data, preprocess_note_for_prediction, create_prediction_dataset, predict_relationships, relate_diagnosis_to_date_rule_based
from data.sample_note import CLINICAL_NOTE
from config import *
from model_training.training_config import *
from model_training.Vocabulary import Vocabulary

In [None]:
print(f"Using device: {DEVICE}")

Step 1 - Generate or Load Dataset

In [6]:
#Change dataset path to the one in the project root
DATASET_PATH = os.path.join(project_root, 'data/synthetic_data.json')
VOCAB_PATH = os.path.join(project_root, 'model_training/vocab.pt')

In [None]:
# Step 1: Generate or load dataset
if os.path.exists(DATASET_PATH):
    print(f"Loading existing dataset from {DATASET_PATH}")
    with open(DATASET_PATH, 'r') as f:
        dataset = json.load(f)

print(f"Dataset contains {len(dataset)} clinical notes")

In [None]:
#Look at dataset
dataset

In [9]:
#Generate new dataset
dataset = generate_dataset(num_notes=NUM_SAMPLES) #generate_dataset is from data/synthetic_clinical_notes.py

In [None]:
#Look at it
dataset

Step 2 - Prepare data for model training 

In [None]:
# Step 2: Prepare data for model training
# within load_and_prepare_data we do the diagnoses extraction and date extraction
features, labels, vocab_instance = load_and_prepare_data(DATASET_PATH, MAX_DISTANCE, Vocabulary)

In [None]:
# Check if vocab was built successfully
if vocab_instance:
    print(f"Successfully built vocabulary with {vocab_instance.n_words} words.")
else:
    print("Error: Vocabulary building failed.")
    # Handle error, maybe exit

In [None]:
print(f"Loaded {len(features)} examples with vocabulary size {vocab_instance.n_words}")

In [None]:
features, labels

In [15]:
# Save vocabulary for later use in prediction
#torch.save(vocab_instance, VOCAB_PATH)
#print(f"Saved vocabulary to {VOCAB_PATH}")

In [None]:
# Check class balance of labels
if len(labels) > 0:
    positive = sum(labels)
    negative = len(labels) - positive
    print(f"Class distribution: {positive} positive examples ({positive/len(labels)*100:.1f}%), {negative} negative examples ({negative/len(labels)*100:.1f}%)")
else:
    print("Warning: No examples found in the dataset!")
    exit(1)  # Exit with error code

Step 3 - Create Train / Val / Test Datasets

In [17]:
# Step 3: Train-validation-test split
#Train and test split
train_features, test_features, train_labels, test_labels = train_test_split(
    features, labels, test_size=0.2, random_state=42)

In [18]:
#Train and val split
train_features, val_features, train_labels, val_labels = train_test_split(
    train_features, train_labels, test_size=0.25, random_state=42)

In [None]:
print(f"Train: {len(train_features)}, Validation: {len(val_features)}, Test: {len(test_features)}")

In [20]:
# Create datasets
train_dataset = ClinicalNoteDataset(train_features, train_labels, vocab_instance, MAX_CONTEXT_LEN, MAX_DISTANCE)
val_dataset = ClinicalNoteDataset(val_features, val_labels, vocab_instance, MAX_CONTEXT_LEN, MAX_DISTANCE)
test_dataset = ClinicalNoteDataset(test_features, test_labels, vocab_instance, MAX_CONTEXT_LEN, MAX_DISTANCE)

In [None]:
train_dataset.features, train_dataset.labels

In [22]:
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

Step 4 - Intialize and Train Model

In [23]:
model = DiagnosisDateRelationModel(
        vocab_size=vocab_instance.n_words,
        embedding_dim=EMBEDDING_DIM,
        hidden_dim=HIDDEN_DIM
    ).to(DEVICE)

In [None]:
vocab_instance.n_words, EMBEDDING_DIM, HIDDEN_DIM, LEARNING_RATE, NUM_EPOCHS, DEVICE

In [None]:
model

In [26]:
# Loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
# Train model
train_losses, val_losses, val_accs = train_model(
    model, train_loader, val_loader, optimizer, criterion, NUM_EPOCHS, DEVICE, MODEL_PATH)

In [None]:
# Plot training curves
plot_training_curves(train_losses, val_losses, val_accs, MODEL_PATH)

Step 5 - Evaluate Model

In [None]:
# Step 5: Evaluate model
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))

In [None]:
results = evaluate_model(model, test_loader, DEVICE)

Testing

In [None]:
# Load model
if os.path.exists(MODEL_PATH):
    print(f"Loading model from {MODEL_PATH}")
    model = DiagnosisDateRelationModel(
        vocab_size=vocab_instance.n_words,
        embedding_dim=EMBEDDING_DIM,
        hidden_dim=HIDDEN_DIM
    ).to(DEVICE)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
else:
    print(f"Error: Model file {MODEL_PATH} not found!")
    print("Please run train_model.py first to train the model.")

# Load vocabulary
if os.path.exists(VOCAB_PATH):
    print(f"Loading vocabulary from {VOCAB_PATH}")
    vocab = torch.load(VOCAB_PATH, weights_only=False)
else:
    print(f"Error: Vocabulary file {VOCAB_PATH} not found!")
    print("Please run train_model.py first to generate the vocabulary.")

In [None]:
model, vocab_instance

In [None]:
# Apply to clinical note
features = preprocess_note_for_prediction(CLINICAL_NOTE, MAX_DISTANCE)
features

In [None]:
test_data = create_prediction_dataset(features, vocab_instance, DEVICE, MAX_DISTANCE, MAX_CONTEXT_LEN)
test_data

In [None]:
ml_relationships = predict_relationships(model, test_data)
ml_relationships

In [None]:
# Organize by date
date_dict = {}
for rel in ml_relationships:
    date = rel['date']
    if date not in date_dict:
        date_dict[date] = []
    date_dict[date].append((rel['diagnosis'], rel['confidence']))

date_dict

In [None]:
# Print ML results
print("\nPatient Timeline from ML Model:")
for date, diagnoses in sorted(date_dict.items()):
    print(f"\n{date}:")
    for diagnosis, confidence in sorted(diagnoses, key=lambda x: x[1], reverse=True):
        print(f"  - {diagnosis} (confidence: {confidence:.2f})")

In [None]:
# Compare with rule-based approach
diagnoses, dates = extract_entities(CLINICAL_NOTE)
diagnoses, dates

In [None]:
rule_based_relationships = relate_diagnosis_to_date_rule_based(diagnoses, dates)
rule_based_relationships

In [None]:
# Organize by date
rule_date_dict = {}
for rel in rule_based_relationships:
    date = rel['date']
    if date not in rule_date_dict:
        rule_date_dict[date] = []
    rule_date_dict[date].append((rel['diagnosis'], rel['distance']))

rule_date_dict

In [None]:
# Print rule-based results
print("\nPatient Timeline from Rule-based Approach:")
for date, diagnoses in sorted(rule_date_dict.items()):
    print(f"\n{date}:")
    for diagnosis, distance in sorted(diagnoses, key=lambda x: x[1]):
        print(f"  - {diagnosis} (distance: {distance} chars)")