Import Libraries

In [8]:
import pandas as pd
import numpy as np
import pickle

DDI_metric

In [9]:
from cornac.metrics import RankingMetric

class DDIRate(RankingMetric):
    def __init__(self, ddi_matrix, k=10, name="DDI@10"):
        """
        Parameters:
        - ddi_pairs: set of (drug_id_1, drug_id_2) tuples indicating known DDIs.
        - k: number of top predicted items to consider per user.
        """
        super().__init__(name=name, k=k)
        self.ddi_matrix = ddi_matrix

    def compute(self, gt_pos, gt_neg, pd_rank, pd_scores, item_indices=None):
        top_k_items = pd_rank[:self.k]
        ddi_count = 0
        total_pairs = 0

        for i in range(len(top_k_items)):
            for j in range(i + 1, len(top_k_items)):
                d1, d2 = top_k_items[i], top_k_items[j]    
                if frozenset({d1, d2}) in self.ddi_matrix or frozenset({d2, d1}) in self.ddi_matrix:
                    ddi_count += 1
                total_pairs += 1

        ddi_rate = ddi_count / total_pairs if total_pairs > 0 else 0.0
        return ddi_rate



Toxicity Metric

In [10]:
import numpy as np
from itertools import combinations

class ToxicityDDIRate(RankingMetric):
    def __init__(self, toxicity_matrix, k=10, name="ToxicityDDI@10"):
        """
        Parameters:
        - toxicity_matrix: 2D NumPy array or sparse matrix where toxicity_matrix[i, j] 
                           gives the toxicity score of the DDI between drugs i and j.
                           (0 if no interaction, >0 if interaction exists)
        - k: number of top predicted items to consider per user.
        """
        super().__init__(name=name, k=k)
        self.toxicity_matrix = toxicity_matrix

    def compute(self, gt_pos, gt_neg, pd_rank, pd_scores, item_indices=None):
        top_k_items = pd_rank[:self.k]
        if len(top_k_items) < 2:
            return 0.0

        # All unordered pairs among top-k
        pairs = np.array(list(combinations(top_k_items, 2)))

        # Sum toxicity of interactions among top-k items
        toxicity_sum = self.toxicity_matrix[pairs[:, 0], pairs[:, 1]].sum()
        total_pairs = len(pairs)

        return toxicity_sum / total_pairs if total_pairs > 0 else 0.0


Baseline Models BPR and WMF

In [11]:
notes = pd.read_csv(
    r'...\NOTEEVENTS.csv.gz',
    dtype={4: str, 5: str}  # or int, float, etc. depending on data
)

notes = notes[notes["CATEGORY"].isin(["Discharge summary"])]
notes = notes.dropna(subset=["TEXT", "HADM_ID"])

patient_texts = notes.groupby("HADM_ID")["TEXT"].apply(lambda x: "\n".join(x)).reset_index()

In [12]:
with open(r'...\mapped_ddi_pairs.pkl', 'rb') as f:
    mapped_ddi_pairs = pickle.load(f)

In [15]:
import re
from tqdm import tqdm
from rapidfuzz import process
import cornac
from cornac.data import Reader
from cornac.eval_methods import CrossValidation
from cornac.data import Dataset


# ------------------- Step 0: Load Files -------------------
ratings_df = pd.read_csv(r'...\user_drug_rating_visit_anemia.csv')
matched_df = pd.read_csv(r'...\drugbank_mimic_rxcui_map.csv')
patient_texts['HADM_ID'] = patient_texts['HADM_ID'].astype(float).astype(int).astype(str)


# ------------------- Step 1: Clean Drug Names -------------------
def clean_drug_name(name):
    if pd.isnull(name): return ""
    name = name.lower().strip()
    name = re.sub(r'\b\d+(\.\d+)?\s*(mg|ml|mcg|units|tablet|tab|capsule|cap|drop|syrup|patch|ointment|cream|injection|solution|suspension|oral|inj|dose|suppository)\b', '', name)
    name = re.sub(r'[^\w\s]', '', name)
    name = re.sub(r'\s+', ' ', name)
    return name.strip()

