**Cell 1: Installs**

In [1]:
# CELL 1: Installs (Revised to handle hash issues)

# 1. Upgrade pip itself
print("Upgrading pip...")
!pip install --upgrade pip

# 2. Clear pip's cache to remove potentially corrupted downloads
print("\nClearing pip cache...")
!pip cache purge

# 3. Install necessary libraries without using the cache directory
#    (--no-cache-dir forces fresh downloads and helps avoid hash mismatches)
print("\nInstalling required packages...")
!pip install -q --no-cache-dir \
    transformers==4.38.2 \
    sentence-transformers==2.7.0 \
    faiss-cpu==1.8.0 \
    torch==2.1.2 \
    accelerate==0.28.0 \
    scikit-learn==1.3.2 \
    pandas==2.1.4 \
    matplotlib==3.8.2 \
    Pillow==10.2.0 \
    tqdm==4.66.2

print("\nPackage installation attempt finished.")

# Optional: Verify faiss installation by trying to import it here
try:
    import faiss
    print("Successfully imported faiss after installation.")
    # You might see the CUDA registration warnings here again, which is fine.
except ImportError:
    print("ERROR: Failed to import faiss even after installation attempt. Check install logs above.")
except Exception as e:
    print(f"An unexpected error occurred during faiss import check: {e}")

Upgrading pip...
Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.0.1

Clearing pip cache...
Files removed: 6 (1.9 MB)

