In [None]:
import os, random, sys
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt

import tensorflow as tf
tf.keras.backend.clear_session()

MODE = 'fusion'


MOJ_ROOT = '/kaggle/input/n-ucla-moj'     
ED_MHI_ROOT = '/kaggle/input/n-ucla-edmhi'      

# training params
SEED = 0
BATCH_SIZE = 8
EPOCHS_HEAD = 10      
EPOCHS_FINETUNE = 40   
LR_HEAD = 1e-3
# LR_HEAD = 1e-4
LR_FINETUNE = 5e-5
LAMBDA_CV_MOJ = 0.01
LAMBDA_CV_ED = 0.1    
TARGET_SIZE = (224, 224)
NUM_WORKERS = 2

PATIENCE = 8
MODEL_SAVE_PATH = "/kaggle/working/best_cvcl_model.h5"

DETERMINISTIC = False



In [None]:
import re
def parse_subject_id(fname):
    m = re.search(r"s\d{2}", fname)
    return m.group(0) if m else None

def parse_action_id(fname):
    m = re.match(r"(a\d{2}_s\d{2}_e\d{2})", fname)
    return m.group(1) if m else None

def collect_samples(root_dir):
    """Return dict view -> {aid: (path, subject)}"""
    view_samples = {}
    for view in ["view_1","view_2","view_3"]:
        view_path = os.path.join(root_dir, view)
        files = [f for f in os.listdir(view_path) if f.lower().endswith((".jpg",".png"))]
        samples = {}
        for f in files:
            sid = parse_subject_id(f)
            aid = parse_action_id(f)
            if sid and aid:
                samples[aid] = (os.path.join(view_path,f), sid)
        view_samples[view] = samples
    return view_samples

def intersect_views(view_samples):
    return set(view_samples["view_1"].keys()) & set(view_samples["view_2"].keys()) & set(view_samples["view_3"].keys())


In [None]:

import re
def parse_subject_id(fname):
    m = re.search(r"s\d{2}", fname)
    return m.group(0) if m else None

def parse_action_id(fname):
    m = re.match(r"(a\d{2}_s\d{2}_e\d{2})", fname)
    return m.group(1) if m else None

def collect_samples(root_dir):
    """Return dict view -> {aid: (path, subject)}"""
    view_samples = {}
    for view in ["view_1","view_2","view_3"]:
        view_path = os.path.join(root_dir, view)
        files = [f for f in os.listdir(view_path) if f.lower().endswith((".jpg",".png"))]
        samples = {}
        for f in files:
            sid = parse_subject_id(f)
            aid = parse_action_id(f)
            if sid and aid:
                samples[aid] = (os.path.join(view_path,f), sid)
        view_samples[view] = samples
    return view_samples

def intersect_views(view_samples):
    return set(view_samples["view_1"].keys()) & set(view_samples["view_2"].keys()) & set(view_samples["view_3"].keys())


In [None]:
def load_image(path, target_size=(224,224), is_moj=True):
    if is_moj:
        img = cv2.imread(path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, target_size)
        img = img.astype("float32") / 255.0
    else:
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, target_size)
        img = np.expand_dims(img, axis=-1)
        img = img.astype("float32")
        img = img / (img.max() + 1e-6)   # normalize

    return img

def build_triplet_dataset(moj_samples, ed_mhi_samples, sample_list, target_size=(224,224), le=None):
    X1_moj, X2_moj, X3_moj = [], [], []
    X1_ed,  X2_ed,  X3_ed  = [], [], []
    y = []
    for aid in tqdm(sample_list, desc="Building triplet dataset"):
        p1_moj, _ = moj_samples["view_1"][aid]
        p2_moj, _ = moj_samples["view_2"][aid]
        p3_moj, _ = moj_samples["view_3"][aid]
        X1_moj.append(load_image(p1_moj, target_size, True))
        X2_moj.append(load_image(p2_moj, target_size, True))
        X3_moj.append(load_image(p3_moj, target_size, True))
        p1_ed, _ = ed_mhi_samples["view_1"][aid]
        p2_ed, _ = ed_mhi_samples["view_2"][aid]
        p3_ed, _ = ed_mhi_samples["view_3"][aid]
        X1_ed.append(load_image(p1_ed, target_size, False))
        X2_ed.append(load_image(p2_ed, target_size, False))
        X3_ed.append(load_image(p3_ed, target_size, False))
        y.append(aid.split("_")[0])
    if le is None:
        le = LabelEncoder()
        y_enc = le.fit_transform(y)
    else:
        y_enc = le.transform(y)
    num_classes = len(le.classes_)
    y_cat = to_categorical(y_enc, num_classes=num_classes)
    return (np.array(X1_moj), np.array(X2_moj), np.array(X3_moj),
            np.array(X1_ed), np.array(X2_ed), np.array(X3_ed),
            y_cat, le, num_classes)