ratings_df.dropna(subset=["user", "item", "rating"], inplace=True)
ratings_df['user'] = ratings_df['user'].astype(str)
ratings_df['item'] = ratings_df['item'].astype(str)
ratings_df['rating'] = ratings_df['rating'].astype(float)
ratings_df['clean_item'] = ratings_df['item'].apply(clean_drug_name)

matched_df['Generic_Name'] = matched_df['Generic_Name'].astype(str)
matched_df['clean_generic'] = matched_df['Generic_Name'].apply(clean_drug_name)


# ------------------- Step 2: Fuzzy Match Drug Names -------------------
ratings_items = ratings_df['clean_item'].unique()
mimic_generics = matched_df['clean_generic'].unique().tolist()

lookup = {}
for item in ratings_items:
    match = process.extractOne(item, mimic_generics, score_cutoff=80)
    if match:
        lookup[item] = match[0]

ratings_df['matched_generic'] = ratings_df['clean_item'].map(lookup)
ratings_df['matched_generic'] = ratings_df['matched_generic'].fillna(ratings_df['clean_item'])

matched_count = ratings_df['matched_generic'].notnull().sum()
total_count = ratings_df['clean_item'].nunique()
print(f"✔️ Fuzzy matched drugs: {matched_count} / {total_count}")


# ------------------- Step 3: Filter Users With Embeddings -------------------
subj_id_to_emb_idx = {sid: idx for idx, sid in enumerate(patient_texts['HADM_ID'])}

# ------------------- Step 4: Prepare UIR -------------------
uir_data = list(zip(ratings_df['user'], ratings_df['matched_generic'], ratings_df['rating']))
if not uir_data:
    raise ValueError("No valid UIR data after filtering!")


# ------------------- Step 5: Build Dataset -------------------
cornac_data = Dataset.from_uir(uir_data, seed=123)
uid_map = cornac_data.uid_map
iid_map = cornac_data.iid_map


ratio_split = CrossValidation(
    n_folds=10,
    data=uir_data,
    exclude_unknowns=True,
    rating_threshold=1.0,
    verbose=True,
    seed=123,
)


# ------------------- Step 7: Filter DDI Pairs -------------------
current_drugs = {drug.lower().strip() for drug in ratings_df["matched_generic"]}
filtered_ddi_pairs = [
    (d1.lower().strip(), d2.lower().strip(), sev)
    for (d1, d2, sev) in mapped_ddi_pairs
    if d1.lower().strip() in current_drugs and d2.lower().strip() in current_drugs
]




SEED = 42
VERBOSE = False
K = 50
wmf = cornac.models.WMF(k=K, max_iter=100, a=1.0, b=0.01, learning_rate=0.001, lambda_u=0.01, lambda_v=0.01,
          verbose=VERBOSE, seed=SEED, name=f"WMF(K={K})")


# Create toxicity matrix
toxicity_map = {"minor": 1.0, "moderate": 2.0, "major": 3.0}
num_items = len(iid_map)
toxicity_matrix = np.zeros((num_items, num_items), dtype=float)
for d1, d2, sev in filtered_ddi_pairs:
    if d1 in iid_map and d2 in iid_map:
        i, j = iid_map[d1], iid_map[d2]
        toxicity_matrix[i, j] = toxicity_map[sev]
        toxicity_matrix[j, i] = toxicity_map[sev]

toxicity_ddi_metric = ToxicityDDIRate(toxicity_matrix=toxicity_matrix, k=10)

ddi_index_pairs = set()
for d1, d2, _ in filtered_ddi_pairs:
    if d1 in iid_map and d2 in iid_map:
        ddi_index_pairs.add(frozenset((iid_map[d1], iid_map[d2])))

ddi_metric = DDIRate(ddi_matrix=ddi_index_pairs, k=10)

eval_metrics = [
  cornac.metrics.Recall(k=10),
  cornac.metrics.NDCG(k=[10]),
  #cornac.metrics.MRR(),
  ddi_metric,
  toxicity_ddi_metric
]

# Put everything together into an experiment and run it
cornac.Experiment(
    eval_method=ratio_split,
    models=[wmf],
    metrics=eval_metrics,
    user_based=True,
).run()

✔️ Fuzzy matched drugs: 448385 / 1880




