<a href="https://colab.research.google.com/github/rghvsrdhtra/Multimodal_Training_Colab/blob/main/Multimodal_Training_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# This notebook was designed as a flexible environment for exploring and comparing various model architectures.


# --- Import Libraries ---
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
from google.colab import drive
import pickle
import shutil
from tqdm import tqdm
from PIL import UnidentifiedImageError

# ##############################################################################
# ## ⚙️ 1. Configuration
# ##############################################################################
# --- Define all user-specific parameters here ---

# === 1.1. Google Drive Paths ===
# Name of your main project folder in Google Drive
DRIVE_PROJECT_DIR_NAME = 'UPMC-FOOD101'
# This will be constructed: '/content/drive/My Drive/YOUR_PROJECT_DIR_NAME'

# === 1.2. Local Colab Environment ===
# Path for temporary data storage on the Colab runtime
LOCAL_DATA_PATH = '/content/data'

# === 1.3. Data Structure ===
# Names of your sub-folders in Google Drive
IMAGE_DIR_NAME = 'images'
TEXT_DIR_NAME = 'texts'

# Names of your CSV files
TRAIN_CSV_NAME = 'train_titles.csv'
TEST_CSV_NAME = 'test_titles.csv'

# === 1.4. CSV/DataFrame Structure ===
# The column names in your CSV files
FILENAME_COL = 'filename'
TEXT_COL = 'title'
LABEL_COL = 'label'

# Set to True if your CSV files have a header row, False otherwise
CSV_HAS_HEADER = False

# === 1.5. Model Hyperparameters ===
IMAGE_SIZE = (128, 128)
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 0.001
DROPOUT_RATE = 0.5

# === 1.6. Text Model Parameters ===
MAX_TOKENS = 10000      # Max vocabulary size
SEQUENCE_LENGTH = 20    # Max words per text input
EMBEDDING_DIM = 128     # Dimension of the word embedding

# === 1.7. Output Asset Names ===
# Names for the files that will be saved back to your Google Drive
MODEL_SAVE_NAME = 'best_multimodal_model.h5'
VOCAB_SAVE_NAME = 'text_vectorizer_vocab.pkl'
ACCURACY_PLOT_NAME = 'training_accuracy_plot.png'
LOSS_PLOT_NAME = 'training_loss_plot.png'

# ##############################################################################
# ## 🚀 2. Setup: Mount Drive & Define Paths
# ##############################################################################

print("Connecting to Google Drive...")
drive.mount('/content/drive')

# --- Construct all paths based on configuration ---
DRIVE_PROJECT_PATH = os.path.join('/content/drive/My Drive', DRIVE_PROJECT_DIR_NAME)

# Source paths (in Google Drive)
drive_images_path = os.path.join(DRIVE_PROJECT_PATH, IMAGE_DIR_NAME)
drive_texts_path = os.path.join(DRIVE_PROJECT_PATH, TEXT_DIR_NAME)

# Local paths (on Colab runtime)
LOCAL_IMAGES_PATH = os.path.join(LOCAL_DATA_PATH, IMAGE_DIR_NAME)
LOCAL_TEXTS_PATH = os.path.join(LOCAL_DATA_PATH, TEXT_DIR_NAME)

train_img_dir = os.path.join(LOCAL_IMAGES_PATH, 'train')
test_img_dir = os.path.join(LOCAL_IMAGES_PATH, 'test')
train_csv_path = os.path.join(LOCAL_TEXTS_PATH, TRAIN_CSV_NAME)
test_csv_path = os.path.join(LOCAL_TEXTS_PATH, TEST_CSV_NAME)

print(f"Project path set to: {DRIVE_PROJECT_PATH}")

# ##############################################################################
# ## 📥 3. Copy Data from Drive to Local
# ##############################################################################
print("\n--- Copying dataset from Google Drive to local Colab environment ---")

if os.path.exists(LOCAL_DATA_PATH):
    print("Local data folder already exists. Skipping copy.")
else:
    try:
        shutil.copytree(drive_images_path, LOCAL_IMAGES_PATH)
        shutil.copytree(drive_texts_path, LOCAL_TEXTS_PATH)
        print("Dataset copied successfully.")
    except Exception as e:
        print(f"ERROR: Could not copy data. Ensure Drive paths are correct.")
        print(f"Details: {e}")
        # Stop execution if data copy fails
        raise

# ##############################################################################
# ## 🧹 4. Load and Preprocess DataFrames
# ##############################################################################
print("\n--- Loading and Preparing Data from Local Colab Storage ---")

def load_dataframe(csv_path):
    """Loads a CSV file based on the configuration."""
    try:
        if CSV_HAS_HEADER:
            df = pd.read_csv(csv_path, header=0)
        else:
            col_names = [FILENAME_COL, TEXT_COL, LABEL_COL]
            df = pd.read_csv(csv_path, names=col_names, header=None)

        print(f"Loaded {csv_path} with {len(df)} rows.")
        print(f"Headers: {df.columns.tolist()}")
        return df
    except Exception as e:
        print(f"ERROR: Could not load local CSV file: {csv_path}.")
        print(f"Details: {e}")
        return None