Installing required packages...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.5/8.5 MB[0m [31m150.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.0/27.0 MB[0m [31m339.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m271.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

**Cell 2: Imports**

In [2]:
import json
import os
import random
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import MultiheadAttention 
from torch.utils.data import Dataset, DataLoader
from transformers import (
    BartForSequenceClassification,
    BartTokenizer,
    BartModel, # Base BART model
    CLIPProcessor, # For vision features
    CLIPVisionModel, # For vision features
    get_linear_schedule_with_warmup,
    CLIPVisionConfig # To get visual feature dimension
)
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report, hamming_loss, multilabel_confusion_matrix
from sklearn.preprocessing import LabelEncoder
from sentence_transformers import SentenceTransformer
import faiss
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm # Use notebook version of tqdm
import logging
import re
from typing import List, Dict, Tuple, Optional
import pickle
import gc # Garbage collector
import torch.nn.functional as F # For sigmoid and softmax
from PIL import Image
import warnings
from sklearn.exceptions import UndefinedMetricWarning
import traceback
from collections import defaultdict

# Ignore specific warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

**Cell 3: Configuration and Seed**

In [None]:
# CELL 3: Configuration and Seed (Updated)

# --- Basic Hyperparameters ---
MAX_LEN = 512          # Max length for BART tokenizer
BATCH_SIZE = 8         # Adjust based on GPU memory (T4 likely needs 4 or 8)
NUM_EPOCHS = 10        # Number of training epochs PER ensemble member
LEARNING_RATE = 3e-5
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.999
ADAM_EPSILON = 1e-8
WEIGHT_DECAY = 0.05
DROPOUT = 0.3

# --- RAG Hyperparameters ---
TEXT_EMBEDDING_MODEL = "BAAI/bge-m3"
RETRIEVAL_K = 3          # Number of examples to retrieve for prompts

# --- Vision Hyperparameters ---
VISION_MODEL_NAME = "openai/clip-vit-base-patch32"
try:
    vision_config = CLIPVisionConfig.from_pretrained(VISION_MODEL_NAME)
    VISUAL_FEATURE_DIM = vision_config.projection_dim
except Exception as e:
    logger.warning(f"Could not load vision config for {VISION_MODEL_NAME}: {e}. Defaulting VISUAL_FEATURE_DIM to 768.")
    VISUAL_FEATURE_DIM = 768

# --- Ensemble Hyperparameters ---
NUM_ENSEMBLE_MODELS = 3 # Number of models to train in the ensemble
BASE_SEED = 42           # Base seed for reproducibility

# --- Classification Model ---
# <<< CHANGE HERE: Use the MentalBART model identifier >>>
BASE_TEXT_MODEL_NAME = "facebook/bart-base"
# <<< END CHANGE >>>


# --- Kaggle File Paths ---
KAGGLE_INPUT_DIR = "/kaggle/input/axiom-dataset"
KAGGLE_WORKING_DIR = "/kaggle/working"

# --- Seed Function ---
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    # torch.backends.cudnn.deterministic = True # Optional for reproducibility
    # torch.backends.cudnn.benchmark = False   # Optional
    logger.info(f"Seed set to {seed}")

# --- Device ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
logger.info(f"Visual Feature Dimension set to: {VISUAL_FEATURE_DIM}")
# Log the base text model being used
logger.info(f"Base Text Model set to: {BASE_TEXT_MODEL_NAME}")
logger.info(f"Text Embedding Model: {TEXT_EMBEDDING_MODEL}")
logger.info(f"Vision Model: {VISION_MODEL_NAME}")


# Set initial seed
set_seed(BASE_SEED)

config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

**Cell 4: Visual Feature Extraction (Run Once)**

In [4]:
# === VISUAL FEATURE EXTRACTION ===
# NOTE: This cell only needs to be run once if the features are saved.
# If features (.pt files) already exist in /kaggle/working/visual_features,
# you can skip running this cell on subsequent runs.

def get_image_path(sample_id, dataset_base_path, json_path):
    """
    Constructs the image path based on sample ID and base path.
    Uses json_path to infer the split if prefix is ambiguous.
    """
    split = None
    # Try matching known prefixes first
    if isinstance(sample_id, str):
        if sample_id.startswith('TR-'): split = 'train'
        elif sample_id.startswith('TE-'): split = 'test'
        elif sample_id.startswith('VL-'): split = 'val' # Assuming VL- for validation if needed

    # If split not determined by prefix, infer from the JSON file path
    if not split:
        if 'train.json' in json_path or 'anxiety_train.json' in json_path:
            split = 'train'
        elif 'test.json' in json_path or 'anxiety_test.json' in json_path:
            split = 'test'
        elif 'val.json' in json_path:
             split = 'val'
        else:
            # Last resort - needs better logic if IDs aren't prefixed and file names don't help
            logger.warning(f"Cannot determine split (train/test/val) for sample_id: {sample_id} from path {json_path}. Assuming 'train'. Check logic.")
            split = 'train'

    # Construct the image folder path based on dataset type and split
    img_folder = None
    parent_dir = os.path.dirname(dataset_base_path) # Go one level up from Anxiety_Data or Depressive_Data
    if "Anxiety_Data" in dataset_base_path:
        img_folder = os.path.join(parent_dir, f"anxiety_{split}_image")
    elif "Depressive_Data" in dataset_base_path:
        # Adjusted path structure for depression based on observed zip structure
        # Check if Images/depressive_image/split exists
        potential_path1 = os.path.join(parent_dir, "Images", "depressive_image", split)
        if os.path.isdir(potential_path1):
             img_folder = potential_path1
        else:
             # Fallback if the structure is different (e.g., directly under depressive_image)
             potential_path2 = os.path.join(parent_dir, "depressive_image", split)
             if os.path.isdir(potential_path2):
                 img_folder = potential_path2
             else:
                 logger.error(f"Cannot find depression image folder for split '{split}' near {parent_dir}")
                 return None
    else:
        logger.error(f"Unknown dataset base path structure: {dataset_base_path}")
        return None

    if not os.path.isdir(img_folder):
        logger.error(f"Determined image folder does not exist: {img_folder}")
        return None

    # Try common extensions
    for ext in ['.jpeg', '.jpg', '.png']:
        img_path = os.path.join(img_folder, f"{sample_id}{ext}")
        if os.path.exists(img_path):
            return img_path

    logger.debug(f"Image file not found for {sample_id} in {img_folder} with common extensions (.jpeg, .jpg, .png).")
    return None


def extract_and_save_features(dataset_name, split_name, json_path, dataset_base_path, processor, vision_model, device, output_file):
    """Extracts CLIP features for a given dataset split and saves them."""
    logger.info(f"Extracting features for {dataset_name} - {split_name} from {json_path}...")
    features_map = {}
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except FileNotFoundError:
        logger.error(f"JSON file not found: {json_path}. Cannot extract features.")
        return
    except Exception as e:
        logger.error(f"Failed to load JSON {json_path}: {e}")
        return

    missing_images = 0
    vision_model.eval() # Set model to evaluation mode
    with torch.no_grad(): # Disable gradient calculations
        for sample in tqdm(data, desc=f"Processing {split_name} images for {dataset_name}"):
            sample_id = sample.get('sample_id', sample.get('id')) # Handle both 'sample_id' and 'id' keys
            if not sample_id:
                logger.warning("Skipping sample with missing ID.")
                continue

            # Pass the json_path to help get_image_path infer split if needed
            image_path = get_image_path(sample_id, dataset_base_path, json_path)

            if image_path and os.path.exists(image_path):
                try:
                    image = Image.open(image_path).convert("RGB")
                    inputs = processor(images=image, return_tensors="pt").to(device)
                    outputs = vision_model(**inputs)
                    # Use pooler_output which is the CLS token's embedding after projection for CLIP
                    # Squeeze to remove batch dimension (as we process one image at a time)
                    features = outputs.pooler_output.squeeze().cpu() # (feature_dim,) -> Move to CPU before storing
                    if features.shape[0] != VISUAL_FEATURE_DIM:
                         logger.warning(f"Extracted feature dim {features.shape[0]} != expected {VISUAL_FEATURE_DIM} for {sample_id}. Check CLIP model/config.")
                         # Handle dimension mismatch if necessary (e.g., skip, pad, project) - here we just warn
                    features_map[sample_id] = features
                except FileNotFoundError:
                    # This case should be less likely now with the initial check
                    logger.warning(f"File not found during processing (should have been checked): {image_path}")
                    missing_images += 1
                except Exception as e:
                    logger.warning(f"Failed to process image {image_path}: {e}")
                    missing_images += 1
            else:
                logger.debug(f"Image path not found or invalid for sample_id: {sample_id}. Path sought: {image_path}")
                missing_images += 1

    if missing_images > 0:
        logger.warning(f"Could not find or process {missing_images}/{len(data)} images for {dataset_name} - {split_name}.")

    if features_map:
        logger.info(f"Saving {len(features_map)} extracted features to {output_file}...")
        try:
            torch.save(features_map, output_file)
            logger.info("Features saved successfully.")
        except Exception as e:
            logger.error(f"Failed to save features to {output_file}: {e}")
    else:
        logger.warning(f"No features were extracted for {dataset_name} - {split_name}. Output file not saved.")

# --- Load CLIP Model --- (Do this once)
clip_processor = None
clip_vision_model = None
try:
    logger.info(f"Loading CLIP processor and vision model: {VISION_MODEL_NAME}")
    clip_processor = CLIPProcessor.from_pretrained(VISION_MODEL_NAME)
    clip_vision_model = CLIPVisionModel.from_pretrained(VISION_MODEL_NAME).to(device)
    clip_vision_model.eval() # Ensure model is in eval mode
    logger.info("CLIP models loaded successfully.")
except Exception as e:
    logger.error(f"Failed to load CLIP model or processor: {e}", exc_info=True)
    logger.error("Visual feature extraction cannot proceed.")

# --- Define Paths and Run Extraction (if models loaded) ---
if clip_processor and clip_vision_model:
    # Define base directories within the input dataset
    anxiety_base_dir = os.path.join(KAGGLE_INPUT_DIR, "dataset", "Anxiety_Data")
    depression_base_dir = os.path.join(KAGGLE_INPUT_DIR, "dataset", "Depressive_Data")

    # Define the output directory for saving features
    output_feature_dir = os.path.join(KAGGLE_WORKING_DIR, "visual_features")
    os.makedirs(output_feature_dir, exist_ok=True)
    logger.info(f"Visual features will be saved to: {output_feature_dir}")

    datasets_to_process = [
        # Anxiety
        ("anxiety", "train", os.path.join(anxiety_base_dir, "anxiety_train.json"), anxiety_base_dir, os.path.join(output_feature_dir, "anxiety_train_features.pt")),
        ("anxiety", "test", os.path.join(anxiety_base_dir, "anxiety_test.json"), anxiety_base_dir, os.path.join(output_feature_dir, "anxiety_test_features.pt")),
        # Depression
        ("depression", "train", os.path.join(depression_base_dir, "train.json"), depression_base_dir, os.path.join(output_feature_dir, "depression_train_features.pt")),
        ("depression", "test", os.path.join(depression_base_dir, "test.json"), depression_base_dir, os.path.join(output_feature_dir, "depression_test_features.pt")),
        ("depression", "val", os.path.join(depression_base_dir, "val.json"), depression_base_dir, os.path.join(output_feature_dir, "depression_val_features.pt")) # Add validation set
    ]

    for name, split, json_p, data_base_p, out_f in datasets_to_process:
        # Only process if the corresponding JSON file exists
        if os.path.exists(json_p):
            # Check if features already exist before extracting
            if not os.path.exists(out_f):
                 logger.info(f"Starting feature extraction for: {name} - {split}")
                 extract_and_save_features(name, split, json_p, data_base_p, clip_processor, clip_vision_model, device, out_f)
                 gc.collect() # Collect garbage after processing each split
                 if device == torch.device('cuda'):
                     torch.cuda.empty_cache()
            else:
                 logger.info(f"Visual features file already exists: {out_f}. Skipping extraction for {name} - {split}.")
        else:
            logger.warning(f"JSON file not found: {json_p}. Skipping feature extraction for {name} - {split}.")

    # Clean up vision model from memory after all extractions are done
    logger.info("Visual feature extraction process finished. Cleaning up CLIP model...")
    del clip_vision_model
    del clip_processor
    gc.collect()
    if device == torch.device('cuda'):
        torch.cuda.empty_cache()
    logger.info("CLIP model cleanup complete.")
else:
    logger.error("CLIP model failed to load. Cannot extract visual features.")
    # Define dummy path so pipeline doesn't crash immediately, but loading will fail later
    output_feature_dir = os.path.join(KAGGLE_WORKING_DIR, "visual_features")
    os.makedirs(output_feature_dir, exist_ok=True) # Still create the dir

2025-04-13 05:43:39.892125: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744523020.081546      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744523020.133529      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

Processing train images for anxiety:   0%|          | 0/2608 [00:00<?, ?it/s]

Processing test images for anxiety:   0%|          | 0/652 [00:00<?, ?it/s]

Processing train images for depression:   0%|          | 0/8814 [00:00<?, ?it/s]

Processing test images for depression:   0%|          | 0/520 [00:00<?, ?it/s]

Processing val images for depression:   0%|          | 0/361 [00:00<?, ?it/s]

**Cell 5: Data Loading, Cleaning, Splitting Functions**

In [5]:
# --- load_data Function (handles image_id and label variations) ---
def load_data(file_path):
    """Loads data from JSON, extracts relevant fields, handles label variations."""
    logger.info(f"Loading data from: {file_path}")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except FileNotFoundError:
        logger.error(f"File not found: {file_path}")
        return []
    except json.JSONDecodeError as e:
         logger.error(f"Error decoding JSON from {file_path}: {e}")
         # Optionally try to load line by line if it's jsonl
         logger.info(f"Attempting to load {file_path} as JSON Lines (.jsonl)")
         data = []
         try:
             with open(file_path, 'r', encoding='utf-8') as f:
                 for line in f:
                     try:
                         data.append(json.loads(line))
                     except json.JSONDecodeError:
                         logger.warning(f"Skipping invalid JSON line in {file_path}: {line.strip()}")
             if not data: return [] # If still empty after trying jsonl
             logger.info(f"Successfully loaded {len(data)} lines as JSON Lines.")
         except Exception as e_jsonl:
             logger.error(f"Failed to load as JSON Lines as well: {e_jsonl}")
             return []

    is_anxiety = "Anxiety_Data" in file_path
    filtered_data = []
    processed_ids = set() # Keep track of processed IDs to avoid duplicates

    for idx, sample in enumerate(data):
        if not isinstance(sample, dict):
            logger.warning(f"Skipping non-dictionary item at index {idx} in {file_path}")
            continue

        # Use 'sample_id' primarily, fallback to 'id', then None
        sample_id = sample.get('sample_id', sample.get('id', None))

        # Skip if ID is missing or already processed
        if sample_id is None:
            logger.warning(f"Skipping sample at index {idx} due to missing ID.")
            continue
        if sample_id in processed_ids:
             logger.warning(f"Skipping duplicate sample ID: {sample_id}")
             continue
        processed_ids.add(sample_id)

        # Ensure necessary fields exist
        ocr_text = sample.get('ocr_text', None)
        triples = sample.get('triples', "") # Default to empty string if missing

        # Get image identifier (use sample_id if 'image_id' isn't explicitly present)
        image_id = sample.get('image_id', sample_id)
        if not image_id:
             logger.warning(f"Sample {sample_id} missing a valid image identifier. Using sample_id '{sample_id}'.")
             image_id = sample_id # Fallback for safety, though should have ID by now

        # Add image_id to the sample dictionary
        sample['image_id'] = image_id

        # Process based on dataset type
        if is_anxiety:
            anxiety_label_key = 'meme_anxiety_category'
            original_label = sample.get(anxiety_label_key, None)

            if ocr_text is not None and original_label is not None:
                 # Standardize known label variations
                 if original_label == 'Irritatbily': original_label = 'Irritability'
                 elif original_label == 'Unknown': original_label = 'Unknown Anxiety'

                 sample['original_labels'] = original_label # Single label for anxiety
                 sample['stratify_label'] = original_label # Use the single label for stratification
                 sample['triples'] = triples # Ensure triples are included
                 sample['ocr_text'] = ocr_text # Ensure ocr_text is included
                 filtered_data.append(sample)
            else:
                 logger.warning(f"Skipping anxiety sample {sample_id} due to missing 'ocr_text' or '{anxiety_label_key}'. OCR: {'Present' if ocr_text is not None else 'Missing'}, Label: {'Present' if original_label is not None else 'Missing'}")

        else: # Depression (multilabel)
             depression_labels_key = 'meme_depressive_categories'
             original_label_data = sample.get(depression_labels_key, None)
             processed_labels = []

             if ocr_text is not None and original_label_data is not None:
                  # Handle different formats of labels (list, string, potentially comma-separated string)
                  if isinstance(original_label_data, list):
                      processed_labels = [str(lbl).strip() for lbl in original_label_data if str(lbl).strip()]
                  elif isinstance(original_label_data, str):
                      # Simple split by comma if it's a string, could be more robust
                      processed_labels = [lbl.strip() for lbl in original_label_data.split(',') if lbl.strip()]
                  else:
                      logger.warning(f"Unexpected label format for depression sample {sample_id}: {type(original_label_data)}. Treating as empty.")
                      processed_labels = []

                  # Ensure 'Unknown Depression' is handled if needed, though seems less common
                  processed_labels = [lbl if lbl != 'Unknown' else 'Unknown Depression' for lbl in processed_labels]

                  if not processed_labels:
                       logger.warning(f"Depression sample {sample_id} resulted in empty label list after processing. Assigning 'Unknown Depression'. Original data: {original_label_data}")
                       processed_labels = ["Unknown Depression"] # Assign a default if empty

                  sample['original_labels'] = processed_labels # List of labels for depression
                  # Use the first label for stratification (or a default if list is somehow empty)
                  sample['stratify_label'] = processed_labels[0] if processed_labels else "Unknown Depression"
                  sample['triples'] = triples # Ensure triples are included
                  sample['ocr_text'] = ocr_text # Ensure ocr_text is included
                  filtered_data.append(sample)
             else:
                 logger.warning(f"Skipping depression sample {sample_id} due to missing 'ocr_text' or '{depression_labels_key}'. OCR: {'Present' if ocr_text is not None else 'Missing'}, Label: {'Present' if original_label_data is not None else 'Missing'}")

    logger.info(f"Loaded {len(filtered_data)} samples from {file_path} after filtering and processing.")
    return filtered_data


# --- clean_triples Function ---
def clean_triples(triples_text):
    """Cleans the structured triples text, keeping section headers."""
    if pd.isna(triples_text) or not isinstance(triples_text, str) or not triples_text.strip():
        return ""

    # Define the sections expected in the triples
    sections = ["Cause-Effect", "Figurative Understanding", "Mental State"]
    cleaned_parts = []

    # Use regex to capture content for each section, handling potential missing sections
    current_text = triples_text
    found_any_section = False
    for i, section in enumerate(sections):
        # Regex to find section header and capture text until the next known header or end of string
        # (?s) is equivalent to re.DOTALL flag
        # Lazily match content with .*?
        # Lookahead (?=...) ensures we stop before the next header or end of string ($)
        next_headers_pattern = "|".join(f"{re.escape(s)}\\s*:" for s in sections[i+1:])
        if next_headers_pattern:
             pattern = rf"(?s){re.escape(section)}\s*:(.*?)(?=\s*(?:{next_headers_pattern}|$))"
        else: # Last section
             pattern = rf"(?s){re.escape(section)}\s*:(.*)"

        match = re.search(pattern, current_text, re.IGNORECASE)

        if match:
            content = match.group(1).strip()
            # Further clean the content: remove excessive whitespace
            content = re.sub(r'\s+', ' ', content).strip()
            if content: # Only add if there's actual content
                cleaned_parts.append(f"{section}: {content}")
                found_any_section = True
                # Reduce the search space for the next iteration (optional but can help)
                # current_text = current_text[match.end():] # This might be too aggressive if order isn't guaranteed
        else:
             logger.debug(f"Section '{section}' not found or empty in triples: {triples_text[:100]}...") # Log if a section is missed


    # If no standard sections were found, return the original text after basic whitespace cleaning
    if not found_any_section:
         logger.debug(f"No standard sections found in triples. Returning cleaned original text: {triples_text[:100]}...")
         return re.sub(r'\s+', ' ', triples_text).strip()

    # Join the cleaned parts with newlines
    return "\n".join(cleaned_parts).strip()


# --- split_data Function (Handles stratification carefully) ---
def split_data(data: List[Dict], val_size: float = 0.1, test_size: float = 0.2, random_state: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
    """Splits data into train, validation, and test sets with stratification."""
    if not data:
        logger.error("Cannot split empty data list.")
        return [], [], []

    n_samples = len(data)
    logger.info(f"Attempting to split {n_samples} samples. Val ratio: {val_size}, Test ratio: {test_size}")

    # Ensure ratios are valid
    if val_size < 0 or val_size >= 1 or test_size < 0 or test_size >= 1 or (val_size + test_size >= 1):
        logger.error(f"Invalid split ratios: val={val_size}, test={test_size}. Ratios must be [0, 1) and sum < 1.")
        # Default behavior: return all data as train if ratios invalid
        return data, [], []

    try:
        # Extract labels for stratification
        stratify_labels = [d['stratify_label'] for d in data]
    except KeyError:
        logger.error("All samples must have a 'stratify_label' key for splitting.")
        return [], [], [] # Cannot proceed without labels

    # Check for labels with only one sample, which breaks stratification
    unique_labels, counts = np.unique(stratify_labels, return_counts=True)
    labels_with_one_sample = unique_labels[counts == 1]

    if len(labels_with_one_sample) > 0:
        logger.warning(f"Labels with only 1 sample found: {list(labels_with_one_sample)}. Stratification might be unstable or fail. Consider merging/removing.")
        # Proceeding anyway, but train_test_split might raise errors later if splits result in single-sample classes

    train_data, val_data, test_data = [], [], []

    # --- First Split: Separate Test Set (if test_size > 0) ---
    if test_size > 0:
        remaining_data = data
        remaining_labels = stratify_labels
        try:
            logger.info(f"Splitting off test set ({test_size * 100:.1f}%)...")
            train_val_indices, test_indices = train_test_split(
                range(n_samples),
                test_size=test_size,
                random_state=random_state,
                stratify=remaining_labels
            )
            train_val_data = [remaining_data[i] for i in train_val_indices]
            test_data = [remaining_data[i] for i in test_indices]
            train_val_labels = [remaining_labels[i] for i in train_val_indices]
            logger.info(f"Split complete: Train/Val pool = {len(train_val_data)}, Test = {len(test_data)}")
        except ValueError as e:
            logger.warning(f"Stratified split for test set failed: {e}. Falling back to non-stratified split for test set.")
            train_val_indices, test_indices = train_test_split(
                range(n_samples),
                test_size=test_size,
                random_state=random_state
            )
            train_val_data = [remaining_data[i] for i in train_val_indices]
            test_data = [remaining_data[i] for i in test_indices]
            # We don't need train_val_labels if the next split is non-stratified
            train_val_labels = [stratify_labels[i] for i in train_val_indices] # Still get them in case next split works

    else: # No test set needed, all data goes to train/val pool
        logger.info("No test set requested (test_size=0). Using all data for train/val split.")
        train_val_data = data
        train_val_labels = stratify_labels
        test_data = []

    # --- Second Split: Separate Validation Set from Train/Val Pool (if val_size > 0) ---
    if val_size > 0 and len(train_val_data) > 0:
        # Adjust val_size relative to the size of the train/val pool
        if test_size > 0: # Need to adjust val_size because test set was removed
             relative_val_size = val_size / (1.0 - test_size)
        else: # No test set removed, val_size is already relative to the whole pool
             relative_val_size = val_size

        # Ensure relative_val_size is valid and meaningful
        if relative_val_size <= 0 or relative_val_size >= 1:
            logger.warning(f"Calculated relative validation size ({relative_val_size:.3f}) is invalid or zero. Assigning all remaining data to train set.")
            train_data = train_val_data
            val_data = []
        elif len(train_val_data) < 2: # Cannot split if only 1 sample left
             logger.warning(f"Only {len(train_val_data)} sample(s) left for train/val split. Assigning all to train set.")
             train_data = train_val_data
             val_data = []
        else:
            try:
                logger.info(f"Splitting off validation set ({relative_val_size * 100:.1f}% of remaining)...")
                train_indices, val_indices = train_test_split(
                    range(len(train_val_data)),
                    test_size=relative_val_size,
                    random_state=random_state,
                    stratify=train_val_labels
                )
                train_data = [train_val_data[i] for i in train_indices]
                val_data = [train_val_data[i] for i in val_indices]
                logger.info(f"Split complete: Train = {len(train_data)}, Validation = {len(val_data)}")
            except ValueError as e:
                logger.warning(f"Stratified split for validation set failed: {e}. Falling back to non-stratified split for validation set.")
                # Check again if we have enough samples for non-stratified split
                if len(train_val_data) >= 2:
                    train_indices, val_indices = train_test_split(
                        range(len(train_val_data)),
                        test_size=relative_val_size,
                        random_state=random_state
                    )
                    train_data = [train_val_data[i] for i in train_indices]
                    val_data = [train_val_data[i] for i in val_indices]
                else: # Should not happen based on earlier check, but safeguard
                    logger.warning("Cannot perform non-stratified split with < 2 samples. Assigning all to train.")
                    train_data = train_val_data
                    val_data = []

    else: # No validation set needed or train/val pool is empty
        logger.info("No validation set requested (val_size=0) or train/val pool empty. Assigning remaining data to train set.")
        train_data = train_val_data # Assign whatever is left (could be empty)
        val_data = []

    # Final check on sizes
    logger.info(f"Final split sizes: Train={len(train_data)}, Validation={len(val_data)}, Test={len(test_data)}")
    if len(train_data) + len(val_data) + len(test_data) != n_samples:
        logger.warning("Total samples after split do not match initial count. Check logic.")

    # Sanity check: Ensure no overlap between sets based on IDs
    train_ids = {d['image_id'] for d in train_data}
    val_ids = {d['image_id'] for d in val_data}
    test_ids = {d['image_id'] for d in test_data}
    if train_ids.intersection(val_ids): logger.error("Overlap detected between Train and Validation sets!")
    if train_ids.intersection(test_ids): logger.error("Overlap detected between Train and Test sets!")
    if val_ids.intersection(test_ids): logger.error("Overlap detected between Validation and Test sets!")


    return train_data, val_data, test_data

**Cell 6: RAG Components (Text-Based Embeddings)**

In [6]:
class EmbeddingGenerator:
    """Handles generation of text embeddings using SentenceTransformer."""
    def __init__(self, model_name: str = TEXT_EMBEDDING_MODEL, device: Optional[str] = None):
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info(f"Initializing SentenceTransformer model: {model_name} on device: {self.device}")
        try:
            self.model = SentenceTransformer(model_name, device=self.device)
            # Test encoding to get embedding dimension
            test_emb = self.model.encode(["test sentence"])
            self.embedding_dim = test_emb.shape[1]
            logger.info(f"SentenceTransformer model loaded successfully. Embedding dimension: {self.embedding_dim}")
        except Exception as e:
            logger.error(f"Failed to load SentenceTransformer model '{model_name}': {e}", exc_info=True)
            raise # Re-raise the exception to halt execution if model loading fails

    def generate_embeddings(self, texts: List[str], batch_size: int = 32) -> Optional[np.ndarray]:
        """Generates embeddings for a list of texts."""
        if not texts:
            logger.warning("Received empty list of texts for embedding generation.")
            return None
        logger.info(f"Generating embeddings for {len(texts)} texts...")
        try:
            # Use tqdm for progress bar if list is large
            show_progress = len(texts) > 1000
            embeddings = self.model.encode(
                texts,
                batch_size=batch_size,
                show_progress_bar=show_progress,
                convert_to_numpy=True,
                device=self.device
            )
            logger.info(f"Generated embeddings with shape: {embeddings.shape}")
            return embeddings
        except Exception as e:
            logger.error(f"Error during embedding generation: {e}", exc_info=True)
            return None

    def generate_fused_embeddings(self, ocr_texts: List[str], triples_texts: List[str], batch_size: int = 32) -> Optional[np.ndarray]:
        """Generates embeddings for OCR and triples, then fuses them (e.g., concatenation)."""
        logger.info("Generating fused text embeddings (OCR + Triples)...")
        if len(ocr_texts) != len(triples_texts):
            logger.error(f"Length mismatch between OCR texts ({len(ocr_texts)}) and Triples texts ({len(triples_texts)}). Cannot fuse.")
            return None

        # Generate embeddings for OCR and Triples separately
        # Handle cases where one or both might be empty lists correctly
        ocr_embeddings = self.generate_embeddings(ocr_texts, batch_size=batch_size) if ocr_texts else None
        triples_embeddings = self.generate_embeddings(triples_texts, batch_size=batch_size) if triples_texts else None

        # Check if generation was successful
        if ocr_embeddings is None and triples_embeddings is None:
             logger.error("Both OCR and Triples embedding generation failed or produced None.")
             return None

        # If one is missing, use zeros of the correct dimension
        num_samples = len(ocr_texts) # Should be same as len(triples_texts)
        if ocr_embeddings is None:
             logger.warning("OCR embedding generation failed or list was empty. Using zero vectors.")
             ocr_embeddings = np.zeros((num_samples, self.embedding_dim), dtype=np.float32)
        if triples_embeddings is None:
             logger.warning("Triples embedding generation failed or list was empty. Using zero vectors.")
             triples_embeddings = np.zeros((num_samples, self.embedding_dim), dtype=np.float32)

        # Normalize embeddings before concatenation (optional, but often good practice)
        # Add small epsilon to avoid division by zero for zero vectors
        epsilon = 1e-12
        ocr_norm = np.linalg.norm(ocr_embeddings, axis=1, keepdims=True)
        triples_norm = np.linalg.norm(triples_embeddings, axis=1, keepdims=True)

        ocr_normalized = np.divide(ocr_embeddings, ocr_norm + epsilon, out=np.zeros_like(ocr_embeddings), where=(ocr_norm + epsilon)!=0)
        triples_normalized = np.divide(triples_embeddings, triples_norm + epsilon, out=np.zeros_like(triples_embeddings), where=(triples_norm + epsilon)!=0)


        # Simple concatenation for fusion
        fused_embeddings = np.concatenate([ocr_normalized, triples_normalized], axis=1)
        logger.info(f"Generated fused embeddings with shape: {fused_embeddings.shape}")

        # Clean up intermediate arrays
        del ocr_embeddings, triples_embeddings, ocr_norm, triples_norm, ocr_normalized, triples_normalized
        gc.collect()

        return fused_embeddings


class RAGRetriever:
    """Handles building and querying a FAISS index for text similarity retrieval."""
    def __init__(self, embeddings: Optional[np.ndarray], top_k: int = RETRIEVAL_K):
        self.top_k = top_k
        self.index = None
        self.dimension = 0

        if embeddings is not None and embeddings.size > 0:
            # Ensure embeddings are float32 for FAISS
            if embeddings.dtype != np.float32:
                 logger.warning(f"Embeddings dtype is {embeddings.dtype}, converting to float32 for FAISS.")
                 embeddings = embeddings.astype(np.float32)
            self.build_index(embeddings)
        else:
            logger.warning("RAGRetriever initialized with no embeddings. Retrieval will not be possible.")

    def build_index(self, embeddings: np.ndarray):
        """Builds a FAISS index from the provided embeddings."""
        if embeddings is None or embeddings.shape[0] == 0:
            logger.error("Cannot build FAISS index from empty or None embeddings.")
            return
        if embeddings.ndim != 2:
             logger.error(f"Embeddings must be 2D (samples, dimension), but got shape {embeddings.shape}. Cannot build index.")
             return

        self.dimension = embeddings.shape[1]
        n_samples = embeddings.shape[0]
        logger.info(f"Building FAISS IndexFlatL2 with dimension {self.dimension} for {n_samples} vectors.")

        try:
            self.index = faiss.IndexFlatL2(self.dimension) # Using L2 distance (Euclidean)
            self.index.add(embeddings)
            logger.info(f"FAISS index built successfully. Index size: {self.index.ntotal} vectors.")
        except Exception as e:
            logger.error(f"Error building FAISS index: {e}", exc_info=True)
            self.index = None # Ensure index is None if building failed

    def retrieve_similar(self, query_embeddings: np.ndarray) -> Optional[np.ndarray]:
        """Retrieves indices of the top_k most similar items from the index."""
        if self.index is None:
            logger.error("FAISS index is not built. Cannot retrieve similar items.")
            return None
        if query_embeddings is None or query_embeddings.size == 0:
            logger.warning("Received empty or None query embeddings for retrieval.")
            return None # Return None for empty query

        # Ensure query is float32 and 2D
        if query_embeddings.dtype != np.float32:
             query_embeddings = query_embeddings.astype(np.float32)
        if query_embeddings.ndim == 1:
             query_embeddings = np.expand_dims(query_embeddings, axis=0) # Reshape (D,) to (1, D)

        if query_embeddings.shape[1] != self.dimension:
            logger.error(f"Query embedding dimension ({query_embeddings.shape[1]}) does not match index dimension ({self.dimension}).")
            return None

        # Determine number of neighbors to search for (k)
        # We search for top_k + 1 because the query item itself might be in the index
        # and we typically want to exclude self-retrieval for prompt construction.
        k = min(self.top_k + 1, self.index.ntotal) # Cannot retrieve more neighbors than exist in index
        if k == 0:
            logger.warning("FAISS index is empty (ntotal=0). Cannot retrieve.")
            # Return an empty array structure consistent with multiple queries
            return np.array([[] for _ in range(query_embeddings.shape[0])], dtype=int)


        logger.info(f"Searching for top {k} neighbors for {query_embeddings.shape[0]} query vectors...")
        try:
            # search returns distances (D) and indices (I)
            distances, indices = self.index.search(query_embeddings, k=k)
            logger.info(f"FAISS search completed. Found indices shape: {indices.shape}")

            # Optional: Exclude self-retrieval if necessary (depends on whether queries are from the indexed data)
            # This requires knowing the original indices of the query embeddings if they are part of the training data.
            # For simplicity here, we return the raw indices including potential self.
            # The PromptConstructor might handle skipping the first result if needed.

            return indices # Shape: (num_queries, k)
        except Exception as e:
            logger.error(f"Error during FAISS search: {e}", exc_info=True)
            return None


class PromptConstructor:
    """Constructs prompts for the classification model, optionally including RAG examples."""
    def __init__(self, train_data: List[Dict], label_encoder: LabelEncoder):
        self.train_data = train_data
        self.label_encoder = label_encoder
        try:
             self.class_names = ", ".join(label_encoder.classes_)
             logger.info(f"PromptConstructor initialized with {len(train_data)} training examples. Target classes: {self.class_names}")
        except AttributeError:
             logger.error("LabelEncoder does not seem to be fitted yet (no classes_ attribute).")
             self.class_names = "ERROR_CLASSES_UNDEFINED"


    def construct_prompt(self, sample: Dict, retrieved_indices: Optional[List[int]] = None) -> str:
        """Constructs a classification prompt for a given sample, optionally adding similar examples."""

        # --- System/Task Instruction ---
        system_instruction = (
            f"Perform { 'multilabel' if isinstance(sample.get('original_labels', None), list) else 'multiclass' } "
            f"classification for the given meme's text and knowledge graph context.\n"
            f"Choose the most relevant category/categories from the following list: {self.class_names}.\n"
            f"{ 'Output all applicable labels separated by commas if multiple apply.' if isinstance(sample.get('original_labels', None), list) else 'Output only the single most likely label.' }\n\n"
        )
        prompt = system_instruction

        # --- Few-Shot Examples (RAG) ---
        if retrieved_indices:
            prompt += "Here are some potentially similar examples:\n\n"
            num_added = 0
            for idx in retrieved_indices:
                # Ensure index is valid and avoid retrieving the sample itself (if applicable - assumes indices match train_data)
                # Simple check: if the ID matches, skip. Assumes 'sample_id' or 'id' exists.
                current_id = sample.get('sample_id', sample.get('id'))
                try:
                    retrieved_sample = self.train_data[idx]
                    retrieved_id = retrieved_sample.get('sample_id', retrieved_sample.get('id'))

                    if current_id is not None and current_id == retrieved_id:
                        logger.debug(f"Skipping self-retrieval for index {idx} (ID: {current_id})")
                        continue # Skip self

                    ex_text = retrieved_sample.get("ocr_text", "N/A")
                    ex_triples = retrieved_sample.get("triples", "") # Cleaned triples should be here
                    ex_labels = retrieved_sample.get("original_labels", "N/A")

                    # Format labels nicely
                    if isinstance(ex_labels, list):
                        ex_label_str = ", ".join(ex_labels) if ex_labels else "None"
                    else: # Should be single string for anxiety
                        ex_label_str = str(ex_labels) if ex_labels is not None else "N/A"

                    prompt += f"--- Example {num_added + 1} ---\n"
                    prompt += f"Example Text: {ex_text}\n"
                    if ex_triples:
                        prompt += f"Example Knowledge:\n{ex_triples}\n" # Keep newline for readability
                    prompt += f"Example Category: {ex_label_str}\n\n"
                    num_added += 1
                    if num_added >= RETRIEVAL_K: # Limit to K examples even if more retrieved
                         break

                except IndexError:
                    logger.warning(f"Retrieved index {idx} is out of bounds for train_data (size {len(self.train_data)}).")
                except Exception as e:
                    logger.error(f"Error processing retrieved example at index {idx}: {e}")

            if num_added > 0:
                 prompt += "--- End of Examples ---\n\n"
            else:
                 prompt += "No similar examples found or provided.\n\n"


        # --- Current Sample to Classify ---
        prompt += "Now, classify the following meme:\n\n"
        current_text = sample.get('ocr_text', 'N/A')
        current_triples = sample.get('triples', '') # Assumes triples are already cleaned

        prompt += f"Text: {current_text}\n"
        if current_triples:
            prompt += f"Knowledge:\n{current_triples}\n" # Keep newline
        prompt += "Category:" # Model should predict what comes after this

        return prompt.strip() # Remove any trailing whitespace

**Cell 7: Dataset Classes (Handles Image Features)**

In [7]:
class AnxietyDataset(Dataset):
    """Dataset for single-label Anxiety classification (multimodal)."""
    def __init__(self,
                 samples: List[Dict],
                 prompts: List[str],
                 tokenizer: BartTokenizer,
                 max_len: int,
                 label_encoder: LabelEncoder,
                 image_features_map: Dict[str, torch.Tensor],
                 visual_feature_dim: int = VISUAL_FEATURE_DIM):
        self.samples = samples
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.label_encoder = label_encoder
        self.image_features_map = image_features_map
        self.visual_feature_dim = visual_feature_dim # Store expected dim
        self.default_image_feature = torch.zeros(self.visual_feature_dim, dtype=torch.float) # Pre-create zero tensor

        logger.info(f"AnxietyDataset created with {len(samples)} samples.")
        if len(samples) != len(prompts):
            logger.warning(f"Mismatch between number of samples ({len(samples)}) and prompts ({len(prompts)}).")

        # Pre-verify image features for a few samples (optional)
        missing_count = 0
        for i in range(min(5, len(samples))):
            sample_id = samples[i].get('image_id')
            if sample_id not in image_features_map:
                 missing_count += 1
        if missing_count > 0:
             logger.warning(f"In first 5 samples, {missing_count} are missing pre-computed image features.")


    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        prompt = self.prompts[idx]
        label_str = sample.get("original_labels", None) # Expecting single string label
        image_id = sample.get("image_id", None) # Get the image identifier

        # Encode label
        label_idx = -1 # Use -1 as placeholder for missing/unknown
        if label_str is not None:
            try:
                label_idx = self.label_encoder.transform([label_str])[0]
            except ValueError:
                logger.error(f"Label '{label_str}' in sample {idx} (ID: {sample.get('id', 'N/A')}) not found in LabelEncoder classes: {self.label_encoder.classes_}. Assigning index 0 (or handle differently).")
                # Decide how to handle unknown labels during training/eval
                # Option 1: Assign a default index (e.g., 0)
                label_idx = 0 # Or find an 'unknown' class index if available
                # Option 2: Raise an error
                # raise ValueError(f"Unknown label encountered: {label_str}")
                # Option 3: Skip the sample (would require changes in DataLoader/training loop)
            except Exception as e:
                 logger.error(f"Error encoding label '{label_str}' for sample {idx}: {e}")
                 label_idx = 0 # Fallback to 0

        # Get image features
        if image_id and image_id in self.image_features_map:
            img_features = self.image_features_map[image_id]
            # Ensure the loaded feature has the correct dimension
            if img_features.shape[0] != self.visual_feature_dim:
                 logger.warning(f"Image feature for ID {image_id} has incorrect dimension {img_features.shape[0]}, expected {self.visual_feature_dim}. Using zeros.")
                 img_features = self.default_image_feature
            # Ensure dtype is float
            if img_features.dtype != torch.float:
                 img_features = img_features.float()
        else:
            logger.debug(f"Image features missing for image_id '{image_id}' in sample {idx}. Using default zero vector.")
            img_features = self.default_image_feature


        # Tokenize prompt
        try:
            encoding = self.tokenizer(
                prompt,
                max_length=self.max_len,
                padding="max_length", # Pad to max_len
                truncation=True,      # Truncate if longer
                return_tensors="pt"   # Return PyTorch tensors
            )
            # Squeeze to remove the batch dimension added by tokenizer
            input_ids = encoding["input_ids"].squeeze(0)
            attention_mask = encoding["attention_mask"].squeeze(0)
        except Exception as e:
             logger.error(f"Error tokenizing prompt for sample {idx} (ID: {sample.get('id', 'N/A')}): {e}", exc_info=True)
             # Return dummy data or raise error
             input_ids = torch.zeros(self.max_len, dtype=torch.long)
             attention_mask = torch.zeros(self.max_len, dtype=torch.long)
             # Keep label_idx and img_features as potentially valid


        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "image_features": img_features, # Return the (potentially zero) image features
            "label": torch.tensor(label_idx, dtype=torch.long) # CrossEntropyLoss expects long indices
        }


class DepressionDataset(Dataset):
    """Dataset for multi-label Depression classification (multimodal)."""
    def __init__(self,
                 samples: List[Dict],
                 prompts: List[str],
                 tokenizer: BartTokenizer,
                 max_len: int,
                 label_encoder: LabelEncoder,
                 image_features_map: Dict[str, torch.Tensor],
                 visual_feature_dim: int = VISUAL_FEATURE_DIM):
        self.samples = samples
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.label_encoder = label_encoder
        self.num_labels = len(label_encoder.classes_)
        self.image_features_map = image_features_map
        self.visual_feature_dim = visual_feature_dim
        self.default_image_feature = torch.zeros(self.visual_feature_dim, dtype=torch.float)

        logger.info(f"DepressionDataset created with {len(samples)} samples and {self.num_labels} labels.")
        if len(samples) != len(prompts):
            logger.warning(f"Mismatch between number of samples ({len(samples)}) and prompts ({len(prompts)}).")

        # Pre-verify image features for a few samples (optional)
        missing_count = 0
        for i in range(min(5, len(samples))):
            sample_id = samples[i].get('image_id')
            if sample_id not in image_features_map:
                 missing_count += 1
        if missing_count > 0:
             logger.warning(f"In first 5 samples, {missing_count} are missing pre-computed image features.")


    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        prompt = self.prompts[idx]
        label_str_list = sample.get("original_labels", []) # Expecting list of strings
        image_id = sample.get("image_id", None)

        # Create multi-hot encoded label vector
        multi_hot_label = torch.zeros(self.num_labels, dtype=torch.float) # Use float for BCEWithLogitsLoss
        if not isinstance(label_str_list, list):
             logger.warning(f"Expected label_str_list to be a list for sample {idx}, but got {type(label_str_list)}. Treating as empty.")
             label_str_list = []

        if not label_str_list:
            logger.warning(f"Sample {idx} (ID: {sample.get('id', 'N/A')}) has an empty label list.")
            # Decide how to handle samples with no labels (e.g., skip, assign 'Unknown')
            # Here, we just leave the multi_hot_label as all zeros.
        else:
            for label_str in label_str_list:
                try:
                    label_idx = self.label_encoder.transform([label_str])[0]
                    if 0 <= label_idx < self.num_labels:
                        multi_hot_label[label_idx] = 1.0
                    else:
                         logger.error(f"Label index {label_idx} out of bounds for label '{label_str}' in sample {idx}.")
                except ValueError:
                    logger.error(f"Label '{label_str}' in sample {idx} (ID: {sample.get('id', 'N/A')}) not found in LabelEncoder classes: {self.label_encoder.classes_}. Skipping this label.")
                except Exception as e:
                     logger.error(f"Error encoding label '{label_str}' for sample {idx}: {e}")

        # Get image features
        if image_id and image_id in self.image_features_map:
            img_features = self.image_features_map[image_id]
            if img_features.shape[0] != self.visual_feature_dim:
                 logger.warning(f"Image feature for ID {image_id} has incorrect dimension {img_features.shape[0]}, expected {self.visual_feature_dim}. Using zeros.")
                 img_features = self.default_image_feature
            if img_features.dtype != torch.float:
                 img_features = img_features.float()
        else:
            logger.debug(f"Image features missing for image_id '{image_id}' in sample {idx}. Using default zero vector.")
            img_features = self.default_image_feature

        # Tokenize prompt
        try:
            encoding = self.tokenizer(
                prompt,
                max_length=self.max_len,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )
            input_ids = encoding["input_ids"].squeeze(0)
            attention_mask = encoding["attention_mask"].squeeze(0)
        except Exception as e:
             logger.error(f"Error tokenizing prompt for sample {idx} (ID: {sample.get('id', 'N/A')}): {e}", exc_info=True)
             input_ids = torch.zeros(self.max_len, dtype=torch.long)
             attention_mask = torch.zeros(self.max_len, dtype=torch.long)


        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "image_features": img_features, # Return the image features
            "label": multi_hot_label # Return the multi-hot vector
        }

**Cell 8: Model Definition (Fusion Classifier)**

In [8]:
# CELL 8: Model Definition (NEW Fusion Classifier with Cross-Attention Simulation - Refined Init)
from torch.nn import MultiheadAttention # Ensure import

class CrossAttentionModule(nn.Module):
    """A simplified module for bidirectional cross-attention."""
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.txt_q_vis_kv_attn = MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.vis_q_txt_kv_attn = MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        # Optional: Add FFN layers if needed
        # self.ffn = nn.Sequential(...)

    def forward(self, text_features, visual_feature):
        # Text attends to Vision
        txt_attn_output, _ = self.txt_q_vis_kv_attn(query=text_features, key=visual_feature, value=visual_feature)
        attended_text = self.norm1(text_features + txt_attn_output)

        # Vision attends to Text
        vis_attn_output, _ = self.vis_q_txt_kv_attn(query=visual_feature, key=attended_text, value=attended_text)
        attended_visual = self.norm2(visual_feature + vis_attn_output)

        return attended_text, attended_visual


class M3H_CrossAttentionClassifier(nn.Module):
    """Classifier using BART encoder, visual features, and cross-attention."""
    def __init__(self,
                 num_labels: int,
                 pretrained_model_name: str, # Removed default
                 is_multilabel: bool = False,
                 visual_feature_dim: int = 768, # Use default from config
                 num_attention_heads: int = 8,
                 dropout_prob: float = 0.1): # Use default from config
        super().__init__()
        self.num_labels = num_labels
        self.is_multilabel = is_multilabel
        self.visual_feature_dim = visual_feature_dim

        logger.info(f"--- Initializing M3H_CrossAttentionClassifier ---")
        logger.info(f"  Base Text Model: {pretrained_model_name}")
        logger.info(f"  Task Type: {'Multi-Label' if is_multilabel else 'Single-Label'}")
        logger.info(f"  Num Labels: {num_labels}")
        logger.info(f"  Visual Dim (Input): {visual_feature_dim}")
        logger.info(f"  Attention Heads: {num_attention_heads}")
        logger.info(f"  Dropout: {dropout_prob}")

        try:
            # Load BART base model FIRST to get hidden size
            logger.info("  Loading BART base model...")
            self.bart = BartModel.from_pretrained(pretrained_model_name)
            self.text_feature_dim = self.bart.config.hidden_size
            logger.info(f"  BART Loaded. Hidden Dim: {self.text_feature_dim}")

            # Define Visual projection
            if self.visual_feature_dim != self.text_feature_dim:
                logger.info(f"  Adding visual projection layer: {visual_feature_dim} -> {self.text_feature_dim}")
                self.visual_projection = nn.Linear(self.visual_feature_dim, self.text_feature_dim)
            else:
                logger.info("  Visual dim matches BART. Using Identity projection.")
                self.visual_projection = nn.Identity()

            # Define Cross-Attention Module
            logger.info("  Defining CrossAttentionModule...")
            self.cross_attention = CrossAttentionModule(
                embed_dim=self.text_feature_dim,
                num_heads=num_attention_heads,
                dropout=dropout_prob
            )

            # Define Classifier Head (Using only attended text CLS for simplicity first)
            self.classifier_input_dim = self.text_feature_dim
            logger.info(f"  Defining Classifier Head (Input Dim: {self.classifier_input_dim})...")
            self.dropout = nn.Dropout(dropout_prob)
            self.classifier = nn.Linear(self.classifier_input_dim, num_labels)

            # Define Loss Function
            if self.is_multilabel:
                self.loss_fct = nn.BCEWithLogitsLoss()
                logger.info("  Using BCEWithLogitsLoss.")
            else:
                self.loss_fct = nn.CrossEntropyLoss()
                logger.info("  Using CrossEntropyLoss.")

            logger.info("--- M3H_CrossAttentionClassifier Initialized Successfully ---")

        except Exception as e:
            logger.error(f"Error initializing M3H_CrossAttentionClassifier: {e}", exc_info=True)
            raise # Re-raise to prevent proceeding with a broken model

    def forward(self, input_ids, attention_mask, image_features, labels=None):
        # ... [Forward pass logic remains the same as previous Cell 8] ...
        # 1. BART Encoder
        encoder_outputs = self.bart.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

        # 2. Project Visual & Reshape
        projected_visual_features = self.visual_projection(image_features)
        visual_features_seq = projected_visual_features.unsqueeze(1) # (B, 1, Dim)

        # 3. Cross-Attention
        attended_text, attended_visual = self.cross_attention(
            text_features=encoder_outputs,
            visual_feature=visual_features_seq
        )

        # 4. Pool & Dropout
        pooled_output = attended_text[:, 0, :] # Use CLS token representation
        final_representation = self.dropout(pooled_output)

        # 5. Classify
        logits = self.classifier(final_representation)

        # 6. Loss Calculation
        loss = None
        if labels is not None:
            try:
                if self.is_multilabel:
                    loss = self.loss_fct(logits, labels.float())
                else:
                    loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            except Exception as loss_e:
                 logger.error(f"Error calculating loss: {loss_e}. Logits shape: {logits.shape}, Labels shape: {labels.shape}, Is MultiLabel: {self.is_multilabel}", exc_info=True)
                 # Return logits but indicate loss calculation failed
                 loss = None # Set loss to None if calculation fails


        # Return simple output object
        class SimpleOutput: pass
        output = SimpleOutput()
        output.loss = loss
        output.logits = logits
        return output

**Cell 9: Training Function (Trains ONE Model)**

In [9]:
# CELL 9: Training Function (Trains ONE Model)

def train_and_evaluate(
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler._LRScheduler,
    device: torch.device,
    num_epochs: int,
    output_dir: str, # Directory for saving models and logs for this specific run
    label_encoder: LabelEncoder,
    is_multilabel: bool
) -> Tuple[Dict[str, List], str]:
    """Trains and evaluates a single model instance."""

    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Starting training run. Output directory: {output_dir}")
    logger.info(f"Task Type: {'Multi-Label' if is_multilabel else 'Single-Label (Multiclass)'}")

    best_val_metric = 0.0 # Use Macro F1 for single-label, Micro F1 for multi-label? Or choose one consistently.
    best_val_f1_macro = 0.0 # Track best macro F1 specifically for model saving
    best_model_path = os.path.join(output_dir, "best_model_macro_f1.pt")
    last_model_path = os.path.join(output_dir, "last_model.pt")
    log_file = os.path.join(output_dir, "training_log.csv")
    training_logs = [] # Store logs for saving to CSV
    history = defaultdict(list) # Store metrics history

    total_steps = len(train_dataloader) * num_epochs
    logger.info(f"Total training steps: {total_steps} ({len(train_dataloader)} steps/epoch * {num_epochs} epochs)")

    try:
        for epoch in range(num_epochs):
            epoch_num = epoch + 1
            logger.info(f"--- Epoch {epoch_num}/{num_epochs} ---")

            # --- Training Phase ---
            model.train() # Set model to training mode
            total_train_loss = 0.0
            all_train_logits = []
            all_train_labels_raw = [] # Store raw labels from dataloader

            train_progress = tqdm(train_dataloader, desc=f"Train Epoch {epoch_num}", leave=False)
            for batch_idx, batch in enumerate(train_progress):
                optimizer.zero_grad()

                # Move batch to device
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                image_features = batch["image_features"].to(device) # Get image features
                labels = batch["label"].to(device)

                try:
                    # Forward pass
                    outputs = model(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    image_features=image_features, # Pass image features
                                    labels=labels)
                    loss = outputs.loss
                    logits = outputs.logits

                    if loss is None:
                        logger.error(f"Epoch {epoch_num}, Batch {batch_idx}: Loss is None. Check model forward pass and loss calculation.")
                        continue # Skip batch if loss is None

                    # Backward pass and optimization
                    loss.backward()
                    # Gradient clipping (optional but recommended)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    scheduler.step() # Update learning rate

                    total_train_loss += loss.item()

                    # Store logits and labels for epoch metrics calculation
                    all_train_logits.append(logits.detach().cpu())
                    all_train_labels_raw.append(labels.cpu())

                    # Update progress bar
                    train_progress.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr()[0])

                except Exception as e:
                    logger.error(f"Error during training batch {batch_idx} in epoch {epoch_num}: {e}", exc_info=True)
                    # Decide whether to continue or stop training
                    # For now, just log and continue to the next batch
                    continue

            # Calculate average training loss for the epoch
            avg_train_loss = total_train_loss / len(train_dataloader) if len(train_dataloader) > 0 else 0.0

            # Calculate training metrics for the epoch
            if not all_train_logits or not all_train_labels_raw:
                 logger.warning(f"Epoch {epoch_num}: No training logits or labels collected. Skipping training metrics calculation.")
                 train_f1_micro, train_f1_macro, train_f1_weighted = 0.0, 0.0, 0.0
            else:
                all_train_logits_cat = torch.cat(all_train_logits, dim=0)
                all_train_labels_raw_cat = torch.cat(all_train_labels_raw, dim=0)

                if is_multilabel:
                    train_probs = torch.sigmoid(all_train_logits_cat).numpy()
                    train_preds = (train_probs > 0.5).astype(int)
                    train_labels = all_train_labels_raw_cat.numpy().astype(int)
                    train_f1_micro = f1_score(train_labels, train_preds, average="micro", zero_division=0)
                    train_f1_macro = f1_score(train_labels, train_preds, average="macro", zero_division=0)
                    train_f1_weighted = f1_score(train_labels, train_preds, average="weighted", zero_division=0)
                    train_accuracy = accuracy_score(train_labels, train_preds) # Subset accuracy
                    train_hamming = hamming_loss(train_labels, train_preds)
                    logger.info(f"Epoch {epoch_num} Train - Loss: {avg_train_loss:.4f}, Acc: {train_accuracy:.4f}, Hamming: {train_hamming:.4f}, MicroF1: {train_f1_micro:.4f}, MacroF1: {train_f1_macro:.4f}")
                else: # Single-label
                    train_preds = torch.argmax(all_train_logits_cat, dim=1).numpy()
                    train_labels = all_train_labels_raw_cat.numpy()
                    train_f1_micro = f1_score(train_labels, train_preds, average="micro", zero_division=0)
                    train_f1_macro = f1_score(train_labels, train_preds, average="macro", zero_division=0)
                    train_f1_weighted = f1_score(train_labels, train_preds, average="weighted", zero_division=0)
                    train_accuracy = accuracy_score(train_labels, train_preds)
                    train_hamming = 0.0 # Not typically used for single-label multiclass
                    logger.info(f"Epoch {epoch_num} Train - Loss: {avg_train_loss:.4f}, Acc: {train_accuracy:.4f}, MacroF1: {train_f1_macro:.4f}, WeightedF1: {train_f1_weighted:.4f}")


            # --- Validation Phase ---
            model.eval() # Set model to evaluation mode
            total_val_loss = 0.0
            all_val_logits = []
            all_val_labels_raw = []

            val_progress = tqdm(val_dataloader, desc=f"Val Epoch {epoch_num}", leave=False)
            with torch.no_grad(): # Disable gradient calculation for validation
                for batch in val_progress:
                    # Move batch to device
                    input_ids = batch["input_ids"].to(device)
                    attention_mask = batch["attention_mask"].to(device)
                    image_features = batch["image_features"].to(device) # Get image features
                    labels = batch["label"].to(device)

                    try:
                        # Forward pass
                        outputs = model(input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        image_features=image_features, # Pass image features
                                        labels=labels)
                        loss = outputs.loss
                        logits = outputs.logits

                        if loss is not None:
                            total_val_loss += loss.item()

                        # Store logits and labels for epoch metrics calculation
                        all_val_logits.append(logits.cpu())
                        all_val_labels_raw.append(labels.cpu())

                    except Exception as e:
                        logger.error(f"Error during validation batch in epoch {epoch_num}: {e}", exc_info=True)
                        continue

            # Calculate average validation loss for the epoch
            avg_val_loss = total_val_loss / len(val_dataloader) if len(val_dataloader) > 0 else 0.0

            # Calculate validation metrics for the epoch
            if not all_val_logits or not all_val_labels_raw:
                 logger.warning(f"Epoch {epoch_num}: No validation logits or labels collected. Skipping validation metrics calculation.")
                 val_f1_micro, val_f1_macro, val_f1_weighted, val_accuracy, val_hamming = 0.0, 0.0, 0.0, 0.0, 1.0
            else:
                all_val_logits_cat = torch.cat(all_val_logits, dim=0)
                all_val_labels_raw_cat = torch.cat(all_val_labels_raw, dim=0)

                if is_multilabel:
                    val_probs = torch.sigmoid(all_val_logits_cat).numpy()
                    val_preds = (val_probs > 0.5).astype(int)
                    val_labels = all_val_labels_raw_cat.numpy().astype(int)
                    val_f1_micro = f1_score(val_labels, val_preds, average="micro", zero_division=0)
                    val_f1_macro = f1_score(val_labels, val_preds, average="macro", zero_division=0)
                    val_f1_weighted = f1_score(val_labels, val_preds, average="weighted", zero_division=0)
                    val_accuracy = accuracy_score(val_labels, val_preds) # Subset accuracy
                    val_hamming = hamming_loss(val_labels, val_preds)
                    logger.info(f"Epoch {epoch_num} Val   - Loss: {avg_val_loss:.4f}, Acc: {val_accuracy:.4f}, Hamming: {val_hamming:.4f}, MicroF1: {val_f1_micro:.4f}, MacroF1: {val_f1_macro:.4f}")
                else: # Single-label
                    val_preds = torch.argmax(all_val_logits_cat, dim=1).numpy()
                    val_labels = all_val_labels_raw_cat.numpy()
                    val_f1_micro = f1_score(val_labels, val_preds, average="micro", zero_division=0)
                    val_f1_macro = f1_score(val_labels, val_preds, average="macro", zero_division=0)
                    val_f1_weighted = f1_score(val_labels, val_preds, average="weighted", zero_division=0)
                    val_accuracy = accuracy_score(val_labels, val_preds)
                    val_hamming = 0.0 # Not applicable
                    logger.info(f"Epoch {epoch_num} Val   - Loss: {avg_val_loss:.4f}, Acc: {val_accuracy:.4f}, MacroF1: {val_f1_macro:.4f}, WeightedF1: {val_f1_weighted:.4f}")


            # --- Logging and Saving ---
            # Append metrics to history
            history['epoch'].append(epoch_num)
            history['train_loss'].append(avg_train_loss)
            history['val_loss'].append(avg_val_loss)
            history['train_f1_macro'].append(train_f1_macro)
            history['val_f1_macro'].append(val_f1_macro)
            history['train_f1_micro'].append(train_f1_micro)
            history['val_f1_micro'].append(val_f1_micro)
            history['train_f1_weighted'].append(train_f1_weighted)
            history['val_f1_weighted'].append(val_f1_weighted)
            history['train_accuracy'].append(train_accuracy)
            history['val_accuracy'].append(val_accuracy)
            history['train_hamming'].append(train_hamming)
            history['val_hamming'].append(val_hamming)
            history['learning_rate'].append(scheduler.get_last_lr()[0])

            # Store log entry for CSV
            current_log = {
                'epoch': epoch_num,
                'train_loss': avg_train_loss, 'val_loss': avg_val_loss,
                'train_f1_macro': train_f1_macro, 'val_f1_macro': val_f1_macro,
                'train_f1_micro': train_f1_micro, 'val_f1_micro': val_f1_micro,
                'train_f1_weighted': train_f1_weighted, 'val_f1_weighted': val_f1_weighted,
                'train_accuracy': train_accuracy, 'val_accuracy': val_accuracy,
                'train_hamming': train_hamming, 'val_hamming': val_hamming,
                'learning_rate': scheduler.get_last_lr()[0]
            }
            training_logs.append(current_log)

            # Save the best model based on validation macro F1
            if val_f1_macro > best_val_f1_macro:
                best_val_f1_macro = val_f1_macro
                logger.info(f"Epoch {epoch_num}: New best validation Macro F1: {best_val_f1_macro:.4f}. Saving model to {best_model_path}")
                torch.save(model.state_dict(), best_model_path)
            else:
                 logger.info(f"Epoch {epoch_num}: Validation Macro F1 ({val_f1_macro:.4f}) did not improve from best ({best_val_f1_macro:.4f}).")


            # Save the model from the last epoch
            torch.save(model.state_dict(), last_model_path)

            # Save logs to CSV
            pd.DataFrame(training_logs).to_csv(log_file, index=False)

            # Clean up GPU memory
            del all_train_logits, all_train_labels_raw, all_val_logits, all_val_labels_raw
            del all_train_logits_cat, all_train_labels_raw_cat # Explicitly delete concatenated tensors
            if 'all_val_logits_cat' in locals(): del all_val_logits_cat
            if 'all_val_labels_raw_cat' in locals(): del all_val_labels_raw_cat
            gc.collect()
            if device == torch.device('cuda'):
                torch.cuda.empty_cache()

        logger.info(f"Training finished after {num_epochs} epochs.")
        logger.info(f"Best Validation Macro F1 achieved: {best_val_f1_macro:.4f}")
        logger.info(f"Best model saved to: {best_model_path}")
        logger.info(f"Last model saved to: {last_model_path}")
        logger.info(f"Training logs saved to: {log_file}")

    except KeyboardInterrupt:
         logger.warning("Training interrupted by user (KeyboardInterrupt).")
         # Save current state if interrupted
         torch.save(model.state_dict(), last_model_path)
         pd.DataFrame(training_logs).to_csv(log_file, index=False)
         logger.info(f"Saved last model state to {last_model_path} and logs to {log_file}.")
    except Exception as e:
        logger.error(f"An unexpected error occurred during the training loop: {e}", exc_info=True)
        # Save logs even if error occurs
        pd.DataFrame(training_logs).to_csv(log_file, index=False)
        raise # Re-raise the exception after logging

    return history, best_model_path # Return history and path to the best model found

**Cell 10: Plotting Function**

In [10]:
def plot_training_history(history: Dict[str, List], output_dir: str, suffix: str = ""):
    """Plots training and validation metrics stored in the history dictionary."""
    if not history or not isinstance(history, dict):
        logger.warning("Cannot plot history: Invalid or empty history dictionary provided.")
        return

    # Check if essential keys exist and have data
    required_keys = ['train_loss', 'val_loss', 'train_f1_macro', 'val_f1_macro']
    if not all(key in history and history[key] for key in required_keys):
         logger.warning(f"History dict missing essential data ({required_keys}). Cannot generate plots.")
         # Log available keys for debugging:
         logger.debug(f"Available keys in history: {list(history.keys())}")
         # Try plotting available metrics anyway
         # return # Or uncomment to stop if essential plots can't be made

    num_epochs = len(history.get('epoch', history.get('train_loss', []))) # Use epoch key if available, else infer from loss
    if num_epochs == 0:
        logger.warning("History contains no epochs to plot.")
        return

    epochs_range = history.get('epoch', range(1, num_epochs + 1)) # Use actual epoch numbers if available

    plt.style.use('seaborn-v0_8-whitegrid') # Use a clean style
    fig, axs = plt.subplots(2, 2, figsize=(16, 12)) # Create 2x2 grid of subplots
    fig.suptitle(f'Training History {suffix}'.strip(), fontsize=16)

    # --- Plot 1: Loss ---
    ax = axs[0, 0]
    if 'train_loss' in history and history['train_loss']:
        ax.plot(epochs_range, history['train_loss'], 'o-', label='Train Loss', color='royalblue')
    if 'val_loss' in history and history['val_loss']:
        ax.plot(epochs_range, history['val_loss'], 'o-', label='Validation Loss', color='darkorange')
    ax.set_title('Training and Validation Loss')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    if ax.has_data(): ax.legend()
    ax.grid(True, linestyle='--', alpha=0.6)

    # --- Plot 2: Macro F1 Score ---
    ax = axs[0, 1]
    if 'train_f1_macro' in history and history['train_f1_macro']:
        ax.plot(epochs_range, history['train_f1_macro'], 'o-', label='Train Macro F1', color='royalblue')
    if 'val_f1_macro' in history and history['val_f1_macro']:
        ax.plot(epochs_range, history['val_f1_macro'], 'o-', label='Validation Macro F1', color='darkorange')
    ax.set_title('Macro F1 Score')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('F1 Score')
    ax.set_ylim(bottom=0, top=1.05) # F1 score between 0 and 1
    if ax.has_data(): ax.legend()
    ax.grid(True, linestyle='--', alpha=0.6)

    # --- Plot 3: Weighted F1 Score ---
    ax = axs[1, 0]
    if 'train_f1_weighted' in history and history['train_f1_weighted']:
        ax.plot(epochs_range, history['train_f1_weighted'], 'o-', label='Train Weighted F1', color='royalblue')
    if 'val_f1_weighted' in history and history['val_f1_weighted']:
        ax.plot(epochs_range, history['val_f1_weighted'], 'o-', label='Validation Weighted F1', color='darkorange')
    ax.set_title('Weighted F1 Score')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('F1 Score')
    ax.set_ylim(bottom=0, top=1.05)
    if ax.has_data(): ax.legend()
    ax.grid(True, linestyle='--', alpha=0.6)

    # --- Plot 4: Validation Accuracy & Hamming Loss ---
    ax = axs[1, 1]
    plot4_has_data = False
    if 'val_accuracy' in history and history['val_accuracy']:
         ax.plot(epochs_range, history['val_accuracy'], 'o-', label='Validation Accuracy', color='forestgreen')
         plot4_has_data = True
    if 'val_hamming' in history and history['val_hamming']:
         # Only plot Hamming if it's meaningful (e.g., > 0 for multilabel)
         if any(h > 0 for h in history['val_hamming']):
             ax.plot(epochs_range, history['val_hamming'], 'o-', label='Validation Hamming Loss', color='crimson')
             plot4_has_data = True
         else:
              logger.info("Skipping Hamming Loss plot as values seem to be zero (likely single-label).")

    ax.set_title('Validation Accuracy / Hamming Loss')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Metric Value')
    ax.set_ylim(bottom=-0.05, top=1.05) # Metrics typically between 0 and 1
    if plot4_has_data: ax.legend()
    ax.grid(True, linestyle='--', alpha=0.6)


    # --- Save and Show Plot ---
    plt.tight_layout(rect=[0, 0.03, 1, 0.96]) # Adjust layout to prevent title overlap
    plot_filename = f'training_history{suffix}.png'
    plot_path = os.path.join(output_dir, plot_filename)
    try:
        plt.savefig(plot_path, dpi=300) # Save with higher resolution
        logger.info(f"Training history plot saved to: {plot_path}")
    except Exception as e:
        logger.error(f"Failed to save training plot to {plot_path}: {e}")

    plt.show() # Display the plot in the notebook
    plt.close(fig) # Close the figure to free memory

**Cell 11: Evaluation Function (Ensemble)**

In [11]:
# CELL 11: Evaluation Function (Ensemble - Corrected Syntax)

def evaluate_ensemble(
    model_paths: List[str],
    dataloader: DataLoader,
    device: torch.device,
    label_encoder: LabelEncoder,
    is_multilabel: bool,
    num_labels: int,
    output_dir: str, # Base output directory for saving reports
    report_suffix: str = "eval_ensemble", # e.g., "validation_ensemble", "test_ensemble"
    base_model_name: str = BASE_TEXT_MODEL_NAME,
    visual_feature_dim: int = VISUAL_FEATURE_DIM
) -> Dict:
    """Evaluates an ENSEMBLE of models on a given dataloader."""
    if not dataloader:
        logger.warning(f"Dataloader for '{report_suffix}' is None or empty. Skipping evaluation.")
        return {'error': 'No data provided'}
    if not model_paths:
        logger.error("No model paths provided for ensemble evaluation.")
        return {'error': 'No models provided'}

    logger.info(f"--- Starting ENSEMBLE Evaluation ({report_suffix}) ---\n  Models: {len(model_paths)}, Samples: {len(dataloader.dataset)}")
    logger.info(f"  Task Type: {'Multi-Label' if is_multilabel else 'Single-Label (Multiclass)'}")

    all_individual_model_logits = [] # Store logits from each model
    all_labels_raw = None # Store true labels once

    # --- Get predictions from each model ---
    for model_idx, model_path in enumerate(model_paths):
        logger.info(f"Loading and evaluating model {model_idx + 1}/{len(model_paths)}: {os.path.basename(model_path)}")

        model = None # Ensure model is reset
        try:
            model = M3H_CrossAttentionClassifier(
                num_labels=num_labels,
                pretrained_model_name=base_model_name,
                is_multilabel=is_multilabel,
                visual_feature_dim=visual_feature_dim
            )
            try:
                model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False), strict=True)
                logger.debug("Model state loaded successfully (strict=True).")
            except RuntimeError as e:
                logger.warning(f"Strict state dict loading failed for model {model_idx+1}: {e}. Retrying with strict=False.")
                model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False), strict=False)
                logger.debug("Model state loaded with strict=False.")

            model.to(device)
            model.eval() # Set to evaluation mode
        except FileNotFoundError:
             logger.error(f"Model file not found: {model_path}. Skipping this model.")
             continue
        except Exception as e:
            logger.error(f"Failed to load model {model_idx + 1} from {model_path}: {e}", exc_info=True)
            continue # Skip this model

        # --- Collect logits and labels for this model ---
        current_model_logits = []
        batch_labels_list = [] # Only populated for model_idx == 0

        eval_progress = tqdm(dataloader, desc=f"Eval Model {model_idx + 1}", leave=False, ncols=100)
        with torch.no_grad(): # Disable gradients for evaluation
            for batch_idx, batch in enumerate(eval_progress):
                try:
                    input_ids = batch['input_ids'].to(device, non_blocking=True)
                    attention_mask = batch['attention_mask'].to(device, non_blocking=True)
                    image_features = batch['image_features'].to(device, non_blocking=True)
                    labels = batch['label'] # Keep labels on CPU initially for the first pass storage

                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        image_features=image_features,
                        labels=None # Don't need labels for forward pass here
                    )
                    current_model_logits.append(outputs.logits.cpu()) # Move logits to CPU

                    # Store labels only during the first model's pass
                    if model_idx == 0:
                        batch_labels_list.append(labels) # Append CPU tensor

                except Exception as e:
                    logger.error(f"Error during evaluation batch {batch_idx} for model {model_idx + 1}: {e}", exc_info=True)
                    # Continue to next batch if one fails

        # Concatenate logits for the current model
        if current_model_logits:
            all_individual_model_logits.append(torch.cat(current_model_logits, dim=0))
            logger.debug(f"Logits collected for model {model_idx+1}. Shape: {all_individual_model_logits[-1].shape}")
        else:
            logger.warning(f"No logits collected for model {model_idx + 1}. It will be skipped in the ensemble average.")

        # Concatenate all labels after the first model's pass
        if model_idx == 0 and batch_labels_list:
            all_labels_raw = torch.cat(batch_labels_list, dim=0)
            logger.debug(f"Labels collected. Shape: {all_labels_raw.shape}")
        elif model_idx == 0 and not batch_labels_list:
             logger.error("Failed to collect any labels during evaluation. Cannot calculate metrics.")
             return {'error': 'Label collection failed'}

        # Clean up model memory
        del model, current_model_logits, batch_labels_list, outputs
        gc.collect()
        if device == torch.device('cuda'):
            torch.cuda.empty_cache()


    # --- Aggregate predictions and calculate metrics ---
    if not all_individual_model_logits or all_labels_raw is None:
        logger.error("Evaluation failed: No valid logits collected from any model or labels are missing.")
        return {'error': 'Logit/Label collection failed'}

    # Stack logits from all models: shape (num_valid_models, num_samples, num_labels)
    try:
        stacked_logits = torch.stack(all_individual_model_logits, dim=0)
        logger.info(f"Logits collected from {stacked_logits.shape[0]} successfully evaluated models. Stacked shape: {stacked_logits.shape}")
    except Exception as stack_e:
        logger.error(f"Error stacking logits (check if shapes are consistent): {stack_e}", exc_info=True)
        return {'error': 'Logit stacking failed'}


    # Average logits across models: shape (num_samples, num_labels)
    avg_logits = torch.mean(stacked_logits, dim=0)

    # --- Calculate Metrics based on Task Type ---
    metrics_dict = {}
    classification_rep = "Classification report generation failed."
    confusion_matrix_report = "Confusion matrix generation failed."

    try:
        if is_multilabel:
            logger.info(f"Calculating multi-label metrics for {report_suffix} (Ensemble)...")
            probs = torch.sigmoid(avg_logits).numpy()
            preds = (probs > 0.5).astype(int) # Apply threshold
            labels = all_labels_raw.numpy().astype(int) # Ground truth multi-hot

            accuracy = accuracy_score(labels, preds) # Subset accuracy
            hamming = hamming_loss(labels, preds)
            f1_micro = f1_score(labels, preds, average='micro', zero_division=0)
            f1_macro = f1_score(labels, preds, average='macro', zero_division=0)
            f1_weighted = f1_score(labels, preds, average='weighted', zero_division=0)
            f1_samples = f1_score(labels, preds, average='samples', zero_division=0)

            metrics_dict = {
                'accuracy_subset': accuracy, 'hamming_loss': hamming,
                'micro_f1': f1_micro, 'macro_f1': f1_macro,
                'weighted_f1': f1_weighted, 'samples_f1': f1_samples
            }

            # Generate classification report
            try:
                 classification_rep = classification_report(
                     labels, preds,
                     target_names=label_encoder.classes_,
                     digits=4, zero_division=0
                 )
            except Exception as cr_e:
                 logger.error(f"Error generating classification report: {cr_e}")
                 classification_rep = f"Error generating classification report: {cr_e}"

            # Generate multi-label confusion matrix report
            try:
                 cm = multilabel_confusion_matrix(labels, preds)
                 # <<< FIXED LINE CONTINUATION ERROR HERE >>>
                 cm_report_path = os.path.join(output_dir, f"confusion_matrix_{report_suffix}.txt")
                 # <<< END FIX >>>
                 cm_lines = [f"--- Multilabel Confusion Matrices ({report_suffix}) ---"]
                 for i, label_name in enumerate(label_encoder.classes_):
                     cm_lines.append(f"\nLabel: {label_name} (Index {i})")
                     cm_lines.append(f"{cm[i]}") # TN, FP / FN, TP
                     cm_lines.append(f"  TN={cm[i][0,0]}, FP={cm[i][0,1]}, FN={cm[i][1,0]}, TP={cm[i][1,1]}")
                 confusion_matrix_report = "\n".join(cm_lines)
                 with open(cm_report_path, "w", encoding='utf-8') as f:
                     f.write(confusion_matrix_report)
                 logger.info(f"Multi-label confusion matrix report saved to: {cm_report_path}")
            except Exception as cm_e:
                 logger.error(f"Error generating multilabel confusion matrix: {cm_e}")
                 confusion_matrix_report = f"Error generating multilabel confusion matrix: {cm_e}"


        else: # Single-label (Multiclass)
            logger.info(f"Calculating single-label metrics for {report_suffix} (Ensemble)...")
            preds = torch.argmax(avg_logits, dim=1).numpy()
            labels = all_labels_raw.numpy()

            accuracy = accuracy_score(labels, preds)
            hamming = hamming_loss(labels, preds)
            f1_micro = f1_score(labels, preds, average='micro', zero_division=0)
            f1_macro = f1_score(labels, preds, average='macro', zero_division=0)
            f1_weighted = f1_score(labels, preds, average='weighted', zero_division=0)
            f1_samples = f1_score(labels, preds, average='samples', zero_division=0) # Less meaningful

            metrics_dict = {
                'accuracy': accuracy, 'hamming_loss': hamming,
                'micro_f1': f1_micro, 'macro_f1': f1_macro,
                'weighted_f1': f1_weighted, 'samples_f1': f1_samples
            }

            # Generate classification report
            try:
                present_labels = sorted(list(set(labels) | set(preds)))
                target_names = [label_encoder.classes_[i] for i in present_labels if 0 <= i < len(label_encoder.classes_)]
                valid_present_labels = [i for i in present_labels if 0 <= i < len(label_encoder.classes_)]
                if valid_present_labels:
                    classification_rep = classification_report(
                        labels, preds,
                        labels=valid_present_labels, target_names=target_names,
                        digits=4, zero_division=0
                    )
                else:
                    classification_rep = "No valid labels found in predictions or ground truth."
            except Exception as cr_e:
                 logger.error(f"Error generating classification report: {cr_e}")
                 classification_rep = f"Error generating classification report: {cr_e}"

            confusion_matrix_report = "Standard confusion matrix can be generated separately if needed."


    except Exception as e:
        logger.error(f"Error calculating metrics for {report_suffix}: {e}", exc_info=True)
        metrics_dict = {'error': f"Metric calculation failed: {e}"}
        classification_rep = f"Report generation failed due to metric error: {e}"

    # --- Save Report ---
    metrics_dict['report'] = classification_rep # Add report string to dictionary
    report_path = os.path.join(output_dir, f"classification_report_{report_suffix}.txt")
    try:
        with open(report_path, "w", encoding='utf-8') as f:
            f.write(f"--- Ensemble Evaluation Metrics ({report_suffix}) ---\n\n")
            for key, value in metrics_dict.items():
                if key != 'report':
                    f.write(f"{key.replace('_', ' ').title()}: {value:.4f if isinstance(value, float) else value}\n")
            f.write("\n--- Classification Report ---\n\n")
            f.write(metrics_dict.get('report', 'Report generation failed.'))
            f.write("\n\n")
            # Optionally add CM report for multi-label
            if is_multilabel:
                f.write(confusion_matrix_report)
        logger.info(f"Evaluation report saved to: {report_path}")
    except Exception as e:
        logger.error(f"Failed to save evaluation report to {report_path}: {e}")

    logger.info(f"--- Ensemble Evaluation ({report_suffix}) Complete ---\n")
    log_key = 'accuracy_subset' if is_multilabel else 'accuracy'
    logger.info(f"Final {report_suffix} {log_key}: {metrics_dict.get(log_key, 'N/A'):.4f}, Macro F1: {metrics_dict.get('macro_f1', 'N/A'):.4f}")

    return metrics_dict

