In [None]:
import pickle

import pandas as pd
from collections import defaultdict
from collections import Counter
import torch
from numpy import ndarray
from torch_geometric.data import Data


with open("data/raw/epl_2015.pkl", "rb") as f:
    loaded_list = pickle.load(f)

print(len(loaded_list))
# print(loaded_list[0])
events: pd.DataFrame = loaded_list[0]["events"]

events

In [None]:
def count_passes(
    team_events: pd.DataFrame,
) -> tuple[Counter[tuple[int, int]], list[str]]:
    """Count successful passes within a team over given events."""
    if team_events.empty:
        raise ValueError("No events found for team.")

    team_players: ndarray[str] = team_events["player"].dropna().unique()
    player_to_idx = {p: i for i, p in enumerate(team_players)}

    rows = team_events[["type", "outcome_type", "player"]].to_numpy()
    passes: list[tuple[int, int]] = []
    pass_from = None

    for event_type, outcome_type, player in rows:
        if event_type == "Pass" and outcome_type == "Successful":
            if player not in player_to_idx:
                continue
            if pass_from is not None and pass_from in player_to_idx:
                passes.append((player_to_idx[pass_from], player_to_idx[player]))
            pass_from = player
        else:
            pass_from = None

    pass_counts = Counter(passes)
    return pass_counts, team_players

In [None]:
def build_graph(
    pass_counts: Counter[tuple[int, int]],
    team_players: list[str],
    team: str,
    time_range: tuple[int, int],
) -> Data:
    if not pass_counts:
        raise ValueError("No passes found.")

    edge_index = (
        torch.tensor(list(pass_counts.keys()), dtype=torch.long).t().contiguous()
    )
    edge_weight = torch.tensor(list(pass_counts.values()), dtype=torch.float)

    # Node features (placeholder: just node indices)
    x = torch.arange(len(team_players), dtype=torch.float).unsqueeze(1)

    data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight)
    data.team = team
    data.time_range = time_range
    data.players = team_players

    return data

In [None]:
def build_team_graphs(events, time_interval=5):
    max_minute = events["minute"].max()
    teams = events["team"].dropna().unique()

    graphs_segmented = []
    graphs_cumulative = {}
    cumulative_pass_counts = defaultdict(Counter)
    cumulative_team_players = defaultdict(set)

    # Iterate over time intervals
    for start_minute in range(0, max_minute, time_interval):
        end_minute = min(start_minute + time_interval, max_minute + 1)
        events_in_interval = events[
            (events["minute"] >= start_minute)
            & (events["minute"] < end_minute)
            & (~events["type"].isin(["Start", "End", "FormationSet"]))
        ]

        for team in teams:
            team_events = events_in_interval[events_in_interval["team"] == team]
            pass_counts, team_players = count_passes(team_events)
            if not pass_counts:
                continue

            # Update cumulative stats
            for (i, j), count in pass_counts.items():
                p_from = team_players[i]
                p_to = team_players[j]
                cumulative_pass_counts[team][(p_from, p_to)] += count
            cumulative_team_players[team].update(team_players)

            # Create segmented graph
            data = build_graph(
                pass_counts, team_players, team, (start_minute, end_minute)
            )
            if data is not None:
                graphs_segmented.append(data)

    # Build cumulative graphs per team
    for team, pass_counts in cumulative_pass_counts.items():
        team_players = sorted(list(cumulative_team_players[team]))
        player_to_idx = {p: i for i, p in enumerate(team_players)}

        remapped = Counter(
            {
                (player_to_idx[i], player_to_idx[j]): count
                for (i, j), count in pass_counts.items()
                if i in player_to_idx and j in player_to_idx
            }
        )

        data = build_graph(remapped, team_players, team, (0, max_minute))
        graphs_cumulative[team] = data

    return graphs_segmented, graphs_cumulative

In [None]:
graphs_segmented, graphs_cumulative = build_team_graphs(events, time_interval=5)

print(f"Segmented graphs: {len(graphs_segmented)} total")
for g in graphs_segmented[:4]:
    print(
        f"  Team: {g.team}, Time: {g.time_range}, Nodes: {g.num_nodes}, Edges: {g.num_edges}"
    )

print("\nCumulative graphs:")
for team, g in graphs_cumulative.items():
    print(f"  Team: {team}, Nodes: {g.num_nodes}, Edges: {g.num_edges}")