In [None]:
import os
import glob
import pickle
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Conv1D, MaxPooling1D, LSTM, Dense
from tensorflow.keras.utils import to_categorical
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from scipy.stats import mode
from google.colab import drive
from scipy.signal import resample  # <-- ADDED IMPORT FOR RESAMPLING

# ---
# STEP 1: MOUNT GOOGLE DRIVE
# ---
print("Mounting Google Drive...")
try:
    drive.mount('/content/drive', force_remount=True)
    print("Drive mounted successfully.")
except Exception as e:
    print(f"Error mounting drive: {e}")

# ---
# STEP 2: SET THE PATH TO YOUR DATASET FOLDER
# ---
# This is the most important step for you to edit.
# Use the Colab file browser (on the left) to find your WESAD folder,
# right-click it, and select "Copy path".
# Paste that path into the variable below.

# -----------------------------------------------------------------------
# ----> !! IMPORTANT !! <----
# ----> PLEASE EDIT THIS LINE with the correct file path:
#
# Example path: '/content/drive/MyDrive/WESAD'
#
dataset_drive_path = '/content/drive/MyDrive/WESAD'  # <-- Change this path!
# -----------------------------------------------------------------------

# Check if the path exists
if not os.path.exists(dataset_drive_path) or 'https:' in dataset_drive_path:
    print(f"--- ERROR ---")
    if 'https:' in dataset_drive_path:
        print("The path looks like a URL. It must be a file path.")
        print("Please use the file browser on the left, right-click your folder, and 'Copy path'.")
    else:
        print(f"The path '{dataset_drive_path}' does not exist.")
    print("Please check the path and try again.")
else:
    print(f"Dataset path found: {dataset_drive_path}")

# ---
# STEP 3: FIND ALL INDIVIDUAL SUBJECT FILES
# ---
# This code will search inside your 'dataset_drive_path'
# for all the subject .pkl files.

# Use glob to recursively find all files ending in .pkl
search_pattern = os.path.join(dataset_drive_path, '**', 'S*.pkl')
subject_pkl_files = glob.glob(search_pattern, recursive=True)

# IMPORTANT: Filter out the E4 (chest) .pkl files. We only want the main subject file.
final_subject_files = []
for pkl_path in subject_pkl_files:
    file_name = os.path.basename(pkl_path)
    parent_folder_name = os.path.basename(os.path.dirname(pkl_path))

    # If S2.pkl is in folder S2, it's the one we want
    if file_name.replace('.pkl', '') == parent_folder_name:
        final_subject_files.append(pkl_path)

if not final_subject_files:
    print(f"\n--- ERROR ---")
    print(f"Could not find any subject .pkl files (e.g., S2/S2.pkl) in the folder:")
    print(f"{dataset_drive_path}")
    print(f"Please check your 'dataset_drive_path' variable. Make sure it points to the folder containing S2, S3, etc.")
else:
    print(f"\nFound {len(final_subject_files)} subject data files.")
    print("Example path:", final_subject_files[0])

# ---
# STEP 4: DEFINE PRE-PROCESSING FUNCTIONS (REWRITTEN)
# ---
# This function is now corrected to resample all signals to a
# common frequency (4Hz) before combining them.

def process_subject_data(subject_data):
    """Extracts, resamples, and combines wrist data for one subject."""

    # Check if 'wrist' data is present
    if 'wrist' not in subject_data['signal']:
        return None, None

    # Get the 4Hz signals (our target frequency)
    try:
        wrist_eda = subject_data['signal']['wrist']['EDA'].flatten()
        wrist_temp = subject_data['signal']['wrist']['TEMP'].flatten()

        # Use EDA as the reference length (it's already @ 4Hz)
        target_len = len(wrist_eda)
        if target_len == 0:
            return None, None # Skip if no EDA data

        # Get other signals (which are at different frequencies)
        wrist_acc = np.array(subject_data['signal']['wrist']['ACC'])
        wrist_bvp = np.array(subject_data['signal']['wrist']['BVP']).flatten()
        labels_raw = np.array(subject_data['label']).flatten()

        # Resample all other signals down to the target_len (N @ 4Hz)
        acc_resampled = resample(wrist_acc, target_len)
        bvp_resampled = resample(wrist_bvp, target_len)
        temp_resampled = resample(wrist_temp, target_len) # Resample temp just in case length is slightly off

        # For labels, we resample and then round to the nearest integer
        # to preserve the categorical labels (1, 2, 3...).
        labels_resampled = resample(labels_raw, target_len)
        labels_rounded = np.round(labels_resampled).astype(int)

        # Combine all features into a single array
        features = np.hstack([
            acc_resampled,                         # (N, 3)
            bvp_resampled.reshape(-1, 1),          # (N, 1)
            wrist_eda.reshape(-1, 1),              # (N, 1)
            temp_resampled.reshape(-1, 1)          # (N, 1)
        ])

        labels = labels_rounded.flatten() # Ensure it's (N,)

        return features, labels

    except Exception as e:
        print(f"    - Error processing signals: {e}")
        return None, None

