# PlayerBERT Training (Masked Event Modeling)

This notebook:
- Loads the pre-trained `EventEncoder`.
- Builds per-player, per-match event sequences ordered by timestamp.
- Trains a **PlayerBERT** model using **Masked Event Modeling** (MSE on masked event embeddings).
- Keeps the EventEncoder frozen (for simplicity) but uses it as the embedding initializer.


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
from pathlib import Path
import json
import math
import random
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Paths (edit if needed)
DATA_PATH = Path('/content/drive/MyDrive/MLSE/events360_v4.jsonl')
EVENT_ENCODER_CKPT = Path('/content/drive/MyDrive/MLSE/models/event_encoder_mam.pt')
PLAYERBERT_OUT = Path('/content/drive/MyDrive/MLSE/models/playerbert_mam.pt')
PLAYERBERT_OUT.parent.mkdir(parents=True, exist_ok=True)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
print('Data:', DATA_PATH.resolve())
print('EventEncoder:', EVENT_ENCODER_CKPT.resolve())


Device: cuda
Data: /content/drive/MyDrive/MLSE/events360_v4.jsonl
EventEncoder: /content/drive/MyDrive/MLSE/models/event_encoder_mam.pt


In [3]:
# Robust JSONL reader

def iter_json_objects(fp):
    decoder = json.JSONDecoder()
    for line in fp:
        line = line.strip()
        if not line:
            continue
        idx = 0
        while idx < len(line):
            obj, end = decoder.raw_decode(line, idx)
            yield obj
            idx = end
            while idx < len(line) and line[idx].isspace():
                idx += 1


In [4]:
# EventEncoder components (must match train_event_encoder)

class PlayerMLP(nn.Module):
    def __init__(self, in_dim=6, hidden=64, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x):
        return self.net(x)


class SetEncoder(nn.Module):
    def __init__(self, player_dim=6, hidden=64, out_dim=128):
        super().__init__()
        self.player_mlp = PlayerMLP(in_dim=player_dim, hidden=hidden, out_dim=out_dim)

    def forward(self, freeze_frames, actor_locs, device):
        batch_embeds = []
        for ff, (ax, ay) in zip(freeze_frames, actor_locs):
            if ff is None or (hasattr(ff, '__len__') and len(ff) == 0):
                batch_embeds.append(torch.zeros(128, device=device))
                continue
            per_player = []
            for p in ff:
                loc = p.get('location')
                if loc is None or len(loc) < 2:
                    continue
                dx = float(loc[0]) - ax
                dy = float(loc[1]) - ay
                dist = math.sqrt(dx*dx + dy*dy)
                angle = math.atan2(dy, dx)
                is_teammate = 1.0 if p.get('teammate', False) else 0.0
                is_keeper = 1.0 if p.get('keeper', False) else 0.0
                vec = torch.tensor([dx, dy, dist, angle, is_teammate, is_keeper], device=device)
                per_player.append(vec)
            if not per_player:
                batch_embeds.append(torch.zeros(128, device=device))
                continue
            players = torch.stack(per_player, dim=0)
            emb = self.player_mlp(players).mean(dim=0)
            batch_embeds.append(emb)
        return torch.stack(batch_embeds, dim=0)


class EventTransformer(nn.Module):
    def __init__(self, vocab_sizes, d_model=128, nhead=4, num_layers=2):
        super().__init__()
        self.features = list(vocab_sizes.keys())
        self.safe_names = [f"f{i}" for i in range(len(self.features))]
        self.name_map = dict(zip(self.features, self.safe_names))
        self.value_embeds = nn.ModuleDict({
            self.name_map[f]: nn.Embedding(vocab_sizes[f], d_model) for f in self.features
        })
        self.feature_embeds = nn.Embedding(len(self.features), d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, feat_ids):
        B, F = feat_ids.shape
        tokens = []
        for i, f in enumerate(self.features):
            v = self.value_embeds[self.name_map[f]](feat_ids[:, i])
            f_emb = self.feature_embeds(torch.tensor(i, device=feat_ids.device))
            tokens.append(v + f_emb)
        x = torch.stack(tokens, dim=1)
        h = self.encoder(x)
        z_event = h.mean(dim=1)
        return z_event, h


