In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import gc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# CSV paths
data_path = '/content/drive/My Drive/mimic_iii_data_raw/'
admission_path = data_path + 'ADMISSIONS.csv.gz'
patients_path = data_path + 'PATIENTS.csv.gz'
icu_stays_path = data_path + 'ICUSTAYS.csv.gz'
note_events_path= data_path + 'NOTEEVENTS.csv.gz'
chart_events_path = data_path + 'CHARTEVENTS.csv.gz'

In [None]:
# Patients data
patients_df = pd.read_csv(patients_path, compression='gzip')
patients_df['DOB'] = pd.to_datetime(patients_df['DOB'])
patients_df = patients_df[['SUBJECT_ID', 'GENDER', 'DOB']]
patients_df.head()

In [None]:
# Admissions data
admissions_df = pd.read_csv(admission_path, compression='gzip')
admissions_df['ADMITTIME'] = pd.to_datetime(admissions_df['ADMITTIME'])
admissions_df['DISCHTIME'] = pd.to_datetime(admissions_df['DISCHTIME'])
admissions_df['DEATHTIME'] = pd.to_datetime(admissions_df['DEATHTIME'])
admissions_df.head()

In [None]:
# Merge admissions and patients data
df = pd.merge(admissions_df, patients_df, on='SUBJECT_ID', how='inner')

del admissions_df
del patients_df
gc.collect()


df['AGE'] = (df['ADMITTIME'].dt.year - df['DOB'].dt.year)
df = df[df['AGE'] >= 18]
df = df.sort_values(by=['SUBJECT_ID', 'ADMITTIME'])

df = df.drop_duplicates(subset=['SUBJECT_ID'], keep='first')
df = df[df['HAS_CHARTEVENTS_DATA'] == 1]
df = df.sample(n=2000, random_state=42)
df.head()

In [None]:
# Drop unnecessary columns
df['IN_HOSPITAL_MORTALITY'] = df['HOSPITAL_EXPIRE_FLAG']
df = df[['SUBJECT_ID', 'HADM_ID', 'ADMITTIME', 'DISCHTIME', 'DEATHTIME', 'ADMISSION_TYPE', 'DIAGNOSIS', 'AGE', 'GENDER', 'IN_HOSPITAL_MORTALITY']]
df.head()

In [None]:
# ICU data
icustays_df = pd.read_csv(icu_stays_path, compression='gzip')
icustays_df['INTIME'] = pd.to_datetime(icustays_df['INTIME'])
icustays_df['OUTTIME'] = pd.to_datetime(icustays_df['OUTTIME'])
icustays_df

In [None]:
# Clean ICU data
cohort_icustays = pd.merge(icustays_df, df[['HADM_ID']], on=['HADM_ID'], how='inner')
del icustays_df
gc.collect()
cohort_icustays = cohort_icustays.sort_values(by=['HADM_ID', 'INTIME'])

first_icustays = cohort_icustays.drop_duplicates(subset=['HADM_ID'], keep='first')
first_icustays = first_icustays[['HADM_ID', 'ICUSTAY_ID', 'INTIME', 'OUTTIME', 'LOS', 'FIRST_CAREUNIT']]
first_icustays

In [None]:
# Merge ICU data
df = pd.merge(df, first_icustays, on=['HADM_ID'], how='inner')
del first_icustays
gc.collect()
df

In [None]:
# Process part of the chart events data
def load_chartevents_for_cohort(chart_events_path, cohort_icustay_ids, max_chunks=1, chunk_size=5000000):
    cohort_ids = set(cohort_icustay_ids)
    filtered_chartevents = []

    columns_to_read = ['ICUSTAY_ID', 'ITEMID', 'CHARTTIME', 'VALUE', 'VALUENUM', 'VALUEUOM']

    chartevents_reader = pd.read_csv(
        chart_events_path,
        compression='gzip',
        usecols=columns_to_read,
        chunksize=chunk_size,
        nrows=chunk_size * max_chunks
    )

    for chunk_idx, chunk in enumerate(chartevents_reader, 1):
        start_row = (chunk_idx - 1) * chunk_size + 1
        end_row = chunk_idx * chunk_size
        print(f"  Processing chartevents chunk {chunk_idx} (rows {start_row} to {end_row})...")

        relevant_records = chunk[chunk['ICUSTAY_ID'].isin(cohort_ids)]

        if not relevant_records.empty:
            filtered_chartevents.append(relevant_records)

        if chunk_idx >= max_chunks:
            print(f"  Stopped reading CHARTEVENTS after {max_chunks} chunk(s).")
            break

    del chartevents_reader
    gc.collect()

    return filtered_chartevents

