In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

Get hit type

In [16]:
import json

with open('./hit_types.json', 'r') as file:
    hit_types = json.load(file)

hit_types

['clear', 'drive', 'drop', 'lob', 'net shot', 'push/rush', 'smash']

In [17]:
import numpy as np

def idx_to_onehot(idx, num_classes):
    """
    Converts an index to a one-hot encoded vector with the first position reserved for null.
    
    Parameters:
    - idx: The index to convert, with 0 being the first actual category.
    - num_classes: The total number of categories excluding the null category.
    
    Returns:
    - A one-hot encoded vector with size (num_classes + 1) to include the null category.
    """
    # Initialize a vector of zeros with length num_classes + 1 (for the null category)
    onehot = np.zeros(num_classes + 1)
    # Increment the idx by 1 to reserve the first position for null and set the appropriate position to 1
    onehot[idx + 1] = 1
    return onehot


In [18]:
import torch
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime
import pandas as pd
import torch
from torch.utils.data import Dataset
import numpy as np
from torch.utils.data import DataLoader
import os

'''
human keypoints: 17*2
court keypoints: 6*2
net keypoints:   4*2
ball keypoints:  1*2
total keypoints: 28*2

'''

class HitDataset(Dataset):
    def __init__(self, dataset_folder,num_consecutive_frames,normalization=True):
        self.dataset=[]
        self.positive=0
        self.negative=0
        # 遍历文件夹及其子文件夹，找到所有的CSV文件路径
        cnt=0
        for root, dirs, files in os.walk(dataset_folder):
            for file in files:
                if file.endswith(".csv"):
                    # 定义处理函数
                    data_path=os.path.join(root, file)
                    print(data_path)

                    try:
                      df = pd.read_csv(data_path, converters={"ball": eval,"top":eval,"bottom":eval,"court":eval,"net":eval})
                    except:
                      print('Error! cannot process: ', data_path)
                      continue

                    rows = len(df)
                    remainder = rows % num_consecutive_frames
                    if remainder > 0:
                        num_to_pad = num_consecutive_frames - remainder
                    else:
                        num_to_pad = 0

                    if num_to_pad > 0:
                        last_row = df.iloc[-1]
                        padding_data = np.tile(last_row.values, (num_to_pad, 1))
                        padded_df = pd.DataFrame(padding_data, columns=df.columns)
                        df = pd.concat([df, padded_df], axis=0)
                        df = df.reset_index(drop=True)

                    small_dataset =df

                    for i in range(len(small_dataset)):
                        if i%num_consecutive_frames!=0:
                            continue
                        if i>=len(small_dataset)-num_consecutive_frames:
                            break
                        oridata=small_dataset.loc[i:i+num_consecutive_frames-1,:].copy()
                        oridata=oridata.reset_index(drop=True)
                        data=[]

                        target1=None

                        for index, row in oridata.iterrows():
                            pos=np.array(row['pos'])
                            hit = row["type"]
                            if hit in hit_types:
                                if str(pos)=='top':
                                    if target1 is None:
                                        target1 = idx_to_onehot(hit_types.index(hit), len(hit_types))
                                        self.positive+=1
                                elif str(pos)=='bottom':
                                    if target1 is None:
                                        target1= idx_to_onehot(hit_types.index(hit), len(hit_types))
                                        self.positive+=1


                            top=np.array(row['top']).reshape(-1,2)
                            bottom=np.array(row['bottom']).reshape(-1,2)
                            court=np.array(row['court']).reshape(-1,2)
                            ball=np.array(row['ball']).reshape(-1,2)

                            frame_data = np.concatenate((top, bottom, court, ball), axis=0)

                            if normalization:
                                frame_data[:,0]/=1920
                                frame_data[:,1]/=1080
                            data.append(frame_data.reshape(1,-1))
                        data=np.array(data)
                        if target1 is None:
                            if self.negative>self.positive:
                                continue
                            target1=[1] + ([0]*len(hit_types))
                            self.negative+=1
                        self.dataset.append((data.reshape(-1),target1))
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        # 假设每个样本是一个元组 (input, target)
        sample = self.dataset[index]
        input_data = sample[0]
        target1 = sample[1]

        # 转换为Tensor对象
        input_tensor = torch.tensor(input_data)
        target1_tensor = torch.tensor(target1)

        return input_tensor, target1_tensor


num_consecutive_frames=30
batch_size=30
shuffle=True
normalization=True

TrainDataset=HitDataset("./train",num_consecutive_frames,normalization)
ValidDataset=HitDataset("./valid",num_consecutive_frames,normalization)

