# GAT — Graph Construction
Transforms IMPECT event data into per-possession `torch_geometric.data.Data` objects for a Graph Attention Network.

**Graph design:**
- **Node** = one action (row) in a possession sequence
- **Real edge** i → i+1: the sequential action transition
- **Synthetic edge** i → i+2 and i → i+3: long-range synergy links

In [None]:
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from sklearn.preprocessing import LabelEncoder
from kloppy import impect

In [None]:
# Load the full event stream so we can detect possession changes correctly
MATCH_ID       = 122838
COMPETITION_ID = 743

dataset = impect.load_open_data(match_id=MATCH_ID, competition_id=COMPETITION_ID)

events_df = (
    dataset
    .transform(to_coordinate_system="secondspectrum")
    .to_df(engine="pandas")
)

# Use the full event stream — NO_VIDEO events are handled at graph-build time
df = events_df.copy().reset_index(drop=True)

print(f"Total events: {len(df)}")
print(f"Event types:  {sorted(df['event_type'].unique().tolist())}")
print(f"Columns:      {df.columns.tolist()}")

In [None]:
def parse_duration_to_seconds(ts) -> float:
    """
    Convert a Kloppy duration value to float seconds.
    Handles pd.Timedelta objects and string formats such as:
      '0\u00b5s', '4s 192999\u00b5s', '50m 29s 618999\u00b5s', '24s 350ms', '50m 39s 229ms'
    """
    if ts is None or (isinstance(ts, float) and np.isnan(ts)):
        return 0.0
    if isinstance(ts, pd.Timedelta):
        return ts.total_seconds()
    total = 0.0
    for part in str(ts).split():
        if   part.endswith("\u00b5s"):  total += float(part[:-2]) / 1_000_000
        elif part.endswith("ms"):        total += float(part[:-2]) / 1_000
        elif part.endswith("m"):         total += float(part[:-1]) * 60
        elif part.endswith("s"):         total += float(part[:-1])
    return round(total, 6)


# Sanity tests
assert abs(parse_duration_to_seconds("0\u00b5s")                - 0.0)        < 1e-9
assert abs(parse_duration_to_seconds("4s 192999\u00b5s")        - 4.192999)   < 1e-6
assert abs(parse_duration_to_seconds("50m 29s 618999\u00b5s")   - 3029.618999) < 1e-4
assert abs(parse_duration_to_seconds("24s 350ms")                - 24.350)     < 1e-6
assert abs(parse_duration_to_seconds("50m 39s 229ms")            - 3039.229)   < 1e-4
print("Timestamp parser OK")

In [None]:
# --- NaN fills ---
# team_id: forward-fill so events like NO_VIDEO inherit surrounding team context
df["team_id"]           = df["team_id"].ffill()
# player_id: some events (e.g. NO_VIDEO) have no player
df["player_id"]         = df["player_id"].fillna("UNKNOWN")
df["pass_type"]         = df["pass_type"].fillna("NONE")
df["is_under_pressure"] = df["is_under_pressure"].fillna(False).astype(bool)
df["result"]            = df["result"].fillna("NONE")
df["success"]           = df["success"].fillna(False).astype(bool)
# Fill start coords before using them to fill end coords
df["coordinates_x"]     = df["coordinates_x"].fillna(0.0)
df["coordinates_y"]     = df["coordinates_y"].fillna(0.0)
df["end_coordinates_x"] = df["end_coordinates_x"].fillna(df["coordinates_x"])
df["end_coordinates_y"] = df["end_coordinates_y"].fillna(df["coordinates_y"])

# --- Parse timestamps and sort chronologically ---
df["timestamp_sec"] = df["timestamp"].apply(parse_duration_to_seconds)
df = df.sort_values(["period_id", "timestamp_sec"]).reset_index(drop=True)

# --- Fit encoders on the full match (globally consistent IDs across all graphs) ---
player_enc     = LabelEncoder().fit(df["player_id"].astype(str))
event_type_enc = LabelEncoder().fit(df["event_type"].astype(str))
pass_type_enc  = LabelEncoder().fit(df["pass_type"].astype(str))
result_enc     = LabelEncoder().fit(df["result"].astype(str))

