In [13]:
import json, random
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras import layers, models, callbacks


In [14]:
# Cell 1 ── Imports & Paths
import json, pandas as pd, numpy as np, tensorflow as tf
from pathlib import Path
from sklearn.model_selection import train_test_split

# Where your allergy images live:
IMG_DIRS    = [
    Path("/mnt/ssd1/saumia/data/images/imgs_part_1"),
    Path("/mnt/ssd1/saumia/data/images/imgs_part_2"),
    Path("/mnt/ssd1/saumia/data/images/imgs_part_3"),
]


In [22]:
# Cell 2 ── Build & Clean DataFrame (final, with NaN‐drop before split)

import json
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split

# 1) Read your metadata CSV
meta_df = pd.read_csv("/mnt/ssd1/saumia/data/text/metadata.csv")

# 2) Gather rows by matching image filenames to metadata
rows = []
for img_dir in IMG_DIRS:
    for p in img_dir.glob("*.png"):
        parts = p.stem.split("_")
        if len(parts) < 4:
            continue
        try:
            lesion_id = int(parts[2])
        except ValueError:
            continue
        m = meta_df[meta_df["lesion_id"] == lesion_id]
        if m.empty:
            continue
        m = m.iloc[0]
        row = {
            "path": str(p),
            "label": m["diagnostic"],
            "age": m["age"],
            "diameter_1": m["diameter_1"],
            "diameter_2": m["diameter_2"],
            "gender_M": 1.0 if str(m["gender"]).upper() == "MALE" else 0.0,
            "region": m["region"],
        }
        # six boolean cols
        for col in ["itch", "bleed", "elevation", "changed", "hurt", "grew"]:
            row[col] = 1.0 if bool(m.get(col)) else 0.0
        rows.append(row)

df = pd.DataFrame(rows)
assert not df.empty, "No images matched metadata!"

# 3) Drop nuisance features
constant_cols = ["itch", "bleed", "elevation", "changed", "hurt", "grew"]
df = df.drop(columns=constant_cols, errors="ignore")

# 4) Normalize continuous features to [0,1]
for col in ["age", "diameter_1", "diameter_2"]:
    if col in df:
        mn, mx = df[col].min(), df[col].max()
        df[col] = (df[col] - mn) / (mx - mn + 1e-8)

# 5) One‐hot encode the region column
df = pd.get_dummies(df, columns=["region"], dtype="float32")

# 6) Create label map and feature‐column list
classes   = sorted(df["label"].unique())
label_map = {c: i for i, c in enumerate(classes)}
feat_cols = [c for c in df.columns if c not in {"path", "label"}]

df["label_id"] = df["label"].map(label_map)

# 7) Persist your mappings
with open("label_map_image.json", "w")    as f: json.dump(label_map, f, indent=2)
with open("feature_cols_image.json", "w") as f: json.dump(feat_cols, f, indent=2)

# 8) Drop any rows with missing feature or label values
df = df.dropna(subset=feat_cols + ["label_id"]).reset_index(drop=True)

# 9) Stratified train/validation split
train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df["label"],
    random_state=42
)

# 10) Ensure correct dtypes
train_df[feat_cols] = train_df[feat_cols].astype("float32")
val_df[feat_cols]   = val_df[feat_cols].astype("float32")
train_df["label_id"] = train_df["label_id"].astype("int32")
val_df["label_id"]   = val_df["label_id"].astype("int32")

print(f"{len(classes)} classes  |  {len(train_df)} train  /  {len(val_df)} val samples")


6 classes  |  1184 train  /  297 val samples


In [23]:
train_df.describe()