def create_windows(data, labels, window_size_sec, overlap_sec, fs=4):
    """Creates overlapping windows of data and assigns a label to each."""

    window_size = window_size_sec * fs  # 30 seconds * 4 Hz = 120 samples
    overlap = overlap_sec * fs          # 15 seconds * 4 Hz = 60 samples
    stride = window_size - overlap

    X, y = [], []

    for i in range(0, len(data) - window_size, stride):
        window_data = data[i : i + window_size]
        window_labels = labels[i : i + window_size]

        # Assign a single label to the window (most frequent label)
        label = mode(window_labels, keepdims=True)[0]

        # We are interested in 1 (Baseline), 2 (Stress), 3 (Amusement)
        # We ignore 0 (transient) and other labels (4, 5, 6, 7)
        if label in [1, 2, 3]:
            X.append(window_data)
            y.append(label[0]) # [0] to get the scalar value

    return np.array(X), np.array(y)

# ---
# STEP 5: CREATE THE FULL, WINDOWED DATASET
# ---

all_X = []
all_y = []

# Define window parameters
WINDOW_SEC = 30
OVERLAP_SEC = 15
SAMPLING_RATE = 4 # 4 Hz

print("\nProcessing data and creating windows for all subjects...")

for pkl_file_path in final_subject_files:
    subject_id = os.path.basename(pkl_file_path).replace('.pkl', '')
    print(f"  - Processing {subject_id}...")

    try:
        # Load the individual subject pickle file from Google Drive
        with open(pkl_file_path, 'rb') as f:
            subject_data = pickle.load(f, encoding='latin1')

        features, labels = process_subject_data(subject_data)

        if features is None or labels is None:
            print(f"    - SKIPPING {subject_id} (Missing wrist data or other error)")
            continue

        # Create windows for this subject
        X_subject, y_subject = create_windows(
            features, labels, WINDOW_SEC, OVERLAP_SEC, SAMPLING_RATE
        )

        if X_subject.shape[0] > 0:
            all_X.append(X_subject)
            all_y.append(y_subject)
            print(f"    - Added {X_subject.shape[0]} windows for {subject_id}")
        else:
            print(f"    - No valid windows (labels 1, 2, 3) found for {subject_id}")

    except Exception as e:
        print(f"    - FAILED to process {subject_id}. Error: {e}")


# --- ADDED ERROR CHECK ---
# Check if the lists are empty *before* trying to concatenate
if not all_X or not all_y:
    print("\n\n--- CRITICAL ERROR ---")
    print("No data was successfully processed.")
    print("This could mean your 'dataset_drive_path' (STEP 2) is wrong,")
    print("or the subjects found do not contain valid wrist data or labels (1, 2, 3).")
    print("Please check your path and the data.")
    raise ValueError("No valid subject data found to create windows.")
# --- END OF ADDED CHECK ---


# Combine data from all subjects
X = np.concatenate(all_X, axis=0)
y = np.concatenate(all_y, axis=0)

# We must remap labels from [1, 2, 3] to [0, 1, 2] for the AI model
# 1 (Baseline) -> 0
# 2 (Stress)    -> 1
# 3 (Amusement) -> 2
y = y - 1

print(f"\nTotal windows created: {X.shape[0]}")
print(f"Window shape (samples, features): {X.shape[1:]}")
print(f"Labels shape: {y.shape}")
print(f"Unique labels: {np.unique(y)}")

# ---
# STEP 6: SCALE DATA & SPLIT FOR TRAINING
# ---

print("\nScaling data...")
# We must scale the data
# We scale each feature across all windows
# X shape is (num_windows, num_samples, num_features)
# We reshape to (num_windows * num_samples, num_features) to scale
num_windows, num_samples, num_features = X.shape
X_reshaped = X.reshape(-1, num_features)

scaler = StandardScaler()
X_scaled_reshaped = scaler.fit_transform(X_reshaped)

# Reshape back to (num_windows, num_samples, num_features)
X_scaled = X_scaled_reshaped.reshape(num_windows, num_samples, num_features)

print("Splitting into training and testing sets...")
# Split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y, test_size=0.2, random_state=42, stratify=y
)

print(f"\nTraining data shape: {X_train.shape}")
print(f"Testing data shape: {X_test.shape}")

# ---
# STEP 7: BUILD THE CNN-LSTM MODEL (as per FIGURE 1)
# ---

# Model constants
window_length = X_train.shape[1] # e.g., 120 samples
num_features = X_train.shape[2]  # e.g., 6 features
num_classes = 3                  # (0: Baseline, 1: Stress, 2: Amusement)

print("\nBuilding CNN-LSTM model...")