In [None]:
cohort_icustay_ids = df['ICUSTAY_ID'].unique()
raw_chartevents_for_cohort = load_chartevents_for_cohort(chart_events_path, cohort_icustay_ids)

In [None]:
all_raw_events_df = pd.concat(raw_chartevents_for_cohort, ignore_index=True)
all_raw_events_df["ICUSTAY_ID"] = all_raw_events_df["ICUSTAY_ID"].astype(int)
all_raw_events_df

In [None]:
# Merge chart events data
df = pd.merge(df, all_raw_events_df, on='ICUSTAY_ID', how='inner')
del all_raw_events_df
del raw_chartevents_for_cohort
gc.collect()
df

In [None]:
def load_noteevents_for_cohort(note_events_path, cohort_hadm_ids, max_chunks=1, chunk_size=500000):
    cohort_ids = set(cohort_hadm_ids)
    filtered_noteevents = []

    columns_to_read = [
        'ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'CHARTDATE', 'CHARTTIME',
        'STORETIME', 'CATEGORY', 'DESCRIPTION', 'CGID', 'ISERROR', 'TEXT'
    ]

    noteevents_reader = pd.read_csv(
        note_events_path,
        compression='gzip',
        usecols=columns_to_read,
        chunksize=chunk_size,
        nrows=chunk_size * max_chunks
    )

    for chunk_idx, chunk in enumerate(noteevents_reader, 1):
        start_row = (chunk_idx - 1) * chunk_size + 1
        end_row = chunk_idx * chunk_size
        print(f"  Processing noteevents chunk {chunk_idx} (rows {start_row} to {end_row})...")

        relevant_notes = chunk[chunk['HADM_ID'].isin(cohort_ids)]

        if not relevant_notes.empty:
            filtered_noteevents.append(relevant_notes)

        if chunk_idx >= max_chunks:
            print(f"  Stopped reading NOTEEVENTS after {max_chunks} chunk(s).")
            break

    del noteevents_reader
    gc.collect()

    return filtered_noteevents

In [None]:
cohort_hadm_ids = df['HADM_ID'].unique()
raw_noteevents_for_cohort = load_noteevents_for_cohort(note_events_path, cohort_hadm_ids)

In [None]:
all_raw_notes_df = pd.concat(raw_noteevents_for_cohort, ignore_index=True)
all_raw_notes_df["HADM_ID"] = all_raw_notes_df["HADM_ID"].astype(int)
all_raw_notes_df = all_raw_notes_df[["HADM_ID", "CHARTDATE", "CATEGORY", "DESCRIPTION", "TEXT"]]

In [None]:
# Merge note events data
notes_combined = all_raw_notes_df.groupby('HADM_ID').agg({
    'TEXT': lambda x: ' '.join(str(text) for text in x if pd.notna(text))
}).reset_index()
df_merged = pd.merge(df, notes_combined, on='HADM_ID', how='inner')
df_merged

In [None]:
# Create histogram of age distribution for unique admissions
plt.figure(figsize=(10, 6))
unique_admissions = df.drop_duplicates(subset=['HADM_ID'])
sns.histplot(data=unique_admissions, x='AGE', kde=True, bins=20)
plt.title('Age Distribution of Unique Admissions')
plt.xlabel('Age (years)')
plt.ylabel('Number of Admissions')
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
sns.countplot(x='GENDER', data=unique_admissions, palette="Set2", hue='GENDER')
plt.title('Gender Distribution of Unique Admissions')
plt.xlabel('Gender')
plt.ylabel('Number of Admissions')
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
mortality_counts = unique_admissions['IN_HOSPITAL_MORTALITY'].value_counts(normalize=True) * 100
sns.countplot(x='IN_HOSPITAL_MORTALITY', data=unique_admissions, palette='Set2', hue='IN_HOSPITAL_MORTALITY', legend=False)
plt.title('In-Hospital Mortality Rate (Unique Admissions)')
plt.xlabel('In-Hospital Mortality')
plt.ylabel('Number of Admissions')
labels = [f'Survived ({mortality_counts.get(0,0):.1f}%)',
          f'Expired ({mortality_counts.get(1,0):.1f}%)']
