# **Sanvia: AI for Early Hope - Notebook 2**
**Modeling & Training Pipeline for VinDr-Mammo**

Initialize Enviroument

In [14]:
import os
import json
import numpy as np
import pandas as pd
import tensorflow as tf
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import matplotlib.pyplot as plt

SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
os.environ['TF_DETERMINISTIC_OPS'] = '1'

policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
print(f"Mixed precision: {policy.name}")

DATA_DIR = Path('/content/drive/MyDrive/VnDir_Mammo')
OUTPUT_DIR = DATA_DIR / 'sanvia_outputs'
CONFIG_PATH = DATA_DIR / 'sanvia_outputs' / 'config.json'

Mixed precision: mixed_float16


Load Notebook 1 Artifacts

In [15]:
with open(CONFIG_PATH, 'r') as f:
    CONFIG = json.load(f)

CLASS_WEIGHTS_PATH = DATA_DIR / 'sanvia_outputs' / 'class_weights.json'
with open(CLASS_WEIGHTS_PATH, 'r') as f:
    CLASS_WEIGHTS = json.load(f)

ARTIFACT_DIR = DATA_DIR / 'sanvia_outputs' / 'artifacts'
VIEW_MAPPING_PATH = ARTIFACT_DIR / 'view_mapping_final.csv'
view_mapping_df = pd.read_csv(VIEW_MAPPING_PATH)

print("All Data And Artifacts Loaded Successfully")

CONFIG['img_size'] = [256, 256]
CONFIG['batch_size'] = 64
print(f"✅ img_size: {CONFIG['img_size']}, batch_size: {CONFIG['batch_size']}")


All Data And Artifacts Loaded Successfully
✅ img_size: [256, 256], batch_size: 64


**Rebuild tf.data Pipelines**