rating_threshold = 1.0
exclude_unknowns = True
Fold: 1




---
Training data:
Number of users = 15162
Number of items = 1122
Number of ratings = 367245
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15162
Number of items = 1122
Number of ratings = 44305
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15162
Number of items = 1122
Number of ratings = 44305
---
Total users = 15162
Total items = 1122

[WMF(K=50)] Training started!


[WMF(K=50)] Evaluation started!


Ranking:   0%|          | 0/13756 [00:00<?, ?it/s]

Fold: 2




---
Training data:
Number of users = 15162
Number of items = 1116
Number of ratings = 367247
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0




---
Test data:
Number of users = 15162
Number of items = 1116
Number of ratings = 44305
Number of unknown users = 0
Number of unknown items = 0
---
Validation data:
Number of users = 15162
Number of items = 1116
Number of ratings = 44305
---
Total users = 15162
Total items = 1116

[WMF(K=50)] Training started!

[WMF(K=50)] Evaluation started!


Ranking:   0%|          | 0/13807 [00:00<?, ?it/s]

Fold: 3




---
Training data:
Number of users = 15161
Number of items = 1129
Number of ratings = 367184
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15161
Number of items = 1129
Number of ratings = 44296
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15161
Number of items = 1129
Number of ratings = 44296
---
Total users = 15161
Total items = 1129

[WMF(K=50)] Training started!

[WMF(K=50)] Evaluation started!


Ranking:   0%|          | 0/13753 [00:00<?, ?it/s]

Fold: 4




---
Training data:
Number of users = 15160
Number of items = 1104
Number of ratings = 367068
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15160
Number of items = 1104
Number of ratings = 44287
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15160
Number of items = 1104
Number of ratings = 44287
---
Total users = 15160
Total items = 1104

[WMF(K=50)] Training started!

[WMF(K=50)] Evaluation started!


Ranking:   0%|          | 0/13777 [00:00<?, ?it/s]

Fold: 5




---
Training data:
Number of users = 15161
Number of items = 1127
Number of ratings = 367155
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15161
Number of items = 1127
Number of ratings = 44303
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15161
Number of items = 1127
Number of ratings = 44303
---
Total users = 15161
Total items = 1127

[WMF(K=50)] Training started!

[WMF(K=50)] Evaluation started!


Ranking:   0%|          | 0/13708 [00:00<?, ?it/s]

Fold: 6




---
Training data:
Number of users = 15160
Number of items = 1118
Number of ratings = 367163
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15160
Number of items = 1118
Number of ratings = 44311
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15160
Number of items = 1118
Number of ratings = 44311
---
Total users = 15160
Total items = 1118

[WMF(K=50)] Training started!

[WMF(K=50)] Evaluation started!


Ranking:   0%|          | 0/13730 [00:00<?, ?it/s]

Fold: 7




---
Training data:
Number of users = 15161
Number of items = 1117
Number of ratings = 367178
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15161
Number of items = 1117
Number of ratings = 44295
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15161
Number of items = 1117
Number of ratings = 44295
---
Total users = 15161
Total items = 1117

[WMF(K=50)] Training started!

[WMF(K=50)] Evaluation started!


Ranking:   0%|          | 0/13804 [00:00<?, ?it/s]

Fold: 8




---
Training data:
Number of users = 15162
Number of items = 1125
Number of ratings = 367193
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15162
Number of items = 1125
Number of ratings = 44289
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15162
Number of items = 1125
Number of ratings = 44289
---
Total users = 15162
Total items = 1125

[WMF(K=50)] Training started!

[WMF(K=50)] Evaluation started!


Ranking:   0%|          | 0/13745 [00:00<?, ?it/s]

Fold: 9




---
Training data:
Number of users = 15161
Number of items = 1111
Number of ratings = 367235
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15161
Number of items = 1111
Number of ratings = 44252
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15161
Number of items = 1111
Number of ratings = 44252
---
Total users = 15161
Total items = 1111

[WMF(K=50)] Training started!

[WMF(K=50)] Evaluation started!


Ranking:   0%|          | 0/13736 [00:00<?, ?it/s]

Fold: 10