**Cell 12: Prediction Function (Ensemble)**

In [12]:
def predict_ensemble(
    model_paths: List[str],
    tokenizer: BartTokenizer,
    label_encoder: LabelEncoder,
    text: str,
    triples: str = "",
    image_features: Optional[torch.Tensor] = None,
    device: Optional[torch.device] = None,
    max_length: int = MAX_LEN,
    is_multilabel: bool = False,
    threshold: float = 0.5,
    base_model_name: str = BASE_TEXT_MODEL_NAME,
    visual_feature_dim: int = VISUAL_FEATURE_DIM,
    num_labels: Optional[int] = None
) -> Dict:
    """Make ENSEMBLE prediction for a single example."""

    # Device configuration
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not model_paths:
        return {'error': 'No models provided'}

    if image_features is None:
        print("Warning: No image features provided. Using zero tensor.")
        image_features = torch.zeros(1, visual_feature_dim)
    elif image_features.shape == (visual_feature_dim,):
        image_features = image_features.unsqueeze(0)
    elif image_features.shape != (1, visual_feature_dim):
        return {'error': 'Incorrect image feature dimensions'}

    if num_labels is None:
        num_labels = len(label_encoder.classes_)

    # Prepare prompt
    class_names = ", ".join(label_encoder.classes_)
    prompt = f"Classify... Choose from: {class_names}.\n\n"
    prompt += f"Text: {text}\n"
    cleaned_triples = clean_triples(triples)
    if cleaned_triples:
        prompt += f"Knowledge: {cleaned_triples}\n"
    prompt += "Category:"

    # Tokenize
    try:
        encoding = tokenizer(prompt, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
    except Exception as e:
        print(f"Tokenization error: {e}")
        return {'error': f'Tokenization fail: {e}'}

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    image_features = image_features.float().to(device)

    # Collect predictions from each model
    all_logits = []
    for model_idx, model_path in enumerate(model_paths):
        try:
            model = M3H_CrossAttentionClassifier(
                num_labels=num_labels,
                pretrained_model_name=base_model_name,
                is_multilabel=is_multilabel,
                visual_feature_dim=visual_feature_dim
            )

            try:
                model.load_state_dict(torch.load(model_path, map_location=device), strict=True)
            except RuntimeError:
                model.load_state_dict(torch.load(model_path, map_location=device), strict=False)

            model.to(device)
            model.eval()

            with torch.no_grad():
                outputs = model(input_ids, attention_mask, image_features, labels=None)
                all_logits.append(outputs.logits.cpu())

            del model
            gc.collect()
            if device.type == 'cuda':
                torch.cuda.empty_cache()

        except Exception as e:
            print(f"Error loading model {model_idx + 1}: {e}")

    if not all_logits:
        return {'error': 'Prediction failed'}

    # Aggregate predictions
    stacked_logits = torch.stack(all_logits, dim=0)
    avg_logits = torch.mean(stacked_logits, dim=0).squeeze()

    prediction_result = {'probabilities': {}}

    try:
        if is_multilabel:
            probs = torch.sigmoid(avg_logits).numpy()
            predicted_indices = np.where(probs > threshold)[0]
            predicted_labels = label_encoder.inverse_transform(predicted_indices).tolist() if len(predicted_indices) > 0 else ["None"]
            prediction_result['predicted_labels'] = predicted_labels
            prediction_result['probabilities'] = {label_encoder.classes_[i]: float(probs[i]) for i in range(len(probs))}
        else:
            probs = torch.softmax(avg_logits, dim=0)
            pred_class_idx = torch.argmax(avg_logits).item()
            predicted_label = "Error: Index out of bounds"
            try:
                if 0 <= pred_class_idx < len(label_encoder.classes_):
                    predicted_label = label_encoder.inverse_transform([pred_class_idx])[0]
                else:
                    print(f"Prediction index {pred_class_idx} out of bounds.")
            except Exception as le_error:
                print(f"Label encoder error: {le_error}")

            prediction_result['predicted_class'] = predicted_label
            class_probs = probs.cpu().numpy()
            prediction_result['probabilities'] = {label_encoder.classes_[i]: float(class_probs[i]) for i in range(len(class_probs))}

    except Exception as e:
        print(f"Aggregation error: {e}")
        prediction_result = {'error': f'Aggregation fail: {e}'}

    return prediction_result


**Cell 13: Main Pipeline Function (Orchestrates Ensemble Training)**

In [None]:
# CELL 13: Main Pipeline Function (Orchestrates Ensemble Training - WITH RAG FOR TRAINING)
import time
import traceback
from collections import defaultdict

# <<< Add this helper function at the top of the cell or in another preceding cell >>>
def check_cuda_memory(step_name=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        logger.debug(f"CUDA Memory ({step_name}): Allocated={allocated:.3f} GB, Reserved={reserved:.3f} GB")
    else:
        logger.debug(f"CUDA not available ({step_name}). Skipping memory check.")
# <<< End helper function >>>


def run_ensemble_pipeline(
    dataset_type: str = 'anxiety',
    num_ensemble_models: int = NUM_ENSEMBLE_MODELS,
    base_seed: int = BASE_SEED,
    use_test_set: bool = True,
    val_split_ratio: float = 0.1,
    test_split_ratio: float = 0.2,
    batch_size_override: int = BATCH_SIZE,
    num_epochs_per_model: int = NUM_EPOCHS,
    embedding_model_name: str = TEXT_EMBEDDING_MODEL, # For RAG text embeddings
    base_text_model_name: str = BASE_TEXT_MODEL_NAME, # For classifier
    visual_feature_dim: int = VISUAL_FEATURE_DIM,
    visual_feature_dir: str = os.path.join(KAGGLE_WORKING_DIR, "visual_features") # Dir where features are saved/loaded
) -> Tuple[Optional[Dict], Optional[BartTokenizer], Optional[LabelEncoder], List[str]]:
    """
    Runs the full ensemble training and evaluation pipeline.
    Includes RAG examples in prompts for ALL splits if retriever is available.
    """
    pipeline_start_time = time.time()
    print(">>> DEBUG: ENTERING run_ensemble_pipeline <<<")
    check_cuda_memory("Start of Pipeline")

    is_multilabel = (dataset_type == 'depression')
    logger.info(f"--- Starting ENSEMBLE Pipeline ({num_ensemble_models} models) for Dataset: {dataset_type} --- ")
    logger.info(f"Task Type: {'Multi-Label' if is_multilabel else 'Single-Label (Multiclass)'}")
    # ... [rest of logging for models/dims] ...
    logger.info(f"Using Base Text Model: {base_text_model_name}")
    logger.info(f"Using Text Embedding Model: {embedding_model_name}")
    logger.info(f"Using Vision Model for Features: {VISION_MODEL_NAME} (Dim: {visual_feature_dim})")

    # --- 1. Define Paths ---
    # ... [Path definition logic remains the same] ...
    output_basedir_name = dataset_type;
    pipeline_base_output_dir = os.path.join(KAGGLE_WORKING_DIR, output_basedir_name, "output", "ensemble_fusion"); os.makedirs(pipeline_base_output_dir, exist_ok=True); logger.info(f"Pipeline base output: {pipeline_base_output_dir}")
    if dataset_type == "anxiety":
        anxiety_base_dir = os.path.join(KAGGLE_INPUT_DIR, "dataset", "Anxiety_Data")
        train_file_path = os.path.join(anxiety_base_dir, "anxiety_train.json")
        test_file_path = os.path.join(anxiety_base_dir, "anxiety_test.json")
        img_feature_train_path = os.path.join(visual_feature_dir, "anxiety_train_features.pt")
        img_feature_test_path = os.path.join(visual_feature_dir, "anxiety_test_features.pt")
        img_feature_val_path=None;
        val_file_path=None;
    elif dataset_type == "depression":
        depression_base_dir = os.path.join(KAGGLE_INPUT_DIR, "dataset", "Depressive_Data")
        train_file_path = os.path.join(depression_base_dir, "train.json")
        test_file_path = os.path.join(depression_base_dir, "test.json")
        val_file_path = os.path.join(depression_base_dir, "val.json")
        img_feature_train_path = os.path.join(visual_feature_dir, "depression_train_features.pt")
        img_feature_test_path = os.path.join(visual_feature_dir, "depression_test_features.pt")
        img_feature_val_path = os.path.join(visual_feature_dir, "depression_val_features.pt")
    else:
        logger.error(f"Invalid dataset_type specified: {dataset_type}. Choose 'anxiety' or 'depression'.")
        return None, None, None, []
    print(f">>> DEBUG: Paths Defined <<<")

    # --- 2. Load Data (Text) & Split ---
    logger.info("Step 1a: Loading text data...")
    # ... [load_data and split_data calls - same as before] ...
    full_train_data = load_data(train_file_path);
    if not full_train_data:
        logger.error(f"Failed to load training data from {train_file_path}. Aborting.")
        return None, None, None, []
    train_data, val_data, test_data = [], [], []; separate_val_file_loaded = False
    if val_file_path and os.path.exists(val_file_path): val_data = load_data(val_file_path); # Load separate val if exists
    if val_data: separate_val_file_loaded = True; logger.info(f"Loaded {len(val_data)} val samples.")
    else: logger.warning("Separate val file empty or not specified.")
    if use_test_set and test_file_path and os.path.exists(test_file_path): test_data = load_data(test_file_path); # Load separate test if exists
    if test_data: logger.info(f"Loaded {len(test_data)} test samples.")
    elif use_test_set: logger.warning(f"Test file not found: {test_file_path}")
    # Determine train/val/test splits
    if separate_val_file_loaded: train_data = full_train_data # All original train is training data
    elif use_test_set and test_data: train_data, val_data, _ = split_data(full_train_data, val_split_ratio, 0, base_seed) # Split train -> train/val
    elif not use_test_set: train_data, val_data, test_data_split = split_data(full_train_data, val_split_ratio, test_split_ratio, base_seed); test_data = test_data_split # Split train -> train/val/test
    else: train_data, val_data, _ = split_data(full_train_data, val_split_ratio, 0, base_seed) # Default: split train -> train/val, no test
    logger.info(f"Final Data split sizes: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}")
    if not train_data or not val_data:
        logger.error("Empty train/val splits.")
        return None, None, None, []
    print(">>> DEBUG: Split data complete <<<")

    # --- 3. Load Visual Features ---
    logger.info("Step 1b: Loading visual features...")
    # ... (Keep feature loading logic using load_feature_file) ...
    def load_feature_file(path, description):
        features = {}
        if path and os.path.exists(path):
            try:
                features = torch.load(path, map_location='cpu')
                logger.info(f"Loaded {len(features)} {description} features from {os.path.basename(path)}.")
            except Exception as e: logger.error(f"Error loading {description} features from {path}: {e}")
        elif path: logger.warning(f"{description.capitalize()} feature file not found: {path}")
        return features
    image_features_train = load_feature_file(img_feature_train_path, "train")
    image_features_val = load_feature_file(img_feature_val_path, "validation") if img_feature_val_path else {}
    image_features_test = load_feature_file(img_feature_test_path, "test")
    image_features_map = {**image_features_test, **image_features_val, **image_features_train}
    logger.info(f"Combined visual feature map size: {len(image_features_map)}")
    # ... (Feature coverage logging) ...
    train_ids_with_features = sum(1 for s in train_data if s.get('image_id') in image_features_map)
    val_ids_with_features = sum(1 for s in val_data if s.get('image_id') in image_features_map)
    test_ids_with_features = sum(1 for s in test_data if s.get('image_id') in image_features_map)
    logger.info(f"Feature coverage: Train={train_ids_with_features}/{len(train_data)}, Val={val_ids_with_features}/{len(val_data)}, Test={test_ids_with_features}/{len(test_data)}")
    print(">>> DEBUG: Loaded visual features <<<")

    # --- 4. Clean Triples ---
    logger.info("Step 1c: Cleaning triples text...")
    # ... (Keep cleaning loop) ...
    [sample.update({'triples': clean_triples(sample.get('triples', ''))}) for dataset in [train_data, val_data, test_data] if dataset for sample in dataset]
    print(">>> DEBUG: Cleaned triples <<<")

    # --- 5. Encode Labels ---
    logger.info("Step 1d: Encoding labels...")
    # ... (Keep label encoding logic) ...
    all_possible_labels = set();
    for ds in [train_data, val_data, test_data]:
        if ds:
            for s in ds:
                 labels_in_sample = s.get('original_labels');
                 if isinstance(labels_in_sample, list): all_possible_labels.update(lbl for lbl in labels_in_sample if lbl)
                 elif isinstance(labels_in_sample, str) and labels_in_sample: all_possible_labels.add(labels_in_sample)
    if not all_possible_labels:
        logger.error("No labels found.")
        return None, None, None, []
    sorted_labels = sorted(list(all_possible_labels)); label_encoder = LabelEncoder(); label_encoder.fit(sorted_labels); num_labels = len(label_encoder.classes_); logger.info(f"Labels ({num_labels}): {label_encoder.classes_.tolist()}")
    print(">>> DEBUG: Encoded labels <<<")

    # --- 6. RAG Setup ---
    logger.info("Steps 2 & 3: Generating Text Embeddings for RAG DB and Building Index...")
    # Generate embeddings needed for RAG DB (train) and potentially for querying (val, test)
    train_fused_embeddings = None
    val_fused_embeddings = None
    test_fused_embeddings = None
    retriever = None
    embedding_generator = None

    try:
        embedding_generator = EmbeddingGenerator(model_name=embedding_model_name, device=device)
        train_ocr = [s.get('ocr_text', '') for s in train_data]; train_triples = [s.get('triples', '') for s in train_data]
        train_fused_embeddings = embedding_generator.generate_fused_embeddings(train_ocr, train_triples)

        # Generate embeddings for val/test if needed for prompt generation RAG lookup
        # If you only do RAG based on train DB, you only need train embeddings for indexing
        # The query embedding is generated on-the-fly by get_prompts_with_rag
        # Let's keep it simple: only generate train embeddings here for the DB
        # val_ocr = ... ; val_triples = ... ; val_fused_embeddings = embedding_generator.generate_fused_embeddings(...)
        # test_ocr = ... ; test_triples = ... ; test_fused_embeddings = embedding_generator.generate_fused_embeddings(...)

        if train_fused_embeddings is not None:
            retriever = RAGRetriever(train_fused_embeddings, top_k=RETRIEVAL_K)
            if retriever.index is None:
                logger.warning("FAISS index building failed. RAG disabled.")
                retriever = None
        else:
            logger.warning("No train text embeddings generated. RAG disabled.")
            retriever = None
    except Exception as e:
        logger.error(f"RAG embedding/index error: {e}", exc_info=True)
        retriever = None # Ensure retriever is None on error
    finally:
        # Cleanup embedding generator AFTER using it for all necessary splits
        if embedding_generator: del embedding_generator
        gc.collect();
        if device == torch.device('cuda'): torch.cuda.empty_cache()
    print(">>> DEBUG: RAG setup complete <<<")
    check_cuda_memory("After RAG setup")


    # --- 7. Prompt Construction ---
    logger.info("Step 4: Preparing prompts (RAG examples included if retriever exists)...")
    prompt_constructor = PromptConstructor(train_data, label_encoder)

    # >>> MODIFIED get_prompts_with_rag function definition <<<
    def get_prompts_with_rag(data_split: List[Dict],
                              embeddings_for_query: Optional[np.ndarray], # Embeddings of the data_split itself
                              split_name: str,
                              is_training_split: bool = False) -> List[str]:
        """Generates prompts, including RAG examples using the global retriever."""
        logger.info(f"Generating prompts for {split_name} ({len(data_split)} samples). RAG enabled: {retriever is not None}")
        prompts = []
        # Check if RAG is possible *at all* (retriever must exist)
        rag_possible_globally = retriever is not None

        # Check if embeddings are provided *for this split* to perform the query
        can_query_rag = rag_possible_globally and (embeddings_for_query is not None)

        if not can_query_rag and rag_possible_globally:
             logger.warning(f"RAG retriever exists, but no query embeddings provided for {split_name}. Generating basic prompts.")
        elif not rag_possible_globally:
             logger.info(f"RAG retriever not available. Generating basic prompts for {split_name}.")


        for i, sample in enumerate(tqdm(data_split, desc=f"Generating {split_name} Prompts")):
            retrieved_indices_for_prompt = [] # Default: no examples
            if can_query_rag:
                try:
                    query_embedding = embeddings_for_query[i : i + 1]
                    # Retrieve k+1 neighbors (raw_indices shape: (1, k+1))
                    raw_indices = retriever.retrieve_similar(query_embedding)

                    if raw_indices is not None and len(raw_indices[0]) > 0:
                        # Exclude self (index i if is_training_split, otherwise keep top K)
                        potential_indices = raw_indices[0]
                        if is_training_split:
                             # Filter out the current sample's own index (i) and take top K remaining
                             final_indices = [idx for idx in potential_indices if idx != i][:RETRIEVAL_K]
                             retrieved_indices_for_prompt = final_indices
                        else:
                             # For val/test, just take the top K (might include self if it was somehow indexed)
                             retrieved_indices_for_prompt = potential_indices[:RETRIEVAL_K]
                    else:
                        logger.debug(f"Retrieval returned None or empty for sample {i} in {split_name}.")
                except Exception as e:
                    logger.error(f"Error during RAG retrieval for sample {i} in {split_name}: {e}", exc_info=True)
                    # Fallback to no examples for this sample on error

            # Construct the prompt with the retrieved indices (or empty list)
            prompts.append(prompt_constructor.construct_prompt(sample, retrieved_indices_for_prompt))

        return prompts
    # >>> END MODIFIED FUNCTION <<<

    # Generate prompts - Pass the correct embeddings for querying
    # We generated train_fused_embeddings earlier for the index AND for querying train data
    train_prompts = get_prompts_with_rag(train_data, train_fused_embeddings, "Train", is_training_split=True)
    # For Val/Test, we currently pass None for embeddings, disabling RAG for them.
    # If you want RAG for Val/Test, generate val_fused_embeddings/test_fused_embeddings in Step 6
    # and pass them here instead of None.
    val_prompts   = get_prompts_with_rag(val_data, None, "Validation", is_training_split=False)
    test_prompts  = get_prompts_with_rag(test_data, None, "Test", is_training_split=False) if test_data else []
    print(">>> DEBUG: Prompts prepared <<<")


    # --- 8. Tokenizer, Datasets, DataLoaders ---
    logger.info("Step 5: Loading tokenizer...")
    # ... [Load tokenizer - same as before] ...
    try:
        tokenizer = BartTokenizer.from_pretrained(base_text_model_name)
    except Exception as e:
        logger.error(f"Tokenizer load fail: {e}")
        return None,None,None,[]
    print(">>> DEBUG: Tokenizer loaded <<<")

    logger.info("Step 6: Creating Datasets...")
    # ... [Create Datasets - same as before] ...
    try:
        DatasetClass = DepressionDataset if is_multilabel else AnxietyDataset
        train_dataset=DatasetClass(train_data, train_prompts, tokenizer, MAX_LEN, label_encoder, image_features_map, visual_feature_dim)
        val_dataset=DatasetClass(val_data, val_prompts, tokenizer, MAX_LEN, label_encoder, image_features_map, visual_feature_dim)
        test_dataset=DatasetClass(test_data, test_prompts, tokenizer, MAX_LEN, label_encoder, image_features_map, visual_feature_dim) if test_data else None
        print(f">>> DEBUG: Datasets created (Train: {len(train_dataset)}, Val: {len(val_dataset)}) <<<")
    except Exception as e:
        logger.error(f"Dataset fail: {e}")
        return None, tokenizer, label_encoder, []

    logger.info("Step 7: Creating DataLoaders...")
    # ... [Create DataLoaders - same as before] ...
    try:
        num_workers = 2 if torch.cuda.is_available() else 0; pin_memory = bool(device == torch.device('cuda'))
        # Add drop_last=True if batch size doesn't divide dataset size perfectly, especially for training
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size_override, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, drop_last=(len(train_dataset) % batch_size_override == 1))
        val_dataloader = DataLoader(val_dataset, batch_size=batch_size_override, shuffle=False, num_workers=num_workers, pin_memory=pin_memory);
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size_override, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) if test_dataset else None
        print(f">>> DEBUG: DataLoaders created (Train batches: {len(train_dataloader)}, Val batches: {len(val_dataloader)}) <<<")
        if len(train_dataloader)==0 or len(val_dataloader)==0:
            logger.error("Train/Val Dataloader empty!")
            return None,tokenizer,label_encoder,[]
    except Exception as e:
        logger.error(f"Dataloader fail: {e}")
        return None, tokenizer, label_encoder, []


    print(">>> DEBUG: Reached Point 3 - BEFORE ENSEMBLE LOOP <<<")
    check_cuda_memory("Before Ensemble Loop")

    # --- 9. Ensemble Training Loop ---
    logger.info(f"Step 8: Starting Ensemble Training ({num_ensemble_models} models)...")
    trained_model_paths = []
    all_histories = []

    print(">>> DEBUG: Reached Point 4 - ENTERING ENSEMBLE LOOP <<<")
    for i in range(num_ensemble_models):
        # ... [Rest of the ensemble loop remains the same] ...
        # It instantiates M3H_CrossAttentionClassifier, sets up optimizer/scheduler,
        # calls train_and_evaluate, saves results, and cleans up.
        print(f">>> DEBUG: Starting ensemble run {i+1}/{num_ensemble_models} <<<"); run_seed = base_seed + i; set_seed(run_seed); model_run_output_dir = os.path.join(pipeline_base_output_dir, f"run_{run_seed}"); os.makedirs(model_run_output_dir, exist_ok=True); logger.info(f"--- Training Model {i + 1}/{num_ensemble_models} (Seed: {run_seed}) --- Output: {model_run_output_dir}")
        model = None
        try:
            print(f">>> DEBUG: Instantiating model run {i+1} <<<")
            model = M3H_CrossAttentionClassifier(
                num_labels=num_labels,
                pretrained_model_name=base_text_model_name,
                is_multilabel=is_multilabel,
                visual_feature_dim=visual_feature_dim,
                dropout_prob=DROPOUT
            ).to(device)
            print(f">>> DEBUG: Model instantiated run {i+1} <<<")
            check_cuda_memory(f"After model init run {i+1}")
        except Exception as e:
            logger.error(f"Model init fail run {i + 1}: {e}", exc_info=True);
            continue
        optimizer = None; scheduler = None
        try:
            print(f">>> DEBUG: Creating opt/sched run {i+1} <<<")
            optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, betas=(ADAM_BETA1, ADAM_BETA2), eps=ADAM_EPSILON, weight_decay=WEIGHT_DECAY)
            total_training_steps = len(train_dataloader) * num_epochs_per_model;
            num_warmup_steps = int(0.1 * total_training_steps)
            scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_training_steps)
            print(f">>> DEBUG: Opt/sched created run {i+1} <<<")
        except Exception as e:
            logger.error(f"Opt/Sched setup fail run {i + 1}: {e}");
            del model;
            gc.collect();
            torch.cuda.empty_cache();
            continue
        try:
            print(f">>> DEBUG: Calling train_and_evaluate run {i+1} <<<")
            history, best_model_path_run = train_and_evaluate(
                train_dataloader, val_dataloader, model, optimizer, scheduler, device,
                num_epochs_per_model, model_run_output_dir, label_encoder, is_multilabel
            )
            all_histories.append(history);
            trained_model_paths.append(best_model_path_run)
            logger.info(f"Run {i+1} training complete. Best model: {best_model_path_run}")
            plot_training_history(history, model_run_output_dir, suffix=f"_run_{run_seed}")
        except Exception as train_e:
            logger.error(f"Training failed run {i + 1}: {train_e}", exc_info=True)
            tb_path = os.path.join(model_run_output_dir, "error_traceback.txt");
            with open(tb_path, "w") as f: traceback.print_exc(file=f)
        finally:
            print(f">>> DEBUG: Cleaning up run {i+1} <<<")
            del model, optimizer, scheduler
            if 'history' in locals(): del history
            if 'best_model_path_run' in locals(): del best_model_path_run
            gc.collect();
            if device == torch.device('cuda'): torch.cuda.empty_cache()
            check_cuda_memory(f"End of run {i+1}")

    print(">>> DEBUG: Reached Point 12 - EXITED ensemble training loop <<<")
    logger.info(f"--- Ensemble Training Finished ({len(trained_model_paths)} models trained) ---")

    if not trained_model_paths:
        logger.error("No models trained.")
        return None, tokenizer, label_encoder, []

    # --- 10. Final Ensemble Evaluation ---
    # ... [Evaluation logic remains the same] ...
    logger.info("Step 9: Evaluating Ensemble..."); final_metrics = {}; logger.info("--- Final Val Eval (Ensemble) ---"); val_metrics = evaluate_ensemble(
        model_paths=trained_model_paths,
        dataloader=val_dataloader,
        device=device,
        label_encoder=label_encoder,
        is_multilabel=is_multilabel,
        num_labels=num_labels,
        output_dir=pipeline_base_output_dir,
        report_suffix="validation_ensemble",
        base_model_name=base_text_model_name,
        visual_feature_dim=visual_feature_dim
    ); final_metrics['validation'] = val_metrics; logger.info(f"Ens Val Acc: {val_metrics.get('accuracy_subset', val_metrics.get('accuracy', 'N/A')):.4f}, MacroF1: {val_metrics.get('macro_f1', 'N/A'):.4f}")
    if test_dataloader:
        logger.info("--- Final Test Eval (Ensemble) ---"); test_metrics = evaluate_ensemble(
            model_paths=trained_model_paths,
            dataloader=test_dataloader,
            device=device,
            label_encoder=label_encoder,
            is_multilabel=is_multilabel,
            num_labels=num_labels,
            output_dir=pipeline_base_output_dir,
            report_suffix="test_ensemble",
            base_model_name=base_text_model_name,
            visual_feature_dim=visual_feature_dim
        ); final_metrics['test'] = test_metrics; logger.info(f"Ens Test Acc: {test_metrics.get('accuracy_subset', test_metrics.get('accuracy', 'N/A')):.4f}, MacroF1: {test_metrics.get('macro_f1', 'N/A'):.4f}")
    else:
        logger.info("No test data."); final_metrics['test'] = "Skipped"

    # --- 11. Save Label Encoder ---
    # ... [Saving logic remains the same] ...
    logger.info("Step 10: Saving Label Encoder..."); label_encoder_path = os.path.join(pipeline_base_output_dir, "label_encoder.pkl"); try: with open(label_encoder_path, 'wb') as f: pickle.dump(label_encoder, f); logger.info(f"LE saved: {label_encoder_path}") except Exception as e: logger.error(f"LE save fail: {e}")

    pipeline_end_time = time.time()
    logger.info(f"--- ENSEMBLE Pipeline finished for {dataset_type} in {(pipeline_end_time - pipeline_start_time)/60:.2f} minutes ---")
    return final_metrics, tokenizer, label_encoder, trained_model_paths

