In [53]:
import sys
from pathlib import Path
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# 1) Imports
import json, numpy as np, tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from transformers import DistilBertTokenizerFast

from src.models.multimodal_model import build_multimodal_model  # only if you prefer build+load_weights
from src.retriever.meds_rag       import retrieve_meds
from src.utils.question_logic     import get_followup
from src.utils.geocode_utils      import find_allergists





In [54]:
MM_KERAS = Path("models/multimodal_model.keras")
MM_H5    = Path("models/multimodal_weights.h5")

# 3) Load label maps & feature columns

LABEL_MAP_IMG = json.loads(Path("label_map_image.json").read_text())
ID2LABEL_IMG  = {v: k for k, v in LABEL_MAP_IMG.items()}
LABEL_MAP_TXT = json.loads(Path("label_map_text.json").read_text())
ID2LABEL_TXT  = {v:k for k,v in LABEL_MAP_TXT.items()}
FEATURE_COLS  = json.loads(Path("feature_cols_image.json").read_text())
NUM_IMG_CLASSES = len(LABEL_MAP_IMG)
NUM_TEXT_CLASSES= len(json.loads(Path("label_map_text.json").read_text()))




In [103]:
# Optional: Rebuild model and save .h5 weights for future use
mm_model = build_multimodal_model(
    image_input_shape=(224, 224, 3),
    num_image_classes=NUM_IMG_CLASSES,
    text_pretrained="distilbert-base-uncased",
    num_text_labels=NUM_TEXT_CLASSES,
    max_seq_len=256,
    metadata_dim=len(FEATURE_COLS),
    num_classes=NUM_IMG_CLASSES,
    dropout=0.3
)

# Optionally load from a .keras file if needed here (but you're skipping that for now)

# Save weights as .h5 for safe FastAPI loading
mm_model.save_weights("models/multimodal.weights.h5")
print("✅ Saved H5 weights to models/multimodal.weights.h5")


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertModel: ['vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias']
- This IS expected if you are initializing TFDistilBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFDistilBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


✅ Saved H5 weights to models/multimodal.weights.h5


In [105]:
# 4) Build-or-Load the multimodal model
from src.models.multimodal_model import build_multimodal_model

MM_KERAS = Path("models/multimodal_model.keras")
MM_H5 = Path("models/multimodal.weights.h5")


if MM_H5.exists():
    print("🔄 Rebuilding architecture and loading weights from .h5…")
    mm_model = build_multimodal_model(
        image_input_shape=(224, 224, 3),
        num_image_classes=NUM_IMG_CLASSES,
        text_pretrained="distilbert-base-uncased",
        num_text_labels=NUM_TEXT_CLASSES,
        max_seq_len=256,
        metadata_dim=len(FEATURE_COLS),
        num_classes=NUM_IMG_CLASSES,
        dropout=0.3
    )
    mm_model.load_weights(str(MM_H5))
    print(f"✅ Built fusion model and loaded weights from {MM_H5}")
else:
    print("⚠️ .h5 weights not found.")
    if MM_KERAS.exists():
        print("⚠️ .keras file found but skipped due to Lambda deserialization issues.")
    raise FileNotFoundError(
        f"❌ Could not find a usable model file.\n"
        "Ensure you saved weights with:\n"
        "  mm_model.save_weights('models/multimodal_weights.h5')"
    )


🔄 Rebuilding architecture and loading weights from .h5…


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertModel: ['vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias']
- This IS expected if you are initializing TFDistilBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFDistilBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


✅ Built fusion model and loaded weights from models/multimodal.weights.h5


In [106]:
mm_model.save("models/multimodal_model.keras")
print("✅ Full multimodal model saved to models/multimodal_model.keras")


✅ Full multimodal model saved to models/multimodal_model.keras


In [107]:
# 5) Load text tokenizer
TOKENIZER_DIR = Path("weights/text_tokenizer")
tokenizer     = DistilBertTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
print("✅ Loaded text tokenizer")


✅ Loaded text tokenizer


In [108]:

# 6) Preprocessing helpers
def preprocess_image(path: str):
    img = load_img(path, target_size=(224,224))
    arr = img_to_array(img)
    return tf.keras.applications.efficientnet.preprocess_input(arr)[None]

def preprocess_text(txt: str):
    toks = tokenizer(
        txt,
        truncation=True,
        padding="max_length",
        max_length=256,
        return_tensors="tf"
    )
    return toks["input_ids"], toks["attention_mask"]

def preprocess_meta(meta: dict):
    vec = np.array([meta.get(f, 0.0) for f in FEATURE_COLS], dtype=np.float32)
    return vec[None]



In [84]:
# ─── Cell 7 — Gather User Inputs ───────────────────────────────────────────

# 7a) Require an image path
img_path = None
while not img_path:
    p = input("📁 Path to image (required): ").strip()
    if p:
        img_path = p
    else:
        print("⚠️  Please provide a valid image file path.")

