In [None]:
# # core
# import json, os, numpy as np
# import tensorflow as tf
# from pathlib import Path
# import os, sys
# # adjust the path so that the folder containing `src/` is on sys.path
# # if your notebook lives in `project/notebooks/`, this gives you `project/`
# project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
# if project_root not in sys.path:
#     sys.path.insert(0, project_root)

# # models
# from src.models.image_model     import build_image_model
# from src.models.text_model      import build_text_model
# from src.models.multimodal_model import build_multimodal_model

# # retriever
# from src.retriever.meds_rag      import retrieve_meds

# # utils
# from src.utils.geocode_utils     import find_allergists
# from src.utils.question_logic    import get_followup
# from src.utils.data_utils        import seed_everything

# # preprocessing
# from tensorflow.keras.preprocessing import image
# from transformers import DistilBertTokenizerFast

# # ─── Paths ─────────────────────────────────────────────────
# MODEL_DIR    = Path("models")
# IMG_MODEL_FP = MODEL_DIR / "image_model.keras"
# TXT_MODEL_FP = MODEL_DIR / "text_model.keras"
# MM_MODEL_FP  = MODEL_DIR / "multimodal_model.keras"

# LABEL_MAP_IMG = json.loads(Path("label_map_image.json").read_text())
# LABEL_MAP_TXT = json.loads(Path("label_map_text.json").read_text())
# FOOD2ALLERGY  = json.loads(Path("food2allergy.json").read_text())

# FEATURE_COLS     = json.loads(Path("feature_cols.json").read_text())            # metadata names
# FEATURE_COLS_IMG = json.loads(Path("feature_cols_image.json").read_text())      # if needed

# # set random seeds
# seed_everything(42)

# # tokenizer
# tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

# 0) Ensure `src/` is on Python path
import sys
from pathlib import Path
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# 1) Imports
import json, numpy as np, tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from transformers import DistilBertTokenizerFast

from src.models.multimodal_model import build_multimodal_model  # only if you prefer build+load_weights
from src.retriever.meds_rag       import retrieve_meds
from src.utils.question_logic     import get_followup
from src.utils.geocode_utils      import find_allergists





In [18]:
# # ─── Cell 2 — Load your trained models from the `models/` folder ────────────
# from pathlib import Path
# from tensorflow.keras.models import load_model

# # Make sure this points to the directory where you ran model.save(...)
# MODEL_DIR        = Path("models")       # ← your .keras files are here
# IMG_MODEL_PATH   = MODEL_DIR / "image_model.keras"
# TXT_MODEL_PATH   = MODEL_DIR / "text_model.keras"
# MM_MODEL_PATH    = MODEL_DIR / "multimodal_model.keras"

# print("🔄 Loading models…")
# img_model = load_model(str(IMG_MODEL_PATH), compile=False)
# txt_model = load_model(str(TXT_MODEL_PATH), compile=False)
# mm_model  = load_model(str(MM_MODEL_PATH),  compile=False)
# print("✅ Models loaded successfully.")

# 3) Build-or-Load the multimodal model
MM_KERAS = Path("models/multimodal_model.keras")
MM_H5    = Path("models/multimodal_weights.h5")

# 3) Load label maps & feature columns

LABEL_MAP_IMG = json.loads(Path("label_map_image.json").read_text())
ID2LABEL_IMG  = {v: k for k, v in LABEL_MAP_IMG.items()}
FEATURE_COLS  = json.loads(Path("feature_cols_image.json").read_text())
NUM_IMG_CLASSES = len(LABEL_MAP_IMG)
NUM_TEXT_CLASSES= len(json.loads(Path("label_map_text.json").read_text()))




In [19]:
# 4) Build‐or‐Load the multimodal model
if MM_KERAS.exists():
    # 4a) Preferred: load full .keras package
    mm_model = load_model(str(MM_KERAS), compile=False)
    print(f"✅ Loaded fusion model from {MM_KERAS}")
elif MM_H5.exists():
    # 4b) Fallback: rebuild arch + load H5 weights
    mm_model = build_multimodal_model(
        image_input_shape=(224,224,3),
        num_image_classes=NUM_IMG_CLASSES,
        text_pretrained="distilbert-base-uncased",
        num_text_labels=NUM_TEXT_CLASSES,
        max_seq_len=256,
        metadata_dim=len(FEATURE_COLS),
        num_classes=NUM_IMG_CLASSES,
        dropout=0.3
    )
    mm_model.load_weights(str(MM_H5))
    print(f"✅ Built fusion model and loaded weights from {MM_H5}")
