In [None]:
# decision tree and cnn with zernike features/ck+
import os
import zipfile
import cv2
import numpy as np
import mediapipe as mp
import mahotas
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Concatenate
from sklearn.tree import DecisionTreeClassifier
import seaborn as sns

# --- Extract dataset ---
zip_path = "C:\\Users\\yuvan\\OneDrive\\Desktop\\CK+48"
extract_dir = "data"
if not os.path.exists(extract_dir):
    print("Extracting dataset...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)
else:
    print("Dataset already exists, skipping extraction.")

def find_subfolder(base_dir, name='train'):
    for root, dirs, _ in os.walk(base_dir):
        if name in dirs:
            return os.path.join(root, name)
    raise FileNotFoundError(f"'{name}' folder not found inside {base_dir}")

train_dir = find_subfolder(extract_dir, 'train')
test_dir = find_subfolder(extract_dir, 'test')

# --- Mediapipe FaceMesh ---
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True)

def get_face_crop(img):
    h, w, _ = img.shape
    results = face_mesh.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    if results.multi_face_landmarks:
        x_coords = [lm.x for lm in results.multi_face_landmarks[0].landmark]
        y_coords = [lm.y for lm in results.multi_face_landmarks[0].landmark]
        x_min, x_max = int(min(x_coords) * w), int(max(x_coords) * w)
        y_min, y_max = int(min(y_coords) * h), int(max(y_coords) * h)
        x_min, x_max = max(0, x_min), min(w, x_max)
        y_min, y_max = max(0, y_min), min(h, y_max)
        face = img[y_min:y_max, x_min:x_max]
        if face.size != 0:
            return cv2.resize(face, (48, 48))
    return None

def extract_zernike(gray_img):
    radius = 21
    thresh = gray_img > gray_img.mean()
    return mahotas.features.zernike_moments(thresh.astype(np.uint8), radius, degree=8)

valid_labels = ['happy', 'sad', 'angry', 'fear', 'surprise']

def load_images_with_zernike(folder):
    X_img, X_zernike, y = [], [], []
    for label in os.listdir(folder):
        original_label = label
        if label == 'neutral':
            label = 'surprise'
        if label not in valid_labels:
            continue
        label_path = os.path.join(folder, original_label)
        for file in os.listdir(label_path):
            img_path = os.path.join(label_path, file)
            img = cv2.imread(img_path)
            if img is None:
                continue
            face = get_face_crop(img)
            if face is not None:
                gray = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY)
                try:
                    zernike_feat = extract_zernike(gray)
                    X_img.append(gray)
                    X_zernike.append(zernike_feat)
                    y.append(label)
                except:
                    continue
    return np.array(X_img), np.array(X_zernike), np.array(y)

# --- Load dataset ---
print("Loading training data...")
X_train_img, X_train_zernike, y_train = load_images_with_zernike(train_dir)
print("Loading testing data...")
X_test_img, X_test_zernike, y_test = load_images_with_zernike(test_dir)

# Normalize and reshape
X_train_img = X_train_img.astype('float32') / 255.
X_test_img = X_test_img.astype('float32') / 255.
X_train_img = X_train_img[..., np.newaxis]
X_test_img = X_test_img[..., np.newaxis]

# Encode labels
le = LabelEncoder()
y_train_enc = le.fit_transform(y_train)
y_test_enc = le.transform(y_test)
y_train_cat = to_categorical(y_train_enc)
y_test_cat = to_categorical(y_test_enc)

# --- CNN + Zernike Model ---
cnn_input = Input(shape=(48, 48, 1), name='cnn_input')
x = Conv2D(32, (3, 3), activation='relu')(cnn_input)
x = MaxPooling2D(2, 2)(x)
x = Conv2D(64, (3, 3), activation='relu')(x)
x = MaxPooling2D(2, 2)(x)
x = Flatten()(x)

zernike_input = Input(shape=(X_train_zernike.shape[1],), name='zernike_input')
z = Dense(64, activation='relu')(zernike_input)

