In [65]:
import json
import numpy as np
from collections import defaultdict
from scipy import interpolate

def load_jsonl(file_path):
    with open(file_path, 'r') as f:
        return [json.loads(line) for line in f]

def load_swing_intervals(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

def interpolate_missing(frames_dict, max_frame):
    frames = np.arange(max_frame + 1)
    existing_frames = np.array(list(frames_dict.keys()))
    existing_values = np.array(list(frames_dict.values()))
    
    # Perform linear interpolation for each feature
    interpolated = []
    for i in range(8):  # 8 features
        feature_values = existing_values[:, i]
        interp_func = interpolate.interp1d(existing_frames, feature_values, kind='linear', 
                                           bounds_error=False, fill_value='extrapolate')
        interpolated.append(interp_func(frames))
    
    return np.array(interpolated).T

def process_video(predictions_file, swing_intervals_file, sequence_length=64, overlap=32):
    predictions = load_jsonl(predictions_file)
    swing_intervals = load_swing_intervals(swing_intervals_file)
    
    # Extract image dimensions from the first prediction
    img_width = predictions[0]['image']['width']
    img_height = predictions[0]['image']['height']
    
    # Process predictions
    processed_frames = {}
    for frame, pred in enumerate(predictions):

        frame_data = {}
        for p in pred['predictions']:
            if p['class'] in ['club', 'club_head']:
                frame_data[p['class']] = [
                    p['x'] / img_width,
                    p['y'] / img_height,
                    p['width'] / img_width,
                    p['height'] / img_height
                ]
        
        # Only add frame data if both club and club_head are detected
        if 'club' in frame_data and 'club_head' in frame_data:
            processed_frames[frame] = [
                *frame_data['club'],
                *frame_data['club_head']
            ]
    
    # Determine the total number of frames
    max_frame = max(processed_frames.keys())
    
    # Interpolate missing frames
    interpolated_frames = interpolate_missing(processed_frames, max_frame)
    
    # Create sequences
    sequences = []
    labels = []
    for i in range(0, len(interpolated_frames) - sequence_length + 1, overlap):
        seq = interpolated_frames[i:i+sequence_length]
        sequences.append(seq)
        
        # Label the sequence
        seq_mid = i + sequence_length // 2
        is_swing = any(start <= seq_mid < end for start, end in swing_intervals)
        labels.append(1 if is_swing else 0)
    
    return np.array(sequences), np.array(labels)


# The 8 dimensions represent:
# 1. Club x-coordinate (normalized)
# 2. Club y-coordinate (normalized)
# 3. Club width (normalized)
# 4. Club height (normalized)
# 5. Club head x-coordinate (normalized)
# 6. Club head y-coordinate (normalized)
# 7. Club head width (normalized)
# 8. Club head height (normalized)


In [76]:
# Usage
video_name = "IMG_3515"
sequences, labels = process_video(
    f"predictions/{video_name}_predictions.jsonl",
    f"swing_intervals/{video_name}_swing_intervals.json"
)

print(f"Sequences shape: {sequences.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Positive samples: {sum(labels)}")
assert np.all((sequences >= 0) & (sequences <= 1)), "Some values are not between 0 and 1"

Sequences shape: (235, 64, 8)
Labels shape: (235,)
Positive samples: 23


AssertionError: Some values are not between 0 and 1

In [78]:
sequences[10:15]

array([[[0.55648148, 0.73411458, 0.21481481, ..., 0.77291667,
         0.02314815, 0.0125    ],
        [0.55787037, 0.73203125, 0.21944444, ..., 0.76953125,
         0.02222222, 0.01302083],
        [0.55972222, 0.73203125, 0.21944444, ..., 0.7671875 ,
         0.02407407, 0.01354167],
        ...,
        [0.52962963, 0.75729167, 0.15      , ..., 0.8125    ,
         0.02407407, 0.01354167],
        [0.52916667, 0.75755208, 0.14907407, ..., 0.81354167,
         0.02314815, 0.0125    ],
        [0.52824074, 0.75807292, 0.14722222, ..., 0.81432292,
         0.02314815, 0.01197917]],

       [[0.52592593, 0.75963542, 0.1462963 , ..., 0.81640625,
         0.02407407, 0.0109375 ],
        [0.52546296, 0.75963542, 0.14537037, ..., 0.81666667,
         0.02407407, 0.01145833],
        [0.52546296, 0.75989583, 0.14537037, ..., 0.81666667,
         0.02314815, 0.01145833],
        ...,
        [0.56435185, 0.74296875, 0.18425926, ..., 0.79270833,
         0.02407407, 0.01041667],
        [0.5

In [63]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import random

def plot_sample_sequences(sequences, labels, num_samples=3):
    # Get indices of positive and negative samples
    positive_indices = np.where(labels == 1)[0]
    negative_indices = np.where(labels == 0)[0]
    
    # Randomly sample from positive and negative sequences
    sample_positive = random.sample(list(positive_indices), min(num_samples, len(positive_indices)))
    sample_negative = random.sample(list(negative_indices), min(num_samples, len(negative_indices)))
    
    # Create subplots: 2 rows (positive/negative) x num_samples columns x 2 sub-rows (x/y coordinates)
    fig = make_subplots(rows=4, cols=num_samples, 
                        subplot_titles=(['Positive Samples']*num_samples + ['Negative Samples']*num_samples),
                        vertical_spacing=0.1,
                        row_heights=[0.23, 0.23, 0.23, 0.23])
    
    def plot_sequence(seq, start_row, col):
        frames = np.arange(len(seq))
        
        # Plot x coordinates
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 0], mode='lines+markers', name='Club X', 
                                 line=dict(color='blue'), showlegend=start_row==1 and col==1), 
                      row=start_row, col=col)
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 4], mode='lines+markers', name='Club Head X', 
                                 line=dict(color='red'), showlegend=start_row==1 and col==1), 
                      row=start_row, col=col)
        
        # Plot y coordinates
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 1], mode='lines+markers', name='Club Y', 
                                 line=dict(color='blue'), showlegend=False), 
                      row=start_row+1, col=col)
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 5], mode='lines+markers', name='Club Head Y', 
                                 line=dict(color='red'), showlegend=False), 
                      row=start_row+1, col=col)
    
    # Plot positive samples
    for i, idx in enumerate(sample_positive):
        plot_sequence(sequences[idx], 1, i+1)
    
    # Plot negative samples
    for i, idx in enumerate(sample_negative):
        plot_sequence(sequences[idx], 3, i+1)
    
    # Update layout
    fig.update_layout(height=1200, width=1200, title_text="Sample Sequences: Positive vs Negative")
    fig.update_xaxes(title_text="Frame Number")
    fig.update_yaxes(title_text="Normalized Coordinate", range=[0, 1])
    
    # Add y-axis titles
    for i in range(1, num_samples + 1):
        fig.update_yaxes(title_text="X Coordinate", row=1, col=i)
        fig.update_yaxes(title_text="Y Coordinate", row=2, col=i)
        fig.update_yaxes(title_text="X Coordinate", row=3, col=i)
        fig.update_yaxes(title_text="Y Coordinate", row=4, col=i)
    
    # Show the plot
    fig.show()


In [56]:
sequences[0]

array([[-1.96388889,  3.69713542,  1.18703704,  4.89114583, -1.69814815,
         6.21927083,  0.33703704, -0.15625   ],
       [-1.94953704,  3.68072917,  1.18055556,  4.86354167, -1.68518519,
         6.18854167,  0.33518519, -0.15520833],
       [-1.93518519,  3.66432292,  1.17407407,  4.8359375 , -1.67222222,
         6.1578125 ,  0.33333333, -0.15416667],
       [-1.92083333,  3.64791667,  1.16759259,  4.80833333, -1.65925926,
         6.12708333,  0.33148148, -0.153125  ],
       [-1.90648148,  3.63151042,  1.16111111,  4.78072917, -1.6462963 ,
         6.09635417,  0.32962963, -0.15208333],
       [-1.89212963,  3.61510417,  1.15462963,  4.753125  , -1.63333333,
         6.065625  ,  0.32777778, -0.15104167],
       [-1.87777778,  3.59869792,  1.14814815,  4.72552083, -1.62037037,
         6.03489583,  0.32592593, -0.15      ],
       [-1.86342593,  3.58229167,  1.14166667,  4.69791667, -1.60740741,
         6.00416667,  0.32407407, -0.14895833],
       [-1.84907407,  3.56588542

In [44]:
# Usage
plot_sample_sequences(sequences, labels, num_samples=6)