In [1]:
import json
import numpy as np
from collections import defaultdict
from scipy import interpolate

def load_jsonl(file_path):
    with open(file_path, 'r') as f:
        return [json.loads(line) for line in f]

def load_swing_intervals(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

def interpolate_missing(frames_dict, max_frame):
    frames = np.arange(max_frame + 1)
    existing_frames = np.array(list(frames_dict.keys()))
    existing_values = np.array(list(frames_dict.values()))
    
    # Perform linear interpolation for each feature
    interpolated = []
    for i in range(8):  # 8 features
        feature_values = existing_values[:, i]
        interp_func = interpolate.interp1d(existing_frames, feature_values, kind='linear', 
                                           bounds_error=False, fill_value='extrapolate')
        interpolated.append(np.clip(interp_func(frames), 0, 1))  # Clamp values to [0, 1]
    
    return np.array(interpolated).T

def process_video(predictions_file, swing_intervals_file, sequence_length=64, overlap=32):
    predictions = load_jsonl(predictions_file)
    swing_intervals = load_swing_intervals(swing_intervals_file)
    
    # Extract image dimensions from the first prediction
    img_width = predictions[0]['image']['width']
    img_height = predictions[0]['image']['height']
    
    # Process predictions
    processed_frames = {}
    for frame, pred in enumerate(predictions):

        frame_data = {}
        for p in pred['predictions']:
            if p['class'] in ['club', 'club_head']:
                frame_data[p['class']] = [
                    p['x'] / img_width,
                    p['y'] / img_height,
                    p['width'] / img_width,
                    p['height'] / img_height
                ]
        
        # Only add frame data if both club and club_head are detected
        if 'club' in frame_data and 'club_head' in frame_data:
            processed_frames[frame] = [
                *frame_data['club'],
                *frame_data['club_head']
            ]
    
    # Determine the total number of frames
    max_frame = max(processed_frames.keys())
    
    # Interpolate missing frames
    interpolated_frames = interpolate_missing(processed_frames, max_frame)
    
    # Create sequences
    sequences = []
    labels = []
    for i in range(0, len(interpolated_frames) - sequence_length + 1, overlap):
        seq = interpolated_frames[i:i+sequence_length]
        sequences.append(seq)
        
        # Label the sequence
        seq_mid = i + sequence_length // 2
        is_swing = any(start <= seq_mid < end for start, end in swing_intervals)
        labels.append(1 if is_swing else 0)
    
    return np.array(sequences), np.array(labels)


# The 8 dimensions represent:
# 1. Club x-coordinate (normalized)
# 2. Club y-coordinate (normalized)
# 3. Club width (normalized)
# 4. Club height (normalized)
# 5. Club head x-coordinate (normalized)
# 6. Club head y-coordinate (normalized)
# 7. Club head width (normalized)
# 8. Club head height (normalized)

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import random

def plot_sample_sequences(sequences, labels, num_samples=3):
    # Get indices of positive and negative samples
    positive_indices = np.where(labels == 1)[0]
    negative_indices = np.where(labels == 0)[0]
    
    # Randomly sample from positive and negative sequences
    sample_positive = random.sample(list(positive_indices), min(num_samples, len(positive_indices)))
    sample_negative = random.sample(list(negative_indices), min(num_samples, len(negative_indices)))
    
    # Create subplots: 2 rows (positive/negative) x num_samples columns x 2 sub-rows (x/y coordinates)
    fig = make_subplots(rows=4, cols=num_samples, 
                        subplot_titles=(['Positive Samples']*num_samples + ['Negative Samples']*num_samples),
                        vertical_spacing=0.1,
                        row_heights=[0.23, 0.23, 0.23, 0.23])
    
    def plot_sequence(seq, start_row, col):
        frames = np.arange(len(seq))
        
        # Plot x coordinates
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 0], mode='lines+markers', name='Club X', 
                                 line=dict(color='blue'), showlegend=start_row==1 and col==1), 
                      row=start_row, col=col)
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 4], mode='lines+markers', name='Club Head X', 
                                 line=dict(color='red'), showlegend=start_row==1 and col==1), 
                      row=start_row, col=col)
        
        # Plot y coordinates
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 1], mode='lines+markers', name='Club Y', 
                                 line=dict(color='blue'), showlegend=False), 
                      row=start_row+1, col=col)
        fig.add_trace(go.Scatter(x=frames, y=seq[:, 5], mode='lines+markers', name='Club Head Y', 
                                 line=dict(color='red'), showlegend=False), 
                      row=start_row+1, col=col)
    
    # Plot positive samples
    for i, idx in enumerate(sample_positive):
        plot_sequence(sequences[idx], 1, i+1)
    
    # Plot negative samples
    for i, idx in enumerate(sample_negative):
        plot_sequence(sequences[idx], 3, i+1)
    
    # Update layout
    fig.update_layout(height=1200, width=1200, title_text="Sample Sequences: Positive vs Negative")
    fig.update_xaxes(title_text="Frame Number")
    fig.update_yaxes(title_text="Normalized Coordinate", range=[0, 1])
    
    # Add y-axis titles
    for i in range(1, num_samples + 1):
        fig.update_yaxes(title_text="X Coordinate", row=1, col=i)
        fig.update_yaxes(title_text="Y Coordinate", row=2, col=i)
        fig.update_yaxes(title_text="X Coordinate", row=3, col=i)
        fig.update_yaxes(title_text="Y Coordinate", row=4, col=i)
    
    # Show the plot
    fig.show()