df["player_id_enc"]   = player_enc.transform(df["player_id"].astype(str))
df["event_type_enc"]  = event_type_enc.transform(df["event_type"].astype(str))
df["pass_type_enc"]   = pass_type_enc.transform(df["pass_type"].astype(str))
df["result_enc"]      = result_enc.transform(df["result"].astype(str))

print(f"Players:     {len(player_enc.classes_)} unique")
print(f"Event types: {list(event_type_enc.classes_)}")
print(f"Pass types:  {list(pass_type_enc.classes_)}")
print(f"Results:     {list(result_enc.classes_)}")

In [None]:
def segment_possessions(df: pd.DataFrame) -> list[pd.DataFrame]:
    """
    Split a match DataFrame (all events, sorted chronologically) into
    possession sequences.

    A new possession begins when:
      - team_id changes from the previous event, OR
      - the previous event was a SHOT

    Returns a list of DataFrames each with a reset 0-based index.
    """
    possessions   = []
    current_start = 0

    for i in range(1, len(df)):
        prev = df.iloc[i - 1]
        curr = df.iloc[i]

        if prev["team_id"] != curr["team_id"] or prev["event_type"] == "SHOT":
            possessions.append(df.iloc[current_start:i].reset_index(drop=True))
            current_start = i

    possessions.append(df.iloc[current_start:].reset_index(drop=True))
    return possessions


possessions = segment_possessions(df)
lengths     = [len(p) for p in possessions]

print(f"Possessions: {len(possessions)}")
print(f"Size — min: {min(lengths)}, max: {max(lengths)}, mean: {np.mean(lengths):.1f}")

In [None]:
import math

# --- Score constants ---
FIELD_HALF      = 52.5          # x-axis half-length (m)
FIELD_LENGTH    = 105.0         # total pitch length (m)
FIELD_WIDTH     = 68.0          # total pitch width (m)
# Max distance any point on the field can be from a goal (corner-to-goal diagonal)
MAX_FIELD_DIST  = (FIELD_LENGTH ** 2 + (FIELD_WIDTH / 2) ** 2) ** 0.5  # ≈ 110.4 m
SIGMOID_SHIFT   = 7.5           # sigmoid midpoint — gives ≈ 0 at t=0, ≈ 1 at t=15 s

# End-of-play event bonuses
GOAL_BONUS        =  1.0
SHOT_BONUS        =  0.3
OUT_OF_BOUNDS_PEN = -0.1
INTERCEPT_PEN     = -0.2


def get_attacking_goal_x(team_id: str, period_id: int, home_team_id: str) -> float:
    """
    Return the x-coordinate of the opponent's goal for the possessing team.

    Convention assumed for Kloppy secondspectrum + IMPECT Bundesliga data:
      Period 1 — home team attacks toward +x  (opponent's goal at x = +52.5)
      Period 2 — teams switch; home attacks toward -x (opponent's goal at x = -52.5)
    """
    is_home = str(team_id) == str(home_team_id)
    attacks_positive = (is_home and period_id == 1) or (not is_home and period_id == 2)
    return FIELD_HALF if attacks_positive else -FIELD_HALF