---
Training data:
Number of users = 15161
Number of items = 1114
Number of ratings = 367026
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15161
Number of items = 1114
Number of ratings = 44314
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15161
Number of items = 1114
Number of ratings = 44314
---
Total users = 15161
Total items = 1114

[WMF(K=50)] Training started!

[WMF(K=50)] Evaluation started!


Ranking:   0%|          | 0/13817 [00:00<?, ?it/s]


TEST:
...
[WMF(K=50)]
       | DDI@10 | NDCG@10 | Recall@10 | SeverityDDI@10 | Train (s) | Test (s)
------ + ------ + ------- + --------- + -------------- + --------- + --------
Fold 0 | 0.3000 |  0.3076 |    0.4141 |         0.3528 |   74.8661 |  12.9587
Fold 1 | 0.3285 |  0.3126 |    0.4212 |         0.3589 |   42.2129 |  12.5640
Fold 2 | 0.2454 |  0.3041 |    0.4146 |         0.2670 |   43.5199 |  11.5935
Fold 3 | 0.3429 |  0.3069 |    0.4171 |         0.3793 |   38.2493 |  11.2342
Fold 4 | 0.2694 |  0.2929 |    0.4028 |         0.3094 |   42.1554 |  11.4217
Fold 5 | 0.2865 |  0.2999 |    0.4129 |         0.3340 |   38.8274 |  12.5310
Fold 6 | 0.2142 |  0.3154 |    0.4195 |         0.2364 |   42.4055 |  11.7966
Fold 7 | 0.2561 |  0.3078 |    0.4150 |         0.2757 |   41.2355 |  11.4685
Fold 8 | 0.2565 |  0.2964 |    0.4141 |         0.2815 |   42.4790 |  12.8358
Fold 9 | 0.3146 |  0.3124 |    0.4200 |         0.3404 |   42.6521 |  12.8398
------ + ------ + ------- + --------- + -

Baseline NeuMf

In [16]:
import re
from tqdm import tqdm
from rapidfuzz import process
import cornac
from cornac.data import Reader
from cornac.eval_methods import CrossValidation
import pandas as pd

from cornac.data import Dataset


# ------------------- Step 0: Load Files -------------------
ratings_df = pd.read_csv(r'...\user_drug_rating_visit_anemia.csv')
matched_df = pd.read_csv(r'...\drugbank_mimic_rxcui_map.csv')
patient_texts['HADM_ID'] = patient_texts['HADM_ID'].astype(float).astype(int).astype(str)


# ------------------- Step 1: Clean Drug Names -------------------
def clean_drug_name(name):
    if pd.isnull(name): return ""
    name = name.lower().strip()
    name = re.sub(r'\b\d+(\.\d+)?\s*(mg|ml|mcg|units|tablet|tab|capsule|cap|drop|syrup|patch|ointment|cream|injection|solution|suspension|oral|inj|dose|suppository)\b', '', name)
    name = re.sub(r'[^\w\s]', '', name)
    name = re.sub(r'\s+', ' ', name)
    return name.strip()

ratings_df.dropna(subset=["user", "item", "rating"], inplace=True)
ratings_df['user'] = ratings_df['user'].astype(str)
ratings_df['item'] = ratings_df['item'].astype(str)
ratings_df['rating'] = ratings_df['rating'].astype(float)
ratings_df['clean_item'] = ratings_df['item'].apply(clean_drug_name)

matched_df['Generic_Name'] = matched_df['Generic_Name'].astype(str)
matched_df['clean_generic'] = matched_df['Generic_Name'].apply(clean_drug_name)


# ------------------- Step 2: Fuzzy Match Drug Names -------------------
ratings_items = ratings_df['clean_item'].unique()
mimic_generics = matched_df['clean_generic'].unique().tolist()

lookup = {}
for item in ratings_items:
    match = process.extractOne(item, mimic_generics, score_cutoff=80)
    if match:
        lookup[item] = match[0]

ratings_df['matched_generic'] = ratings_df['clean_item'].map(lookup)
ratings_df['matched_generic'] = ratings_df['matched_generic'].fillna(ratings_df['clean_item'])

matched_count = ratings_df['matched_generic'].notnull().sum()
total_count = ratings_df['clean_item'].nunique()
print(f"✔️ Fuzzy matched drugs: {matched_count} / {total_count}")