In [3]:
# # Usage
# video_name = "IMG_3515"
# sequences, labels = process_video(
#     f"predictions/{video_name}_predictions.jsonl",
#     f"swing_intervals/{video_name}_swing_intervals.json"
# )

# print(f"Sequences shape: {sequences.shape}")
# print(f"Labels shape: {labels.shape}")
# print(f"Positive samples: {sum(labels)}")
# assert np.all((sequences >= 0) & (sequences <= 1)), "Some values are not between 0 and 1"

# # Usage
# plot_sample_sequences(sequences, labels, num_samples=4)

In [4]:
import os
import numpy as np

def combine_sequences(predictions_dir, swing_intervals_dir):
    sequences = []
    labels = []
    
    # Iterate over all files in the predictions directory
    for file in os.listdir(predictions_dir):
        if file.endswith('_predictions.jsonl'):
            # Extract the video name from the file name
            video_name = file.replace('_predictions.jsonl', '')
            
            # Construct the paths to the predictions and swing intervals files
            predictions_file = os.path.join(predictions_dir, file)
            swing_intervals_file = os.path.join(swing_intervals_dir, f'{video_name}_swing_intervals.json')
            
            # Check if the corresponding swing intervals file exists
            if os.path.exists(swing_intervals_file):
                # Process the video
                seq, label = process_video(predictions_file, swing_intervals_file)
                sequences.append(seq)
                labels.append(label)
    
    # Combine the sequences and labels into a single array
    combined_sequences = np.concatenate(sequences, axis=0)
    combined_labels = np.concatenate(labels, axis=0)
    
    return combined_sequences, combined_labels

predictions_dir = 'predictions'
swing_intervals_dir = 'swing_intervals'

combined_sequences, combined_labels = combine_sequences(predictions_dir, swing_intervals_dir)

print(f"Combined sequences shape: {combined_sequences.shape}")
print(f"Combined labels shape: {combined_labels.shape}")
print(f"Positive samples: {sum(combined_labels)}")
assert np.all((combined_sequences >= 0) & (combined_sequences <= 1)), "Some values are not between 0 and 1"

Combined sequences shape: (1205, 64, 8)
Combined labels shape: (1205,)
Positive samples: 87


In [6]:
# plot_sample_sequences(sequences, labels, num_samples=5)

In [7]:
import numpy as np

def correlation_sampling(positive_sequences, negative_sequences, num_samples=87):
    balanced_sequences = []
    balanced_labels = []
    
    # Calculate correlation between positive sequences and negative sequences
    correlations = np.array([np.corrcoef(pos_seq.flatten(), neg_seq.flatten())[0, 1] for pos_seq in positive_sequences for neg_seq in negative_sequences])
    
    # Select top correlated negative sequences
    top_corr_indices = np.argsort(correlations)[::-1][:num_samples]
    top_corr_neg_seqs = negative_sequences[top_corr_indices % len(negative_sequences)]
    
    # Add sampled negative sequences to balanced sequences
    balanced_sequences.extend(top_corr_neg_seqs)
    balanced_labels.extend([0] * len(top_corr_neg_seqs))
    
    # Add positive sequences to balanced sequences
    balanced_sequences.extend(positive_sequences)
    balanced_labels.extend([1] * len(positive_sequences))
    
    return np.array(balanced_sequences), np.array(balanced_labels)

# Example usage:
positive_sequences = combined_sequences[combined_labels == 1]
negative_sequences = combined_sequences[combined_labels == 0]

balanced_sequences, balanced_labels = correlation_sampling(positive_sequences, negative_sequences)

In [8]:
print(f"Combined sequences shape: {balanced_sequences.shape}")
print(f"Combined labels shape: {balanced_labels.shape}")
print(f"Positive samples: {sum(balanced_labels)}")
assert np.all((balanced_sequences >= 0) & (balanced_sequences <= 1)), "Some values are not between 0 and 1"

Combined sequences shape: (174, 64, 8)
Combined labels shape: (174,)
Positive samples: 87


In [10]:
# plot_sample_sequences(balanced_sequences, balanced_labels, num_samples=10)

In [11]:
import torch
# Create a new directory to save the data
data_dir = 'sequences'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
# Save the sequences and labels as PyTorch tensors
data = {
    'X': torch.tensor(balanced_sequences, dtype=torch.float),
    'y': torch.tensor(balanced_labels, dtype=torch.float)
}

torch.save(data, os.path.join(data_dir, 'data_balanced.pt'))

In [12]:
# data = torch.load(os.path.join(data_dir, 'data_01.pt'))
# X = data['X']
# y = data['y']