<a href="https://colab.research.google.com/github/yusuke-satani/swing_classification/blob/main/LSTM_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import cv2
import numpy as np
import os

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import numpy as np
all_data = np.load('/content/drive/MyDrive/all_data.npy', allow_pickle=True)

In [None]:
def pad_sequences(sequences, max_length):
    # Align sequence lengths
    return [seq[:max_length] if len(seq) > max_length else np.pad(seq, ((0, max_length - len(seq)), (0, 0), (0, 0)), 'constant') for seq in sequences]

X = [data[0] for data in all_data]
y = [data[1] for data in all_data]

# Align sequence lengths
X_padded = pad_sequences(X, 100)

X_array = np.array(X_padded)
y_array = np.array(y)
shot_types = ['forehand_stroke','forehand_volley','forehand_slice','backhand_stroke','backhand_volley','backhand_slice']
y_onehot = np.eye(len(shot_types))[y_array]

In [None]:
import numpy as np
from tensorflow.keras.utils import Sequence
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, Dropout, AdditiveAttention, Concatenate
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

class DataGenerator(Sequence):
    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        return batch_x, batch_y

def create_lstm_attention_model(sequence_length, num_keypoints, num_coords, num_classes):
    inputs = Input(shape=(sequence_length, num_keypoints * num_coords))
    x = LSTM(128, return_sequences=True)(inputs)
    attention = AdditiveAttention()([x, x])
    x = Concatenate()([x, attention])
    x = LSTM(64, return_sequences=True)(x)
    attention = AdditiveAttention()([x, x])
    x = Concatenate()([x, attention])
    x = LSTM(32)(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.3)(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs, outputs)
    model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
    return model

X_train, X_test, y_train, y_test = train_test_split(X_array, y_onehot, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], -1)
X_val = X_val.reshape(X_val.shape[0], X_val.shape[1], -1)
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], -1)

train_generator = DataGenerator(X_train, y_train, batch_size=16)
val_generator = DataGenerator(X_val, y_val, batch_size=16)
test_generator = DataGenerator(X_test, y_test, batch_size=16)

sequence_length = X_train.shape[1]
num_keypoints = 17
num_coords = 2
num_classes = 6
model = create_lstm_attention_model(sequence_length, num_keypoints, num_coords, num_classes)

class_weights = compute_class_weight('balanced', classes=np.unique(np.argmax(y_train, axis=1)), y=np.argmax(y_train, axis=1))
class_weight_dict = dict(enumerate(class_weights))

early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)


history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=100,
    callbacks=[early_stopping],
    class_weight=class_weight_dict
)

test_loss, test_accuracy = model.evaluate(test_generator)
print(f"Test accuracy: {test_accuracy:.4f}")


In [None]:
model.save('swing_class.h5')

In [None]:
from tensorflow.keras.models import load_model

loaded_model = load_model('/content/drive/MyDrive/swing_class.h5')

predictions = loaded_model.predict(test_generator)


In [None]:
test_generator
np.argmax(predictions, axis=1)