In [16]:
def build_tf_dataset(view_df: pd.DataFrame, config: Dict, is_training: bool = False) -> tf.data.Dataset:
    images_root = Path(config['images_dir'])
    records = view_df.to_dict('records')

    def gen():
        for record in records:
            yield (
                str(record['study_id']),
                str(record['split_final']),
                str(record['image_id_L_CC']),
                str(record['image_id_L_MLO']),
                str(record['image_id_R_CC']),
                str(record['image_id_R_MLO']),
                int(record['breast_birads_L_CC']),
                int(record['breast_density_L_CC']),
                float(record['age_norm']),
                int(record['age_missing_flag'])
            )

    dataset = tf.data.Dataset.from_generator(
        gen,
        output_signature=(
            tf.TensorSpec(shape=(), dtype=tf.string),
            tf.TensorSpec(shape=(), dtype=tf.string),
            tf.TensorSpec(shape=(), dtype=tf.string),
            tf.TensorSpec(shape=(), dtype=tf.string),
            tf.TensorSpec(shape=(), dtype=tf.string),
            tf.TensorSpec(shape=(), dtype=tf.string),
            tf.TensorSpec(shape=(), dtype=tf.int64),
            tf.TensorSpec(shape=(), dtype=tf.int64),
            tf.TensorSpec(shape=(), dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int64)
        )
    )

    if not is_training:
        dataset = dataset.cache()
    else:
        cache_path = str(OUTPUT_DIR / f'cache_{is_training}.tfdata')
        dataset = dataset.cache(cache_path)

    if is_training:
        dataset = dataset.shuffle(
            buffer_size=min(config['buffer_size'], len(records)),
            seed=int(config['seed']),
            reshuffle_each_iteration=True
        )

    target_size = config['img_size']
    num_channels = config['num_channels']

    @tf.function
    def load_image_tf(image_id_tensor, study_id_tensor):
        image_id = tf.strings.as_string(image_id_tensor)
        study_id = tf.strings.as_string(study_id_tensor)
        image_id = tf.cond(
            tf.strings.regex_full_match(image_id, r'.*\.png$'),
            lambda: image_id,
            lambda: tf.strings.join([image_id, '.png'])
        )
        possible_paths = [
            tf.strings.join([str(images_root), '/', study_id, '/', image_id]),
            tf.strings.join([str(images_root), '/', image_id]),
        ]

        img = tf.zeros((target_size[0], target_size[1], 1), dtype=tf.float32)
        for path in possible_paths:
            try:
                img_raw = tf.io.read_file(path)
                img = tf.io.decode_png(img_raw, channels=1)
                break
            except:
                continue

        img = tf.cast(img, tf.float32) / 255.0
        orig_h = tf.cast(tf.shape(img)[0], tf.float32)
        orig_w = tf.cast(tf.shape(img)[1], tf.float32)
        aspect_ratio = orig_w / orig_h

        resize_h = tf.cond(
            aspect_ratio > 1.0,
            lambda: tf.cast(target_size[1] / aspect_ratio, tf.int32),
            lambda: target_size[0]
        )
        resize_w = tf.cond(
            aspect_ratio > 1.0,
            lambda: target_size[1],
            lambda: tf.cast(target_size[0] * aspect_ratio, tf.int32)
        )

        img_resized = tf.image.resize(img, [resize_h, resize_w], method='bilinear')
        img_padded = tf.image.resize_with_pad(
            img_resized,
            target_height=target_size[0],
            target_width=target_size[1],
            method='bilinear'
        )

        if num_channels == 3:
            img_padded = tf.squeeze(img_padded, axis=-1)
            img_padded = tf.stack([img_padded]*3, axis=-1)

        return img_padded

    def process_record(study_id, split_final, image_id_L_CC, image_id_L_MLO,
                      image_id_R_CC, image_id_R_MLO, birads_L_CC, density_L_CC,
                      age_norm, age_missing_flag):
        L_CC = load_image_tf(image_id_L_CC, study_id)
        L_MLO = load_image_tf(image_id_L_MLO, study_id)
        R_CC = load_image_tf(image_id_R_CC, study_id)
        R_MLO = load_image_tf(image_id_R_MLO, study_id)

        for img in [L_CC, L_MLO, R_CC, R_MLO]:
            img.set_shape((*target_size, num_channels))

        tabular = tf.stack([
            tf.cast(age_norm, tf.float32),
            tf.cast(age_missing_flag, tf.float32)
        ], axis=0)

        birads = tf.one_hot(tf.cast(birads_L_CC, tf.int32),
                           depth=config['birads_classes'], dtype=tf.float32)
        density = tf.one_hot(tf.cast(density_L_CC, tf.int32),
                            depth=config['density_classes'], dtype=tf.float32)

        return (L_CC, L_MLO, R_CC, R_MLO, tabular), (birads, density)

    dataset = dataset.map(
        process_record,
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=not is_training
    )

    dataset = dataset.batch(config['batch_size'])
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset

In [17]:
train_df = view_mapping_df[view_mapping_df['split_final'] == 'train'].copy()
val_df = view_mapping_df[view_mapping_df['split_final'] == 'val'].copy()
test_df = view_mapping_df[view_mapping_df['split_final'] == 'test'].copy()

train_dataset = build_tf_dataset(train_df, CONFIG, is_training=True)
val_dataset = build_tf_dataset(val_df, CONFIG, is_training=False)
test_dataset = build_tf_dataset(test_df, CONFIG, is_training=False)

print(f"Datasets: Train{len(train_df)}، Val {len(val_df)}، Test {len(test_df)}")


Datasets: Train4000، Val 500، Test 500


In [18]:
print("PRE-TRAINING VERIFICATION")

#Check directory
images_root = Path(CONFIG['images_dir'])
print(f"1. Images root: {images_root}")
print(f"   Exists: {images_root.exists()}")

#Sample structure
if images_root.exists():
    subfolders = [d for d in images_root.iterdir() if d.is_dir()]
    if subfolders:
        print(f"   Found {len(subfolders)} subfolders")
        sample = subfolders[0]
        images = list(sample.glob('*.png'))
        print(f"   Sample subfolder: {sample.name}")
        print(f"   Images in it: {len(images)}")
        if images:
            print(f"   Example: {images[0].name}")
    else:
        flat_images = list(images_root.glob('*.png'))
        print(f"   No subfolders. Flat structure: {len(flat_images)} images")