# ------------------- Step 3: Filter Users With Embeddings -------------------
subj_id_to_emb_idx = {sid: idx for idx, sid in enumerate(patient_texts['HADM_ID'])}

# ------------------- Step 4: Prepare UIR -------------------
uir_data = list(zip(ratings_df['user'], ratings_df['matched_generic'], ratings_df['rating']))
if not uir_data:
    raise ValueError("No valid UIR data after filtering!")


# ------------------- Step 5: Build Dataset -------------------
cornac_data = Dataset.from_uir(uir_data, seed=123)
uid_map = cornac_data.uid_map
iid_map = cornac_data.iid_map


# ------------------- Step 7: Filter DDI Pairs -------------------
current_drugs = {drug.lower().strip() for drug in ratings_df["matched_generic"]}
filtered_ddi_pairs = [
    (d1.lower().strip(), d2.lower().strip(), sev)
    for (d1, d2, sev) in mapped_ddi_pairs
    if d1.lower().strip() in current_drugs and d2.lower().strip() in current_drugs
]




ratio_split = CrossValidation(
    n_folds=10,
    data=uir_data,
    exclude_unknowns=True,
    rating_threshold=1.0,
    verbose=True,
    seed=123,
)

neumf = cornac.models.NeuMF(
    num_factors=8,
    layers=[64, 32, 16, 8],
    act_fn="tanh",
    learner="adam",
    num_epochs=30,
    batch_size=256,
    lr=0.001,
    num_neg=50,
    seed=42,
)

# Create toxicity matrix
toxicity_map = {"minor": 1.0, "moderate": 2.0, "major": 3.0}
num_items = len(iid_map)
toxicity_matrix = np.zeros((num_items, num_items), dtype=float)
for d1, d2, sev in filtered_ddi_pairs:
    if d1 in iid_map and d2 in iid_map:
        i, j = iid_map[d1], iid_map[d2]
        toxicity_matrix[i, j] = toxicity_map[sev]
        toxicity_matrix[j, i] = toxicity_map[sev]

toxicity_ddi_metric = ToxicityDDIRate(toxicity_matrix=toxicity_matrix, k=10)

ddi_index_pairs = set()
for d1, d2, _ in filtered_ddi_pairs:
    if d1 in iid_map and d2 in iid_map:
        ddi_index_pairs.add(frozenset((iid_map[d1], iid_map[d2])))

ddi_metric = DDIRate(ddi_matrix=ddi_index_pairs, k=10)



eval_metrics = [
  cornac.metrics.Recall(k=10),
  cornac.metrics.NDCG(k=[10,]),
  ddi_metric,
  toxicity_ddi_metric
]

# Put everything together into an experiment and run it
cornac.Experiment(
    eval_method=ratio_split,
    models=[neumf],
    metrics=eval_metrics,
    user_based=True,
).run()








✔️ Fuzzy matched drugs: 448385 / 1880




rating_threshold = 1.0
exclude_unknowns = True
Fold: 1




---
Training data:
Number of users = 15162
Number of items = 1122
Number of ratings = 367245
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15162
Number of items = 1122
Number of ratings = 44305
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15162
Number of items = 1122
Number of ratings = 44305
---
Total users = 15162
Total items = 1122

[NeuMF] Training started!


  0%|          | 0/30 [00:00<?, ?it/s]


[NeuMF] Evaluation started!


Ranking:   0%|          | 0/13756 [00:00<?, ?it/s]

Fold: 2




---
Training data:
Number of users = 15162
Number of items = 1116
Number of ratings = 367247
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15162
Number of items = 1116
Number of ratings = 44305
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15162
Number of items = 1116
Number of ratings = 44305
---
Total users = 15162
Total items = 1116

[NeuMF] Training started!


  0%|          | 0/30 [00:00<?, ?it/s]


[NeuMF] Evaluation started!


Ranking:   0%|          | 0/13807 [00:00<?, ?it/s]

Fold: 3




---
Training data:
Number of users = 15161
Number of items = 1129
Number of ratings = 367184
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15161
Number of items = 1129
Number of ratings = 44296
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15161
Number of items = 1129
Number of ratings = 44296
---
Total users = 15161
Total items = 1129