train_df = load_dataframe(train_csv_path)
test_df = load_dataframe(test_csv_path)

if train_df is None:
    raise ValueError("Failed to load training data. Stopping execution.")

# --- 4.1. Clean DataFrame ---
def clean_df(df, img_dir):
    """
    Cleans the DataFrame by creating a 'full_path' to the image
    and removing rows where the image file does not exist.
    """
    if df is None:
        return pd.DataFrame() # Return empty DF if loading failed

    print(f"\nCleaning DataFrame for image directory: {img_dir}")
    # Create the full path using the configured column names
    df['full_path'] = df.apply(
        lambda row: os.path.join(img_dir, str(row[LABEL_COL]), str(row[FILENAME_COL])),
        axis=1
    )

    # Check which files actually exist
    mask = df['full_path'].apply(os.path.exists)
    print(f"Found {mask.sum()} existing files out of {len(df)} entries.")

    # Return only the rows with existing files
    return df[mask]

train_df = clean_df(train_df, train_img_dir)
test_df = clean_df(test_df, test_img_dir)

print(f"Cleaned to {len(train_df)} training samples and {len(test_df)} testing samples.")

if len(train_df) == 0:
    raise ValueError("Training dataframe is empty after cleaning. Check paths and filenames.")

# --- 4.2. Create Class Dictionaries ---
class_names = sorted(train_df[LABEL_COL].unique())
class_to_index = {name: i for i, name in enumerate(class_names)}
index_to_class = {i: name for i, name in enumerate(class_names)}
NUM_CLASSES = len(class_names)
print(f"\nFound {NUM_CLASSES} classes.")
print(class_names)

# ##############################################################################
# ## 🔠 5. Text Vectorization
# ##############################################################################
print("\n--- Setting up Text Vectorizer ---")

text_vectorizer = layers.TextVectorization(
    max_tokens=MAX_TOKENS,
    output_sequence_length=SEQUENCE_LENGTH,
    standardize="lower_and_strip_punctuation"
)

# Adapt the vectorizer to the training text
# Ensure to use the correct text column and handle potential NaN values
valid_train_titles = train_df[TEXT_COL].dropna()
text_vectorizer.adapt(valid_train_titles)

# --- Save the vocabulary ---
vocab_save_path = os.path.join(DRIVE_PROJECT_PATH, VOCAB_SAVE_NAME)
try:
    with open(vocab_save_path, 'wb') as f:
        pickle.dump(text_vectorizer.get_vocabulary(), f)
    print(f"Vocabulary saved successfully to: {vocab_save_path}")
except Exception as e:
    print(f"WARNING: Could not save vocabulary. Details: {e}")

# ##############################################################################
# ## 💾 6. Pre-load Data into Memory
# ##############################################################################
print("\n--- Pre-loading all images and text into memory ---")

def load_data_into_memory(df):
    """Loads images, vectorizes text, and gets labels for a given DataFrame."""
    images = []
    titles = []
    labels = []

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Loading data"):
        try:
            # Load and process image
            img = keras.preprocessing.image.load_img(
                row['full_path'],
                target_size=IMAGE_SIZE
            )
            img_array = keras.preprocessing.image.img_to_array(img)
            images.append(img_array)

            # Add text and label
            titles.append(str(row[TEXT_COL])) # Ensure text is string
            labels.append(class_to_index[row[LABEL_COL]])

        except (UnidentifiedImageError, FileNotFoundError, IsADirectoryError) as e:
            print(f"\nWARNING: Skipping corrupted or missing file: {row['full_path']}")
            continue
        except Exception as e:
            print(f"\nWARNING: An unexpected error occurred for {row['full_path']}. Error: {e}")
            continue

    # Vectorize all titles at once
    vectorized_titles = text_vectorizer(np.array(titles))

    return np.array(images), vectorized_titles, np.array(labels)

X_train_images, X_train_text, y_train = load_data_into_memory(train_df)
X_test_images, X_test_text, y_test = load_data_into_memory(test_df)

print(f"\nSuccessfully loaded {len(X_train_images)} training samples and {len(X_test_images)} testing samples.")

# ##############################################################################
# ## 📦 7. Create tf.data.Dataset Pipelines
# ##############################################################################
print("\n--- Creating tf.data.Dataset pipelines ---")

def create_dataset(images, texts, labels):
    """Creates a batched and prefetched tf.data.Dataset."""
    dataset = tf.data.Dataset.from_tensor_slices(
        ({"image_input": images, "text_input": texts}, labels)
    )
    dataset = dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    return dataset

train_dataset = create_dataset(X_train_images, X_train_text, y_train)
print("Training dataset pipeline created.")

test_dataset = None
if len(X_test_images) > 0:
    test_dataset = create_dataset(X_test_images, X_test_text, y_test)
    print("Test (validation) dataset pipeline created.")