#Sample rows
print("\n2. Sample DataFrame rows:")
sample_rows = view_mapping_df.head(2)
for idx, row in sample_rows.iterrows():
    print(f"\n   Row {idx} - Study: {row['study_id']}")
    for view in ['L_CC', 'L_MLO', 'R_CC', 'R_MLO']:
        img_id = row[f'image_id_{view}']
        print(f"     {view}: {img_id}")

print("\n" + "=" * 60)

PRE-TRAINING VERIFICATION
1. Images root: /content/drive/MyDrive/VnDir_Mammo/images/images_png
   Exists: True
   Found 5000 subfolders
   Sample subfolder: cd574c78251753f3f55c853068acdb4e
   Images in it: 4
   Example: 2d84d04a07ef137d95cd54d08e39cd8a.png

2. Sample DataFrame rows:

   Row 0 - Study: 0025a5dc99fd5c742026f0b2b030d3e9
     L_CC: 451562831387e2822923204cf8f0873e
     L_MLO: 2ddfad7286c2b016931ceccd1e2c7bbc
     R_CC: fcf12c2803ba8dc564bf1287c0c97d9a
     R_MLO: 47c8858666bcce92bcbd57974b5ce522

   Row 1 - Study: 0028fb2c7f0b3a5cb9a80cb0e1cdbb91
     L_CC: 3704f91985dcbc69f6ac2803523d1ecb
     L_MLO: 7fc1f1bb8bb1a7efaf7104e49c4d8b86
     R_CC: c4ce68631bf70949570ded31a3c69e60
     R_MLO: 16e58fc1d65fa7587247e6224ee96527



**Custom Layers - Tabular Encoder**
MLP-based tabular feature encoder with LayerNorm and dropout.


In [19]:
class TabularEncoder(tf.keras.layers.Layer):
    def __init__(self, embed_dim: int, dropout_rate: float = 0.3, **kwargs):
        super().__init__(**kwargs)
        self.dense1 = tf.keras.layers.Dense(embed_dim * 2, activation='gelu')
        self.norm1 = tf.keras.layers.LayerNormalization()
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dense2 = tf.keras.layers.Dense(embed_dim, activation='gelu')
        self.norm2 = tf.keras.layers.LayerNormalization()
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        x = self.norm1(x)
        x = self.dropout1(x, training=training)
        x = self.dense2(x)
        x = self.norm2(x)
        x = self.dropout2(x, training=training)
        return x

**Custom Layers - Deformable Cross-Attention (Simplified)**
Simplified deformable cross-attention layer for CC-MLO fusion. Uses learned offsets and bilinear sampling to approximate deformable attention.