[NeuMF] Training started!


  0%|          | 0/30 [00:00<?, ?it/s]


[NeuMF] Evaluation started!


Ranking:   0%|          | 0/13753 [00:00<?, ?it/s]

Fold: 4




---
Training data:
Number of users = 15160
Number of items = 1104
Number of ratings = 367068
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15160
Number of items = 1104
Number of ratings = 44287
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15160
Number of items = 1104
Number of ratings = 44287
---
Total users = 15160
Total items = 1104

[NeuMF] Training started!


  0%|          | 0/30 [00:00<?, ?it/s]


[NeuMF] Evaluation started!


Ranking:   0%|          | 0/13777 [00:00<?, ?it/s]

Fold: 5




---
Training data:
Number of users = 15161
Number of items = 1127
Number of ratings = 367155
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15161
Number of items = 1127
Number of ratings = 44303
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15161
Number of items = 1127
Number of ratings = 44303
---
Total users = 15161
Total items = 1127

[NeuMF] Training started!


  0%|          | 0/30 [00:00<?, ?it/s]


[NeuMF] Evaluation started!


Ranking:   0%|          | 0/13708 [00:00<?, ?it/s]

Fold: 6




---
Training data:
Number of users = 15160
Number of items = 1118
Number of ratings = 367163
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15160
Number of items = 1118
Number of ratings = 44311
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15160
Number of items = 1118
Number of ratings = 44311
---
Total users = 15160
Total items = 1118

[NeuMF] Training started!


  0%|          | 0/30 [00:00<?, ?it/s]


[NeuMF] Evaluation started!


Ranking:   0%|          | 0/13730 [00:00<?, ?it/s]

Fold: 7




---
Training data:
Number of users = 15161
Number of items = 1117
Number of ratings = 367178
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15161
Number of items = 1117
Number of ratings = 44295
Number of unknown users = 0
Number of unknown items = 0
---
Validation data:
Number of users = 15161
Number of items = 1117
Number of ratings = 44295
---
Total users = 15161
Total items = 1117





[NeuMF] Training started!


  0%|          | 0/30 [00:00<?, ?it/s]


[NeuMF] Evaluation started!


Ranking:   0%|          | 0/13804 [00:00<?, ?it/s]

Fold: 8




---
Training data:
Number of users = 15162
Number of items = 1125
Number of ratings = 367193
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15162
Number of items = 1125
Number of ratings = 44289
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15162
Number of items = 1125
Number of ratings = 44289
---
Total users = 15162
Total items = 1125

[NeuMF] Training started!


  0%|          | 0/30 [00:00<?, ?it/s]


[NeuMF] Evaluation started!


Ranking:   0%|          | 0/13745 [00:00<?, ?it/s]

Fold: 9




---
Training data:
Number of users = 15161
Number of items = 1111
Number of ratings = 367235
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15161
Number of items = 1111
Number of ratings = 44252
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15161
Number of items = 1111
Number of ratings = 44252
---
Total users = 15161
Total items = 1111

[NeuMF] Training started!


  0%|          | 0/30 [00:00<?, ?it/s]


[NeuMF] Evaluation started!


Ranking:   0%|          | 0/13736 [00:00<?, ?it/s]

Fold: 10




---
Training data:
Number of users = 15161
Number of items = 1114
Number of ratings = 367026
Max rating = 1.0
Min rating = 1.0
Global mean = 1.0
---
Test data:
Number of users = 15161
Number of items = 1114
Number of ratings = 44314
Number of unknown users = 0
Number of unknown items = 0




---
Validation data:
Number of users = 15161
Number of items = 1114
Number of ratings = 44314
---
Total users = 15161
Total items = 1114

[NeuMF] Training started!


  0%|          | 0/30 [00:00<?, ?it/s]


[NeuMF] Evaluation started!


Ranking:   0%|          | 0/13817 [00:00<?, ?it/s]


TEST:
...
[NeuMF]
       | DDI@10 | NDCG@10 | Recall@10 | SeverityDDI@10 |  Train (s) | Test (s)
