In [1]:
import os

In [3]:
from torch.utils.data import Dataset

In [4]:
import struct, sys

# Ball prediction dataset

A PyTorch loader for basketball data.

In [2]:
b = "/tmp/stephan/bball_data/predict_action/01590"

In [91]:
# Each file contains a play. The XXXX_gt.csv contains the ground truth of ball trajectory, which in format x1, y1, x2, y2...

# In the XXXX.csv file, there are 7 columns, which are:
# Time
# Team ID
# X
# Y
# Z
# Player ID
# Action label

# The -1 in team ID and player ID indicates the ball. You can concatenate the location with same player ID to get the trajectory.
import numpy as np

def load_tuples(fn, num):
  with open(fn, "rb") as f:
    print(fn)    
    t = 0
    pairs = []
    while True:
      t += 1
      bytes = f.read(num)      
      if not bytes:
        break
      pair = struct.unpack("{}b".format(num), bytes)
      pairs += [pair]
  return np.array(pairs)

def check_tuples(fn, num):
  with open(fn, "rb") as f:
    print(fn)    
    t = 0
    pairs = []
    while True:
      t += 1
      bytes = f.read(num)      
      if not bytes:
        break
      pair = struct.unpack("{}b".format(num), bytes)
      pairs += [pair]
      if pair[0] > 0:
        print(fn, t, pair)
  return np.array(pairs)

class BallPredictionDataset(Dataset):
    """8 second plays at 25 Hz. 
    Input are locations of 10 players, on a grid.
    Labels are the ball locations. 
    Actions are noted only for first 4 seconds."""
    def __init__(self, root_dir, transform=None):
        """
        Args:            
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.fn = os.listdir(root_dir)
        self.fn_x = [i for i in self.fn if not ("label" in i or "action" in i)] # pos
        self.fn_y = [i for i in self.fn if "label" in i] # ball 
        self.fn_a = [i for i in self.fn if "action" in i] # action
        assert(len(self.fn_x) == len(self.fn_y))
        assert(len(self.fn_x) == len(self.fn_a))
  
    def __len__(self):
        return len(self.fn_x)

    def __getitem__(self, idx):
        _xs = load_tuples(os.path.join(self.root_dir, self.fn_x[idx]), 2)
        _ys = load_tuples(os.path.join(self.root_dir, self.fn_y[idx]), 2)
        _as = load_tuples(os.path.join(self.root_dir, self.fn_a[idx]), 2)
        check_tuples(os.path.join(self.root_dir, self.fn_a[idx]), 2)
        
        sample = {'x': _xs, 'y': _ys, "a": _as}

        if self.transform:
            sample = self.transform(sample)

        return sample

In [92]:
bd = BallPredictionDataset(b)

In [None]:
for i in range(bd.__len__()):
  sample = bd.__getitem__(i)
  print(sample["x"].shape)
  print(sample["y"].shape)
  print(sample["a"].shape)

# Raw tracking dataset

In [5]:
b = "/cs/ml/datasets/bball/v1/bball_tracking"

In [112]:
# Each file contains a play. The XXXX_gt.csv contains the ground truth of ball trajectory, which in format x1, y1, x2, y2...

# In the XXXX.csv file, there are 7 columns, which are:
# Time
# Team ID
# X
# Y
# Z
# Player ID
# Action label

# The -1 in team ID and player ID indicates the ball. You can concatenate the location with same player ID to get the trajectory.
import numpy as np

BYTES_PER_INT = 4

def load_ints(fn, num, pos=0):
  with open(fn, "rb") as f:    
    f.seek(pos * BYTES_PER_INT, 0)
    bytes = f.read(num * BYTES_PER_INT)
    tup = struct.unpack("{}i".format(num), bytes)    
  return np.array(tup)

class ShotsDataset(Dataset):
    """8 second plays at 25 Hz. 
    Input are locations of ballhandler on the court and defenders in grid around bh.
    Labels are whether bh shot ("strong") or not ("weak labels"). 
    
    Note that sequence structure is lost in this dataset: we can only train memory-less models
    on this data.
    
    There are 439 unique ballhandlers.
    
    Feature data bh: [serial_idx] position of ballhandler 
    Feature data bh: [serial_idx0 ... serial_idx4]. Could be -1 if there is no defender there. At most 5 defenders (= size of team)
    Ground truth: [ballhandler_id frame_id action_label]
    
    Grid dims:
    
    Ballhandler:
    10: 4x5
    5: 8x10 
    2: 20x25
    1: 40x50
    
    Defender: 
    2:  6x6 (+1 bias) 
    1:  12x12 (+1 bias)
    
    "NumberOfFrames"                              : 15740724,
    "NumberOfStrongLabels_Test"                   : 0,
    "NumberOfStrongLabels_Train"                  : 1160723,
    "NumberOfStrongLabels_TrainVal"               : 1289692,
    "NumberOfStrongLabels_Val"                    : 128969,
    "NumberOfWeakLabels_Test"                     : 0,
    "NumberOfWeakLabels_Train"                    : 13005929,
    "NumberOfWeakLabels_TrainVal"                 : 14451032,
    "NumberOfWeakLabels_Val"                      : 1445103
    """
    def __init__(self, root_dir, transform=None):
        """
        Args:            
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.feat = os.listdir(os.path.join(root_dir, "features"))
        self.gt = os.listdir(os.path.join(root_dir, "groundtruth"))
        fp_feat = os.path.join(root_dir, "features")
        fp_gt = os.path.join(root_dir, "groundtruth")
        
        self.feat_bh = {
          "1": os.path.join(fp_feat, "ballh_c1.bin"),
          "2": os.path.join(fp_feat, "ballh_c2.bin"),
          "5": os.path.join(fp_feat, "ballh_c5.bin"),
          "10": os.path.join(fp_feat, "ballh_c10.bin")
        }        
        
        self.feat_def = {
          "1": os.path.join(fp_feat, "feat_defender_occupancy_raw_lvl1.bin"),
          "2": os.path.join(fp_feat, "feat_defender_occupancy_coarsegrid_lvl2.bin")
        }
        self.res_bh = "1"
        self.res_def = "1"
        self.gt_pos = os.path.join(fp_gt, "gtruth_ballhandler_strong.bin")
        self.gt_neg = os.path.join(fp_gt, "gtruth_ballhandler_weak.bin")
        self.num_pos = int(os.path.getsize(self.gt_pos) / BYTES_PER_INT / 3)
        self.num_neg = int(os.path.getsize(self.gt_neg) / BYTES_PER_INT / 3) # num of integers
        
    def __len__(self):
        return self.num_neg + self.num_pos

    def __getitem__(self, idx):
        # print("Seeking to", idx, "/", self.num_neg, self.__len__())
        f_bh = load_ints(self.feat_bh[self.res_bh], 1, pos=idx)
        f_def = load_ints(self.feat_def[self.res_def], 5, pos=5*idx) 
        
        assert(idx < self.num_neg + self.num_pos)
        if idx >= self.num_neg:
            pos = 3 * (idx - self.num_neg)
            y = load_bytes(self.gt_pos, 3, pos=pos)
        else:
            pos = 3 * idx
            y = load_bytes(self.gt_neg, 3, pos=pos)
                      
        sample = {'f_bh': f_bh, 'f_def': f_def, "y": y}

        if self.transform:
            sample = self.transform(sample)

        return sample

