In [None]:
import cv2
import numpy as np
import mediapipe as mp
import torch
from collections import deque
from model import TransferBiLSTM  # your updated model

# Load trained multiclass model
model = TransferBiLSTM(input_size=132, hidden_size=64, num_layers=2, num_classes=5)
model.load_state_dict(torch.load('pain_model_multiclass.pth', map_location=torch.device('cpu')))
model.eval()

# Mediapipe setup
mp_holistic = mp.solutions.holistic
holistic = mp_holistic.Holistic()
seq_len = 8
landmark_buffer = deque(maxlen=seq_len)

def extract_landmarks(results):
    pose = results.pose_landmarks
    face = results.face_landmarks
    lh = results.left_hand_landmarks
    rh = results.right_hand_landmarks

    def flatten_landmarks(landmarks, n=33):
        if landmarks:
            return np.array([[lmk.x, lmk.y, lmk.z, lmk.visibility] for lmk in landmarks.landmark]).flatten()
        else:
            return np.zeros(n * 4)

    return np.concatenate([
        flatten_landmarks(pose, 33),
        flatten_landmarks(face, 468),
        flatten_landmarks(lh, 21),
        flatten_landmarks(rh, 21)
    ])[:132]  # adjust if your model uses full landmark size

# Webcam capture
cap = cv2.VideoCapture(0)

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = holistic.process(image)
    
    landmarks = extract_landmarks(results)
    landmark_buffer.append(landmarks)

    if len(landmark_buffer) == seq_len:
        input_seq = torch.tensor([list(landmark_buffer)], dtype=torch.float32)
        with torch.no_grad():
            logits = model(input_seq)
            probs = torch.softmax(logits, dim=1).numpy()[0]
            predicted_class = int(np.argmax(probs))
            confidence = probs[predicted_class]

        # Display result
        text = f'Pain Class: {predicted_class} ({confidence:.2f})'
        cv2.putText(frame, text, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1,
                    (0, 0, 255) if predicted_class > 0 else (0, 255, 0), 2)

    cv2.imshow("Real-Time Pain Detection", frame)
    if cv2.waitKey(10) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()


# Full pipeline initial 


In [None]:
import torch
import cv2
import mediapipe as mp
import numpy as np
from torchvision import transforms

# Load both models
image_model = torch.load("/Users/suryanshpatel/Projects/Directed Readings/Technical/src/Models/best_model.pth")
lstm_model = torch.load("/Users/suryanshpatel/Projects/Directed Readings/Technical/src/Models/gru_classifier.pth")

image_model.eval()
lstm_model.eval()

# Setup Mediapipe
mp_pose = mp.solutions.pose
pose = mp_pose.Pose()

# --- IMAGE PROCESSING (2-ch grayscale) ---
def get_2ch_image_from_frame(frame):
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    gray = cv2.resize(gray, (48, 48))

    edges = cv2.Laplacian(gray, cv2.CV_64F)
    edges = np.uint8(np.absolute(edges))

    stacked = np.stack([gray, edges], axis=0)
    return torch.tensor(stacked, dtype=torch.float32).unsqueeze(0) / 255.0

# --- LANDMARK PROCESSING ---
def extract_landmark_vector(results):
    if results.pose_landmarks:
        landmarks = [[lm.x, lm.y, lm.z] for lm in results.pose_landmarks.landmark]
        flat = np.array(landmarks).flatten()
        return torch.tensor(flat, dtype=torch.float32).unsqueeze(0)
    return None

# Real-time prediction
cap = cv2.VideoCapture(0)

sequence = []
sequence_length = 30  # how many frames per sequence

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = pose.process(image)

    # --- Landmark Model Inference ---
    lm_vector = extract_landmark_vector(results)
    if lm_vector is not None:
        sequence.append(lm_vector)
        if len(sequence) > sequence_length:
            sequence.pop(0)

        if len(sequence) == sequence_length:
            landmark_input = torch.stack(sequence).unsqueeze(0)  # shape: (1, 30, 99*3)
            with torch.no_grad():
                pred_lm = lstm_model(landmark_input)
                prob_lm = torch.softmax(pred_lm, dim=1)
    else:
        prob_lm = torch.zeros(1, 5)

    # --- Image Model Inference ---
    image_tensor = get_2ch_image_from_frame(frame)
    with torch.no_grad():
        pred_img = image_model(image_tensor)
        prob_img = torch.softmax(pred_img, dim=1)

    # --- Combine Predictions ---
    combined_prob = (prob_img + prob_lm) / 2
    final_class = torch.argmax(combined_prob, dim=1).item()

    # --- Display ---
    cv2.putText(frame, f"Pain Level: {final_class}", (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    cv2.imshow("Pain Detection", frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()
