# Model Training: Masked Attribute Modeling with 360 Fusion

This notebook trains an **EventEncoder** using masked attribute modeling (MAM). It includes:
1. Bucketized numerical features (already computed in preprocessing).
2. MLP for 360 frame features.
3. Transformer encoder for event tabular features.
4. Gated fusion of event + frame embeddings.
5. Masked attribute modeling training.
6. Save trained model.


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os
os.listdir('/content/drive/MyDrive/MLSE')

['events360_v4.jsonl', 'models']

In [4]:
from pathlib import Path
import json
import math
import random
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

DATA_PATH = Path('/content/drive/MyDrive/MLSE/events360_v4.jsonl')
MODEL_OUT = Path('/content/drive/MyDrive/MLSE/models/event_encoder_mam.pt')
MODEL_OUT.parent.mkdir(parents=True, exist_ok=True)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
print('Data path:', DATA_PATH.resolve())


Device: cuda
Data path: /content/drive/MyDrive/MLSE/events360_v4.jsonl


In [5]:
# Helper functions

def get_flat(d, key, default=None):
    return d.get(key, default)


def iter_json_objects(fp):
    decoder = json.JSONDecoder()
    for line in fp:
        line = line.strip()
        if not line:
            continue
        idx = 0
        while idx < len(line):
            obj, end = decoder.raw_decode(line, idx)
            yield obj
            idx = end
            while idx < len(line) and line[idx].isspace():
                idx += 1


In [6]:
# Feature schema (derived from flattened dataset)
# We keep categorical/bucketized features and exclude list fields and ids.

import json

# Fields to exclude (identifiers, lists, raw continuous fields)
EXCLUDE_PREFIXES = [
    'event_uuid',
    'id',
    'related_events',
    'freeze_frame',
    'visible_area',
]

EXCLUDE_CONTAINS = [
    'location',   # raw locations
    'end_location',
    'angle',
    'length',
    'statsbomb_xg',
]

# Include bucketized fields
INCLUDE_SUFFIXES = [
    'bucket',
    'bucket.label',
]

MASK_TOKEN = '[MASK]'
UNK_TOKEN = '[UNK]'

# Scan dataset to get flattened keys
cols = set()
with DATA_PATH.open('r', encoding='utf-8') as f:
    for ev in iter_json_objects(f):
        for k, v in ev.items():
            cols.add(k)

# Filter keys
EVENT_FEATURES = []
for k in sorted(cols):
    if any(k.startswith(p) for p in EXCLUDE_PREFIXES):
        continue
    if any(tok in k for tok in EXCLUDE_CONTAINS) and not any(suf in k for suf in INCLUDE_SUFFIXES):
        continue
    EVENT_FEATURES.append(k)

print('Derived EVENT_FEATURES:', len(EVENT_FEATURES))


Derived EVENT_FEATURES: 72


In [7]:
# Build vocab per feature
feature_vocab = {f: {UNK_TOKEN: 0, MASK_TOKEN: 1} for f in EVENT_FEATURES}

with DATA_PATH.open('r', encoding='utf-8') as f:
    for ev in iter_json_objects(f):
        for feat in EVENT_FEATURES:
            val = get_flat(ev, feat, UNK_TOKEN)
            # normalize booleans to string
            if isinstance(val, bool):
                val = str(val)
            if val is None:
                val = UNK_TOKEN
            if val not in feature_vocab[feat]:
                feature_vocab[feat][val] = len(feature_vocab[feat])

vocab_sizes = {k: len(v) for k, v in feature_vocab.items()}
print('Vocab sizes (sample):', list(vocab_sizes.items())[:5])


Vocab sizes (sample): [('50_50', 2), ('ball_receipt', 2), ('ball_receipt.outcome.id', 3), ('ball_receipt.outcome.name', 3), ('ball_recovery', 2)]


In [8]:
# Dataset

class EventDataset(Dataset):
    def __init__(self, path, feature_vocab):
        self.events = []
        with path.open('r', encoding='utf-8') as f:
            for ev in iter_json_objects(f):
                self.events.append(ev)
        self.feature_vocab = feature_vocab

    def __len__(self):
        return len(self.events)

    def __getitem__(self, idx):
        ev = self.events[idx]
        # event features as categorical indices
        feat_ids = []
        for feat in EVENT_FEATURES:
            val = get_flat(ev, feat, UNK_TOKEN)
            if isinstance(val, bool):
                val = str(val)
            if val is None:
                val = UNK_TOKEN
            feat_ids.append(self.feature_vocab[feat].get(val, 0))

        # frame features (numeric)
        frame = ev.get('frame_features')
        # If not already precomputed, fall back to zeros
        if not isinstance(frame, dict):
            frame_vec = [0.0] * len([])
        else:
            frame_vec = [float(frame.get(k, 0.0)) for k in []]

        return torch.tensor(feat_ids, dtype=torch.long), torch.tensor(frame_vec, dtype=torch.float32)

In [9]:
# Model components

class PlayerMLP(nn.Module):
    def __init__(self, in_dim=6, hidden=64, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x):
        return self.net(x)