Unnamed: 0,age,diameter_1,diameter_2,gender_M,region_ABDOMEN,region_ARM,region_BACK,region_CHEST,region_EAR,region_FACE,region_FOOT,region_FOREARM,region_HAND,region_LIP,region_NECK,region_NOSE,region_SCALP,region_THIGH,label_id
count,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0,1184.0
mean,0.650031,0.117426,0.124278,0.493243,0.009291,0.086993,0.11402,0.128378,0.033784,0.231419,0.010135,0.130068,0.048142,0.014358,0.059966,0.096284,0.003378,0.033784,1.461993
std,0.16662,0.086309,0.082412,0.500166,0.095979,0.281944,0.31797,0.334652,0.180749,0.421918,0.100204,0.33652,0.214156,0.119012,0.237525,0.295105,0.05805,0.180749,1.38226
min,0.079545,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.545455,0.07,0.071429,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
50%,0.659091,0.1,0.1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
75%,0.772727,0.15,0.142857,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
max,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,5.0


In [24]:
# Cell 3 ── tf.data Pipelines
IMG_SIZE   = (224,224)
BATCH_SIZE = 16
AUTOTUNE   = tf.data.AUTOTUNE

def preprocess_image(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE)
    return tf.keras.applications.efficientnet.preprocess_input(img)

def make_dataset(df, shuffle=True):
    paths  = df["path"].values
    metas  = df[feat_cols].values
    labels = df["label_id"].values
    ds = tf.data.Dataset.from_tensor_slices((paths, metas, labels))
    if shuffle:
        ds = ds.shuffle(len(df), seed=42)
    def _load(path, meta, label):
        return {"image": preprocess_image(path), "meta": meta}, label
    ds = ds.map(_load, num_parallel_calls=AUTOTUNE)
    return ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)

train_ds = make_dataset(train_df, shuffle=True)
val_ds   = make_dataset(val_df,   shuffle=False)
# Clean and coerce feature types
for col in feat_cols:
    train_df[col] = pd.to_numeric(train_df[col], errors="coerce")
    val_df[col]   = pd.to_numeric(val_df[col], errors="coerce")

# Drop rows with missing features
train_df = train_df.dropna(subset=feat_cols + ["label_id"]).reset_index(drop=True)
val_df   = val_df.dropna(subset=feat_cols + ["label_id"]).reset_index(drop=True)
print(f"✅ Dropped rows with missing features: now {len(train_df)} train / {len(val_df)} val")

# Ensure all features are float32
train_df[feat_cols] = train_df[feat_cols].astype("float32")
val_df[feat_cols]   = val_df[feat_cols].astype("float32")

# Optional: Clip to [0, 1] for safety
train_df[feat_cols] = train_df[feat_cols].clip(0.0, 1.0)
val_df[feat_cols]   = val_df[feat_cols].clip(0.0, 1.0)


# sanity check
for (batch_x, batch_y) in train_ds.take(1):
    print("Image batch shape:", batch_x["image"].shape)
    print("Meta batch shape:",  batch_x["meta"].shape)
    print("Labels:", batch_y.numpy()[:8])


✅ Dropped rows with missing features: now 1184 train / 297 val
Image batch shape: (16, 224, 224, 3)
Meta batch shape: (16, 18)
Labels: [3 2 1 3 1 1 1 0]


2025-05-23 05:52:43.510408: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [25]:
# Cell 4 ── Build & Compile the Model (Improved Version)
from tensorflow.keras import layers, Model
import tensorflow as tf

# Load EfficientNetB0 with pretrained ImageNet weights
base_img = tf.keras.applications.EfficientNetB0(
    include_top=False, pooling="avg", weights="imagenet")
base_img.trainable = False  # Freeze to prevent overfitting on small data

# Inputs
img_in  = layers.Input(shape=IMG_SIZE + (3,), name="image")
meta_in = layers.Input(shape=(len(feat_cols),), name="meta")

# Image path
x1 = base_img(img_in)

# Metadata path
x2 = layers.BatchNormalization()(meta_in)
x2 = layers.Dense(64, activation="relu")(x2)
x2 = layers.Dropout(0.3)(x2)

# Combine
x = layers.concatenate([x1, x2])
x = layers.Dense(128, activation="relu")(x)
x = layers.Dropout(0.4)(x)
out = layers.Dense(len(classes), activation="softmax")(x)

# Build model
model = Model([img_in, meta_in], out, name="image_meta_model")



