In [2]:
pip install torch_geometric --upgrade

Note: you may need to restart the kernel to use updated packages.


In [3]:
from SoccerNet.Downloader import SoccerNetDownloader
mySoccerNetDownloader = SoccerNetDownloader(LocalDirectory="path/to/SoccerNet")
#mySoccerNetDownloader.downloadDataTask(task="tracking", split=["train","test","challenge"])
mySoccerNetDownloader.downloadDataTask(task="tracking-2023", split=["train", "test", "challenge"])


Downloading path/to/SoccerNet/tracking-2023/train.zip...: : 9.58GiB [04:49, 33.0MiB/s]                          
Downloading path/to/SoccerNet/tracking-2023/test.zip...: : 8.71GiB [05:05, 28.5MiB/s]                          
Downloading path/to/SoccerNet/tracking-2023/challenge2023.zip...: : 5.31GiB [03:12, 27.6MiB/s]                          


In [15]:
import pandas as pd
import configparser
from pathlib import Path

# Path to one sequence
base_path = Path("path/to/SoccerNet/tracking-2023/train/SNMOT-060")

# Load ground truth
gt_path = base_path / "gt" / "gt.txt"
gt = pd.read_csv(gt_path, header=None)
gt.columns = ["frame", "id", "x", "y", "w", "h", "conf", "class_id", "visibility","extra"]

# Load detections
det_path = base_path / "det" / "det.txt"
det = pd.read_csv(det_path, header=None)
det.columns = ["frame", "id", "x", "y", "w", "h", "conf", "class_id", "visibility","extra"]

config = configparser.ConfigParser()
config.optionxform = str  # preserve case of keys
config.read(base_path / "gameinfo.ini")

seq = configparser.ConfigParser()
config.optionxform = str  # preserve case of keys
seq.read(base_path / "seqinfo.ini")

ball_id = None
players_left = []
players_right = []
referees = []
goalkeepers_left = []
goalkeepers_right = []

for key, value in config["Sequence"].items():
    if key.startswith("trackletID_"):
        tid = int(key.replace("trackletID_", ""))

        role, role_info = value.split(";")

        # Ball
        if role == "ball":
            ball_id = tid

        # Player – two teams
        elif role == "player team left":
            players_left.append(tid)
        elif role == "player team right":
            players_right.append(tid)

        # Goalkeepers
        elif role == "goalkeepers team left" or role == "goalkeeper team left":
            goalkeepers_left.append(tid)
        elif role == "goalkeepers team right" or role == "goalkeeper team right":
            goalkeepers_right.append(tid)

        # Referees
        elif role.startswith("referee"):
            referees.append(tid)


In [16]:
print("Ball ID:", ball_id)
print("Left team players:", players_left)
print("Right team players:", players_right)
print("Left GK:", goalkeepers_left)
print("Right GK:", goalkeepers_right)
print("Referees:", referees)


Ball ID: 18
Left team players: [1, 2, 11, 12, 13, 15, 16, 19, 20, 21]
Right team players: [3, 4, 5, 6, 7, 8, 9, 10, 23, 24]
Left GK: [22]
Right GK: [25]
Referees: [14, 17, 26]


In [3]:
"""
GNN preprocessing pipeline for SoccerNet-tracking (tracking-2023 style)
Produces per-frame graphs (torch_geometric.data.Data) and saves them per split.

Expected folder layout for each sequence:
  <BASE_DIR>/<split>/<SEQUENCE_NAME>/
      gt/gt.txt
      det/det.txt
      gameinfo.ini
      seqinfo.ini

Outputs:
  outputs/graphs/<split>_graphs.pt  # list of Data objects (one per frame)
"""

import os
from pathlib import Path
import configparser
import pandas as pd
import numpy as np
from sklearn.neighbors import NearestNeighbors
import torch
from torch_geometric.data import Data
import tqdm
import json

# -----------------------------
# CONFIG
# -----------------------------
BASE_DIR = Path("path/to/SoccerNet/tracking-2023")  # change to your root
SPLITS = ["train", "test", "challenge"]
OUTPUT_DIR = Path("outputs/graphs")
K_NEIGHBORS = 6            # k for k-NN graph edges (excluding self loops)
USE_TEMPORAL_EDGES = False # set True to link same tracklet across consecutive frames
TEMPORAL_WINDOW = 1        # how many future frames to link (1 => t -> t+1)
INCLUDE_REFS = False       # include referees as nodes? usually False
DEVICE = "cpu"
# -----------------------------

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# -----------------------------
# HELPERS: parsers
# -----------------------------
def read_seqinfo(seqinfo_path: Path):
    config = configparser.ConfigParser()
    config.optionxform = str
    config.read(seqinfo_path)
    s = config["Sequence"]
    frame_rate = int(s.get("frameRate", 25))
    seq_length = int(s.get("seqLength", 0))
    im_w = int(s.get("imWidth", 1920))
    im_h = int(s.get("imHeight", 1080))
    return {"frame_rate": frame_rate, "seq_length": seq_length, "im_w": im_w, "im_h": im_h}