In [20]:
class DeformableCrossAttentionLayer(tf.keras.layers.Layer):
    def __init__(self, num_heads: int = 8, key_dim: int = 64, **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.query_proj = tf.keras.layers.Conv2D(num_heads * key_dim, 1, activation='gelu')
        self.key_proj = tf.keras.layers.Conv2D(num_heads * key_dim, 1, activation='gelu')
        self.value_proj = tf.keras.layers.Conv2D(num_heads * key_dim, 1, activation='gelu')
        self.output_proj = tf.keras.layers.Conv2D(key_dim, 1)

    def call(self, query_features, key_value_features, training=False):
        batch_size = tf.shape(query_features)[0]
        H = tf.shape(query_features)[1]
        W = tf.shape(query_features)[2]

        queries = self.query_proj(query_features)
        keys = self.key_proj(key_value_features)
        values = self.value_proj(key_value_features)

        queries = tf.reshape(queries, [batch_size, H, W, self.num_heads, self.key_dim])
        keys = tf.reshape(keys, [batch_size, H, W, self.num_heads, self.key_dim])
        values = tf.reshape(values, [batch_size, H, W, self.num_heads, self.key_dim])

        attention_scores = tf.einsum('bhwnc,bhwnc->bhwn', queries, keys)
        attention_probs = tf.nn.softmax(attention_scores, axis=-1)
        attended = tf.einsum('bhwn,bhwnc->bhwnc', attention_probs, values)
        attended = tf.reshape(attended, [batch_size, H, W, self.num_heads * self.key_dim])

        output = self.output_proj(attended)
        return output

**Custom Layers - Gated Fusion**
Gated fusion layer to combine CC and MLO features with learnable gates.


In [21]:
class GatedFusionLayer(tf.keras.layers.Layer):
    def __init__(self, feature_dim: int, **kwargs):
        super().__init__(**kwargs)
        self.feature_dim = feature_dim
        self.cc_gate = tf.keras.layers.Dense(feature_dim, activation='sigmoid')
        self.mlo_gate = tf.keras.layers.Dense(feature_dim, activation='sigmoid')
        self.cc_transform = tf.keras.layers.Dense(feature_dim)
        self.mlo_transform = tf.keras.layers.Dense(feature_dim)
        self.norm = tf.keras.layers.LayerNormalization()

    def call(self, cc_features, mlo_features, training=False):
        cc_transformed = self.cc_transform(cc_features)
        mlo_transformed = self.mlo_transform(mlo_features)
        cc_gate = self.cc_gate(cc_features)
        mlo_gate = self.mlo_gate(mlo_features)
        cc_gated = cc_transformed * cc_gate
        mlo_gated = mlo_transformed * mlo_gate
        fused = self.norm(cc_gated + mlo_gated, training=training)
        return fused

**Build Multiview Backbone**
Build shared EfficientNet-B4 encoder for multiview feature extraction. Returns a model that outputs feature maps for each view.

In [22]:
def build_multiview_backbone(config: Dict) -> tf.keras.Model:
    inputs = {
        'L_CC': tf.keras.Input(shape=(*config['img_size'], config['num_channels']), name='L_CC'),
        'L_MLO': tf.keras.Input(shape=(*config['img_size'], config['num_channels']), name='L_MLO'),
        'R_CC': tf.keras.Input(shape=(*config['img_size'], config['num_channels']), name='R_CC'),
        'R_MLO': tf.keras.Input(shape=(*config['img_size'], config['num_channels']), name='R_MLO'),
    }

    base_model = tf.keras.applications.EfficientNetB0(
        include_top=False,
        weights='imagenet',
        input_shape=(*config['img_size'], config['num_channels']),
        pooling=None
    )

    for layer in base_model.layers[:100]:
        layer.trainable = False

    view_features = {}
    for view_name, input_tensor in inputs.items():
        preprocessed = tf.keras.applications.efficientnet.preprocess_input(input_tensor)
        features = base_model(preprocessed, training=False)
        reduced = tf.keras.layers.Conv2D(config['tab_embed_dim'], 1, activation='gelu')(features)
        reduced = tf.keras.layers.LayerNormalization()(reduced)
        reduced = tf.keras.layers.Dropout(0.3)(reduced)
        view_features[view_name] = reduced

    return tf.keras.Model(inputs=inputs, outputs=view_features, name='multiview_backbone')

Building Model

In [23]:
def build_sanvia_model(config: Dict) -> tf.keras.Model:
    L_CC_input = tf.keras.Input(shape=(*config['img_size'], config['num_channels']), name='L_CC')
    L_MLO_input = tf.keras.Input(shape=(*config['img_size'], config['num_channels']), name='L_MLO')
    R_CC_input = tf.keras.Input(shape=(*config['img_size'], config['num_channels']), name='R_CC')
    R_MLO_input = tf.keras.Input(shape=(*config['img_size'], config['num_channels']), name='R_MLO')
    tabular_input = tf.keras.Input(shape=(2,), name='tabular')

    backbone = build_multiview_backbone(config)
    view_features = backbone({
        'L_CC': L_CC_input,
        'L_MLO': L_MLO_input,
        'R_CC': R_CC_input,
        'R_MLO': R_MLO_input,
    })

    left_cross_attn = DeformableCrossAttentionLayer(
        num_heads=8, key_dim=config['tab_embed_dim'], name='left_cross_attn'
    )(view_features['L_CC'], view_features['L_MLO'])

    right_cross_attn = DeformableCrossAttentionLayer(
        num_heads=8, key_dim=config['tab_embed_dim'], name='right_cross_attn'
    )(view_features['R_CC'], view_features['R_MLO'])

    left_fused = GatedFusionLayer(config['tab_embed_dim'], name='left_fusion')(
        view_features['L_CC'], left_cross_attn
    )

    right_fused = GatedFusionLayer(config['tab_embed_dim'], name='right_fusion')(
        view_features['R_CC'], right_cross_attn
    )

    left_pooled = tf.keras.layers.GlobalAveragePooling2D()(left_fused)
    right_pooled = tf.keras.layers.GlobalAveragePooling2D()(right_fused)

    visual_features = tf.keras.layers.Concatenate()([left_pooled, right_pooled])
    tabular_encoded = TabularEncoder(config['tab_embed_dim'], dropout_rate=0.3)(tabular_input)

    fused_features = tf.keras.layers.Concatenate()([visual_features, tabular_encoded])
    fused_features = tf.keras.layers.LayerNormalization()(fused_features)
    fused_features = tf.keras.layers.Dropout(0.4)(fused_features, training=True)
    fused_features = tf.keras.layers.Dense(config['tab_embed_dim'] * 2, activation='gelu')(fused_features)
    fused_features = tf.keras.layers.LayerNormalization()(fused_features)
    fused_features = tf.keras.layers.Dropout(0.3)(fused_features, training=True)

    birads_features = tf.keras.layers.Dense(config['tab_embed_dim'], activation='gelu')(fused_features)
    birads_features = tf.keras.layers.Dropout(0.3)(birads_features, training=True)
    birads_output = tf.keras.layers.Dense(config['birads_classes'], activation='softmax', name='birads_head')(birads_features)

    density_features = tf.keras.layers.Dense(config['tab_embed_dim'], activation='gelu')(fused_features)
    density_features = tf.keras.layers.Dropout(0.3)(density_features, training=True)
    # ✅ أصلح السطر التالي:
    density_output = tf.keras.layers.Dense(config['density_classes'], activation='softmax', name='density_head')(density_features)

    model = tf.keras.Model(
        inputs=[L_CC_input, L_MLO_input, R_CC_input, R_MLO_input, tabular_input],
        outputs=[birads_output, density_output],
        name='Sanvia_Multimodal'
    )

    return model

**Focal Loss Implementation**
Focal loss with class weighting for multi-class classification.

In [24]:
def focal_loss_with_class_weights(num_classes: int, class_weights: List[float],
                                 gamma: float = 2.0, alpha: float = 0.25,
                                 label_smoothing: float = 0.0):
    class_weights = tf.constant(class_weights, dtype=tf.float32)
    class_weights = tf.reshape(class_weights, [1, num_classes])

    def focal_loss(y_true, y_pred):
        y_pred = tf.cast(y_pred, tf.float32)
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)

        if label_smoothing > 0:
            y_true = y_true * (1.0 - label_smoothing) + label_smoothing / num_classes

        ce_loss = -y_true * tf.math.log(y_pred)
        focal_factor = tf.pow(1.0 - y_pred, gamma)
        focal_loss = focal_factor * ce_loss
        focal_loss = focal_loss * class_weights

        return tf.reduce_mean(tf.reduce_sum(focal_loss, axis=-1))

    return focal_loss