# Compile with gradient clipping and learning rate
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4, clipnorm=1.0),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

model.summary()


In [26]:
print("✅ Any NaNs in final train data?", train_df[feat_cols].isnull().any().any())


✅ Any NaNs in final train data? False


In [27]:
# Cell 5 ── Train (Improved Version)
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        "image_meta_best.keras",   # Save only the best
        save_best_only=True,
        monitor="val_accuracy",
        verbose=1
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor="val_accuracy",
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=2,
        min_lr=1e-6,
        verbose=1
    )
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=33,
    callbacks=callbacks
)




Epoch 1/33
[1m74/74[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 216ms/step - accuracy: 0.2888 - loss: 1.7759
Epoch 1: val_accuracy improved from -inf to 0.55892, saving model to image_meta_best.keras
[1m74/74[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 302ms/step - accuracy: 0.2909 - loss: 1.7726 - val_accuracy: 0.5589 - val_loss: 1.2571 - learning_rate: 1.0000e-04
Epoch 2/33
[1m74/74[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 219ms/step - accuracy: 0.5467 - loss: 1.2526
Epoch 2: val_accuracy improved from 0.55892 to 0.56566, saving model to image_meta_best.keras
[1m74/74[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 286ms/step - accuracy: 0.5468 - loss: 1.2528 - val_accuracy: 0.5657 - val_loss: 1.2153 - learning_rate: 1.0000e-04
Epoch 3/33
[1m74/74[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 211ms/step - accuracy: 0.5739 - loss: 1.2095
Epoch 3: val_accuracy improved from 0.56566 to 0.57239, saving model to image_meta_best.keras


In [30]:
# 0) Imports
import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dropout, Dense
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.utils.class_weight import compute_class_weight

# 1) Settings
DATA_DIR    = "/mnt/ssd1/saumia/data/images/IMG_CLASSES"
BATCH_SIZE  = 32
IMG_SIZE    = (224, 224)
EPOCHS      = 35
PATIENCE    = 5
WEIGHTS_FP  = "models/img_classes_weights.h5"



In [31]:
# 2) Discover classes
class_names = sorted([d for d in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, d))])
num_classes = len(class_names)
print(f"Found {num_classes} classes: {class_names}")



Found 10 classes: ['1. Eczema 1677', '10. Warts Molluscum and other Viral Infections - 2103', '2. Melanoma 15.75k', '3. Atopic Dermatitis - 1.25k', '4. Basal Cell Carcinoma (BCC) 3323', '5. Melanocytic Nevi (NV) - 7970', '6. Benign Keratosis-like Lesions (BKL) 2624', '7. Psoriasis pictures Lichen Planus and related diseases - 2k', '8. Seborrheic Keratoses and other Benign Tumors - 1.8k', '9. Tinea Ringworm Candidiasis and other Fungal Infections - 1.7k']


In [33]:
# 3) Generators with augmentation & split
datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    horizontal_flip=True,
    rotation_range=20,
    zoom_range=0.1,
    width_shift_range=0.1,
    height_shift_range=0.1,
)

train_gen = datagen.flow_from_directory(
    DATA_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    subset="training",
    shuffle=True,
    seed=42
)

val_gen = datagen.flow_from_directory(
    DATA_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    subset="validation",
    shuffle=False,
    seed=42
)



Found 21726 images belonging to 10 classes.
Found 5427 images belonging to 10 classes.


In [34]:
# 4) Compute class weights to handle imbalance
labels = train_gen.classes
cw = compute_class_weight(class_weight="balanced", classes=np.unique(labels), y=labels)
class_weights = {i: w for i, w in enumerate(cw)}
print("Class weights:", class_weights)


Class weights: {0: 1.6189269746646795, 1: 1.290909090909091, 2: 0.8648885350318471, 3: 2.1596421471172964, 4: 0.8170740880030086, 5: 0.3407465495608532, 6: 1.3056490384615385, 7: 1.3215328467153284, 8: 1.469959404600812, 9: 1.5951541850220263}


In [35]:
# 5) Define the CNN model
model_2 = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=IMG_SIZE + (3,)),
    MaxPooling2D(2, 2),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Flatten(),
    Dropout(0.5),
    Dense(128, activation='relu'),
    Dense(num_classes, activation='softmax', name="img_output")
], name="img_classes_cnn")

model_2.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)
model_2.summary()


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [36]:
# 6) Callbacks: checkpoint & early stopping
callbacks = [
    ModelCheckpoint(
        WEIGHTS_FP,
        save_best_only=True,
        monitor="val_accuracy",
        verbose=1
    ),
    EarlyStopping(
        monitor="val_accuracy",
        patience=PATIENCE,
        restore_best_weights=True,
        verbose=1
    )
]


In [37]:
# 7) Train
history = model_2.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS,
    class_weight=class_weights,
    callbacks=callbacks
)


  self._warn_if_super_not_called()


Epoch 1/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 751ms/step - accuracy: 0.3409 - loss: 2.4410
Epoch 1: val_accuracy improved from -inf to 0.31196, saving model to models/img_classes_weights.h5




[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m566s[0m 833ms/step - accuracy: 0.3410 - loss: 2.4402 - val_accuracy: 0.3120 - val_loss: 1.9778
Epoch 2/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 748ms/step - accuracy: 0.4933 - loss: 1.5493
Epoch 2: val_accuracy did not improve from 0.31196
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m560s[0m 825ms/step - accuracy: 0.4933 - loss: 1.5492 - val_accuracy: 0.2828 - val_loss: 1.9565
Epoch 3/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 759ms/step - accuracy: 0.5236 - loss: 1.4489
Epoch 3: val_accuracy did not improve from 0.31196
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m570s[0m 839ms/step - accuracy: 0.5236 - loss: 1.4489 - val_accuracy: 0.3007 - val_loss: 1.9067
Epoch 4/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 758ms/step - accuracy: 0.5463 - loss: 1.3907
Epoch 4: val_accuracy did not improve from 0.31196
[1m679/679[0m 



[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m568s[0m 836ms/step - accuracy: 0.5640 - loss: 1.3378 - val_accuracy: 0.3324 - val_loss: 1.9886
Epoch 6/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 755ms/step - accuracy: 0.5764 - loss: 1.2895
Epoch 6: val_accuracy did not improve from 0.33241
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m566s[0m 833ms/step - accuracy: 0.5764 - loss: 1.2895 - val_accuracy: 0.3118 - val_loss: 2.0001
Epoch 7/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 761ms/step - accuracy: 0.5821 - loss: 1.2873
Epoch 7: val_accuracy did not improve from 0.33241
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m568s[0m 836ms/step - accuracy: 0.5821 - loss: 1.2873 - val_accuracy: 0.3219 - val_loss: 1.9633
Epoch 8/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 755ms/step - accuracy: 0.6046 - loss: 1.2329
Epoch 8: val_accuracy did not improve from 0.33241
[1m679/679[0m 



[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m571s[0m 840ms/step - accuracy: 0.6045 - loss: 1.2139 - val_accuracy: 0.3470 - val_loss: 1.9081
Epoch 10/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 753ms/step - accuracy: 0.6177 - loss: 1.1904
Epoch 10: val_accuracy improved from 0.34697 to 0.35010, saving model to models/img_classes_weights.h5




[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m566s[0m 833ms/step - accuracy: 0.6177 - loss: 1.1904 - val_accuracy: 0.3501 - val_loss: 1.8911
Epoch 11/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 755ms/step - accuracy: 0.6241 - loss: 1.1680
Epoch 11: val_accuracy improved from 0.35010 to 0.36558, saving model to models/img_classes_weights.h5




[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m568s[0m 836ms/step - accuracy: 0.6241 - loss: 1.1680 - val_accuracy: 0.3656 - val_loss: 1.9646
Epoch 12/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 756ms/step - accuracy: 0.6234 - loss: 1.1538
Epoch 12: val_accuracy did not improve from 0.36558
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m568s[0m 836ms/step - accuracy: 0.6234 - loss: 1.1538 - val_accuracy: 0.3042 - val_loss: 2.1731
Epoch 13/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 757ms/step - accuracy: 0.6291 - loss: 1.1507
Epoch 13: val_accuracy did not improve from 0.36558
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m567s[0m 834ms/step - accuracy: 0.6291 - loss: 1.1507 - val_accuracy: 0.3453 - val_loss: 2.0328
Epoch 14/35
[1m679/679[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 751ms/step - accuracy: 0.6396 - loss: 1.1300
Epoch 14: val_accuracy did not improve from 0.36558
[1m679/67

In [38]:
# ─── Cell 8 — Two-Head Fusion: IMG_CLASSES CNN + Metadata Head ─────────

import tensorflow as tf
from tensorflow.keras import Model, Sequential, layers
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, Flatten, Dropout, Dense, BatchNormalization, Concatenate
)
# 1) Re-build the CNN “backbone” up through the penultimate Dense(128):
cnn_backbone = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(224,224,3)),
    MaxPooling2D(2,2),
    Conv2D(64, (3,3), activation='relu'),
    MaxPooling2D(2,2),
    Flatten(),
    Dropout(0.5),
    Dense(128, activation='relu'),
], name="img_feature_extractor")



In [39]:
# 2) Load the weights you saved from model_2
cnn_backbone.load_weights("models/img_classes_weights.h5", by_name=True)


In [40]:
# 3) Re-attach the original IMG_CLASSES head so it remains part of the graph
old_logits = Dense(
    num_classes, activation='softmax', name="img_output"
)(cnn_backbone.output)


AttributeError: The layer img_feature_extractor has never been called and thus has no defined output.

In [None]:
# 4) Build the metadata branch exactly as in your original image+meta model
meta_in = Input(shape=(len(feat_cols),), name="meta_input")
m = BatchNormalization()(meta_in)
m = Dense(64, activation='relu')(m)
m = Dropout(0.3)(m)


In [None]:
# 5) Fuse image features + metadata and append a new “diagnosis” head
fusion = Concatenate()([cnn_backbone.output, m])
x      = Dense(128, activation='relu')(fusion)
x      = Dropout(0.4)(x)
new_logits = Dense(
    len(classes), activation='softmax', name="meta_output"
)(x)


In [None]:
# 6) Assemble the two-head model
multihead = Model(
    inputs  = [cnn_backbone.input, meta_in],
    outputs = [old_logits, new_logits],
    name    = "img_plus_meta_multihead"
)

# 7) Freeze everything except the new metadata head
for layer in multihead.layers:
    if layer.name != "meta_output":
        layer.trainable = False


In [None]:
# 8) Compile with two losses (ignore the old head’s loss via weight=0 if desired)
multihead.compile(
    optimizer    = tf.keras.optimizers.Adam(1e-3),
    loss         = ["categorical_crossentropy", "sparse_categorical_crossentropy"],
    loss_weights = [0.0, 1.0],     # train only the metadata head
    metrics      = ["accuracy", "accuracy"]
)
multihead.summary()


In [None]:
# 9) Wrap your existing `train_ds`/`val_ds` into the format this model expects:
AUTOTUNE = tf.data.AUTOTUNE

def make_multi_ds(ds):
    return ds.map(
        lambda x, y: (
            (x["image"], x["meta"]),     # inputs
            (tf.one_hot(y, depth=num_classes), y)  # outputs: [img_onehot, meta_label]
        ),
        num_parallel_calls=AUTOTUNE
    ).prefetch(AUTOTUNE)

train_multi_ds = make_multi_ds(train_ds)
val_multi_ds   = make_multi_ds(val_ds)


In [None]:
# 10) Train the metadata head
history_multi = multihead.fit(
    train_multi_ds,
    validation_data=val_multi_ds,
    epochs=10,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(
            monitor="val_loss", patience=3, restore_best_weights=True
        )
    ]
)