def read_gameinfo(gameinfo_path: Path):
    config = configparser.ConfigParser()
    config.optionxform = str
    config.read(gameinfo_path)
    seq = config["Sequence"]
    # parse trackletID_* entries
    mapping = {}  # tid -> role_info (role, id_string)
    for key, value in seq.items():
        if key.startswith("trackletID_"):
            tid = int(key.replace("trackletID_", ""))
            # value like " player team left;10"
            try:
                role, role_info = value.split(";", 1)
                role = role.strip()
                role_info = role_info.strip()
            except ValueError:
                role = value.strip()
                role_info = ""
            mapping[tid] = {"role": role, "role_info": role_info}
    # also get actionClass/actionPosition if present
    action_class = seq.get("actionClass", None)
    action_position = seq.get("actionPosition", None)
    return {"mapping": mapping, "action_class": action_class, "action_position": action_position}

def load_gt(gt_path: Path):
    # gt.txt format: frame,trackId,x,y,w,h,conf,unused,unused,unused
    df = pd.read_csv(gt_path, header=None)
    # ensure at least 7 columns; sometimes files have 10 columns MOT-like
    if df.shape[1] >= 7:
        cols = df.shape[1]
        # create generic names for first 7 and extras
        names = ["frame","track_id","x","y","w","h","conf"] + [f"c{i}" for i in range(8, cols+1)]
        df.columns = names[:df.shape[1]]
    else:
        df.columns = ["frame","track_id","x","y","w","h","conf"]
    # types
    df["frame"] = df["frame"].astype(int)
    df["track_id"] = df["track_id"].astype(int)
    return df

# -----------------------------
# Graph builder: per-frame
# -----------------------------
def build_graph_for_frame(nodes_df, roles_map, seq_info, k_neighbors=6, include_refs=False):
    """
    nodes_df: DataFrame with rows for tracklets present in this frame:
      columns: ['track_id','x','y','w','h','conf', ...]
    roles_map: dict track_id -> {'role':..., 'role_info':...}
    seq_info: {'im_w','im_h',...}
    """
    # if empty frame
    if nodes_df.shape[0] == 0:
        return None

    # normalize coordinates to [0,1]
    im_w, im_h = seq_info["im_w"], seq_info["im_h"]
    xs = nodes_df["x"].to_numpy().astype(float) / im_w
    ys = nodes_df["y"].to_numpy().astype(float) / im_h
    ws = nodes_df["w"].to_numpy().astype(float) / im_w
    hs = nodes_df["h"].to_numpy().astype(float) / im_h
    confs = nodes_df["conf"].to_numpy().astype(float)

    # role one-hot: player_left, player_right, goalkeeper_left, goalkeeper_right, ball, referee, other
    role_vectors = []
    track_ids = nodes_df["track_id"].to_numpy().astype(int).tolist()
    for tid in track_ids:
        r = roles_map.get(tid, {}).get("role", "").lower()
        vec = [
            1 if r == "player team left" else 0,
            1 if r == "player team right" else 0,
            1 if "goalkeep" in r else 0,
            1 if r == "ball" else 0,
            1 if r.startswith("referee") else 0,
            1 if r == "" else 0
        ]
        role_vectors.append(vec)
    role_vectors = np.array(role_vectors, dtype=float)

    # node features: [x,y,w,h,conf, role_onehot...]
    node_features = np.concatenate([
        xs.reshape(-1,1),
        ys.reshape(-1,1),
        ws.reshape(-1,1),
        hs.reshape(-1,1),
        confs.reshape(-1,1),
        role_vectors
    ], axis=1)

    # create k-NN edges
    # if nodes < k+1, connect fully
    N = node_features.shape[0]
    if N <= 1:
        edge_index = np.zeros((2,0), dtype=int)
    else:
        if N <= k_neighbors:
            # full directed edges excluding self-loops
            sources = []
            targets = []
            for i in range(N):
                for j in range(N):
                    if i != j:
                        sources.append(i); targets.append(j)
            edge_index = np.vstack([sources, targets])
        else:
            coords = np.vstack([xs, ys]).T
            nbrs = NearestNeighbors(n_neighbors=k_neighbors+1, algorithm="auto").fit(coords)
            distances, indices = nbrs.kneighbors(coords)
            # indices includes self at pos 0; skip it
            sources = []
            targets = []
            for i in range(N):
                for nb in indices[i,1:]:
                    sources.append(i)
                    targets.append(int(nb))
            edge_index = np.vstack([sources, targets])

    # convert to torch
    x = torch.tensor(node_features, dtype=torch.float)
    edge_index = torch.tensor(edge_index, dtype=torch.long) if edge_index.size else torch.empty((2,0), dtype=torch.long)

    return {
        "x": x,
        "edge_index": edge_index,
        "track_ids": track_ids,
        "raw_coords": np.vstack([xs, ys]).T
    }