In [25]:
#Implement
sanvia_model = build_sanvia_model(CONFIG)

birads_class_weights = dict(enumerate(CLASS_WEIGHTS['birads'].values()))
density_class_weights = dict(enumerate(CLASS_WEIGHTS['density'].values()))

sanvia_model.compile(
    optimizer=tf.keras.optimizers.AdamW(
        learning_rate=CONFIG.get('learning_rate', 1e-4),
        weight_decay=CONFIG.get('weight_decay', 1e-5)
    ),
    loss={
        'birads_head': focal_loss_with_class_weights(
            num_classes=CONFIG['birads_classes'],
            class_weights=list(CLASS_WEIGHTS['birads'].values()),
            gamma=CONFIG.get('focal_gamma', 2.0),
            alpha=CONFIG.get('focal_alpha', 0.25)
        ),
        'density_head': focal_loss_with_class_weights(
            num_classes=CONFIG['density_classes'],
            class_weights=list(CLASS_WEIGHTS['density'].values()),
            gamma=CONFIG.get('focal_gamma', 2.0),
            alpha=CONFIG.get('focal_alpha', 0.25)
        )
    },
    metrics={
        'birads_head': [tf.keras.metrics.CategoricalAccuracy(name='birads_acc')],
        'density_head': [tf.keras.metrics.CategoricalAccuracy(name='density_acc')]
    },
    loss_weights={'birads_head': 1.0, 'density_head': 1.0}
)