model = Sequential([
    # Input layer
    Input(shape=(window_length, num_features), name="input_layer"),

    # 1. CNN Feature Extraction
    Conv1D(filters=64, kernel_size=5, activation='relu', padding='same', name="conv1d_1"),
    MaxPooling1D(pool_size=2, name="maxpool_1"),

    # 2. LSTM Temporal Analysis
    LSTM(128, return_sequences=False, name="lstm_1"),

    # 3. Output Layer (Emotion Classification)
    Dense(64, activation='relu', name="dense_1"),
    Dense(num_classes, activation='softmax', name="output_layer")
])

model.summary()

# ---
# STEP 8: TRAIN THE MODEL
# ---

print("\nCompiling and training model...")

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy', # Use this for integer labels
    metrics=['accuracy']
)

# Train for a few epochs
# In a real project, you would train for 50-100 epochs
history = model.fit(
    X_train,
    y_train,
    epochs=10,
    batch_size=32,
    validation_data=(X_test, y_test)
)

print("\nModel training complete.")

# Evaluate the model
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f"\nTest Accuracy: {accuracy * 100:.2f}%")

# ---
# ---
# STEP 9: CONVERT & SAVE THE TFLite MODEL (Corrected)
# ---
# This step adds the required flags to handle the LSTM layer conversion.

print("\nConverting model to TensorFlow Lite format (with Select TF Ops for LSTM)...")

# Convert the Keras model
converter = tf.lite.TFLiteConverter.from_keras_model(model)

# --- THIS IS THE FIX ---
# 1. Tell the converter to allow TensorFlow ops (for the LSTM)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # Enable default TFLite ops
    tf.lite.OpsSet.SELECT_TF_OPS     # Enable TensorFlow ops (for the LSTM)
]
# 2. Disable the experimental feature that causes the error
converter._experimental_lower_tensor_list_ops = False
# --- END OF FIX ---

# We can still apply the default optimizations
converter.optimizations = [tf.lite.Optimize.DEFAULT]

try:
    # Try to convert the model with the new settings
    tflite_model = converter.convert()
    print("Model converted successfully.")

    # Define the output path
    # We'll save it inside the same folder you specified in STEP 2
    output_filename = os.path.join(dataset_drive_path, 'emotion_model.tflite')

    # Save the .tflite file to your Google Drive
    try:
        with open(output_filename, 'wb') as f:
            f.write(tflite_model)

        print(f"\n--- SUCCESS! ---")
        print(f"Model saved successfully to your Google Drive at:")
        print(output_filename)
        print(f"File size: {os.path.getsize(output_filename) / 1024:.2f} KB")

    except Exception as e:
        print(f"\n--- ERROR SAVING FILE ---")
        print(f"Could not save model to {output_filename}")
        print(f"Error: {e}")
        print("Trying to save locally instead...")
        # Fallback to saving locally in Colab
        local_filename = 'emotion_model.tflite'
        with open(local_filename, 'wb') as f:
            f.write(tflite_model)
        print(f"Successfully saved model locally as '{local_filename}'")
        print("You will need to download it manually from the Colab file browser.")

except Exception as e:
    print(f"\n--- TFLITE CONVERSION FAILED ---")
    print(f"Error: {e}")
    print("This can happen if TensorFlow versions are incompatible. Please check the runtime.")

Mounting Google Drive...
Mounted at /content/drive
Drive mounted successfully.
Dataset path found: /content/drive/MyDrive/WESAD

Found 2 subject data files.
Example path: /content/drive/MyDrive/WESAD/S16/S16.pkl

Processing data and creating windows for all subjects...
  - Processing S16...
    - Added 148 windows for S16
  - Processing S8...
    - Added 148 windows for S8

Total windows created: 296
Window shape (samples, features): (120, 6)
Labels shape: (296,)
Unique labels: [0 1 2]

Scaling data...
Splitting into training and testing sets...

Training data shape: (236, 120, 6)
Testing data shape: (60, 120, 6)

Building CNN-LSTM model...



Compiling and training model...
Epoch 1/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 134ms/step - accuracy: 0.5889 - loss: 0.9668 - val_accuracy: 0.7500 - val_loss: 0.5743
Epoch 2/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 88ms/step - accuracy: 0.8240 - loss: 0.4543 - val_accuracy: 0.8667 - val_loss: 0.2192
Epoch 3/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 151ms/step - accuracy: 0.9188 - loss: 0.1572 - val_accuracy: 0.9833 - val_loss: 0.1077
Epoch 4/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 147ms/step - accuracy: 0.9802 - loss: 0.1275 - val_accuracy: 0.9500 - val_loss: 0.1548
Epoch 5/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 154ms/step - accuracy: 0.9532 - loss: 0.1452 - val_accuracy: 0.9667 - val_loss: 0.0773
Epoch 6/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 98ms/step - accuracy: 0.9841 - loss: 0.0534 - val_accuracy: 0.9667 - val_loss: 0.0882
Epoch 7/1