plt.xticks([0, 1], labels)
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
sns.countplot(y='ADMISSION_TYPE', data=unique_admissions, order=unique_admissions['ADMISSION_TYPE'].value_counts().index, palette="Set2", hue="ADMISSION_TYPE")
plt.title('Admission Type Distribution (Unique Admissions)')
plt.xlabel('Number of Admissions')
plt.ylabel('Admission Type')
plt.tight_layout()
plt.show()

In [None]:
numeric_cols_admissions = ['AGE', 'LOS']
print(unique_admissions[numeric_cols_admissions].describe().T)

In [None]:
print(pd.crosstab(unique_admissions['ADMISSION_TYPE'], unique_admissions['IN_HOSPITAL_MORTALITY'], margins=True, margins_name="Total"))

In [None]:
del unique_admissions
gc.collect()

In [None]:
df_final = df_merged.groupby('ICUSTAY_ID').agg({
    'VALUENUM': 'mean',
    'TEXT': 'first',
    'IN_HOSPITAL_MORTALITY': 'first'
}).reset_index()
df_final = df_final.dropna(subset=['VALUENUM', 'TEXT', 'IN_HOSPITAL_MORTALITY'])
df_final

In [None]:
X_train, X_test = train_test_split(df_final, test_size=0.2, random_state=42)

scaler = StandardScaler()
X_train['VALUENUM'] = scaler.fit_transform(X_train[['VALUENUM']])
X_test['VALUENUM'] = scaler.transform(X_test[['VALUENUM']])

In [None]:
class MultimodalDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=128):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.text_char_limit = max_length * 10

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]

        measurements = torch.tensor([row['VALUENUM']], dtype=torch.float32)
        text = str(row['TEXT'])[:self.text_char_limit]

        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        label = torch.tensor(row['IN_HOSPITAL_MORTALITY'], dtype=torch.float32)

        return {
            'measurements': measurements,
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': label
        }


In [None]:
class MeasurementEncoder(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=32):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, hidden_dim=32, freeze_bert=True):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.projection = nn.Linear(768, hidden_dim)

        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_token = bert_output.last_hidden_state[:, 0, :]
        return self.projection(cls_token)

In [None]:
class MultimodalModel(nn.Module):
    def __init__(self, measurement_dim=1, hidden_dim=32):
        super(MultimodalModel, self).__init__()

        self.measurement_encoder = MeasurementEncoder(measurement_dim, hidden_dim)
        self.text_encoder = TextEncoder(hidden_dim)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, 1),
            nn.Sigmoid()
        )

    def forward(self, measurements, input_ids, attention_mask):
        measurement_features = self.measurement_encoder(measurements)
        text_features = self.text_encoder(input_ids, attention_mask)

        combined_features = torch.cat([measurement_features, text_features], dim=1)

        return self.classifier(combined_features)

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_dataset = MultimodalDataset(
    dataframe=X_train,
    tokenizer=tokenizer,
    max_length=64
)

test_dataset = MultimodalDataset(
    dataframe=X_test,
    tokenizer=tokenizer,
    max_length=64
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=2,
    shuffle=True
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=2,
    shuffle=False
)

model = MultimodalModel()

criterion = nn.BCELoss()
optimizer = optim.Adam(
    params=model.parameters(),
    lr=1e-4,
    weight_decay=1e-5
)

In [None]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=1):
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0.0

        for batch in train_loader:
            measurements = batch['measurements']
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['label'].unsqueeze(1)

            optimizer.zero_grad()
            outputs = model(measurements, input_ids, attention_mask)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)
        print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')

    return model