else:
    raise FileNotFoundError(
        f"Neither {MM_KERAS} nor {MM_H5} were found. "
        "Please run your training notebook to save the model."
    )



FileNotFoundError: Neither models/multimodal_model.keras nor models/multimodal_weights.h5 were found. Please run your training notebook to save the model.

In [20]:
# 5) Load text tokenizer
TOKENIZER_DIR = Path("weights/text_tokenizer")
tokenizer     = DistilBertTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
print("✅ Loaded text tokenizer")


✅ Loaded text tokenizer


In [21]:

# 6) Preprocessing helpers
def preprocess_image(path: str):
    img = load_img(path, target_size=(224,224))
    arr = img_to_array(img)
    return tf.keras.applications.efficientnet.preprocess_input(arr)[None]

def preprocess_text(txt: str):
    toks = tokenizer(
        txt,
        truncation=True,
        padding="max_length",
        max_length=256,
        return_tensors="tf"
    )
    return toks["input_ids"], toks["attention_mask"]

def preprocess_meta(meta: dict):
    vec = np.array([meta.get(f, 0.0) for f in FEATURE_COLS], dtype=np.float32)
    return vec[None]



In [30]:
# 7) Gather user inputs
img_path     = input("📁 Path to image (or blank): ").strip() or None
symptom_text = input("✍️  Describe your symptoms (or blank): ").strip() or ""
print("Enter metadata values (leave blank for 0):")
meta_input = {f: float(input(f"  {f} = ").strip() or 0.0) for f in FEATURE_COLS}



Enter metadata values (leave blank for 0):


In [31]:
# 8) Preprocess inputs
x_img = preprocess_image(img_path) if img_path else np.zeros((1,224,224,3), dtype=np.float32)
if symptom_text:
    x_ids, x_mask = preprocess_text(symptom_text)
else:
    x_ids  = tf.zeros((1,256), dtype=tf.int32)
    x_mask = tf.zeros((1,256), dtype=tf.int32)
x_meta = preprocess_meta(meta_input)


In [32]:
# 9) Run fusion prediction
preds = mm_model.predict([x_img, x_ids, x_mask, x_meta])[0]  
# preds is a 1-D array of length n_classes

# show its shape and valid indices
print("▶ preds shape:", preds.shape)
min_idx, max_idx = 0, preds.shape[0] - 1
print(f"✅ Valid class indices: {min_idx} through {max_idx}")

# pick the top class (no axis or axis=0)
cls  = int(np.argmax(preds))   
conf = float(preds[cls])
diagnosis = ID2LABEL_IMG[cls]
print(f"\n🏷️  Diagnosis: {diagnosis} (index {cls}, {conf*100:.1f}%)")


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 196ms/step
▶ preds shape: (6,)
✅ Valid class indices: 0 through 5

🏷️  Diagnosis: SCC (index 4, 86.9%)


In [33]:
# 10) Follow-up question
followup = get_followup(symptom_text, diagnosis)
print("❓ Follow-up:", followup)


❓ Follow-up: Have you recently used any new skincare products, detergents, or soaps?


In [28]:
# 11) Medications / Instructions
meds = retrieve_meds(diagnosis, k=3)
print("\n💊 Medications / Instructions:")
for i, doc in enumerate(meds, 1):
    snippet = doc.replace("\n"," ")[:200]
    print(f"  {i}. {snippet}…")



💊 Medications / Instructions:
  1. Directions: Acne Clearing Cleanser Acne Clearing Tonic Acne Clearing Treatment 101…
  2. 1 INDICATIONS AND USAGE Varenicline tablets are indicated for use as an aid to smoking cessation treatment. Varenicline tablets are a nicotinic receptor partial agonist indicated for use as an aid to …
  3. INDICATIONS Condition listed above or as directed by the physician…


In [29]:
# 12) Nearby allergists
print("\n🩺 Enter your coordinates to find allergists:")
lat = float(input("  latitude = ").strip())
lng = float(input("  longitude = ").strip())
places = find_allergists(lat, lng, radius_m=5000)
print("Nearby allergists:")
for p in places:
    name = p.get("name","Unknown")
    addr = p.get("tags",{}).get("addr:street","")
    print(f"  • {name} ({addr})")



🩺 Enter your coordinates to find allergists:
Nearby allergists:
