# Predict using trained models on test data
Create a streamlined process to:
- turn predictions into sequences
- predict on sequences
- verify results by plotting predicted golf swings

In [7]:
import json
import numpy as np
from scipy import interpolate

def load_jsonl(file_path):
    with open(file_path, 'r') as f:
        return [json.loads(line) for line in f]

def load_swing_intervals(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

def interpolate_missing(frames_dict, max_frame):
    frames = np.arange(max_frame + 1)
    existing_frames = np.array(list(frames_dict.keys()))
    existing_values = np.array(list(frames_dict.values()))
    
    # Perform linear interpolation for each feature
    interpolated = []
    for i in range(8):  # 8 features
        feature_values = existing_values[:, i]
        interp_func = interpolate.interp1d(existing_frames, feature_values, kind='linear', 
                                           bounds_error=False, fill_value='extrapolate')
        interpolated.append(np.clip(interp_func(frames), 0, 1))  # Clamp values to [0, 1]
    
    return np.array(interpolated).T

def process_video(predictions_file, swing_intervals_file=None, sequence_length=64, overlap=32):
    predictions = load_jsonl(predictions_file)
    swing_intervals = load_swing_intervals(swing_intervals_file) if swing_intervals_file else None
    
    # Extract image dimensions from the first prediction
    img_width = predictions[0]['image']['width']
    img_height = predictions[0]['image']['height']
    
    # Process predictions
    processed_frames = {}
    for frame, pred in enumerate(predictions):

        frame_data = {}
        for p in pred['predictions']:
            if p['class'] in ['club', 'club_head']:
                frame_data[p['class']] = [
                    p['x'] / img_width,
                    p['y'] / img_height,
                    p['width'] / img_width,
                    p['height'] / img_height
                ]
        
        # Only add frame data if both club and club_head are detected
        if 'club' in frame_data and 'club_head' in frame_data:
            processed_frames[frame] = [
                *frame_data['club'],
                *frame_data['club_head']
            ]
    
    # Determine the total number of frames
    max_frame = max(processed_frames.keys())
    
    # Interpolate missing frames
    interpolated_frames = interpolate_missing(processed_frames, max_frame)
    
    # Create sequences
    sequences = []
    labels = []
    for i in range(0, len(interpolated_frames) - sequence_length + 1, overlap):
        seq = interpolated_frames[i:i+sequence_length]
        sequences.append(seq)
        
        # Label the sequence if swing intervals are provided
        if swing_intervals:
            seq_mid = i + sequence_length // 2
            is_swing = any(start <= seq_mid < end for start, end in swing_intervals)
            labels.append(1 if is_swing else 0)
    
    if swing_intervals:
        return np.array(sequences), np.array(labels)
    else:
        return np.array(sequences)
    
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import random

def plot_sample_sequences(sequences, labels, num_samples=3):
    # Get indices of positive and negative samples
    positive_indices = np.where(labels == 1)[0]
    negative_indices = np.where(labels == 0)[0]
    
    # Randomly sample from positive and negative sequences
    sample_positive = random.sample(list(positive_indices), min(num_samples, len(positive_indices)))
    sample_negative = random.sample(list(negative_indices), min(num_samples, len(negative_indices)))
    
    # Create subplots: 2 rows (positive/negative) x num_samples columns x 2 sub-rows (x/y coordinates)
    fig = make_subplots(rows=4, cols=num_samples, 
                        subplot_titles=(['Positive Samples']*num_samples + ['Negative Samples']*num_samples),
                        vertical_spacing=0.1,
                        row_heights=[0.23, 0.23, 0.23, 0.23])
    
    def plot_sequence(seq, start_row, col):
        frames = np.arange(len(seq))
        
        # Plot x coordinates
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 0], mode='lines+markers', name='Club X', 
                                 line=dict(color='blue'), showlegend=start_row==1 and col==1), 
                      row=start_row, col=col)
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 4], mode='lines+markers', name='Club Head X', 
                                 line=dict(color='red'), showlegend=start_row==1 and col==1), 
                      row=start_row, col=col)
        
        # Plot y coordinates
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 1], mode='lines+markers', name='Club Y', 
                                 line=dict(color='blue'), showlegend=False), 
                      row=start_row+1, col=col)
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 5], mode='lines+markers', name='Club Head Y', 
                                 line=dict(color='red'), showlegend=False), 
                      row=start_row+1, col=col)
    
    # Plot positive samples
    for i, idx in enumerate(sample_positive):
        plot_sequence(sequences[idx], 1, i+1)
    
    # Plot negative samples
    for i, idx in enumerate(sample_negative):
        plot_sequence(sequences[idx], 3, i+1)
    
    # Update layout
    fig.update_layout(height=1200, width=1200, title_text="Sample Sequences: Positive vs Negative")
    fig.update_xaxes(title_text="Frame Number")
    fig.update_yaxes(title_text="Normalized Coordinate", range=[0, 1])
    
    # Add y-axis titles
    for i in range(1, num_samples + 1):
        fig.update_yaxes(title_text="X Coordinate", row=1, col=i)
        fig.update_yaxes(title_text="Y Coordinate", row=2, col=i)
        fig.update_yaxes(title_text="X Coordinate", row=3, col=i)
        fig.update_yaxes(title_text="Y Coordinate", row=4, col=i)
    
    # Show the plot
    fig.show()



In [2]:
# for prediction without labels

video_name = "IMG_3517"
sequences = process_video(
    f"predictions/test/{video_name}_predictions.jsonl"
)

In [3]:
import torch
import torch.nn as nn
import os
from models import (
    BaseModel,
    FlattenModel,
    LogisticRegression,
    MLP,
    LSTMModel,
    BidirectionalLSTMModel
)

def load_model(model_path, model_class, *args, **kwargs):
    model = model_class(*args, **kwargs)
    loaded = torch.load(model_path, map_location=torch.device('cpu'))
    if isinstance(loaded, dict):
        # If it's a state dict, load it
        model.load_state_dict(loaded)
    elif isinstance(loaded, nn.Module):
        # If it's a full model, use it directly
        model = loaded
    else:
        raise TypeError(f"Unexpected type loaded from {model_path}: {type(loaded)}")
    model.eval()
    return model

def predict(model, sequences):
    with torch.no_grad():
        inputs = torch.tensor(sequences, dtype=torch.float32)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, dim=1)
        return predicted.numpy()

# Update model_configs based on how models were defined
input_dim = 64 * 8  # Assuming your input shape is (batch_size, 64, 8)
model_configs = {
    'logistic_regression': (LogisticRegression, {'input_dim': input_dim}),
    'mlp': (MLP, {'input_dim': input_dim}),
    'lstm': (LSTMModel, {'input_size': 8, 'hidden_size': 128, 'num_layers': 1}),
    'bidirectional_lstm': (BidirectionalLSTMModel, {'input_size': 8, 'hidden_size': 128, 'num_layers': 1}),
}

# Load and test all models
model_dir = 'models'
for model_file in os.listdir(model_dir):
    if model_file.endswith('.pth'):
        model_path = os.path.join(model_dir, model_file)
        model_name = model_file.split('_', 1)[1].split('.')[0].lower()  # Convert to lowercase for matching
        
    if model_name in model_configs:
        model_class, model_params = model_configs[model_name]
        try:
            model = load_model(model_path, model_class, **model_params)
            print(f"Successfully loaded model: {model_file}")
            
            # Assuming 'sequences' is defined somewhere in your code
            labels = predict(model, sequences)
            print(f"Predictions for {model_file}:")
            print(f"0s: {len(labels) - sum(labels)}")
            print(f"1s: {sum(labels)}")
            print("------------------------------------------------------------")
        except Exception as e:
            print(f"Error loading or predicting with model {model_file}: {str(e)}")
    else:
        print(f"Unknown model type for file: {model_file}")

  loaded = torch.load(model_path, map_location=torch.device('cpu'))


Successfully loaded model: v1_bidirectional_lstm.pth
Predictions for v1_bidirectional_lstm.pth:
0s: 809
1s: 0
------------------------------------------------------------
Successfully loaded model: v1_logistic_regression.pth
Predictions for v1_logistic_regression.pth:
0s: 809
1s: 0
------------------------------------------------------------
Successfully loaded model: v1_lstm.pth
Predictions for v1_lstm.pth:
0s: 809
1s: 0
------------------------------------------------------------
Successfully loaded model: v1_mlp.pth
Predictions for v1_mlp.pth:
0s: 809
1s: 0
------------------------------------------------------------
Successfully loaded model: v2_bidirectional_lstm.pth
Predictions for v2_bidirectional_lstm.pth:
0s: 809
1s: 0
------------------------------------------------------------
Successfully loaded model: v2_logistic_regression.pth
Predictions for v2_logistic_regression.pth:
0s: 809
1s: 0
------------------------------------------------------------
Successfully loaded model: v

In [8]:
plot_sample_sequences(sequences, labels, num_samples=4)

In [9]:
from golf_swing_plotter import GolfSwingPlotter

video_file = "IMG_3517.MOV"
video_dir = "input_videos/test"
prediction_dir = "predictions/test"
plotter = GolfSwingPlotter(video_file, video_dir, prediction_dir)

In [16]:
cap = cv2.VideoCapture("input_videos/test/IMG_3517.MOV")

if not cap.isOpened():
    print("Error: Could not open video.")


In [14]:
plotter.combined_plot(
    start_frame=1000,
    end_frame=3000,
    num_frames=10
)

Error: Could not open video.


AttributeError: 'NoneType' object has no attribute 'number'