**Cell 14: Execution Cell (Runs Ensemble Pipeline)**

In [None]:
# --- Configuration for the Run ---
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
#       <<<<< CHOOSE DATASET TYPE HERE >>>>>
DATASET_CHOICE = 'anxiety'   # Options: 'anxiety' or 'depression'
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

# --- Other Ensemble & Training Settings ---
NUM_ENSEMBLE_RUNS = 3       # How many models in the ensemble (e.g., 3 or 5)
EPOCHS_PER_MODEL = 10       # Epochs for EACH model in the ensemble
BATCH_SIZE_TO_USE = 8       # Adjust based on GPU memory (T4 usually handles 8 well)

# --- Data Splitting Strategy ---
# If True, expects 'anxiety_test.json' or 'test.json'/'val.json' in dataset dirs
# If False, splits the main training file into train/val/test sets
USE_SEPARATE_TEST_FILE = True
# Ratios used ONLY if USE_SEPARATE_TEST_FILE = False (or if separate files are missing)
# These ratios are relative to the *original* training data size.
VALIDATION_SPLIT = 0.1      # e.g., 10% of original train data for validation
TEST_SPLIT = 0.2            # e.g., 20% of original train data for test

# Override hyperparameters from Cell 3 if needed for this specific run
# LEARNING_RATE_OVERRIDE = 5e-5
# DROPOUT_OVERRIDE = 0.15