dummy_inputs = [
    tf.zeros((1, 256, 256, 3)),
    tf.zeros((1, 256, 256, 3)),
    tf.zeros((1, 256, 256, 3)),
    tf.zeros((1, 256, 256, 3)),
    tf.zeros((1, 2))
]
_ = sanvia_model(dummy_inputs)
print(f"Completed Build Model. Variables: {len(sanvia_model.trainable_variables):,}")

print("\n" + "="*80)
sanvia_model.summary()
print("="*80)

CONFIG['epochs'] = 100

# Callbacks
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=15,
        restore_best_weights=True,
        verbose=1
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-6,
        verbose=1
    )
]

steps_per_epoch = len(train_df) // CONFIG['batch_size']
validation_steps = len(val_df) // CONFIG['batch_size']

if len(train_df) % CONFIG['batch_size'] != 0:
    steps_per_epoch += 1
if len(val_df) % CONFIG['batch_size'] != 0:
    validation_steps += 1

print(f"steps_per_epoch: {steps_per_epoch}")
print(f"validation_steps: {validation_steps}")



Completed Build Model. Variables: 197



steps_per_epoch: 63
validation_steps: 8


Train Model

In [26]:
history = sanvia_model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=CONFIG['epochs'],
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    callbacks=callbacks,
    verbose=1
)

Epoch 1/100


KeyboardInterrupt: 

**Plot Training Curves**
Plot training and validation loss curves

In [None]:
from pathlib import Path

plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training loss', marker='o')
plt.plot(history.history['val_loss'], label='Validation loss', marker='s')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Curves training and validation', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(history.history['birads_head_accuracy'], label='دقة BI-RADS', marker='o')
plt.plot(history.history['density_head_accuracy'], label='دقة Density', marker='s')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.title('Accuracy tasks', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()

plot_path = Path(CONFIG['output_dir']) / 'training_curves_fit.png'
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"\nCurves Saved in: {plot_path}")

plt.show()

In [None]:
def evaluate_model(model: tf.keras.Model, dataset: tf.data.Dataset, name: str = "Validation"):
    metrics = {
        'birads': tf.keras.metrics.CategoricalAccuracy(),
        'density': tf.keras.metrics.CategoricalAccuracy(),
        'birads_auc': tf.keras.metrics.AUC(multi_label=True),
        'density_auc': tf.keras.metrics.AUC(multi_label=True)
    }

    for inputs, targets in dataset:
        outputs = model(inputs, training=False)

        metrics['birads'].update_state(targets[0], outputs[0])
        metrics['density'].update_state(targets[1], outputs[1])
        metrics['birads_auc'].update_state(targets[0], outputs[0])
        metrics['density_auc'].update_state(targets[1], outputs[1])

    print(f"\nResults{name}:")
    print(f"   ACC BI-RADS: {metrics['birads'].result():.4f}")
    print(f"   ACC Density: {metrics['density'].result():.4f}")
    print(f"   AUC BI-RADS: {metrics['birads_auc'].result():.4f}")
    print(f"   AUC Density: {metrics['density_auc'].result():.4f}")

evaluate_model(sanvia_model, val_dataset, name="Val")
evaluate_model(sanvia_model, test_dataset, name="Test")


In [None]:
def save_final_model(model: tf.keras.Model, config: Dict):
    output_dir = Path(config['output_dir'])
    output_dir.mkdir(parents=True, exist_ok=True)

    model.save(output_dir / 'sanvia_Final_model.h5')
    model.save_weights(output_dir / 'sanvia_Final_weights.h5')
    model_json = model.to_json()
    with open(output_dir / 'sanvia_Final_architecture.json', 'w') as f:
        f.write(model_json)
    print(f"\n💾 Done Saving Model in: {output_dir}")

save_final_model(sanvia_model, CONFIG)