else:
    print("\nWARNING: No valid test data found. Validation will be skipped.")

# ##############################################################################
# ## 🤖 8. Build the Multimodal Model
# ##############################################################################
print("\n--- Building the Multimodal Model ---")

# --- 8.1. Image Branch (CNN) ---
# Data augmentation
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1)
], name="data_augmentation")

image_input = keras.Input(
    shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3),
    name="image_input"
)
augmented_image = data_augmentation(image_input)

# Pre-trained base model
# You can swap this with any keras.applications model (e.g., VGG16, ResNet50)
base_model = keras.applications.EfficientNetB0(
    include_top=False,
    weights="imagenet",
    pooling="avg"
)
base_model.trainable = False # Freeze pre-trained weights
image_features = base_model(augmented_image, training=False)
image_features = layers.Dense(128, activation="relu")(image_features)

# --- 8.2. Text Branch (LSTM) ---
text_input = keras.Input(
    shape=(SEQUENCE_LENGTH,),
    dtype="int32",
    name="text_input"
)
text_features = layers.Embedding(
    input_dim=MAX_TOKENS,
    output_dim=EMBEDDING_DIM
)(text_input)
text_features = layers.LSTM(64)(text_features)
text_features = layers.Dense(128, activation="relu")(text_features)

# --- 8.3. Fusion Branch ---
combined_features = layers.concatenate([image_features, text_features])
combined_features = layers.Dropout(DROPOUT_RATE)(combined_features)
combined_features = layers.Dense(64, activation="relu")(combined_features)
output = layers.Dense(NUM_CLASSES, activation="softmax")(combined_features)

# --- 8.4. Create and Compile Model ---
model = keras.Model(
    inputs=[image_input, text_input],
    outputs=output
)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()

# ##############################################################################
# ## 🚂 9. Train the Model
# ##############################################################################

# --- 9.1. Define Callbacks ---
# Define the path to save the best model in your Google Drive
best_model_save_path = os.path.join(DRIVE_PROJECT_PATH, MODEL_SAVE_NAME)

# Monitor validation accuracy if available, otherwise training accuracy
monitor_metric = 'val_accuracy' if test_dataset else 'accuracy'

checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=best_model_save_path,
    save_best_only=True,
    monitor=monitor_metric,
    mode='max',
    verbose=1
)

early_stopping_callback = keras.callbacks.EarlyStopping(
    patience=5,
    monitor=monitor_metric,
    mode='max',
    restore_best_weights=True
)

print(f"\n--- Starting Model Training (monitoring '{monitor_metric}') ---")

history = model.fit(
    train_dataset,
    validation_data=test_dataset,  # This will be None if no test data
    epochs=EPOCHS,
    callbacks=[checkpoint_callback, early_stopping_callback]
)

print("--- Model Training Finished ---")

# ##############################################################################
# ## 📊 10. Save Results and Evaluate
# ##############################################################################
print("\n--- Saving All Outputs Permanently to Google Drive ---")

# --- 10.1. Save Plots ---
history_df = pd.DataFrame(history.history)

def save_plot(history_df, metric_name, plot_save_path):
    """Generates and saves a training plot."""
    plt.figure(figsize=(10, 5))

    # Plot training metric
    plt.plot(history_df[metric_name], label=f'Training {metric_name.capitalize()}')

    # Plot validation metric if it exists
    val_metric_name = f'val_{metric_name}'
    if val_metric_name in history_df.columns:
        plt.plot(history_df[val_metric_name], label=f'Validation {metric_name.capitalize()}')

    plt.title(f'Model {metric_name.capitalize()} vs. Epochs')
    plt.xlabel('Epoch')
    plt.ylabel(metric_name.capitalize())
    plt.legend()
    plt.grid(True)

    try:
        plt.savefig(plot_save_path)
        print(f"Successfully saved plot to {plot_save_path}")
    except Exception as e:
        print(f"WARNING: Could not save plot. Details: {e}")
    plt.close()

# Save Accuracy Plot
accuracy_plot_path = os.path.join(DRIVE_PROJECT_PATH, ACCURACY_PLOT_NAME)
save_plot(history_df, 'accuracy', accuracy_plot_path)

# Save Loss Plot
loss_plot_path = os.path.join(DRIVE_PROJECT_PATH, LOSS_PLOT_NAME)
save_plot(history_df, 'loss', loss_plot_path)

# --- 10.2. Evaluate Final Model ---
# Note: The model already has its best weights restored by EarlyStopping
if test_dataset:
    print("\n--- Evaluating Best Model on Test Data ---")
    loss, accuracy = model.evaluate(test_dataset)
    print(f"Final Test Accuracy: {accuracy*100:.2f}%")
    print(f"Final Test Loss: {loss:.4f}")
else:
    print("\nSkipping final evaluation as no valid test data was found.")

print(f"\n--- All outputs saved to your Google Drive in '{DRIVE_PROJECT_DIR_NAME}'! ---")