------ + ------ + ------- + --------- + -------------- + ---------- + --------
Fold 0 | 0.2723 |  0.3587 |    0.3963 |         0.3168 | 17460.1610 | 241.5111
Fold 1 | 0.2819 |  0.3641 |    0.4054 |         0.3133 | 17209.3704 | 243.6673
Fold 2 | 0.2436 |  0.3608 |    0.3996 |         0.2675 | 16638.8669 | 213.8703
Fold 3 | 0.3076 |  0.3643 |    0.4013 |         0.3454 | 16473.8948 | 209.3395
Fold 4 | 0.2710 |  0.3752 |    0.4165 |         0.3093 | 16499.7353 | 212.9331
Fold 5 | 0.2600 |  0.3705 |    0.4146 |         0.2941 | 16257.6613 | 210.6989
Fold 6 | 0.2230 |  0.3634 |    0.4070 |         0.2524 | 16222.4830 | 220.6049
Fold 7 | 0.2576 |  0.3760 |    0.4171 |         0.2848 | 16323.8270 | 221.3236
Fold 8 | 0.2537 |  0.3710 |    0.4127 |         0.2812 | 16623.8170 | 226.5266
Fold 9 | 0.2814 |  0.3570 |    0.3929 |         0.3118 | 16832.8561 | 226.6828
------ + ------ + ------- + -----

In [19]:
# ------------------- Step 0: Pick 5 Patients -------------------
case_patients = ratings_df["user"].unique()[:10]  # first 5 unique patients
print("\n===== CASE STUDY: Top-10 Recommendations (NeuMF) =====\n")

# ------------------- Step 1: Prepare Train Set -------------------
train_set = cornac_data  # using all data for example
neumf.fit(train_set)     # fit NeuMF model

# ------------------- Step 2: Generate Top-10 Recommendations -------------------
for pid in case_patients:
    if pid not in uid_map:
        continue  # skip unmapped patient

    internal_uid = uid_map[pid]

    # Ground-truth drugs for this patient
    gt_drugs = ratings_df.loc[ratings_df["user"] == pid, "matched_generic"].unique().tolist()

    # All items (internal IDs)
    all_items = list(iid_map.keys())
    item_indices = [iid_map[i] for i in all_items]

    # Get NeuMF scores for this user
    scores = neumf.score(user_idx=internal_uid)

    # Rank top-10 items
    top10_idx = np.argsort(scores)[::-1][:10]
    top10_drugs = [(all_items[i], scores[i]) for i in top10_idx]

    # Print results
    print(f"Patient HADM_ID: {pid}")
    print(f"  Ground Truth Drugs: {gt_drugs}")
    print("  Top-10 Recommendations:")
    for rank, (drug, score) in enumerate(top10_drugs, 1):
        marker = "✓" if drug in gt_drugs else ""
        print(f"    {rank:2d}. {drug} ({score:.4f}) {marker}")
    print("-" * 60)



===== CASE STUDY: Top-10 Recommendations (NeuMF) =====

Patient HADM_ID: 100003
  Ground Truth Drugs: ['acetaminophen', 'chloraseptic throat spray', 'folic acid', 'furosemide', 'lactulose', 'lidocaine', 'magnesium sulfate', 'nadolol', 'sarna lotion', 'sodium chloride', 'spironolactone', 'terbinafine 1 cream', 'thiamine']
  Top-10 Recommendations:
     1. potassium chloride (0.9998) 
     2. acetaminophen (0.9938) ✓
     3. spironolactone (0.8724) ✓
     4. sarna lotion (0.8684) ✓
     5. nadolol (0.8451) ✓
     6. thiamine (0.8078) ✓
     7. folic acid (0.7814) ✓
     8. lactulose (0.7784) ✓
     9. lidocaine (0.7464) ✓
    10. sodium chloride (0.7431) ✓
------------------------------------------------------------
Patient HADM_ID: 100009
  Ground Truth Drugs: ['fenofibrate', 'acetaminophen', 'aspirin', 'atenolol', 'bisacodyl', 'calcium carbonate', 'cephalexin', 'chlorhexidine gluconate', 'dextrose 50', 'docusate sodium', 'sodium citrate', 'ezetimibe', 'furosemide', 'glucagon', 'glycop