# --- Run the Ensemble Pipeline ---
logger.info(f"===== Starting Ensemble Pipeline Execution for: {DATASET_CHOICE} =====")
logger.info(f"Number of ensemble models: {NUM_ENSEMBLE_RUNS}")
logger.info(f"Epochs per model: {EPOCHS_PER_MODEL}")
logger.info(f"Batch size: {BATCH_SIZE_TO_USE}")
logger.info(f"Using separate test file: {USE_SEPARATE_TEST_FILE}")

# Clean up memory before starting
gc.collect()
if device == torch.device('cuda'):
    torch.cuda.empty_cache()
    logger.info(f"CUDA Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    logger.info(f"CUDA Memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")


# Initialize variables to store results
final_eval_metrics = None
trained_tokenizer = None
trained_label_encoder = None
trained_model_paths = []

try:
    final_eval_metrics, trained_tokenizer, trained_label_encoder, trained_model_paths = run_ensemble_pipeline(
        dataset_type=DATASET_CHOICE,
        num_ensemble_models=NUM_ENSEMBLE_RUNS,
        base_seed=BASE_SEED, # Use the global base seed
        use_test_set=USE_SEPARATE_TEST_FILE,
        val_split_ratio=VALIDATION_SPLIT,
        test_split_ratio=TEST_SPLIT,
        batch_size_override=BATCH_SIZE_TO_USE,
        num_epochs_per_model=EPOCHS_PER_MODEL,
        embedding_model_name=TEXT_EMBEDDING_MODEL, # From config
        base_text_model_name=BASE_TEXT_MODEL_NAME, # From config
        visual_feature_dim=VISUAL_FEATURE_DIM,     # From config
        visual_feature_dir = os.path.join(KAGGLE_WORKING_DIR, "visual_features") # Pass dir where features are stored
        # Pass overrides if defined:
        # learning_rate_override = LEARNING_RATE_OVERRIDE,
        # dropout_override = DROPOUT_OVERRIDE,
    )
    logger.info("<<<<< run_ensemble_pipeline function returned >>>>>")

    # --- Post-Training Analysis ---
    if trained_model_paths and trained_tokenizer and trained_label_encoder:
        logger.info(f"Pipeline completed successfully for '{DATASET_CHOICE}'.")
        logger.info(f"Trained {len(trained_model_paths)} models.")
        logger.info(f"Best model paths: {trained_model_paths}")
        logger.info("Final Evaluation Metrics:")
        print(json.dumps(final_eval_metrics, indent=2)) # Pretty print the metrics dict

        # --- Example Prediction Call (Requires loading features for the sample) ---
        # You would need to get the 'image_features' tensor for a specific sample ID
        # example_sample_id = "some_image_id_from_your_data"
        # try:
        #     # Load the combined feature map again (or pass it)
        #     combined_features_path = os.path.join(KAGGLE_WORKING_DIR, "visual_features", f"{DATASET_CHOICE}_combined_features.pt") # Assuming you saved a combined one
        #     if os.path.exists(combined_features_path):
        #          all_features = torch.load(combined_features_path, map_location='cpu')
        #          example_image_feature = all_features.get(example_sample_id)
        #          if example_image_feature is not None:
        #              logger.info(f"\n--- Running Example Prediction for ID: {example_sample_id} ---")
        #              prediction = predict_ensemble(
        #                  model_paths=trained_model_paths,
        #                  tokenizer=trained_tokenizer,
        #                  label_encoder=trained_label_encoder,
        #                  text="This is some example meme text.",
        #                  triples="Cause-Effect: text causes laughter", # Optional example triples
        #                  image_features=example_image_feature, # Provide the loaded features
        #                  device=device,
        #                  is_multilabel=(DATASET_CHOICE == 'depression'),
        #                  base_model_name=BASE_TEXT_MODEL_NAME,
        #                  visual_feature_dim=VISUAL_FEATURE_DIM,
        #                  num_labels=len(trained_label_encoder.classes_)
        #              )
        #              logger.info("Example Prediction Result:")
        #              print(json.dumps(prediction, indent=2))
        #          else:
        #              logger.warning(f"Could not find image features for example ID: {example_sample_id}")
        #     else:
        #         logger.warning("Could not load combined features map for example prediction.")

        # except Exception as pred_e:
        #     logger.error(f"Error running example prediction: {pred_e}", exc_info=True)


    else:
        logger.error(f"Pipeline for '{DATASET_CHOICE}' did not complete successfully. Check logs for errors.")
        if not trained_model_paths: logger.error("  - No models were successfully trained.")
        if not trained_tokenizer: logger.error("  - Tokenizer was not loaded/returned.")
        if not trained_label_encoder: logger.error("  - Label encoder was not fitted/returned.")


except Exception as pipeline_e:
    logger.error(f"!!!!!! CATASTROPHIC ERROR in pipeline execution cell !!!!!!")
    logger.error(f"Error Type: {type(pipeline_e).__name__}")
    logger.error(f"Error Message: {pipeline_e}")
    logger.error("Traceback:", exc_info=True)
    # Save traceback to file
    tb_main_path = os.path.join(KAGGLE_WORKING_DIR, f"{DATASET_CHOICE}_pipeline_crash_traceback.txt")
    with open(tb_main_path, "w") as f:
        traceback.print_exc(file=f)
    logger.info(f"Crash traceback saved to: {tb_main_path}")


# --- Final Cleanup ---
logger.info("Performing final cleanup...")
# Explicitly delete large objects if they exist
if 'final_eval_metrics' in locals(): del final_eval_metrics
if 'trained_tokenizer' in locals(): del trained_tokenizer
if 'trained_label_encoder' in locals(): del trained_label_encoder
if 'trained_model_paths' in locals(): del trained_model_paths
# Delete potentially large data splits if no longer needed
if 'train_data' in locals(): del train_data
if 'val_data' in locals(): del val_data
if 'test_data' in locals(): del test_data
if 'full_train_data' in locals(): del full_train_data
if 'image_features_map' in locals(): del image_features_map

gc.collect() # Run garbage collection
if device == torch.device('cuda'):
    torch.cuda.empty_cache() # Clear PyTorch CUDA cache
    logger.info("CUDA cache cleared.")
    logger.info(f"Final CUDA Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    logger.info(f"Final CUDA Memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

logger.info(f"===== Execution cell finished for: {DATASET_CHOICE} =====")

**Cell 15: Zip Output**

In [None]:
import os # Ensure os is imported if running this cell independently

# --- Determine paths based on the DATASET_CHOICE from the previous execution cell ---
# This relies on DATASET_CHOICE being set correctly in Cell 14
try:
    # Check if DATASET_CHOICE exists and is one of the expected values
    if 'DATASET_CHOICE' in locals() and DATASET_CHOICE in ['anxiety', 'depression']:
        output_subdir = DATASET_CHOICE
        logger.info(f"Determined output subdirectory for zipping: '{output_subdir}'")
    else:
        # Attempt to infer from directory structure if variable is missing
        anxiety_path = "/kaggle/working/anxiety/output/ensemble_fusion/"
        depression_path = "/kaggle/working/depression/output/ensemble_fusion/"
        if os.path.exists(anxiety_path):
             output_subdir = "anxiety"
             logger.warning("DATASET_CHOICE variable not found or invalid. Inferred '{output_subdir}' from directory structure.")
        elif os.path.exists(depression_path):
             output_subdir = "depression"
             logger.warning("DATASET_CHOICE variable not found or invalid. Inferred '{output_subdir}' from directory structure.")
        else:
             output_subdir = "anxiety" # Default if inference fails
             logger.error("DATASET_CHOICE variable not found and cannot infer directory. Defaulting to '{output_subdir}'. Check Cell 14 execution and output paths.")

except NameError:
    output_subdir = "anxiety" # Fallback if DATASET_CHOICE is not defined at all
    logger.error("NameError: DATASET_CHOICE variable not defined. Defaulting to '{output_subdir}'. Please run Cell 14 first.")


# Define the source directory to zip (the base ensemble output dir)
# This directory contains the run_* subfolders, reports, and label encoder
source_path = f"/kaggle/working/{output_subdir}/output/ensemble_fusion/"
zip_filename = f"/kaggle/working/{output_subdir}_ensemble_fusion_output.zip" # Changed filename for clarity

print(f"\nAttempting to zip contents of: {source_path}")
print(f"Creating archive named: {zip_filename}")

# --- Create the zip archive ---
if os.path.exists(source_path) and os.path.isdir(source_path):
    # Check if the directory is empty
    if not os.listdir(source_path):
        print(f"\nWARNING: Source directory '{source_path}' is empty. Zip file will be created but empty.")
        # Still create an empty zip or skip? Let's create it for consistency.
        # The command might fail on an empty dir depending on zip version, let's try anyway.
        !zip -r -q {zip_filename} {source_path}* # Using * inside might handle empty dir better sometimes
        # Check if zip was created
        if os.path.exists(zip_filename):
            print(f"Empty zip archive created for empty source directory.")
        else:
            print(f"Failed to create zip archive (source directory might be truly empty or inaccessible).")

    else:
        # Use -r for recursive, -q for quiet execution
        # Zip the contents *inside* the source directory to avoid nested 'ensemble_fusion' folder in zip
        # We cd into the parent dir, then zip the target dir
        parent_dir = os.path.dirname(source_path.rstrip('/'))
        target_dir_name = os.path.basename(source_path.rstrip('/'))
        zip_command = f"cd {parent_dir} && zip -r -q {zip_filename} {target_dir_name}"
        print(f"Executing zip command: {zip_command}")
        !{zip_command}
        # Check if zip was created
        if os.path.exists(zip_filename):
             print(f"\nZip process finished successfully for '{output_subdir}'.")
             print(f"Archive created: {zip_filename}")
        else:
             print(f"\nERROR: Zip command executed but archive not found at {zip_filename}. Check permissions or command output.")

else:
    print(f"\nERROR: Source directory not found or is not a directory: {source_path}")
    print("Cannot create zip archive. Please ensure the pipeline ran correctly and produced output.")

# --- Verify by listing the working directory ---
print("\nContents of /kaggle/working/ (showing zip files and output directories):")
# Use ls with options: -l (long format), -h (human readable sizes), -t (sort by time, newest first)
# Filter for zip files and the output directory for clarity
!ls -lht /kaggle/working/ | grep -E '.zip$|anxiety$|depression$'
print("\nFull contents of /kaggle/working/:")
!ls -lht /kaggle/working/