In [1]:
!pip install -q flax optax

In [2]:
import os
import pandas as pd
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, roc_auc_score

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.training import train_state
import optax
from tqdm import tqdm
import csv

2025-06-21 21:47:34.584136: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750542454.928942      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750542455.025904      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
DATASET_PATH = "/kaggle/input/140k-real-and-fake-faces"
IMAGE_DIR = os.path.join(DATASET_PATH, "real_vs_fake")

train_df = pd.read_csv(os.path.join(DATASET_PATH, "train.csv"))
val_df = pd.read_csv(os.path.join(DATASET_PATH, "valid.csv"))

# Map labels: REAL → 0, FAKE → 1
label_map = {'REAL': 0, 'FAKE': 1}
train_df['label'] = train_df['label'].map(label_map)
val_df['label'] = val_df['label'].map(label_map)

In [10]:
BASE_IMAGE_PATH = "/kaggle/input/140k-real-and-fake-faces/real_vs_fake/real-vs-fake"

In [11]:
def load_dataset_from_df(df, base_image_dir, image_size=(224, 224), batch_size=32):
    def process(file_path, label):
        img = tf.io.read_file(file_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, image_size)
        img = img / 255.0
        return img, label

    # ✅ Join full image path using `path` from CSV (e.g., "train/real/31355.jpg")
    image_paths = [os.path.join(base_image_dir, rel_path) for rel_path in df["path"]]
    labels = df["label"].values

    ds = tf.data.Dataset.from_tensor_slices((image_paths, labels))
    ds = ds.map(process, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

In [12]:
import pandas as pd

DATASET_PATH = "/kaggle/input/140k-real-and-fake-faces"
BASE_IMAGE_PATH = os.path.join(DATASET_PATH, "real_vs_fake/real-vs-fake")

train_df = pd.read_csv(os.path.join(DATASET_PATH, "train.csv"))
val_df = pd.read_csv(os.path.join(DATASET_PATH, "valid.csv"))

train_ds = load_dataset_from_df(train_df, BASE_IMAGE_PATH)
val_ds = load_dataset_from_df(val_df, BASE_IMAGE_PATH)

In [13]:
class GatingFunction(nn.Module):
    num_experts: int
    threshold: float = 0.0

    @nn.compact
    def __call__(self, x):
        logits = nn.Dense(self.num_experts)(x)
        weights = nn.softmax(logits)
        weights = jnp.where(weights > self.threshold, weights, 0.0)
        weights = weights / (jnp.sum(weights, axis=1, keepdims=True) + 1e-8)
        self.sow('intermediates', 'gating_weights', weights)
        return weights

class Expert(nn.Module):
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        return x

class ForensicMoEBlock(nn.Module):
    num_experts: int
    expert_hidden_dim: int
    threshold: float = 0.1

    @nn.compact
    def __call__(self, x):
        gating = GatingFunction(self.num_experts, self.threshold)(x)
        expert_outputs = [Expert(self.expert_hidden_dim)(x) for _ in range(self.num_experts)]
        expert_stack = jnp.stack(expert_outputs, axis=0)  # (experts, batch, dim)
        gating = gating.T[:, :, None]  # (experts, batch, 1)
        gated_output = jnp.sum(gating * expert_stack, axis=0)
        return gated_output

class VMoEBinaryClassifier(nn.Module):
    num_classes: int = 2

    @nn.compact
    def __call__(self, x, train=True):
        x = x.reshape((x.shape[0], -1))
        for _ in range(2):
            x = ForensicMoEBlock(num_experts=3, expert_hidden_dim=256)(x)
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_classes)(x)
        return x

In [14]:
def create_train_state(rng, model, learning_rate, input_shape):
    variables = model.init(rng, jnp.ones(input_shape))
    tx = optax.adam(learning_rate)
    state = train_state.TrainState.create(apply_fn=model.apply, params=variables['params'], tx=tx)
    return state, variables

@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        x, y = batch
        logits = state.apply_fn({'params': params}, x)
        one_hot = jax.nn.one_hot(y, 2)
        loss = optax.softmax_cross_entropy(logits, one_hot).mean()
        return loss
    grads = jax.grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads)

def evaluate(model, state, val_ds):
    y_true, y_pred, y_prob, expert_usage = [], [], [], []

    for batch in val_ds:
        x_batch, y_batch = batch
        x_batch = jnp.array(x_batch).reshape((x_batch.shape[0], -1))
        vars_with_gating = {'params': state.params}
        logits, intermediates = model.apply(vars_with_gating, x_batch, mutable=['intermediates'], train=False)
        gating_weights = intermediates['intermediates']['gating_weights']
        probs = jax.nn.softmax(logits, axis=-1)
        preds = jnp.argmax(probs, axis=-1)

        y_true.extend(np.array(y_batch))
        y_pred.extend(np.array(preds))
        y_prob.extend(np.array(probs[:, 1]))
        expert_usage.extend(np.array(gating_weights))

    print(classification_report(y_true, y_pred))
    print("AUC Score:", roc_auc_score(y_true, y_prob))

    plt.figure(figsize=(8, 4))
    plt.title("Average Gating Weights Across Validation Set")
    plt.bar(range(len(expert_usage[0])), np.mean(expert_usage, axis=0))
    plt.xlabel("Expert Index")
    plt.ylabel("Average Weight")
    plt.show()

In [16]:
rng = jax.random.PRNGKey(0)
input_shape = (1, 224 * 224 * 3)
model = VMoEBinaryClassifier()
state, variables = create_train_state(rng, model, 1e-4, input_shape)

EPOCHS = 5
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    for batch in tqdm(train_ds):
        x_batch, y_batch = batch
        x_batch = jnp.array(x_batch).reshape((x_batch.shape[0], -1))
        y_batch = jnp.array(y_batch)
        state = train_step(state, (x_batch, y_batch))
    evaluate(model, state, val_ds)


Epoch 1/5


100%|██████████| 3125/3125 [59:42<00:00,  1.15s/it]  


KeyError: 'gating_weights'