class EventEncoder(nn.Module):
    def __init__(self, vocab_sizes):
        super().__init__()
        self.event_encoder = EventTransformer(vocab_sizes)
        self.frame_encoder = SetEncoder()
        self.gate = nn.Sequential(
            nn.Linear(128 * 2, 128),
            nn.Sigmoid(),
        )

    def forward(self, feat_ids, freeze_frames, actor_locs, device):
        z_event, _ = self.event_encoder(feat_ids)
        z_frame = self.frame_encoder(freeze_frames, actor_locs, device)
        g = self.gate(torch.cat([z_event, z_frame], dim=-1))
        z = g * z_event + (1 - g) * z_frame
        return z


In [5]:
# Load EventEncoder checkpoint

ckpt = torch.load(EVENT_ENCODER_CKPT, map_location='cpu')
feature_vocab = ckpt['feature_vocab']
vocab_sizes = {k: len(v) for k, v in feature_vocab.items()}

encoder = EventEncoder(vocab_sizes).to(device)
encoder.load_state_dict(ckpt['event_encoder'])
encoder.eval()

print('Loaded EventEncoder with', len(feature_vocab), 'features')


Loaded EventEncoder with 72 features


In [6]:
# Helper: build feature ids for a flattened event

UNK_TOKEN = '[UNK]'

FEATURE_LIST = list(feature_vocab.keys())


def build_feat_ids(ev):
    ids = []
    for feat in FEATURE_LIST:
        val = ev.get(feat, UNK_TOKEN)
        if isinstance(val, bool):
            val = str(val)
        if val is None:
            val = UNK_TOKEN
        idx = feature_vocab[feat].get(val, 0)
        # safety clamp
        if idx >= len(feature_vocab[feat]):
            idx = 0
        ids.append(idx)
    return torch.tensor(ids, dtype=torch.long)


In [23]:
# Debug check: verify feature id ranges

max_ids = [0]*len(FEATURE_LIST)

with DATA_PATH.open('r', encoding='utf-8') as f:
    for ev in iter_json_objects(f):
        ids = build_feat_ids(ev)
        for i, v in enumerate(ids.tolist()):
            if v > max_ids[i]:
                max_ids[i] = v

bad = []
for i, feat in enumerate(FEATURE_LIST):
    if max_ids[i] >= len(feature_vocab[feat]):
        bad.append((feat, max_ids[i], len(feature_vocab[feat])))

print('Bad features:', bad[:10])
print('Total bad features:', len(bad))


Bad features: []
Total bad features: 0


In [7]:
# Build per-player, per-match sequences ordered by timestamp

# Group events by (match_id, player_id)
sequences = defaultdict(list)

with DATA_PATH.open('r', encoding='utf-8') as f:
    for ev in iter_json_objects(f):
        match_id = ev.get('match_id')
        player_id = ev.get('player.id')
        if match_id is None or player_id is None:
            continue
        sequences[(match_id, player_id)].append(ev)

# Sort events by period, minute, second, timestamp
for key, events in sequences.items():
    events.sort(key=lambda e: (
        e.get('period', 0),
        e.get('minute', 0),
        e.get('second', 0),
        e.get('timestamp', ''),
        e.get('index', 0),
    ))

print('Total sequences:', len(sequences))


Total sequences: 2976


In [11]:
# Dataset with sliding windows

class SequenceDataset(Dataset):
    def __init__(self, sequences, max_len=256, stride=128):
        self.samples = []
        for key, events in sequences.items():
            if len(events) == 0:
                continue
            if len(events) <= max_len:
                self.samples.append(events)
            else:
                for i in range(0, len(events) - max_len + 1, stride):
                    self.samples.append(events[i:i+max_len])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


