In [None]:
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageTk
import cv2
import torch
import numpy as np
import os
import sys

# -------------------------
# Helpers & Resource Paths
# -------------------------
def resource_path(relative_path):
    try:
        base_path = sys._MEIPASS
    except Exception:
        base_path = os.path.abspath(".")
    return os.path.join(base_path, relative_path)

# Import models
from emotion_classification.models.cnn import CNN
from emotion_classification.models.resnet_vanilla import LightweightResNet
from emotion_classification.models.resnet18 import ResNet18Emotion

# -------------------------
# Detection & Device Setup
# -------------------------
from ultralytics import YOLO
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

YOLO_MODEL_PATH = resource_path("./face_detection/yolo/yolov8n_face_detection/weights/best.pt")
face_detector = YOLO(YOLO_MODEL_PATH)
face_detector.to(DEVICE)

# -------------------------
# Emotion Model Configuration
# -------------------------
num_classes = 4
EMOTIONS = ["Angry", "Happy", "Sad", "Surprise"]
BASE_MODEL_DIR = resource_path("./emotion_classification")

MODEL_INFO = {
    "CNN": (CNN, os.path.join(BASE_MODEL_DIR, "checkpoints_cnn/best.pth")),
    "ResNet18": (ResNet18Emotion, os.path.join(BASE_MODEL_DIR, "checkpoints_resnet18/best.pth")),
    "Resnet_Vanilla": (LightweightResNet, os.path.join(BASE_MODEL_DIR, "checkpoints_resnet_vanilla/best.pth")),
}

loaded_models = {}

def safe_load_weights(model, ckpt_path):
    model.to(DEVICE)
    if not os.path.exists(ckpt_path):
        print(f"[Warning] Checkpoint not found: {ckpt_path}")
        return model
    
    ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=True)
    state = ckpt.get("model_state", ckpt)
    try:
        model.load_state_dict(state)
    except:
        model.load_state_dict(state, strict=False)
    
    model.eval()
    return model

def get_model(model_name):
    if model_name in loaded_models:
        return loaded_models[model_name]

    constructor, ckpt_path = MODEL_INFO[model_name]
    model = constructor(num_classes)
    model = safe_load_weights(model, ckpt_path)
    loaded_models[model_name] = model
    return model

def preprocess_face(face_img, model_name):
    # ResNet18 trained on 224, CNN/Vanilla trained on 96
    size = 224 if model_name == "ResNet18" else 96
    img = cv2.resize(face_img, (size, size))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    tensor = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(DEVICE) / 255.0
    return tensor

# -------------------------
# Main Application GUI
# -------------------------
class EmotionApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Face & Emotion Detection")
        self.root.geometry("1200x800")

        # UI Elements
        self.selected_model = tk.StringVar(value="CNN")
        self.model_menu = ttk.Combobox(root, textvariable=self.selected_model, values=list(MODEL_INFO.keys()))
        self.model_menu.pack(pady=10)

        self.canvas = tk.Canvas(root, bg="black")
        self.canvas.pack(fill="both", expand=True)

        # Video Capture
        self.cap = cv2.VideoCapture(0)
        self.update_frame()
        self.root.protocol("WM_DELETE_WINDOW", self.on_close)

    def update_frame(self):
        ret, frame = self.cap.read()
        if not ret:
            self.root.after(30, self.update_frame)
            return

        display_frame = frame.copy()
        
        # 1. Detect Faces
        results = face_detector(frame, verbose=False)
        boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)

        # 2. Process each detected face
        for x1, y1, x2, y2 in boxes:
            face_img = frame[y1:y2, x1:x2]
            if face_img.size == 0: continue

            model_name = self.selected_model.get()
            model = get_model(model_name)
            input_tensor = preprocess_face(face_img, model_name)

            with torch.no_grad():
                out = model(input_tensor)
                logits = out[0] if isinstance(out, tuple) else out
                pred = torch.argmax(logits, dim=1).item()
                conf = torch.softmax(logits, dim=1)[0, pred].item()
                label = f"{EMOTIONS[pred]} ({conf:.2f})"

            # Draw results on frame
            cv2.rectangle(display_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(display_frame, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

        # 3. Render to Tkinter Canvas
        img_rgb = cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
        img_pil = Image.fromarray(img_rgb)
        self.img_tk = ImageTk.PhotoImage(img_pil)
        self.canvas.create_image(0, 0, anchor=tk.NW, image=self.img_tk)
        
        self.root.after(10, self.update_frame)

    def on_close(self):
        self.cap.release()
        self.root.destroy()

if __name__ == "__main__":
    root = tk.Tk()
    app = EmotionApp(root)
    root.mainloop()