# -----------------------------
# Main loop: build dataset for split
# -----------------------------
def process_split(split_dir: Path, out_path: Path, k_neighbors=6, include_refs=False):
    """
    Processes all sequences in split_dir and appends per-frame Data objects to list.
    Saves list to out_path.
    """
    data_list = []
    sequences = [p for p in split_dir.iterdir() if p.is_dir()]
    sequences = sorted(sequences)

    for seq_path in tqdm.tqdm(sequences, desc=f"Seqs in {split_dir.name}"):
        gt_path = seq_path / "gt" / "gt.txt"
        gameinfo_path = seq_path / "gameinfo.ini"
        seqinfo_path = seq_path / "seqinfo.ini"
        events_path = seq_path / "events.json"  # optional

        if not gt_path.exists() or not gameinfo_path.exists() or not seqinfo_path.exists():
            # skip if missing
            print("Skipping (missing files):", seq_path)
            continue

        seq_info = read_seqinfo(seqinfo_path)
        game_info = read_gameinfo(gameinfo_path)
        roles_map = game_info["mapping"]

        gt = load_gt(gt_path)

        # precompute velocity: group by track_id sorted by frame
        gt_sorted = gt.sort_values(["track_id","frame"])
        # compute per-row vx,vy using frame diffs (simple finite diff)
        gt_sorted["vx"] = 0.0
        gt_sorted["vy"] = 0.0
        for tid, group in gt_sorted.groupby("track_id"):
            frames = group["frame"].values
            xs = group["x"].values
            ys = group["y"].values
            # forward diff
            if len(frames) >= 2:
                dx = np.diff(xs) / np.maximum(1, np.diff(frames))  # pixels/frame
                dy = np.diff(ys) / np.maximum(1, np.diff(frames))
                # assign to second row onwards
                gt_sorted.loc[group.index[1:], "vx"] = dx
                gt_sorted.loc[group.index[1:], "vy"] = dy

        # create mapping frame -> rows
        frames = sorted(gt["frame"].unique())
        # optional events mapping
        events_map = load_events(events_path)

        # iterate frames
        for fr in frames:
            nodes_df = gt[gt["frame"] == fr].copy().reset_index(drop=True)
            # attach vx/vy from gt_sorted
            vxvy = gt_sorted[gt_sorted["frame"] == fr][["track_id","vx","vy"]]
            if not vxvy.empty:
                nodes_df = nodes_df.merge(vxvy, on="track_id", how="left")

            # optionally filter refs if include_refs False
            if not include_refs:
                # keep only players and ball and keep goalkeepers too
                keep_track_ids = [tid for tid,r in roles_map.items() if ("player" in r["role"] or "ball" in r["role"] or "goalkeep" in r["role"])]
                nodes_df = nodes_df[nodes_df["track_id"].isin(keep_track_ids)].reset_index(drop=True)
                if nodes_df.empty:
                    continue

            built = build_graph_for_frame(nodes_df, roles_map, seq_info, k_neighbors=k_neighbors, include_refs=include_refs)
            if built is None:
                continue

            x = built["x"]
            edge_index = built["edge_index"]

            # label: check events_map else None
            label = events_map.get(fr, None)

            # create Data object
            data = Data(x=x, edge_index=edge_index, y=torch.tensor([0]) if label is None else torch.tensor([1]))  # placeholder binary labeling
            # better: store raw metadata
            data.seq_name = seq_path.name
            data.frame = int(fr)
            data.track_ids = built["track_ids"]
            data.raw_coords = torch.tensor(built["raw_coords"], dtype=torch.float)
            # optional: store role mapping and seq info
            data.roles_map = roles_map
            data.seq_info = seq_info

            data_list.append(data)

    # save list
    torch.save(data_list, out_path)
    print(f"Saved {len(data_list)} graphs to {out_path}")

# -----------------------------
# Run for splits
# -----------------------------
if __name__ == "__main__":
    for split in SPLITS:
        split_dir = BASE_DIR / split
        if not split_dir.exists():
            print("Split not found, skipping:", split_dir)
            continue
        out_path = OUTPUT_DIR / f"{split}_graphs.pt"
        process_split(split_dir, out_path, k_neighbors=K_NEIGHBORS, include_refs=INCLUDE_REFS)


Seqs in train: 100%|██████████| 57/57 [03:48<00:00,  4.01s/it]


Saved 42750 graphs to outputs/graphs/train_graphs.pt


Seqs in test: 100%|██████████| 49/49 [03:14<00:00,  3.96s/it]


Saved 36750 graphs to outputs/graphs/test_graphs.pt
Split not found, skipping: path/to/SoccerNet/tracking-2023/challenge