# 7b) Free-text symptoms (optional)
symptom_text = input("✍️  Describe your symptoms (or leave blank): ").strip() or ""

# 7c) Structured metadata
print("\nEnter metadata values (leave blank for 0):")
meta_input = {}
for f in FEATURE_COLS:
    val = input(f"  {f} = ").strip()
    meta_input[f] = float(val) if val else 0.0


⚠️  Please provide a valid image file path.
⚠️  Please provide a valid image file path.
⚠️  Please provide a valid image file path.

Enter metadata values (leave blank for 0):


In [74]:
# ─── Cell 8 — Preprocess Inputs ────────────────────────────────────────────

import numpy as np
import tensorflow as tf

# 8a) Image → 4D tensor
x_img = preprocess_image(img_path)   # your helper loads + preprocesses + adds batch dim

# 8b) Text → token IDs & attention mask
if symptom_text:
    x_ids, x_mask = preprocess_text(symptom_text)
else:
    x_ids  = tf.zeros((1, 256), dtype=tf.int32)
    x_mask = tf.zeros((1, 256), dtype=tf.int32)

# 8c) Metadata → 2D array
x_meta = preprocess_meta(meta_input)

print("✅ Inputs ready:")
print("   • x_img:", x_img.shape)
print("   • x_ids:", x_ids.shape, "x_mask:", x_mask.shape)
print("   • x_meta:", x_meta.shape)


✅ Inputs ready:
   • x_img: (1, 224, 224, 3)
   • x_ids: (1, 256) x_mask: (1, 256)
   • x_meta: (1, 18)


In [75]:
# ─── Cell 9 — Fusion-Only Prediction ───────────────────────────────────────

# Always run the multimodal fusion model
preds = mm_model.predict([x_img, x_ids, x_mask, x_meta])[0]

print("▶ preds shape:", preds.shape)
print(f"✅ Valid class indices: 0 through {preds.shape[0] - 1}")

cls       = int(np.argmax(preds))
conf      = float(preds[cls])
diagnosis = ID2LABEL_IMG[cls]

print(f"\n🏷️  Diagnosis: {diagnosis} (index {cls}, {conf*100:.1f}%)")


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 203ms/step
▶ preds shape: (6,)
✅ Valid class indices: 0 through 5

🏷️  Diagnosis: SCC (index 4, 72.1%)


In [76]:
# 10) Interactive follow-up question(s)
from src.utils.question_logic import FOLLOWUP_BANK, LABEL_TO_BUCKET

# determine which bucket the diagnosis lives in
bucket = LABEL_TO_BUCKET.get(diagnosis)  

# primary question
primary_q = FOLLOWUP_BANK.get(bucket, "Could you tell me when the symptoms started?")
# if it’s a “skin” issue, also ask the digestive/food question
if bucket == "skin":
    secondary_q = FOLLOWUP_BANK["digestive"]
    combined_q = f"{primary_q} Also: {secondary_q}"
    user_followup = input(f"❓ {combined_q}\nYour answer: ").strip()
else:
    user_followup = input(f"❓ {primary_q}\nYour answer: ").strip()

print("👍 Thanks for your answer!")


👍 Thanks for your answer!


In [77]:
# 11) Medications / Instructions
meds = retrieve_meds(diagnosis, k=3)
print("\n💊 Medications / Instructions:")
for i, doc in enumerate(meds, 1):
    snippet = doc.replace("\n"," ")[:200]
    print(f"  {i}. {snippet}…")



💊 Medications / Instructions:
  1. Directions: Acne Clearing Cleanser Acne Clearing Tonic Acne Clearing Treatment 101…
  2. 1 INDICATIONS AND USAGE Varenicline tablets are indicated for use as an aid to smoking cessation treatment. Varenicline tablets are a nicotinic receptor partial agonist indicated for use as an aid to …
  3. INDICATIONS Condition listed above or as directed by the physician…


In [79]:
# 12) Nearby allergists
print("\n🩺 Enter your coordinates to find allergists:")
lat = float(input("  latitude = ").strip())
lng = float(input("  longitude = ").strip())
places = find_allergists(lat, lng, radius_m=5000)
print("Nearby allergists:")
for p in places:
    name = p.get("name","Unknown")
    addr = p.get("tags",{}).get("addr:street","")
    print(f"  • {name} ({addr})")



🩺 Enter your coordinates to find allergists:
Nearby allergists:


In [None]:
# <!DOCTYPE html>
# <html lang="en">
# <head>
#   <meta charset="UTF-8">
#   <title>AllerGenie Multimodal Demo</title>
#   <style>
#     body { font-family: sans-serif; max-width: 800px; margin: 2em auto; }
#     label { display: block; margin-top: 1em; }
#     input, textarea, select { width: 100%; padding: 0.5em; margin-top: 0.2em; }
#     button { margin-top: 1.5em; padding: 0.7em 1.5em; font-size: 1em; }
#     #results { margin-top: 2em; padding: 1em; border: 1px solid #ccc; }
#   </style>
# </head>
# <body>

