# Test OpenPose heavy model

Computational heavy

Test the ST-GCN

In [1]:
# Check if GPU is available
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.0+cpu
CUDA available: False


In [None]:
# Install dependencies
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install numpy
!pip install matplotlib
!pip install opencv-python
!pip install tqdm
!pip install scikit-learn
!pip install pandas

In [None]:
# Clone the ST-GCN repository
!git clone https://github.com/yysijie/st-gcn.git
%cd st-gcn

In [None]:
# Install additional dependencies required by ST-GCN
!pip install -r requirements.txt

In [None]:
# Download pre-trained model
!mkdir -p checkpoints
!wget -P checkpoints https://github.com/yysijie/st-gcn/raw/master/checkpoints/kinetics-st-gcn.pt

In [None]:
# Create a directory for sample data
!mkdir -p data/sample

# Download sample videos (you can replace these with your own videos)
!wget -P data/sample https://github.com/yysijie/st-gcn/raw/master/resource/media/skateboarding.mp4
!wget -P data/sample https://github.com/yysijie/st-gcn/raw/master/resource/media/walking.mp4

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm
import json
from IPython.display import HTML
from base64 import b64encode

In [None]:
# Import ST-GCN modules
import torch
from torch.autograd import Variable
from net.st_gcn import Model
from utils import *

# Load the model
def load_model():
    num_class = 400  # Number of classes in Kinetics dataset
    graph_args = {'layout': 'openpose', 'strategy': 'spatial'}
    model = Model(num_class, graph_args, edge_importance_weighting=True)

    # Load the pre-trained weights
    weights_path = 'checkpoints/kinetics-st-gcn.pt'
    weights = torch.load(weights_path)

    if 'state_dict' in weights:
        # Old format weights
        model.load_state_dict(weights['state_dict'])
    else:
        # New format weights
        model.load_state_dict(weights)

    model.eval()

    # Move model to GPU if available
    if torch.cuda.is_available():
        model = model.cuda()

    return model

model = load_model()
print("Model loaded successfully!")

In [None]:
# Define utility functions for processing
def load_json(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

def process_video(video_path, output_path):
    # Extract frames from video
    cap = cv2.VideoCapture(video_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    frames = []

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)

    cap.release()

    # Here you would normally extract keypoints using OpenPose or similar
    # For this example, we'll simulate with random keypoints
    num_frames = len(frames)
    num_joints = 18  # OpenPose format
    num_coords = 3  # x, y, confidence

    # Generate random keypoints (replace with actual keypoint extraction)
    keypoints = np.random.rand(num_frames, num_joints, num_coords)

    # Save keypoints
    np.save(output_path, keypoints)

    return keypoints

def load_keypoints(keypoints_path):
    return np.load(keypoints_path)

def preprocess_keypoints(keypoints):
    # Normalize keypoints
    # This is a simplified version - actual preprocessing would be more complex
    keypoints = keypoints.reshape(keypoints.shape[0], -1)
    mean = np.mean(keypoints, axis=0)
    std = np.std(keypoints, axis=0)
    keypoints = (keypoints - mean) / (std + 1e-9)

    # Reshape back to original format
    num_frames = keypoints.shape[0]
    num_joints = 18
    num_coords = 3
    keypoints = keypoints.reshape(num_frames, num_joints, num_coords)

    return keypoints

def predict_action(model, keypoints):
    # Preprocess keypoints
    keypoints = preprocess_keypoints(keypoints)

    # Convert to tensor
    keypoints = torch.tensor(keypoints, dtype=torch.float32)

    # Add batch dimension
    keypoints = keypoints.unsqueeze(0)

    # Move to GPU if available
    if torch.cuda.is_available():
        keypoints = keypoints.cuda()

    # Make prediction
    with torch.no_grad():
        output = model(keypoints)

    # Get predicted class
    _, predicted = torch.max(output, 1)

    return predicted.item()

# Load action labels
def load_labels():
    # This is a simplified version - actual labels would come from the dataset
    labels = [f"action_{i}" for i in range(400)]
    return labels

labels = load_labels()

In [None]:
# Process sample videos and extract keypoints
video_paths = ['data/sample/skateboarding.mp4', 'data/sample/walking.mp4']
keypoints_paths = []

for video_path in video_paths:
    video_name = os.path.basename(video_path).split('.')[0]
    keypoints_path = f'data/sample/{video_name}_keypoints.npy'

    if not os.path.exists(keypoints_path):
        print(f"Processing {video_path}...")
        process_video(video_path, keypoints_path)
        print(f"Keypoints saved to {keypoints_path}")
    else:
        print(f"Keypoints already exist at {keypoints_path}")

    keypoints_paths.append(keypoints_path)

In [None]:
# Run inference on sample videos
for i, (video_path, keypoints_path) in enumerate(zip(video_paths, keypoints_paths)):
    print(f"\nProcessing video: {video_path}")

    # Load keypoints
    keypoints = load_keypoints(keypoints_path)
    print(f"Keypoints shape: {keypoints.shape}")

    # Make prediction
    predicted_class = predict_action(model, keypoints)
    print(f"Predicted action: {labels[predicted_class]}")

    # Display video
    print(f"Displaying video: {video_path}")
    mp4 = open(video_path,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    display(HTML(f"""
    <video width=400 controls>
          <source src="{data_url}" type="video/mp4">
    </video>
    """))

In [None]:
# Visualize keypoints on frames
def visualize_keypoints(video_path, keypoints_path):
    # Load video
    cap = cv2.VideoCapture(video_path)

    # Load keypoints
    keypoints = load_keypoints(keypoints_path)

    # Get frame info
    num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))

    # Create output video
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    output_path = video_path.replace('.mp4', '_keypoints.mp4')
    out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

    # Define skeleton connections (OpenPose format)
    skeleton = [
        (1, 2), (1, 5), (2, 3), (3, 4), (5, 6), (6, 7),
        (1, 8), (8, 9), (9, 10), (1, 11), (11, 12), (12, 13),
        (1, 0), (0, 14), (14, 16), (0, 15), (15, 17)
    ]

    # Process each frame
    frame_idx = 0
    while cap.isOpened() and frame_idx < len(keypoints):
        ret, frame = cap.read()
        if not ret:
            break

        # Get keypoints for this frame
        frame_keypoints = keypoints[frame_idx]

        # Draw keypoints
        for kp in frame_keypoints:
            if kp[2] > 0.3:  # Confidence threshold
                cv2.circle(frame, (int(kp[0]), int(kp[1])), 3, (0, 255, 0), -1)

        # Draw skeleton
        for connection in skeleton:
            kp1 = frame_keypoints[connection[0]]
            kp2 = frame_keypoints[connection[1]]
            if kp1[2] > 0.3 and kp2[2] > 0.3:  # Confidence threshold
                cv2.line(frame, (int(kp1[0]), int(kp1[1])), (int(kp2[0]), int(kp2[1])), (255, 0, 0), 2)

        # Write frame to output video
        out.write(frame)

        frame_idx += 1

    # Release resources
    cap.release()
    out.release()

    return output_path

# Visualize keypoints for each video
for video_path, keypoints_path in zip(video_paths, keypoints_paths):
    print(f"\nVisualizing keypoints for {video_path}")
    output_path = visualize_keypoints(video_path, keypoints_path)
    print(f"Output video saved to {output_path}")

    # Display video with keypoints
    mp4 = open(output_path,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    display(HTML(f"""
    <video width=400 controls>
          <source src="{data_url}" type="video/mp4">
    </video>
    """))