./train\Akane_YAMAGUCHI_AN_Se_Young_DAIHATSU_YONEX_Japan_Open_2022_Finals\rally_1-1.csv
./train\Akane_YAMAGUCHI_AN_Se_Young_DAIHATSU_YONEX_Japan_Open_2022_Finals\rally_1-10.csv
./train\Akane_YAMAGUCHI_AN_Se_Young_DAIHATSU_YONEX_Japan_Open_2022_Finals\rally_1-12.csv
./train\Akane_YAMAGUCHI_AN_Se_Young_DAIHATSU_YONEX_Japan_Open_2022_Finals\rally_1-15.csv
./train\Akane_YAMAGUCHI_AN_Se_Young_DAIHATSU_YONEX_Japan_Open_2022_Finals\rally_1-19.csv
./train\Akane_YAMAGUCHI_AN_Se_Young_DAIHATSU_YONEX_Japan_Open_2022_Finals\rally_1-20.csv
./train\Akane_YAMAGUCHI_AN_Se_Young_DAIHATSU_YONEX_Japan_Open_2022_Finals\rally_1-21.csv
./train\Akane_YAMAGUCHI_AN_Se_Young_DAIHATSU_YONEX_Japan_Open_2022_Finals\rally_1-24.csv
./train\Akane_YAMAGUCHI_AN_Se_Young_DAIHATSU_YONEX_Japan_Open_2022_Finals\rally_1-25.csv
./train\Akane_YAMAGUCHI_AN_Se_Young_DAIHATSU_YONEX_Japan_Open_2022_Finals\rally_1-26.csv
./train\Akane_YAMAGUCHI_AN_Se_Young_DAIHATSU_YONEX_Japan_Open_2022_Finals\rally_1-27.csv
./train\Akane_YAMAGUCH

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ShotTypeModel(nn.Module):
    def __init__(self, feature_dim, num_consecutive_frames, num_classes):
        super(ShotTypeModel, self).__init__()
        self.num_consecutive_frames = num_consecutive_frames
        self.feature_dim = feature_dim

        # Change GRU to LSTM
        self.lstm1 = nn.LSTM(feature_dim // num_consecutive_frames, 64, bidirectional=True, batch_first=True)
        self.lstm2 = nn.LSTM(128, 64, bidirectional=True, batch_first=True)
        self.global_maxpool = nn.MaxPool1d(num_consecutive_frames)
        self.dense = nn.Linear(128, num_classes)  # Output layer for binary classification

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.float()
        # Reshape input data
        x = x.view(batch_size, self.num_consecutive_frames, self.feature_dim // self.num_consecutive_frames)
        # Apply LSTM layers
        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)
        # Apply global max pooling and dense layer
        x = x.transpose(1, 2)
        x = self.global_maxpool(x).squeeze()
        x = self.dense(x)
        return x  # Output is now logits without softmax

feature_dim=82*num_consecutive_frames
print(TrainDataset.positive,TrainDataset.negative)
print(ValidDataset.positive,ValidDataset.negative)

train_data_loader = DataLoader(TrainDataset, batch_size=batch_size, shuffle=shuffle)
valid_data_loader = DataLoader(ValidDataset, batch_size=batch_size, shuffle=shuffle)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=ShotTypeModel(feature_dim, num_consecutive_frames, len(hit_types)+1)
criterion = nn.CrossEntropyLoss()
model.to(device)
criterion.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 100

11625 3746
741 266


In [20]:
train_loss_list = []
valid_loss_list=[]
best_acc = 0
for epoch in range(num_epochs):

    train_loss_sum=0
    model.train()
    for batch_data in train_data_loader:
        inputs, labels = batch_data
        inputs = inputs.to(device)  # 将输入数据移动到设备上
        labels = labels.to(device).float()  # 将输入数据移动到设备上

        outputs = model(inputs)
        outputs=outputs.reshape(-1, len(hit_types)+1)

        train_loss = criterion(outputs,labels)
        train_loss_sum+=train_loss.detach()
        # 反向传播和优化
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        # 执行你的训练或测试操作
    train_loss_list.append(train_loss_sum)


    # 打印训练信息
    if (epoch + 1) % 1== 0:
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs,
                                                    train_loss.item()))
    with torch.no_grad():
        correct = 0
        total = 0
        for inputs, labels in valid_data_loader:
            inputs = inputs.to(device)  # 将输入数据移动到设备上
            labels= labels.to(device).float()
            # 前向传播
            outputs = model(inputs)
            outputs=outputs.reshape(-1, len(hit_types)+1)


            y_true=torch.argmax(labels,axis=1)
            y_pred=torch.argmax(outputs,axis=1)
            total+=len(y_true)
            correct+=(y_true==y_pred).sum().item()
        if total==0:
            print(f'Accuracy on test set: {0}')
            continue



        print(f'Accuracy on test set: {correct/total}')

        if correct/total > best_acc:
            best_acc = correct/total
            torch.save(model, './shot_detect.pth')


Epoch [1/100], Loss: 1.5031
Accuracy on test set: 0.407149950347567
Epoch [2/100], Loss: 1.7282
Accuracy on test set: 0.5471698113207547
Epoch [3/100], Loss: 0.8852
Accuracy on test set: 0.548162859980139
Epoch [4/100], Loss: 1.1773
Accuracy on test set: 0.6236345580933466
Epoch [5/100], Loss: 0.6306
Accuracy on test set: 0.6673286991062563
Epoch [6/100], Loss: 1.2748
Accuracy on test set: 0.6355511420059583
Epoch [7/100], Loss: 0.6766
Accuracy on test set: 0.6564051638530288
Epoch [8/100], Loss: 1.7603
Accuracy on test set: 0.6643495531281033
Epoch [9/100], Loss: 1.6043
Accuracy on test set: 0.5938430983118173
Epoch [10/100], Loss: 1.0141
Accuracy on test set: 0.6941410129096326
Epoch [11/100], Loss: 0.8396
Accuracy on test set: 0.7269116186693148
Epoch [12/100], Loss: 0.7918
Accuracy on test set: 0.7219463753723933
Epoch [13/100], Loss: 1.1260
Accuracy on test set: 0.7080436941410129
Epoch [14/100], Loss: 0.9820
Accuracy on test set: 0.7636544190665343
Epoch [15/100], Loss: 0.9160
Ac