<a href="https://colab.research.google.com/github/sanikasanikachaudhari071/BE_project/blob/main/implementaion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tensorflow keras opencv-python mtcnn

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Dense, Conv2D, MaxPool2D, GlobalAveragePooling2D,
    Concatenate, Lambda, LayerNormalization, MultiHeadAttention, Dropout
)
from tensorflow.keras.models import Model
from tensorflow.keras.applications import DenseNet121

# Our input is a 224x224 RGB image
input_shape = (224, 224, 3)
img_input = Input(shape=input_shape)

In [None]:
# 1. Load the pre-trained DenseNet121 model spatial vector
base_model = DenseNet121(
    weights='imagenet',
    include_top=False,
    input_tensor=img_input,
    pooling='avg'  # This adds a GlobalAveragePooling2D layer
)

# 2. Freeze the base model layers (for initial training)
# for layer in base_model.layers:
#     layer.trainable = False

# 3. This is our "spatial vector"
spatial_vector = base_model.output

In [None]:
# 1. DCT Layer
# We apply DCT on each channel independently
def dct_layer(x):
    # Apply 2D DCT (Type-II)
    # tf.signal.dct expects float32 or float64
    x = tf.cast(x, tf.float32)
    return tf.signal.dct(x, type=2, norm='ortho')

# Wrap it in a Lambda layer
freq_stream = Lambda(dct_layer)(img_input)

# 2. Simple CNN
# The DCT output is still 224x224x3
cnn = Conv2D(32, (3, 3), activation='relu', padding='same')(freq_stream)
cnn = MaxPool2D((2, 2))(cnn)

cnn = Conv2D(64, (3, 3), activation='relu', padding='same')(cnn)
cnn = MaxPool2D((2, 2))(cnn)

# 3. This is our "freq vector"
# We use GlobalAveragePooling2D to get a flat vector
freq_vector = GlobalAveragePooling2D()(cnn)

In [None]:
# 1. Check vector shapes (for debugging)
print(f"Spatial vector shape: {spatial_vector.shape}")
print(f"Frequency vector shape: {freq_vector.shape}")

# 2. Concatenate them
fused_features = Concatenate()([spatial_vector, freq_vector])

In [None]:
# A Transformer Encoder block
def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
    # Attention and Normalization
    x = LayerNormalization(epsilon=1e-6)(inputs)
    # Note: MultiHeadAttention expects sequence input. We reshape.
    # We treat our flat vector as a "sequence" of 1 element.
    # This might need adjustment based on the "Cross-ViT" paper.
    # A common trick is to expand dims to create a "sequence"

    # Reshape for MHA: (batch_size, sequence_length, features)
    # Let's assume the fused_features shape is (batch_size, num_features)
    # We add a sequence dimension: (batch_size, 1, num_features)
    x_seq = tf.expand_dims(x, axis=1)

    attention_output = MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout
    )(x_seq, x_seq) # Self-attention

    attention_output = Dropout(dropout)(attention_output)
    # Skip connection
    x = x + attention_output[:, 0, :] # Squeeze back from sequence

    # Feed-Forward Part
    ffn = LayerNormalization(epsilon=1e-6)(x)
    ffn = Dense(ff_dim, activation="relu")(ffn)
    ffn = Dropout(dropout)(ffn)
    ffn = Dense(inputs.shape[-1])(ffn) # Project back to original feature dim

    # Second skip connection
    transformer_output = x + ffn
    return transformer_output

# Apply the transformer block
transformer_output = transformer_encoder(
    fused_features,
    head_size=256,
    num_heads=4,
    ff_dim=512,
    dropout=0.1
)

In [None]:
# Classification MLP Head
x = Dense(128, activation='relu')(transformer_output)
x = Dropout(0.5)(x)
x = Dense(64, activation='relu')(x)

# Output layer: 1 neuron with sigmoid for binary (real/fake) classification
# 0 = real, 1 = fake
output = Dense(1, activation='sigmoid')(x)

# Create the final model
model = Model(inputs=img_input, outputs=output)

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

# Print a summary to check your work
model.summary()