class SetEncoder(nn.Module):
    def __init__(self, player_dim=6, hidden=64, out_dim=128):
        super().__init__()
        self.player_mlp = PlayerMLP(in_dim=player_dim, hidden=hidden, out_dim=out_dim)

    def forward(self, freeze_frames, actor_locs, device):
        batch_embeds = []
        for ff, (ax, ay) in zip(freeze_frames, actor_locs):
            # handle empty lists/arrays/tensors
            if ff is None or (hasattr(ff, '__len__') and len(ff) == 0):
                batch_embeds.append(torch.zeros(128, device=device))
                continue
            per_player = []
            for p in ff:
                loc = p.get('location')
                if loc is None or len(loc) < 2:
                    continue
                dx = float(loc[0]) - ax
                dy = float(loc[1]) - ay
                dist = math.sqrt(dx*dx + dy*dy)
                angle = math.atan2(dy, dx)
                is_teammate = 1.0 if p.get('teammate', False) else 0.0
                is_keeper = 1.0 if p.get('keeper', False) else 0.0
                vec = torch.tensor([dx, dy, dist, angle, is_teammate, is_keeper], device=device)
                per_player.append(vec)
            if not per_player:
                batch_embeds.append(torch.zeros(128, device=device))
                continue
            players = torch.stack(per_player, dim=0)
            emb = self.player_mlp(players).mean(dim=0)
            batch_embeds.append(emb)
        return torch.stack(batch_embeds, dim=0)


class EventTransformer(nn.Module):
    def __init__(self, vocab_sizes, d_model=128, nhead=4, num_layers=2):
        super().__init__()
        self.features = list(vocab_sizes.keys())
        # ModuleDict keys cannot include dots; use stable safe keys
        self.safe_names = [f"f{i}" for i in range(len(self.features))]
        self.name_map = dict(zip(self.features, self.safe_names))
        self.value_embeds = nn.ModuleDict({
            self.name_map[f]: nn.Embedding(vocab_sizes[f], d_model) for f in self.features
        })
        self.feature_embeds = nn.Embedding(len(self.features), d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, feat_ids):
        B, F = feat_ids.shape
        tokens = []
        for i, f in enumerate(self.features):
            v = self.value_embeds[self.name_map[f]](feat_ids[:, i])
            f_emb = self.feature_embeds(torch.tensor(i, device=feat_ids.device))
            tokens.append(v + f_emb)
        x = torch.stack(tokens, dim=1)
        h = self.encoder(x)
        z_event = h.mean(dim=1)
        return z_event, h


class EventEncoder(nn.Module):
    def __init__(self, vocab_sizes):
        super().__init__()
        self.event_encoder = EventTransformer(vocab_sizes)
        self.frame_encoder = SetEncoder()
        self.gate = nn.Sequential(
            nn.Linear(128 * 2, 128),
            nn.Sigmoid(),
        )

    def forward(self, feat_ids, freeze_frames, actor_locs, device):
        z_event, h_tokens = self.event_encoder(feat_ids)
        z_frame = self.frame_encoder(freeze_frames, actor_locs, device)
        g = self.gate(torch.cat([z_event, z_frame], dim=-1))
        z = g * z_event + (1 - g) * z_frame
        return z, h_tokens


In [10]:
# Masked Attribute Modeling

class MAMHead(nn.Module):
    def __init__(self, vocab_sizes):
        super().__init__()
        self.features = list(vocab_sizes.keys())
        self.safe_names = [f"f{i}" for i in range(len(self.features))]
        self.name_map = dict(zip(self.features, self.safe_names))
        self.heads = nn.ModuleDict({
            self.name_map[f]: nn.Linear(128, vocab_sizes[f]) for f in self.features
        })

    def forward(self, token_reprs):
        outs = {}
        for i, f in enumerate(self.features):
            outs[f] = self.heads[self.name_map[f]](token_reprs[:, i, :])
        return outs


def mask_inputs(feat_ids, mask_prob=0.15, vocab_sizes=None):
    B, F = feat_ids.shape
    labels = feat_ids.clone()
    masked = feat_ids.clone()

    for i in range(B):
        for j in range(F):
            if random.random() < mask_prob:
                rand = random.random()
                if rand < 0.8:
                    masked[i, j] = 1  # MASK token index
                elif rand < 0.9 and vocab_sizes is not None:
                    vsize = vocab_sizes[j]
                    masked[i, j] = random.randint(0, vsize - 1)
                else:
                    pass
            else:
                labels[i, j] = -100
    return masked, labels


In [11]:
# Training loop

dataset = EventDataset(DATA_PATH, feature_vocab)

def collate_fn(batch):
    feat_ids = torch.stack([b[0] for b in batch], dim=0)
    freeze_frames = [b[1] if len(b) > 1 else [] for b in batch]
    actor_locs = [b[2] if len(b) > 2 else (0.0, 0.0) for b in batch]
    return feat_ids, freeze_frames, actor_locs

loader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=0, collate_fn=collate_fn)

model = EventEncoder(vocab_sizes).to(device)
head = MAMHead(vocab_sizes).to(device)

optimizer = torch.optim.Adam(list(model.parameters()) + list(head.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

feature_list = list(vocab_sizes.keys())
feature_vocab_sizes = [vocab_sizes[f] for f in feature_list]

model.train()
head.train()

for epoch in range(3):
    total_loss = 0.0
    for feat_ids, freeze_frames, actor_locs in loader:
        feat_ids = feat_ids.to(device)

        masked, labels = mask_inputs(feat_ids, mask_prob=0.15, vocab_sizes=feature_vocab_sizes)
        masked = masked.to(device)
        labels = labels.to(device)

        _, token_reprs = model(masked, freeze_frames, actor_locs, device)
        logits = head(token_reprs)

        loss = 0.0
        for i, f in enumerate(feature_list):
            loss = loss + criterion(logits[f], labels[:, i])
        loss = loss / len(feature_list)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} loss: {total_loss/len(loader):.4f}")


Epoch 1 loss: 0.8700
Epoch 2 loss: 0.6477
Epoch 3 loss: 0.6143


In [12]:
# Save model

state = {
    'event_encoder': model.state_dict(),
    'mam_head': head.state_dict(),
    'feature_vocab': feature_vocab,
}

torch.save(state, MODEL_OUT)
print('Saved model to', MODEL_OUT.resolve())

Saved model to /content/drive/MyDrive/MLSE/models/event_encoder_mam.pt