#   <h1>AllerGenie Multimodal Inference</h1>

#   <!-- 1) Image Upload -->
#   <label>
#     📁 Upload lesion image:
#     <input type="file" id="imgInput" accept="image/*">
#   </label>

#   <!-- 2) Symptom Text -->
#   <label>
#     ✍️ Describe your symptoms:
#     <textarea id="symptomText" rows="3" placeholder="e.g. 'I have had a red, itchy patch on my arm…'"></textarea>
#   </label>

#   <!-- 3) Metadata Fields -->
#   <fieldset>
#     <legend>Enter metadata values (leave blank for 0)</legend>
    
#     <label>
#       Age (years):
#       <input type="number" step="0.1" id="age" placeholder="e.g. 45">
#     </label>

#     <label>
#       Diameter 1 (mm):
#       <input type="number" step="0.1" id="diameter_1" placeholder="e.g. 12">
#     </label>

#     <label>
#       Diameter 2 (mm):
#       <input type="number" step="0.1" id="diameter_2" placeholder="e.g. 8">
#     </label>

#     <label>
#       Gender:
#       <select id="gender_M">
#         <option value="0">Female</option>
#         <option value="1">Male</option>
#       </select>
#     </label>

#     <!-- Region one-hots -->
#     <label>Region:</label>
#     <div style="display: flex; flex-wrap: wrap; gap: 0.5em;">
#       <label><input type="radio" name="region" value="region_ABDOMEN"> Abdomen</label>
#       <label><input type="radio" name="region" value="region_ARM"> Arm</label>
#       <label><input type="radio" name="region" value="region_BACK"> Back</label>
#       <label><input type="radio" name="region" value="region_CHEST"> Chest</label>
#       <label><input type="radio" name="region" value="region_EAR"> Ear</label>
#       <label><input type="radio" name="region" value="region_FACE"> Face</label>
#       <label><input type="radio" name="region" value="region_FOOT"> Foot</label>
#       <label><input type="radio" name="region" value="region_FOREARM"> Forearm</label>
#       <label><input type="radio" name="region" value="region_HAND"> Hand</label>
#       <label><input type="radio" name="region" value="region_LIP"> Lip</label>
#       <label><input type="radio" name="region" value="region_NECK"> Neck</label>
#       <label><input type="radio" name="region" value="region_NOSE"> Nose</label>
#       <label><input type="radio" name="region" value="region_SCALP"> Scalp</label>
#       <label><input type="radio" name="region" value="region_THIGH"> Thigh</label>
#     </div>
#   </fieldset>

#   <button id="submitBtn">🩺 Submit for Diagnosis</button>

#   <div id="results" hidden>
#     <h2>Results</h2>
#     <p id="diagnosis"></p>
#     <p id="confidence"></p>
#     <p id="followup"></p>
#   </div>

#   <script>
#     document.getElementById('submitBtn').addEventListener('click', async () => {
#       const imgFile = document.getElementById('imgInput').files[0];
#       if (!imgFile) {
#         alert("Please upload an image.");
#         return;
#       }

#       // gather metadata
#       const meta = {
#         age: parseFloat(document.getElementById('age').value) || 0.0,
#         diameter_1: parseFloat(document.getElementById('diameter_1').value) || 0.0,
#         diameter_2: parseFloat(document.getElementById('diameter_2').value) || 0.0,
#         gender_M: parseFloat(document.getElementById('gender_M').value),
#       };
#       // set one-hot region
#       const regionEls = document.querySelectorAll('input[name="region"]');
#       regionEls.forEach(el => meta[el.value] = (el.checked ? 1.0 : 0.0));

#       const symptom_text = document.getElementById('symptomText').value;

#       // build FormData
#       const form = new FormData();
#       form.append('image', imgFile);
#       form.append('symptom_text', symptom_text);
#       form.append('metadata', JSON.stringify(meta));

#       // call your backend API
#       const resp = await fetch('/api/infer', {
#         method: 'POST',
#         body: form
#       });
#       if (!resp.ok) {
#         alert("Error: " + resp.statusText);
#         return;
#       }
#       const { diagnosis, confidence, followup } = await resp.json();

#       // display
#       document.getElementById('diagnosis').innerText = `Diagnosis: ${diagnosis}`;
#       document.getElementById('confidence').innerText = `Confidence: ${(confidence*100).toFixed(1)}%`;
#       document.getElementById('followup').innerText = `Follow-up: ${followup}`;
#       document.getElementById('results').hidden = false;
#     });
#   </script>
# </body>
# </html>


IndentationError: unindent does not match any outer indentation level (<tokenize>, line 92)