def score_play(play_clean: pd.DataFrame, attacking_goal_x: float) -> float:
    """
    Compute scalar play score P for one possession sequence (NO_VIDEO removed).

        P = (displacement_scaled * advancement_sq * sigmoid_t) + end_bonus

    Components
    ----------
    displacement_scaled : float [-1, 1]
        How much the ball moved toward the opponent's goal during the play,
        normalised by MAX_FIELD_DIST (~110.4 m). Negative when the ball moved
        away from goal, positive when it moved toward goal.

    advancement_sq : float [0, 1]
        Square of the ball's ending field position along the attack axis.
        0 = own goal line, 1 = opponent's goal line. Squaring heavily rewards
        plays that finish deep in the attacking third.

    sigmoid_t : float [0, 1]
        Logistic sigmoid of (t - 7.5), where t is the play duration in seconds.
        Starts near 0 at t = 0 and approaches 1 at t = 15 s, rewarding
        sustained possession.

    end_bonus : float
        SHOT resulting in GOAL  : +1.0
        SHOT (no goal)          : +0.3
        Ball out of bounds      : -0.1   (last result contains 'OUT')
        Interception            : -0.2   (last event_type == 'INTERCEPTION')
        Otherwise               :  0.0

    Parameters
    ----------
    play_clean : pd.DataFrame
        Possession rows with NO_VIDEO removed, sorted chronologically.
    attacking_goal_x : float
        x-coordinate of the opponent's goal (+52.5 or -52.5).

    Returns
    -------
    float : P
    """
    if len(play_clean) == 0:
        return 0.0

    first = play_clean.iloc[0]
    last  = play_clean.iloc[-1]

    x_start = float(first["coordinates_x"])
    y_start = float(first["coordinates_y"])
    x_end   = float(last["end_coordinates_x"])
    y_end   = float(last["end_coordinates_y"])
    goal_y  = 0.0

    # --- 1. Ball displacement toward goal [-1, 1] ---
    d_start = ((x_start - attacking_goal_x) ** 2 + (y_start - goal_y) ** 2) ** 0.5
    d_end   = ((x_end   - attacking_goal_x) ** 2 + (y_end   - goal_y) ** 2) ** 0.5
    displacement_scaled = min(1.0, (d_start - d_end) / MAX_FIELD_DIST)

    # --- 2. Ball ending field advancement [0, 1], squared ---
    # 0 = ball at own goal line, 1 = ball at opponent's goal line
    attack_sign    = 1.0 if attacking_goal_x > 0 else -1.0
    advancement    = (x_end * attack_sign + FIELD_HALF) / FIELD_LENGTH
    advancement    = max(0.0, min(1.0, advancement))
    advancement_sq = advancement ** 2

    # --- 3. Sigmoid of play duration: ≈ 0 at t=0, ≈ 1 at t=15 s ---
    t     = float(last["timestamp_sec"] - first["timestamp_sec"])
    sig_t = 1.0 / (1.0 + math.exp(-(t - SIGMOID_SHIFT)))

    mult = displacement_scaled * advancement_sq * sig_t

    # --- 4. End-of-play event bonus ---
    last_type   = str(last["event_type"])
    last_result = str(last["result"]).upper()

    if last_type == "SHOT":
        end_bonus = GOAL_BONUS if "GOAL" in last_result else SHOT_BONUS
    elif "OUT" in last_result:
        end_bonus = OUT_OF_BOUNDS_PEN
    elif last_type == "INTERCEPTION":
        end_bonus = INTERCEPT_PEN
    else:
        end_bonus = 0.0

    return mult + end_bonus

In [None]:
NO_VIDEO_THRESHOLD = 3.0  # seconds — discard a play if NO_VIDEO exceeds this