combined = Concatenate()([x, z])
combined = Dense(128, activation='relu')(combined)
combined = Dropout(0.3)(combined)
output = Dense(len(le.classes_), activation='softmax')(combined)

model = Model(inputs=[cnn_input, zernike_input], outputs=output)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

print("Training CNN + Zernike model...")
model.fit([X_train_img, X_train_zernike], y_train_cat, epochs=20, batch_size=32,
          validation_data=([X_test_img, X_test_zernike], y_test_cat))

model.save("emotion_cnn_zernike_model.h5")

# Evaluate CNN+Zernike model
y_pred = model.predict([X_test_img, X_test_zernike])
y_pred_cls = np.argmax(y_pred, axis=1)
cnn_accuracy = accuracy_score(y_test_enc, y_pred_cls)
cnn_error_rate = 1 - cnn_accuracy
print(f"CNN + Zernike Test Accuracy: {cnn_accuracy * 100:.2f}%")
print(f"CNN + Zernike Error Rate: {cnn_error_rate * 100:.2f}%")

# --- Train and evaluate Decision Tree on Zernike features only ---
print("\nTraining Decision Tree on Zernike features...")
dt = DecisionTreeClassifier(random_state=42)
dt.fit(X_train_zernike, y_train_enc)

y_dt_pred = dt.predict(X_test_zernike)
dt_accuracy = accuracy_score(y_test_enc, y_dt_pred)
dt_error_rate = 1 - dt_accuracy
print(f"Decision Tree Accuracy: {dt_accuracy * 100:.2f}%")
print(f"Decision Tree Error Rate: {dt_error_rate * 100:.2f}%")

print("\nClassification Report (Decision Tree):")
print(classification_report(y_test_enc, y_dt_pred, target_names=le.classes_))

# Confusion Matrix
cm = confusion_matrix(y_test_enc, y_dt_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=le.classes_, yticklabels=le.classes_)
plt.title("Decision Tree Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

# --- Real-time Detection (CNN + Zernike only) ---
print("Starting real-time detection. Press 'q' to quit.")
model = load_model("emotion_cnn_zernike_model.h5")
emotion_labels = le.classes_
face_mesh_live = mp_face_mesh.FaceMesh(static_image_mode=False, max_num_faces=1, refine_landmarks=True)

cap = cv2.VideoCapture(0)
predicted_emotions = []

while True:
    ret, frame = cap.read()
    if not ret:
        break
    h, w, _ = frame.shape
    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = face_mesh_live.process(rgb)
    if results.multi_face_landmarks:
        landmarks = results.multi_face_landmarks[0].landmark
        x_all = [lm.x for lm in landmarks]
        y_all = [lm.y for lm in landmarks]
        x_min, x_max = int(min(x_all) * w), int(max(x_all) * w)
        y_min, y_max = int(min(y_all) * h), int(max(y_all) * h)
        x_min, x_max = max(0, x_min), min(w, x_max)
        y_min, y_max = max(0, y_min), min(h, y_max)

        face_crop = frame[y_min:y_max, x_min:x_max]
        if face_crop.size > 0:
            gray = cv2.cvtColor(face_crop, cv2.COLOR_BGR2GRAY)
            resized = cv2.resize(gray, (48, 48))
            norm = resized.astype('float32') / 255.
            zernike_feat = extract_zernike(resized)
            input_img = np.expand_dims(norm, axis=(0, -1))
            input_zernike = np.expand_dims(zernike_feat, axis=0)
            pred = model.predict([input_img, input_zernike], verbose=0)
            emotion = emotion_labels[np.argmax(pred)]
            predicted_emotions.append(emotion)
            cv2.putText(frame, f'Emotion: {emotion}', (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2)
    cv2.imshow("Emotion Detection", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

# --- Plot Real-Time Emotion Frequency ---
if predicted_emotions:
    counts = Counter(predicted_emotions)
    plt.bar(counts.keys(), counts.values(), color='orange')
    plt.title("Real-Time Emotion Frequency")
    plt.xlabel("Emotion")
    plt.ylabel("Count")
    plt.grid(True)
    plt.show()
else:
    print("No emotions detected during the session.")