In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [2]:
# update colab repo
!rm -rf /content/tennis_vision/
!git clone https://github.com/relja456/tennis_vision.git

Cloning into 'tennis_vision'...
remote: Enumerating objects: 71, done.[K
remote: Counting objects: 100% (27/27), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 71 (delta 8), reused 20 (delta 5), pack-reused 44 (from 2)[K
Receiving objects: 100% (71/71), 174.40 MiB | 20.74 MiB/s, done.
Resolving deltas: 100% (24/24), done.


In [3]:
import torch
from torch.cuda.amp import GradScaler, autocast
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from tennis_vision.dataset_init.dataset_ball_only_triplet_frames import DatasetBallOnlyTripletFrames
from tennis_vision.models.tracknet.tracknet import TrackNet

In [5]:
import gdown
gdown.download(f"https://drive.google.com/uc?id=1sUbc-TXS0pVA9GG84qmRgesOuNYwT9xk", 'gdrive_data', quiet=False)


Downloading...
From (original): https://drive.google.com/uc?id=1sUbc-TXS0pVA9GG84qmRgesOuNYwT9xk
From (redirected): https://drive.google.com/uc?id=1sUbc-TXS0pVA9GG84qmRgesOuNYwT9xk&confirm=t&uuid=a048e94d-182e-4ec3-a23d-16ef1e17050c
To: /content/gdrive_data
100%|██████████| 2.57G/2.57G [00:45<00:00, 56.4MB/s]


'gdrive_data'

In [8]:
!unzip -q "/content/gdrive_data" -d "/content/gdrive_ds"

In [4]:
# mount my gdrive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
model = TrackNet(in_channels=9, num_bins=1).to(device)

  self.scaler = GradScaler(enabled=torch.cuda.is_available())


In [11]:
DATASET_ROOT_PATH = '/content/gdrive_ds/ball_only'


train_ds = DatasetBallOnlyTripletFrames(DATASET_ROOT_PATH, out_size=(352, 640))
val_split = int(len(train_ds) * 0.1)
train_subset, val_subset = torch.utils.data.random_split(
    train_ds, [len(train_ds) - val_split, val_split], generator=torch.Generator().manual_seed(42)
)

train_loader = DataLoader(train_subset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True, persistent_workers=True)
val_loader = DataLoader(val_subset, batch_size=16, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)

In [12]:
model.train_override(5, train_loader, val_loader)

Training:   0%|          | 0/1116 [00:00<?, ?it/s]

Validation:   0%|          | 0/124 [00:00<?, ?it/s]

epoch 01 | train 0.0889 | val 0.0213
  ↳ saved /content/tennis_vision/models/tracknet/tracknet_best.pth


Training:   0%|          | 0/1116 [00:00<?, ?it/s]

Validation:   0%|          | 0/124 [00:00<?, ?it/s]

epoch 02 | train 0.0101 | val 0.0051
  ↳ saved /content/tennis_vision/models/tracknet/tracknet_best.pth


Training:   0%|          | 0/1116 [00:00<?, ?it/s]

Validation:   0%|          | 0/124 [00:00<?, ?it/s]

epoch 03 | train 0.0037 | val 0.0028
  ↳ saved /content/tennis_vision/models/tracknet/tracknet_best.pth


Training:   0%|          | 0/1116 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [6]:
import os, csv, cv2, math
import numpy as np
import torch
import torch.nn.functional as F

@torch.inference_mode()
def infer_and_save_video_from_file(
    model,
    video_path,
    weights="/content/tracknet_best.pth",
    out_path="/content/output_pp.mp4",
    fps=None,                     # if None, use source fps
    heat_alpha=0.40,
    # peak finding
    base_thresh=0.35,
    nms_ks=7,
    topk=10,
    # normalization / dynamic threshold
    use_zscore=True,
    perc_floor=97.5,
    # look-ahead global path (beam search)
    beam_width=6,
    w_conf=1.0,
    w_vel=0.06,
    w_acc=0.01,
    max_step_px=110,
    # batching
    batch_size=8,
    # ROI mask
    roi_mask=None,                # numpy bool/float mask; either (H,W) or (orig_h,orig_w)
    # drawing
    trail_len=3,
    dot_radius=3,
    # outputs
    save_csv=None,
    last_n=None,                  # if set, only process last N usable frames
    # sizes / preprocessing
    model_in_size=(352, 640),     # (H, W) input size for model triplets
    heat_out_size=(352, 640),     # (H, W) size to which logits are resized for peak finding
    preprocess_fn=None,           # callable(frames_bgr_3:list[np.ndarray], in_size:(H,W)) -> torch.FloatTensor (9,H,W)
):
    """
    Same global tracking logic as your dataset-based function, but reads from a video file.

    - Builds 3-frame windows (t-2, t-1, t) -> 9ch tensor.
    - Runs model in batches.
    - Extracts peaks with z-score + percentile floor + NMS.
    - Beam-search across frames with velocity prior and acceleration penalty.
    - Subpixel refine on the raw (pre-threshold) prob map.
    - Overlays heatmap and draws a green trail; saves MP4 at original resolution.
    - Optionally writes per-frame (x,y) CSV in original resolution.

    NOTE: For best results, pass a `preprocess_fn` that matches your dataset’s transforms exactly.
    """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ---------- helpers ----------
    def load_best_weights(m, path, dev):
        ckpt = torch.load(path, map_location=dev)
        state = ckpt.get("model", ckpt)
        m.load_state_dict(state, strict=True)
        m.to(dev).eval()

    def overlay_heatmap(orig_bgr, prob01, alpha=0.4):
        hm = (prob01 * 255.0).clip(0,255).astype(np.uint8)
        hm = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
        return cv2.addWeighted(orig_bgr, 1.0, hm, alpha, 0.0)

    def prep_mask(mask, H, W, orig_h, orig_w):
        if mask is None:
            return None
        m = mask
        if m.shape == (orig_h, orig_w):
            m = cv2.resize(m.astype(np.float32), (W, H), interpolation=cv2.INTER_NEAREST)
        elif m.shape != (H, W):
            raise ValueError("roi_mask must be (H,W) or (orig_h,orig_w)")
        m = (m > 0.5).astype(np.float32)
        return m

    def nms_peaks(prob2d, thresh, nms=7, topk=10):
        p = prob2d[None, None]
        p = F.avg_pool2d(p, 3, 1, 1)  # denoise
        m = F.max_pool2d(p, kernel_size=nms, stride=1, padding=nms//2)
        keep = (p == m) & (p > thresh)
        ys, xs = torch.where(keep[0,0])
        if ys.numel() == 0:
            return []
        scores = p[0,0,ys,xs]
        order = torch.argsort(scores, descending=True)[:topk]
        xs = xs[order].int().tolist()
        ys = ys[order].int().tolist()
        sc = scores[order].float().tolist()
        return list(zip(xs, ys, sc))

    def subpixel_refine(prob_np, x, y):
        H, W = prob_np.shape
        x0, x1 = max(0,x-1), min(W-1,x+1)
        y0, y1 = max(0,y-1), min(H-1,y+1)
        patch = prob_np[y0:y1+1, x0:x1+1]
        s = patch.sum()
        if patch.size == 0 or s < 1e-6:
            return float(x), float(y)
        ys, xs = np.mgrid[y0:y1+1, x0:x1+1]
        w = patch / (s + 1e-6)
        return float((w*xs).sum()), float((w*ys).sum())

    def vel_prior_reward(prev_pos, prev_vel, cand_pos, sigma=25.0):
        if prev_pos is None or prev_vel is None:
            return 0.0
        px, py = prev_pos
        vx, vy = prev_vel
        tx, ty = px + vx, py + vy
        dx, dy = cand_pos[0] - tx, cand_pos[1] - ty
        d2 = dx*dx + dy*dy
        return math.exp(-0.5 * d2 / (sigma*sigma))  # 0..1

    def default_preprocess(frames3_bgr, in_size):
        # Resize each to (H,W), convert BGR->RGB, scale to [0,1], stack to (9,H,W)
        H, W = in_size
        arrs = []
        for fr in frames3_bgr:
            fr = cv2.resize(fr, (W, H), interpolation=cv2.INTER_LINEAR)
            fr = cv2.cvtColor(fr, cv2.COLOR_BGR2RGB)
            fr = fr.astype(np.float32) / 255.0
            arrs.append(fr.transpose(2,0,1))  # (3,H,W)
        x = np.concatenate(arrs, axis=0)      # (9,H,W)
        return torch.from_numpy(x).float()

    if preprocess_fn is None:
        preprocess_fn = default_preprocess

    # ---------- open video & basic props ----------
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise FileNotFoundError(f"Cannot open video: {video_path}")

    src_fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
    if fps is None:
        fps = src_fps

    orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    H_in, W_in   = model_in_size
    H_out, W_out = heat_out_size

    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (orig_w, orig_h))

    # weights, mask
    load_best_weights(model, weights, device)
    mask_hw = prep_mask(roi_mask, H_out, W_out, orig_h, orig_w)

    # ---------- read all frames (or stream if you prefer) ----------
    frames = []
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        frames.append(frame)
    cap.release()

    if len(frames) < 3:
        # Need at least three to form first triplet
        print("⚠️ Need at least 3 frames; duplicating edges to proceed.")
        while len(frames) < 3:
            frames.append(frames[-1])

    # Create indices for 3-frame windows (t-2,t-1,t)
    # We'll align predictions to the 't' frame.
    idxs = list(range(2, len(frames)))  # start at 2 so t-2 exists
    if last_n is not None:
        idxs = idxs[-int(last_n):]

    # ---------- batched inference over triplets ----------
    probs = [None] * len(idxs)   # prob maps aligned with idxs positions (H_out, W_out)
    def batch_iter(seq, bs):
        for i in range(0, len(seq), bs):
            yield seq[i:i+bs]

    for batch_positions in batch_iter(list(range(len(idxs))), batch_size):
        batch_tensors = []
        for bi in batch_positions:
            t = idxs[bi]
            trip = [frames[t-2], frames[t-1], frames[t]]
            x = preprocess_fn(trip, model_in_size)        # (9, H_in, W_in)
            batch_tensors.append(x)
        xb = torch.stack(batch_tensors, dim=0).to(device)  # (B,9,H_in,W_in)

        with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=torch.cuda.is_available()):
            logit = model(xb)  # (B,1,h,w) – any output size
            if logit.shape[-2:] != (H_out, W_out):
                logit = F.interpolate(logit, size=(H_out, W_out), mode="bilinear", align_corners=False)
            pb = torch.sigmoid(logit).float().cpu()[:,0]  # (B,H_out,W_out)

        for k, bi in enumerate(batch_positions):
            p = pb[k].numpy()
            if mask_hw is not None:
                p = p * mask_hw
            probs[bi] = p

    # ---------- peak extraction ----------
    peaks_list = []
    stats_list = []
    for p in probs:
        if use_zscore:
            mu, sd = float(p.mean()), float(p.std()) + 1e-6
            p_norm = (p - mu) / sd
            p_view = 1.0 / (1.0 + np.exp(-p_norm))
        else:
            p_view = p
            mu, sd = float(p.mean()), float(p.std())
        pf = np.percentile(p_view, perc_floor)
        thresh = max(base_thresh, pf)
        pk = nms_peaks(torch.from_numpy(p_view), thresh=thresh, nms=nms_ks, topk=topk)
        peaks_list.append(pk)
        stats_list.append((mu, sd, thresh))

    # ---------- beam search with velocity + acceleration ----------
    beams = []
    if peaks_list and peaks_list[0]:
        init = []
        for (x,y,s) in peaks_list[0][:beam_width]:
            init.append({"pos": (float(x), float(y)), "vel": None, "score": float(s), "back": -1})
        beams.append(init)
    else:
        beams.append([{"pos": None, "vel": None, "score": 0.0, "back": -1}])

    for t in range(1, len(idxs)):
        prev_beam = beams[-1]
        curr = peaks_list[t]
        cand_list = []

        if not curr:
            for i, h in enumerate(prev_beam):
                cand_list.append({"pos": h["pos"], "vel": h["vel"], "score": h["score"] - 0.5, "back": i})
        else:
            for i, h in enumerate(prev_beam):
                if h["pos"] is None:
                    for (x,y,s) in curr:
                        cand_list.append({"pos": (float(x), float(y)), "vel": None,
                                          "score": h["score"] + w_conf*float(s), "back": i})
                else:
                    px, py = h["pos"]
                    vx, vy = h["vel"] if h["vel"] is not None else (0.0, 0.0)
                    for (x,y,s) in curr:
                        d = math.hypot(x - px, y - py)
                        if d > max_step_px:
                            continue
                        nvx, nvy = (x - px), (y - py)
                        acc = math.hypot(nvx - vx, nvy - vy)
                        vpr = vel_prior_reward(h["pos"], h["vel"], (x,y), sigma=25.0)
                        score = h["score"] + w_conf*float(s) + w_vel*vpr - w_acc*acc
                        cand_list.append({"pos": (float(x), float(y)), "vel": (nvx, nvy),
                                          "score": score, "back": i})

        if not cand_list:
            j = int(np.argmax([h["score"] for h in prev_beam]))
            beams.append([{"pos": prev_beam[j]["pos"], "vel": prev_beam[j]["vel"],
                           "score": prev_beam[j]["score"] - 0.5, "back": j}])
        else:
            cand_list.sort(key=lambda h: h["score"], reverse=True)
            beams.append(cand_list[:beam_width])

    # backtrack best
    last = beams[-1]
    j = int(np.argmax([h["score"] for h in last]))
    sel = [None] * len(beams)
    for t in range(len(beams)-1, -1, -1):
        sel[t] = beams[t][j]
        j = beams[t][j]["back"] if beams[t][j]["back"] != -1 else 0

    # subpixel refine on original prob maps (not percentile-thresholded)
    sel_xy = []
    for t, h in enumerate(sel):
        pos = h["pos"]
        if pos is None or not peaks_list[t]:
            sel_xy.append(None)
        else:
            cx, cy = subpixel_refine(probs[t], int(round(pos[0])), int(round(pos[1])))
            sel_xy.append((cx, cy))

    # ---------- render + CSV ----------
    sx, sy = orig_w / float(W_out), orig_h / float(H_out)
    trail = []

    if save_csv:
        os.makedirs(os.path.dirname(save_csv) or ".", exist_ok=True)
        csv_f = open(save_csv, "w", newline="")
        csv_w = csv.writer(csv_f)
        csv_w.writerow(["frame_index", "x", "y"])  # original-res coordinates

    # idxs[t] is the video frame aligned to probs[t]
    for t, vf in enumerate(idxs):
        orig = frames[vf].copy()
        p_big = cv2.resize(probs[t], (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
        frame = overlay_heatmap(orig, p_big, alpha=heat_alpha)

        if sel_xy[t] is not None:
            ux = int(round(sel_xy[t][0] * sx))
            uy = int(round(sel_xy[t][1] * sy))

            if trail_len > 0:
                trail.append((ux, uy))
                if len(trail) > trail_len:
                    trail.pop(0)
                for k in range(1, len(trail)):
                    cv2.line(frame, trail[k-1], trail[k], (0, 255, 0), 2, cv2.LINE_AA)

            cv2.circle(frame, (ux, uy), dot_radius, (0, 255, 0), -1)

            if save_csv:
                csv_w.writerow([vf, ux, uy])
        else:
            if save_csv:
                csv_w.writerow([vf, "", ""])

        writer.write(frame)

    writer.release()
    if save_csv:
        csv_f.close()

    print(f"✅ Saved video to {out_path}" + (f" and CSV to {save_csv}" if save_csv else ""))


In [7]:
infer_and_save_video_from_file(
    model,
    video_path="/content/Novak Djokovic v Carlos Alcaraz Extended Highlights ｜ Australian Open 2025 Quarterfinal.mp4",
    weights="/content/tennis_vision/models/tracknet/tracknet_best.pth",
    out_path="/content/output_pp2.mp4",
    save_csv="/content/track.csv",
    model_in_size=(352, 640),   # (H, W) – set to what your model expects
    heat_out_size=(352, 640),   # (H, W) – heatmap resolution before upscaling overlay
)

✅ Saved video to /content/output_pp2.mp4 and CSV to /content/track.csv