In [113]:
sd = ShotsDataset(b)

In [117]:
for i in range(100):
  print(sd.__getitem__(sd.num_neg + i))

Seeking to 14451032 / 14451032 15740724
{'f_bh': array([1513]), 'f_def': array([-1, -1, -1, -1, 80]), 'y': array([302,  73,   1])}
Seeking to 14451033 / 14451032 15740724
{'f_bh': array([1464]), 'f_def': array([-1, -1, -1, -1, 80]), 'y': array([302,  74,   1])}
Seeking to 14451034 / 14451032 15740724
{'f_bh': array([1464]), 'f_def': array([-1, -1, -1, -1, 92]), 'y': array([302,  75,   1])}
Seeking to 14451035 / 14451032 15740724
{'f_bh': array([1414]), 'f_def': array([-1, -1, -1, -1, 92]), 'y': array([302,  76,   1])}
Seeking to 14451036 / 14451032 15740724
{'f_bh': array([1364]), 'f_def': array([-1, -1, -1, -1, 91]), 'y': array([302,  77,   1])}
Seeking to 14451037 / 14451032 15740724
{'f_bh': array([1364]), 'f_def': array([-1, -1, -1, -1, 91]), 'y': array([302,  78,   1])}
Seeking to 14451038 / 14451032 15740724
{'f_bh': array([1314]), 'f_def': array([-1, -1, -1, -1, 91]), 'y': array([302,  79,   1])}
Seeking to 14451039 / 14451032 15740724
{'f_bh': array([1264]), 'f_def': array([-1,