def build_graph(play: pd.DataFrame, home_team_id: str) -> Data | None:
    """
    Build a PyTorch Geometric Data object from one possession sequence.

    NO_VIDEO handling:
      - Duration of each event = gap to the next event's timestamp.
      - If the total NO_VIDEO duration in the play exceeds NO_VIDEO_THRESHOLD,
        the entire play is discarded (return None).
      - Otherwise, NO_VIDEO rows are removed before building the graph.

    Node features (6):
      [player_id_enc, coord_x, coord_y, event_type_enc, timestamp_rel, is_under_pressure]

    Edge attributes (7):
      [edge_type, pass_type_enc, end_x, end_y, action_distance, result_enc, success]
      edge_type: 1.0 = real sequential edge, 2.0 = synthetic long-range edge

    g.y : tensor([P]) — play score from score_play()
    """
    # --- NO_VIDEO check ---
    timestamps = play["timestamp_sec"].to_numpy()
    durations  = np.diff(timestamps)
    durations  = np.append(durations, 0.0)

    no_video_mask = (play["event_type"] == "NO_VIDEO").to_numpy()
    if no_video_mask.any():
        if durations[no_video_mask].sum() > NO_VIDEO_THRESHOLD:
            return None  # too much missing video — discard play

    # Remove NO_VIDEO rows; the remaining events form the graph
    play_clean = play[~no_video_mask].reset_index(drop=True)

    n = len(play_clean)
    if n < 2:
        return None

    play_start_sec = play_clean["timestamp_sec"].iloc[0]

    # --- Node features (N, 6) ---
    node_features = [
        [
            float(row["player_id_enc"]),
            float(row["coordinates_x"]),
            float(row["coordinates_y"]),
            float(row["event_type_enc"]),
            float(row["timestamp_sec"] - play_start_sec),
            float(row["is_under_pressure"]),
        ]
        for _, row in play_clean.iterrows()
    ]
    x = torch.tensor(node_features, dtype=torch.float)

    # --- Edges ---
    src, dst, attrs = [], [], []
    REAL_TYPE  = 1.0
    SYNTH_TYPE = 2.0

    for i in range(n):
        row = play_clean.iloc[i]
        x1, y1 = row["coordinates_x"],    row["coordinates_y"]
        x2, y2 = row["end_coordinates_x"], row["end_coordinates_y"]
        dist    = ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5

        real_attr  = [
            REAL_TYPE,
            float(row["pass_type_enc"]),
            float(x2), float(y2),
            float(dist),
            float(row["result_enc"]),
            float(row["success"]),
        ]
        synth_attr = [SYNTH_TYPE, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

        if i + 1 < n:  # real sequential edge
            src.append(i); dst.append(i + 1); attrs.append(real_attr)
        if i + 2 < n:  # synthetic skip-1
            src.append(i); dst.append(i + 2); attrs.append(synth_attr)
        if i + 3 < n:  # synthetic skip-2
            src.append(i); dst.append(i + 3); attrs.append(synth_attr)

    edge_index = torch.tensor([src, dst], dtype=torch.long)  # (2, E)
    edge_attr  = torch.tensor(attrs, dtype=torch.float)       # (E, 7)

    # --- Play score (training label) ---
    goal_x = get_attacking_goal_x(
        play_clean.iloc[0]["team_id"],
        play_clean.iloc[0]["period_id"],
        home_team_id,
    )
    y = torch.tensor([score_play(play_clean, goal_x)], dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

In [None]:
home_team_id = str(dataset.metadata.teams[0].team_id)

graphs  = []
skipped = 0

for play in possessions:
    g = build_graph(play, home_team_id)
    if g is None:
        skipped += 1
    else:
        graphs.append(g)

scores = [g.y.item() for g in graphs]
print(f"Graphs built:  {len(graphs)}")
print(f"Skipped:       {skipped}  (single-event plays or NO_VIDEO > {NO_VIDEO_THRESHOLD}s)")
print(f"Score — min: {min(scores):.3f}, max: {max(scores):.3f}, mean: {np.mean(scores):.3f}")

In [None]:
def validate_graphs(graphs: list[Data]) -> None:
    sizes = [g.num_nodes for g in graphs]
    edges = [g.num_edges for g in graphs]

    print(f"Total graphs: {len(graphs)}")
    print(f"Nodes — min: {min(sizes)}, max: {max(sizes)}, mean: {np.mean(sizes):.1f}")
    print(f"Edges — min: {min(edges)}, max: {max(edges)}, mean: {np.mean(edges):.1f}")

    for i, g in enumerate(graphs):
        assert g.x.shape[1] == 6,          f"Graph {i}: node feature width {g.x.shape[1]} != 6"
        assert g.edge_attr.shape[1] == 7,  f"Graph {i}: edge attr width {g.edge_attr.shape[1]} != 7"
        assert g.edge_index.shape[0] == 2, f"Graph {i}: edge_index shape {g.edge_index.shape}"
        assert g.edge_index.max().item() < g.num_nodes, f"Graph {i}: out-of-range node index"
        assert g.y.shape == (1,),          f"Graph {i}: y shape {g.y.shape} != (1,)"

        n = g.num_nodes
        expected = max(0, n - 1) + max(0, n - 2) + max(0, n - 3)
        assert g.num_edges == expected, (
            f"Graph {i}: expected {expected} edges for n={n}, got {g.num_edges}"
        )

    print("All assertions passed.")

    print("\nSample graphs:")
    for g in graphs[:5]:
        real  = (g.edge_attr[:, 0] == 1.0).sum().item()
        synth = (g.edge_attr[:, 0] == 2.0).sum().item()
        print(f"  nodes={g.num_nodes:>3}, edges={g.num_edges:>3} "
              f"(real={real:>2}, synthetic={synth:>2})  score={g.y.item():+.4f}")


validate_graphs(graphs)