In [16]:
# Collate function: encode events with frozen EventEncoder

def collate_fn(batch):
    # filter out empty sequences
    batch = [seq for seq in batch if len(seq) > 0]
    if len(batch) == 0:
        return None, None

    seq_lens = [len(seq) for seq in batch]
    max_len = max(seq_lens)

    embeddings = []
    for seq in batch:
        feat_ids = torch.stack([build_feat_ids(ev) for ev in seq], dim=0)
        freeze_frames = [ev.get('freeze_frame') or [] for ev in seq]
        actor_locs = []
        for ev in seq:
            loc = ev.get('location')
            if loc is None or len(loc) < 2:
                actor_locs.append((0.0, 0.0))
            else:
                actor_locs.append((float(loc[0]), float(loc[1])))

        with torch.no_grad():
            z = encoder(
                feat_ids.to(device),
                freeze_frames,
                actor_locs,
                device,
            )
        embeddings.append(z.cpu())

    d = embeddings[0].shape[-1]
    padded = torch.zeros(len(batch), max_len, d)
    attn_mask = torch.zeros(len(batch), max_len, dtype=torch.bool)
    for i, emb in enumerate(embeddings):
        L = emb.shape[0]
        padded[i, :L, :] = emb
        attn_mask[i, :L] = 1

    return padded, attn_mask


In [13]:
# PlayerBERT model

class PlayerBERT(nn.Module):
    def __init__(self, embed_dim=128, nhead=4, num_layers=2, max_len=512):
        super().__init__()
        self.pos_embed = nn.Embedding(max_len, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.mask_token = nn.Parameter(torch.zeros(embed_dim))

    def forward(self, x, attn_mask, mask_positions=None):
        # x: (B, T, D)
        B, T, D = x.shape
        pos = torch.arange(T, device=x.device).unsqueeze(0).repeat(B, 1)
        x = x + self.pos_embed(pos)

        if mask_positions is not None:
            x = x.clone()
            x[mask_positions] = self.mask_token

        # src_key_padding_mask expects True for padding
        src_key_padding_mask = ~attn_mask
        h = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
        return h


In [14]:
# Masked Event Modeling on sequences

def mask_sequence(attn_mask, mask_prob=0.15):
    # attn_mask: (B, T) True for real tokens
    B, T = attn_mask.shape
    mask_positions = torch.zeros(B, T, dtype=torch.bool)
    for i in range(B):
        for j in range(T):
            if not attn_mask[i, j]:
                continue
            if random.random() < mask_prob:
                mask_positions[i, j] = True
    return mask_positions


In [17]:
# Training loop

seq_dataset = SequenceDataset(sequences, max_len=256, stride=128)
loader = DataLoader(seq_dataset, batch_size=16, shuffle=True, num_workers=0, collate_fn=collate_fn)

playerbert = PlayerBERT(embed_dim=128, nhead=4, num_layers=2, max_len=256).to(device)
optimizer = torch.optim.Adam(playerbert.parameters(), lr=1e-4)

playerbert.train()

for epoch in range(1):
    total_loss = 0.0
    for batch_emb, attn_mask in loader:
        if batch_emb is None:
            continue
        batch_emb = batch_emb.to(device)
        attn_mask = attn_mask.to(device)

        mask_pos = mask_sequence(attn_mask, mask_prob=0.15).to(device)
        outputs = playerbert(batch_emb, attn_mask, mask_positions=mask_pos)

        # MSE loss on masked positions
        target = batch_emb
        pred = outputs
        loss = ((pred - target) ** 2)[mask_pos].mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} loss: {total_loss/len(loader):.4f}")


Epoch 1 loss: 3.4717


In [18]:
# Save PlayerBERT

state = {
    'playerbert': playerbert.state_dict(),
}

torch.save(state, PLAYERBERT_OUT)
print('Saved PlayerBERT to', PLAYERBERT_OUT.resolve())


Saved PlayerBERT to /content/drive/MyDrive/MLSE/models/playerbert_mam.pt