In [None]:
def get_augmentation_layer():
    return tf.keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.08),
        layers.RandomZoom(0.08),
        layers.RandomTranslation(0.05, 0.05),
    ], name="data_augmentation")
def get_augmentation_layer_ed():
    return tf.keras.Sequential([
        layers.RandomFlip("horizontal")
    ])

def build_encoder_rgb(input_shape=(224,224,3), use_augment=True, dropout_rate=0.3):
    inp = layers.Input(shape=input_shape)
    x = inp
    if use_augment:
        x = get_augmentation_layer()(x)
    x = layers.Conv2D(32, (7,7), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(64, (5,5), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(128, (3,3), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(256, (3,3), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.LayerNormalization()(x)
    return tf.keras.Model(inp, x, name="encoder_rgb")

def build_encoder_gray(input_shape=(224,224,1), use_augment=False, dropout_rate=0.3):
    inp = layers.Input(shape=input_shape)
    x = inp
    if use_augment:
        x = get_augmentation_layer_ed()(x)
    x = layers.Conv2D(64, (7,7), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(128, (5,5), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(256, (3,3), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.LayerNormalization()(x)
    return tf.keras.Model(inp, x, name="encoder_gray")

def build_classifier(num_classes, input_dim=256):
    inp = layers.Input(shape=(input_dim,))
    x = layers.Dense(256, activation='relu')(inp)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    out = layers.Dense(num_classes, activation='softmax')(x)
    return tf.keras.Model(inp, out, name="classifier")




In [None]:
class CrossViewModel(tf.keras.Model):
    """
    mode: 'moj' | 'ed' | 'fusion'
    - if 'moj': encoder_rgb used for x1,x2,x3
    - if 'ed': encoder_gray used for x1,x2,x3
    - if 'fusion': both encoders used, fusion on averaged features
    """
    def reset_metrics(self):
        self.train_acc.reset_state()
        self.val_acc.reset_state()
        self.cv_tracker.reset_state()


    def __init__(self, encoder_rgb=None, encoder_gray=None, classifier=None,
                 mode='moj', lambda_cv_moj=0.01, lambda_cv_ed=0.01):
        super().__init__()
        self.encoder_rgb = encoder_rgb
        self.encoder_gray = encoder_gray
        self.classifier = classifier
        self.mode = mode
        self.loss_ce = tf.keras.losses.CategoricalCrossentropy()
        self.lambda_cv_moj = lambda_cv_moj
        self.lambda_cv_ed = lambda_cv_ed

        # metrics
        self.train_acc = tf.keras.metrics.CategoricalAccuracy(name="accuracy")
        self.val_acc   = tf.keras.metrics.CategoricalAccuracy(name="val_accuracy")

        self.cv_tracker = tf.keras.metrics.Mean(name="cv_loss")

    def call(self, inputs, training=False):
        # For compatibility with model.fit API when using predict/evaluate
        if self.mode == 'moj':
            x1,x2,x3 = inputs
            f1 = self.encoder_rgb(x1, training=training)
            f2 = self.encoder_rgb(x2, training=training)
            f3 = self.encoder_rgb(x3, training=training)
            feat = (f1 + f2 + f3) / 3.0
            return self.classifier(feat, training=training)
        elif self.mode == 'ed':
            x1,x2,x3 = inputs
            f1 = self.encoder_gray(x1, training=training)
            f2 = self.encoder_gray(x2, training=training)
            f3 = self.encoder_gray(x3, training=training)
            feat = (f1 + f2 + f3) / 3.0
            return self.classifier(feat, training=training)
        elif self.mode == 'fusion':
            x1_m,x2_m,x3_m, x1_e,x2_e,x3_e = inputs
            f1_m = self.encoder_rgb(x1_m, training=training)
            f2_m = self.encoder_rgb(x2_m, training=training)
            f3_m = self.encoder_rgb(x3_m, training=training)
            f_m = (f1_m + f2_m + f3_m)/3.0

            f1_e = self.encoder_gray(x1_e, training=training)
            f2_e = self.encoder_gray(x2_e, training=training)
            f3_e = self.encoder_gray(x3_e, training=training)
            f_e = (f1_e + f2_e + f3_e)/3.0

            fused = layers.Concatenate()([f_m, f_e])
            return self.classifier(fused, training=training)
        else:
            raise ValueError("Unknown mode")

    def train_step(self, data):
        x, y = data
        # x shape depends on mode
        with tf.GradientTape() as tape:
            if self.mode == 'moj':
                x1,x2,x3 = x
                f1 = self.encoder_rgb(x1, training=True)
                f2 = self.encoder_rgb(x2, training=True)
                f3 = self.encoder_rgb(x3, training=True)
                # l2-normalize features (stabilizes CV loss)
                f1n = tf.math.l2_normalize(f1, axis=1)
                f2n = tf.math.l2_normalize(f2, axis=1)
                f3n = tf.math.l2_normalize(f3, axis=1)
                feat = (f1 + f2 + f3)/3.0
                logits = self.classifier(feat, training=True)

                CE = self.loss_ce(y, logits)
                CV = (tf.reduce_mean(tf.square(f1n - f2n)) +
                      tf.reduce_mean(tf.square(f2n - f3n)) +
                      tf.reduce_mean(tf.square(f1n - f3n)))
                loss = CE + self.lambda_cv_moj * CV

            elif self.mode == 'ed':
                x1,x2,x3 = x
                f1 = self.encoder_gray(x1, training=True)
                f2 = self.encoder_gray(x2, training=True)
                f3 = self.encoder_gray(x3, training=True)
                f1n = tf.math.l2_normalize(f1, axis=1)
                f2n = tf.math.l2_normalize(f2, axis=1)
                f3n = tf.math.l2_normalize(f3, axis=1)
                feat = (f1 + f2 + f3)/3.0
                logits = self.classifier(feat, training=True)

                CE = self.loss_ce(y, logits)
                CV = (tf.reduce_mean(tf.square(f1n - f2n)) +
                      tf.reduce_mean(tf.square(f2n - f3n)) +
                      tf.reduce_mean(tf.square(f1n - f3n)))
                loss = CE + self.lambda_cv_ed * CV

            elif self.mode == 'fusion':
                x1_m,x2_m,x3_m, x1_e,x2_e,x3_e = x
                f1_m = self.encoder_rgb(x1_m, training=True)
                f2_m = self.encoder_rgb(x2_m, training=True)
                f3_m = self.encoder_rgb(x3_m, training=True)
                f1_e = self.encoder_gray(x1_e, training=True)
                f2_e = self.encoder_gray(x2_e, training=True)
                f3_e = self.encoder_gray(x3_e, training=True)

                # normalize
                f1_mn = tf.math.l2_normalize(f1_m, axis=1)
                f2_mn = tf.math.l2_normalize(f2_m, axis=1)
                f3_mn = tf.math.l2_normalize(f3_m, axis=1)

                f1_en = tf.math.l2_normalize(f1_e, axis=1)
                f2_en = tf.math.l2_normalize(f2_e, axis=1)
                f3_en = tf.math.l2_normalize(f3_e, axis=1)

                f_m = (f1_m + f2_m + f3_m)/3.0
                f_e = (f1_e + f2_e + f3_e)/3.0
                fused = tf.concat([f_m, f_e], axis=1)
                logits = self.classifier(fused, training=True)

                CE = self.loss_ce(y, logits)
                CV_m = (tf.reduce_mean(tf.square(f1_mn - f2_mn)) +
                        tf.reduce_mean(tf.square(f2_mn - f3_mn)) +
                        tf.reduce_mean(tf.square(f1_mn - f3_mn)))
                CV_e = (tf.reduce_mean(tf.square(f1_en - f2_en)) +
                        tf.reduce_mean(tf.square(f2_en - f3_en)) +
                        tf.reduce_mean(tf.square(f1_en - f3_en)))
                loss = CE + self.lambda_cv_moj * CV_m + self.lambda_cv_ed * CV_e
                CV = CV_m + CV_e

            else:
                raise ValueError("Unknown mode")

        # gradients and apply
        trainable_vars = self.trainable_variables
        grads = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(grads, trainable_vars))

        # update metrics
        self.train_acc.update_state(y, logits)
        # return values that Keras understands and which will be recorded in history
        self.cv_tracker.update_state(CV)
        return {
            "loss": loss,
            "accuracy": self.train_acc.result(),
        }


    def test_step(self, data):
        x, y = data
        if self.mode == 'moj':
            x1,x2,x3 = x
            f1 = self.encoder_rgb(x1, training=False)
            f2 = self.encoder_rgb(x2, training=False)
            f3 = self.encoder_rgb(x3, training=False)
            feat = (f1 + f2 + f3)/3.0
            logits = self.classifier(feat, training=False)
        elif self.mode == 'ed':
            x1,x2,x3 = x
            f1 = self.encoder_gray(x1, training=False)
            f2 = self.encoder_gray(x2, training=False)
            f3 = self.encoder_gray(x3, training=False)
            feat = (f1 + f2 + f3)/3.0
            logits = self.classifier(feat, training=False)
        elif self.mode == 'fusion':
            x1_m,x2_m,x3_m, x1_e,x2_e,x3_e = x
            f_m = (self.encoder_rgb(x1_m, training=False) + self.encoder_rgb(x2_m, training=False) + self.encoder_rgb(x3_m, training=False))/3.0
            f_e = (self.encoder_gray(x1_e, training=False) + self.encoder_gray(x2_e, training=False) + self.encoder_gray(x3_e, training=False))/3.0
            fused = tf.concat([f_m, f_e], axis=1)
            logits = self.classifier(fused, training=False)
        else:
            raise ValueError("Unknown mode")

        CE = self.loss_ce(y, logits)
        self.val_acc.update_state(y, logits)
        return {
            "loss": CE,                      
            "accuracy": self.val_acc.result(),
        }





In [None]:
def make_callbacks():
    return [
        tf.keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            mode='max',
            patience=6,
            restore_best_weights=True
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=3
        )
    ]


def plot_history(history):
    hist = history.history

    epochs = range(len(hist['loss']))

    # Plot 1 — Loss
    plt.figure(figsize=(12,5))
    plt.plot(epochs, hist['loss'], label="Train Loss")
    plt.plot(epochs, hist['val_loss'], label="Val Loss")
    plt.title("Train vs Val Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

    # Plot 2 — Accuracy
    plt.figure(figsize=(12,5))
    plt.plot(epochs, hist['accuracy'], label="Train Acc")
    plt.plot(epochs, hist['val_accuracy'], label="Val Acc")
    plt.title("Train vs Val Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)
    plt.show()




In [None]:
def main_train(mode=MODE):
    # collect files
    moj_samples = collect_samples(MOJ_ROOT)
    ed_mhi_samples = collect_samples(ED_MHI_ROOT)

    moj_common = intersect_views(moj_samples)
    ed_common = intersect_views(ed_mhi_samples)
    all_common = moj_common & ed_common
    print("Total common samples:", len(all_common))

    # cross-subject split as before (s01-s09 train, s10 test)
    train_ids = [aid for aid in all_common if parse_subject_id(aid) in [f"s{str(i).zfill(2)}" for i in range(1,10)]]
    test_ids = [aid for aid in all_common if parse_subject_id(aid) == "s10"]
    print("Train samples:", len(train_ids), "Test samples:", len(test_ids))

    # build triplet datasets
    (X1_moj, X2_moj, X3_moj,
     X1_ed, X2_ed, X3_ed,
     Y, le, num_classes) = build_triplet_dataset(moj_samples, ed_mhi_samples, train_ids, target_size=TARGET_SIZE)

    (X1_moj_test, X2_moj_test, X3_moj_test,
     X1_ed_test, X2_ed_test, X3_ed_test,
     Y_test, _, _) = build_triplet_dataset(moj_samples, ed_mhi_samples, test_ids, target_size=TARGET_SIZE, le=le)

    print("num_classes:", num_classes)

    # build encoders/classifier fresh (IMPORTANT to reinitialize weights between runs)
    encoder_rgb = build_encoder_rgb(input_shape=(TARGET_SIZE[0], TARGET_SIZE[1], 3), use_augment=True)
    encoder_gray = build_encoder_gray(input_shape=(TARGET_SIZE[0], TARGET_SIZE[1], 1), use_augment=True)
    # classifier = build_classifier(num_classes)

    # MODE-specific model
    if mode == 'moj':
        classifier = build_classifier(num_classes, input_dim=256)
        model = CrossViewModel(encoder_rgb=encoder_rgb, classifier=classifier, mode='moj', lambda_cv_moj=LAMBDA_CV_MOJ)
        train_inputs = [X1_moj, X2_moj, X3_moj]
        val_inputs = [X1_moj_test, X2_moj_test, X3_moj_test]

    elif mode == 'ed':
        classifier = build_classifier(num_classes, input_dim=256)
        model = CrossViewModel(encoder_gray=encoder_gray, classifier=classifier, mode='ed', lambda_cv_ed=LAMBDA_CV_ED)
        train_inputs = [X1_ed, X2_ed, X3_ed]
        val_inputs = [X1_ed_test, X2_ed_test, X3_ed_test]
        

    elif mode == 'fusion':
        classifier = build_classifier(num_classes, input_dim=512)
        model = CrossViewModel(encoder_rgb=encoder_rgb, encoder_gray=encoder_gray, classifier=classifier,
                               mode='fusion', lambda_cv_moj=LAMBDA_CV_MOJ, lambda_cv_ed=LAMBDA_CV_ED)
        train_inputs = [X1_moj, X2_moj, X3_moj, X1_ed, X2_ed, X3_ed]
        val_inputs = [X1_moj_test, X2_moj_test, X3_moj_test, X1_ed_test, X2_ed_test, X3_ed_test]
        
    
    else:
        raise ValueError("Unknown mode")
       
    print("\n====== ENCODER RGB SUMMARY ======")
    encoder_rgb.summary()

    print("\n====== ENCODER GRAY SUMMARY ======")
    encoder_gray.summary()

    print("\n====== CLASSIFIER SUMMARY ======")
    classifier.summary()

    print(model.summary())
    
    # compile and callbacks
    model.compile(optimizer=tf.keras.optimizers.Adam(LR_HEAD))
    # callbacks = make_callbacks()

    # stage 1: freeze encoders (warm-up)
    if mode in ('moj','fusion'):
        encoder_rgb.trainablkwoduse = False
    if mode in ('ed','fusion'):
        encoder_gray.trainable = False
    model.compile(optimizer=tf.keras.optimizers.Adam(LR_HEAD))
    print("Stage 1 (warm-up) training with encoder frozen...")
    history1 = model.fit(
        train_inputs, Y,
        validation_data=(val_inputs, Y_test),
        epochs=EPOCHS_HEAD, batch_size=BATCH_SIZE, verbose=1
    )

    # stage 2: unfreeze and fine-tune
    if mode in ('moj','fusion'):
        encoder_rgb.trainable = True
    if mode in ('ed','fusion'):
        encoder_gray.trainable = True

    model.compile(optimizer=tf.keras.optimizers.Adam(LR_FINETUNE))
    print("Stage 2 (fine-tune) training with encoder unfrozen...")
    history2 = model.fit(
        train_inputs, Y,
        validation_data=(val_inputs, Y_test),
        initial_epoch=history1.epoch[-1] if hasattr(history1,'epoch') else 0,
        epochs=history1.epoch[-1] + EPOCHS_FINETUNE + 1 if hasattr(history1,'epoch') else EPOCHS_FINETUNE,
        batch_size=BATCH_SIZE, verbose=1
    )

    # combine histories for plotting convenience (simple concat)
    merged_history = {}
    for k in set(list(history1.history.keys()) + list(history2.history.keys())):
        merged_history[k] = history1.history.get(k, []) + history2.history.get(k, [])
    merged = type("H", (), {"history": merged_history})

    # evaluate
    eval_results = model.evaluate(val_inputs, Y_test, verbose=0)
    print("Final eval (on test set):", eval_results)

    # plots
    plot_history(merged)

    # save final model weights
    try:
        model.save_weights(MODEL_SAVE_PATH)
        print("Saved weights to", MODEL_SAVE_PATH)
    except Exception as e:
        print("Could not save weights:", e)

        # ---- CONFUSION MATRIX ----
    from sklearn.metrics import confusion_matrix
    import numpy as np

    # get predictions
    y_pred = model.predict(val_inputs)
    y_pred = np.argmax(y_pred, axis=1)
    y_true = np.argmax(Y_test, axis=1)

    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 5))
    plt.imshow(cm)
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    
    # Add numbers inside the boxes
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, cm[i, j], ha='center', va='center')
    
    plt.colorbar()
    plt.tight_layout()
    plt.show()

    return model, merged, le


In [None]:
if __name__ == "__main__":

    model, history, le = main_train(mode='moj')


In [None]:
if __name__ == "__main__":
    model, history, le = main_train(mode='ed')
    


In [None]:
if __name__ == "__main__":
    model, history, le = main_train(mode='fusion')