In [None]:
model = train_model(model, train_loader, criterion, optimizer, 10)
torch.save(model.state_dict(), 'multimodal_model.pt')

In [None]:
def evaluate_saved_model(model_path, test_dataset, batch_size=2):
    model = MultimodalModel()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in test_loader:
            measurements = batch['measurements']
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['label'].unsqueeze(1)

            outputs = model(measurements, input_ids, attention_mask)

            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds).flatten()
    all_labels = np.array(all_labels).flatten()

    metrics = {
        'auc_roc': roc_auc_score(all_labels, all_preds),
        'auc_pr': average_precision_score(all_labels, all_preds),
        'predictions': all_preds,
        'labels': all_labels
    }

    return metrics

In [None]:
def zero_shot_evaluation_saved(model_path, test_dataset, tokenizer):
    model = MultimodalModel()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    phrases = {
        "positive": "patient deceased",
        "negative": "discharged today"
    }

    encodings = {}
    for key, phrase in phrases.items():
        encodings[key] = tokenizer(
            phrase,
            max_length=64,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

    test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

    all_preds, all_labels = [], []

    with torch.no_grad():
        text_encoder = model.text_encoder
        pos_emb = text_encoder(encodings["positive"]['input_ids'],
                              encodings["positive"]['attention_mask'])
        neg_emb = text_encoder(encodings["negative"]['input_ids'],
                              encodings["negative"]['attention_mask'])

        for batch in test_loader:
            measurements = batch['measurements']
            labels = batch['label']

            meas_emb = model.measurement_encoder(measurements)

            pos_sim = torch.nn.functional.cosine_similarity(
                meas_emb, pos_emb.expand(meas_emb.size(0), -1), dim=1)
            neg_sim = torch.nn.functional.cosine_similarity(
                meas_emb, neg_emb.expand(meas_emb.size(0), -1), dim=1)

            probs = torch.nn.functional.softmax(
                torch.stack([neg_sim, pos_sim], dim=1), dim=1)

            all_preds.extend(probs[:, 1].cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    metrics = {
        'auc_roc': roc_auc_score(all_labels, all_preds),
        'auc_pr': average_precision_score(all_labels, all_preds),
        'predictions': all_preds,
        'labels': all_labels
    }

    return metrics


In [None]:
def plot_curves(results):
    from sklearn.metrics import roc_curve, precision_recall_curve
    import matplotlib.pyplot as plt

    y_pred = results['predictions']
    y_true = results['labels']

    fpr, tpr, _ = roc_curve(y_true, y_pred)
    precision, recall, _ = precision_recall_curve(y_true, y_pred)

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    axes[0].plot(fpr, tpr, label=f'AUC-ROC = {results["auc_roc"]:.4f}')
    axes[0].plot([0, 1], [0, 1], 'k--')
    axes[0].set_xlabel('False Positive Rate')
    axes[0].set_ylabel('True Positive Rate')
    axes[0].set_title('ROC Curve')
    axes[0].legend(loc='lower right')
    axes[0].grid(True, linestyle='--', alpha=0.7)

    axes[1].plot(recall, precision, label=f'AUC-PR = {results["auc_pr"]:.4f}')
    axes[1].set_xlabel('Recall')
    axes[1].set_ylabel('Precision')
    axes[1].set_title('Precision-Recall Curve')
    axes[1].legend(loc='lower left')
    axes[1].grid(True, linestyle='--', alpha=0.7)

    plt.tight_layout()
    plt.show()

In [None]:
def compare_with_paper(our_results):
    paper_results = {
        'supervised': {'auc_roc': 0.856, 'auc_roc_std': 0.004, 'auc_pr': 0.495, 'auc_pr_std': 0.005},
        'zero_shot': {'auc_roc': 0.709, 'auc_pr': 0.214}
    }

    print("| Evaluation | AUC-ROC | Paper AUC-ROC | AUC-PR | Paper AUC-PR |")
    print("|------------|---------|---------------|--------|-------------|")

    for eval_type in ['supervised', 'zero_shot']:
        display_name = "Supervised" if eval_type == 'supervised' else "Zero-shot"

        our_roc = our_results[eval_type]['auc_roc']
        our_pr = our_results[eval_type]['auc_pr']
        paper_roc = paper_results[eval_type]['auc_roc']
        paper_pr = paper_results[eval_type]['auc_pr']

        print(f"| {display_name:<11} | {our_roc:.4f} | {paper_roc:.4f} | {our_pr:.4f} | {paper_pr:.4f} |")

    sup_roc_diff = (our_results['supervised']['auc_roc'] - paper_results['supervised']['auc_roc']) / paper_results['supervised']['auc_roc'] * 100
    sup_pr_diff = (our_results['supervised']['auc_pr'] - paper_results['supervised']['auc_pr']) / paper_results['supervised']['auc_pr'] * 100
    zero_roc_diff = (our_results['zero_shot']['auc_roc'] - paper_results['zero_shot']['auc_roc']) / paper_results['zero_shot']['auc_roc'] * 100
    zero_pr_diff = (our_results['zero_shot']['auc_pr'] - paper_results['zero_shot']['auc_pr']) / paper_results['zero_shot']['auc_pr'] * 100


In [None]:
def evaluate_model_performance():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    test_dataset = MultimodalDataset(X_test, tokenizer, max_length=64)
    model_path = 'multimodal_model.pt'

    supervised_results = evaluate_saved_model(model_path, test_dataset)
    plot_curves(supervised_results)

    zero_shot_results = zero_shot_evaluation_saved(model_path, test_dataset, tokenizer)
    results = {
        'supervised': supervised_results,
        'zero_shot': zero_shot_results
    }

    compare_with_paper(results)
    return results

In [None]:
results = evaluate_model_performance()

In [None]:
class MultimodalModelWithDropout(nn.Module):
    def __init__(self, measurement_dim=1, hidden_dim=32, dropout_prob=0.5):
        super().__init__()
        self.measurement_encoder = MeasurementEncoder(measurement_dim, hidden_dim)
        self.text_encoder = TextEncoder(hidden_dim)
        self.classifier = nn.Linear(hidden_dim * 2, 1)
        self.sigmoid = nn.Sigmoid()

        self.dropout_prob = dropout_prob

        self.missing_measurement_token = nn.Parameter(torch.randn(hidden_dim))
        self.missing_text_token = nn.Parameter(torch.randn(hidden_dim))

    def forward(self, measurements, input_ids, attention_mask, training=True):
        drop_measurement = False
        drop_text = False

        if training:
            drop_measurement = torch.rand(1).item() < self.dropout_prob
            drop_text = torch.rand(1).item() < self.dropout_prob

            if drop_measurement and drop_text:
                drop_measurement = torch.rand(1).item() < 0.5
                drop_text = not drop_measurement

        batch_size = measurements.size(0)

        measurement_emb = (
            self.missing_measurement_token.expand(batch_size, -1)
            if drop_measurement
            else self.measurement_encoder(measurements)
        )

        text_emb = (
            self.missing_text_token.expand(batch_size, -1)
            if drop_text
            else self.text_encoder(input_ids, attention_mask)
        )

        combined = torch.cat([measurement_emb, text_emb], dim=1)
        return self.sigmoid(self.classifier(combined))

In [None]:
def train_model_with_dropout(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    history = {'loss': []}

    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch in train_loader:
            measurements = batch['measurements']
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['label'].unsqueeze(1)

            optimizer.zero_grad()

            outputs = model(measurements, input_ids, attention_mask)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        history['loss'].append(epoch_loss)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')

    return model, history


In [None]:
def evaluate_modality_settings(model, test_dataset):
    test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
    results = {}

    model.eval()

    # Evaluate with both modalities
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in test_loader:
            measurements = batch['measurements']
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['label'].unsqueeze(1)

            outputs = model(measurements, input_ids, attention_mask, training=False)

            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds).flatten()
    all_labels = np.array(all_labels).flatten()

    roc = roc_auc_score(all_labels, all_preds)
    pr = average_precision_score(all_labels, all_preds)

    results['both'] = {'auc_roc': roc, 'auc_pr': pr}
    print(f"Both modalities - AUC-ROC: {roc:.4f}, AUC-PR: {pr:.4f}")

    # Evaluate with only measurements
    print("\nEvaluating with only measurements...")
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in test_loader:
            measurements = batch['measurements']
            input_ids = batch['input_ids']
            labels = batch['label'].unsqueeze(1)

            batch_size = input_ids.size(0)
            text_emb = model.missing_text_token.expand(batch_size, -1)
            measurement_emb = model.measurement_encoder(measurements)

            combined = torch.cat([measurement_emb, text_emb], dim=1)
            output = model.sigmoid(model.classifier(combined))

            all_preds.extend(output.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds).flatten()
    all_labels = np.array(all_labels).flatten()

    roc = roc_auc_score(all_labels, all_preds)
    pr = average_precision_score(all_labels, all_preds)

    results['measurements_only'] = {'auc_roc': roc, 'auc_pr': pr}
    print(f"Measurements only - AUC-ROC: {roc:.4f}, AUC-PR: {pr:.4f}")

    # Evaluate with only text
    print("\nEvaluating with only text...")
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in test_loader:
            measurements = batch['measurements']
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['label'].unsqueeze(1)

            batch_size = measurements.size(0)
            measurement_emb = model.missing_measurement_token.expand(batch_size, -1)
            text_emb = model.text_encoder(input_ids, attention_mask)

            combined = torch.cat([measurement_emb, text_emb], dim=1)
            output = model.sigmoid(model.classifier(combined))

            all_preds.extend(output.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds).flatten()
    all_labels = np.array(all_labels).flatten()

    roc = roc_auc_score(all_labels, all_preds)
    pr = average_precision_score(all_labels, all_preds)

    results['text_only'] = {'auc_roc': roc, 'auc_pr': pr}
    print(f"Text only - AUC-ROC: {roc:.4f}, AUC-PR: {pr:.4f}")

    return results

In [None]:
def plot_modality_comparison(results):
    settings = ['both', 'measurements_only', 'text_only']
    settings_labels = ['Both Modalities', 'Measurements Only', 'Text Only']

    roc_values = [results[s]['auc_roc'] for s in settings]
    pr_values = [results[s]['auc_pr'] for s in settings]

    x = np.arange(len(settings))
    width = 0.35

    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width/2, roc_values, width, label='AUC-ROC')
    rects2 = ax.bar(x + width/2, pr_values, width, label='AUC-PR')

    ax.set_ylabel('Score')
    ax.set_title('Performance by Modality Setting')
    ax.set_xticks(x)
    ax.set_xticklabels(settings_labels)
    ax.legend()

    for rects in (rects1, rects2):
        for rect in rects:
            height = rect.get_height()
            ax.annotate(f'{height:.3f}',
                        xy=(rect.get_x() + rect.get_width()/2, height),
                        xytext=(0, 3),
                        textcoords="offset points",
                        ha='center', va='bottom')

    fig.tight_layout()
    plt.show()

In [None]:
def run_modality_dropout_extension():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    train_dataset = MultimodalDataset(X_train, tokenizer, max_length=64)
    test_dataset = MultimodalDataset(X_test, tokenizer, max_length=64)

    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

    model = MultimodalModelWithDropout(dropout_prob=0.3)

    original_model = MultimodalModel()
    original_model.load_state_dict(torch.load('multimodal_model.pt'))

    for target, source in zip([model.measurement_encoder, model.text_encoder],
                             [original_model.measurement_encoder, original_model.text_encoder]):
        target.load_state_dict(source.state_dict())

    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

    model, history = train_model_with_dropout(model, train_loader, criterion,
                                             optimizer, num_epochs=5)

    torch.save(model.state_dict(), 'multimodal_model_with_dropout.pt')

    modality_results = evaluate_modality_settings(model, test_dataset)
    plot_modality_comparison(modality_results)

    return modality_results


In [None]:
modality_dropout_results = run_modality_dropout_extension()