In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, f1_score
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import os
import random

In [12]:
# Load relevant MIMIC-III data
DATA_PATH = "/content/"

In [13]:
def load_mimic_data():
    admissions = pd.read_csv(os.path.join(DATA_PATH, 'ADMISSIONS.csv.gz'))
    diagnoses = pd.read_csv(os.path.join(DATA_PATH, 'DIAGNOSES_ICD.csv.gz'))
    labevents = pd.read_csv(os.path.join(DATA_PATH, 'LABEVENTS.csv.gz'))
    return admissions, diagnoses, labevents

In [22]:
# Load relevant MIMIC-III data
DATA_PATH = "/content/"
admissions, diagnoses, labevents = load_mimic_data()

EOFError: Compressed file ended before the end-of-stream marker was reached

In [14]:
def preprocess_features(labevents):
    top_items = labevents['ITEMID'].value_counts().nlargest(20).index.tolist()
    filtered = labevents[labevents['ITEMID'].isin(top_items)]
    pivot = filtered.pivot_table(
        index='SUBJECT_ID',
        columns='ITEMID',
        values='VALUENUM',
        aggfunc='mean'
    ).fillna(0)
    return pivot

In [None]:
features_df = preprocess_features(labevents).reset_index()

In [15]:
# Extract diagnosis code list per subject
def get_diagnosis_labels(diagnoses):
    grouped = diagnoses.groupby('SUBJECT_ID')['ICD9_CODE'].apply(list)
    return grouped

In [None]:
diagnosis_df = get_diagnosis_labels(diagnoses).reset_index()

In [None]:
# Merge outcomes
merged = pd.merge(features_df, admissions[['SUBJECT_ID', 'HOSPITAL_EXPIRE_FLAG', 'DISCHARGE_LOCATION']], on='SUBJECT_ID')
merged = pd.merge(merged, diagnosis_df, on='SUBJECT_ID')

In [None]:
X = merged.drop(columns=['SUBJECT_ID', 'HOSPITAL_EXPIRE_FLAG', 'DISCHARGE_LOCATION', 'ICD9_CODE']).values
y_mortality = merged['HOSPITAL_EXPIRE_FLAG'].values
y_discharge = merged['DISCHARGE_LOCATION'].values
y_diagnosis = merged['ICD9_CODE'].values

In [16]:
# Dummy dataset class for contrastive learning
class ContrastivePatientDataset(Dataset):
    def __init__(self, features):
        self.features = features

    def __getitem__(self, index):
        x1 = self.features[index]
        x2 = self._augment(x1)
        return torch.tensor(x1).float(), torch.tensor(x2).float()

    def __len__(self):
        return len(self.features)

    def _augment(self, x):
        noise = np.random.normal(0, 0.01, size=x.shape)
        return x + noise

# Encoder model
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, x):
        return self.net(x)

# Contrastive Loss
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        z_i = nn.functional.normalize(z_i, dim=1)
        z_j = nn.functional.normalize(z_j, dim=1)
        logits = torch.matmul(z_i, z_j.T) / self.temperature
        labels = torch.arange(len(z_i)).to(z_i.device)
        return nn.CrossEntropyLoss()(logits, labels)

# Train contrastive model
def train_contrastive_model(data):
    model = Encoder(input_dim=data.shape[1], hidden_dim=128)
    loss_fn = ContrastiveLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    dataset = ContrastivePatientDataset(data)
    loader = DataLoader(dataset, batch_size=64, shuffle=True)

    for epoch in range(5):
        for x1, x2 in loader:
            z1 = model(x1)
            z2 = model(x2)
            loss = loss_fn(z1, z2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
    return model

In [None]:
model = train_contrastive_model(X)
embeddings = model(torch.tensor(X).float()).detach().numpy()

In [17]:
# Mortality Prediction Task
def mortality_prediction(embeddings, labels):
    X_train, X_test, y_train, y_test = train_test_split(embeddings, labels, test_size=0.2)
    clf = LogisticRegression(max_iter=1000)
    clf.fit(X_train, y_train)
    preds = clf.predict(X_test)
    proba = clf.predict_proba(X_test)[:, 1]
    print("\nMortality Prediction Report:")
    print(classification_report(y_test, preds))
    print("ROC AUC:", roc_auc_score(y_test, proba))

In [None]:
mortality_prediction(embeddings, y_mortality)

In [18]:
# Discharge Disposition Task
def discharge_prediction(embeddings, dispositions):
    le = LabelEncoder()
    labels = le.fit_transform(dispositions)
    X_train, X_test, y_train, y_test = train_test_split(embeddings, labels, test_size=0.2)
    clf = LogisticRegression(max_iter=1000)
    clf.fit(X_train, y_train)
    preds = clf.predict(X_test)
    print("\nDischarge Disposition Prediction Report:")
    print(classification_report(y_test, preds))

In [None]:
discharge_prediction(embeddings, y_discharge)

In [19]:
# Diagnosis Code Prediction Task (multi-label)
def diagnosis_prediction(embeddings, diagnosis_codes):
    mlb = MultiLabelBinarizer()
    Y = mlb.fit_transform(diagnosis_codes)
    X_train, X_test, y_train, y_test = train_test_split(embeddings, Y, test_size=0.2)
    clf = LogisticRegression(max_iter=1000)
    clf.fit(X_train, y_train)
    preds = clf.predict(X_test)
    print("\nDiagnosis Code Prediction F1 (Micro):", f1_score(y_test, preds, average='micro'))


In [None]:
diagnosis_prediction(embeddings, y_diagnosis)

In [20]:
# t-SNE Visualization
def visualize_embeddings(embeddings, labels, title="t-SNE Visualization"):
    tsne = TSNE(n_components=2)
    reduced = tsne.fit_transform(embeddings)
    plt.figure(figsize=(8, 6))
    plt.scatter(reduced[:, 0], reduced[:, 1], c=labels, cmap='viridis', alpha=0.6)
    plt.title(title)
    plt.show()

In [None]:
visualize_embeddings(embeddings, y_mortality, title="t-SNE of Contrastive Embeddings (Mortality)")