In [46]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, GATv2Conv
from torch_geometric.utils import to_undirected

from typing import Tuple, Dict, List
from PIL import Image
from pathlib import Path
from tqdm import tqdm

In [47]:
class RFEncoder(nn.Module):
    def __init__(self, in_dim: int = 2, hid: int = 64, heads: int = 2):
        super().__init__()
        self.conv1 = HeteroConv(
            {
                ('pixel', 'adjacent', 'pixel'):  # ← 3-tuple!
                    GATv2Conv(in_dim, hid,
                              heads=heads,
                              concat=False,
                              add_self_loops=False),
                ('pixel', 'ray', 'pixel'):       # ← 3-tuple!
                    GATv2Conv(in_dim, hid,
                              heads=heads,
                              concat=False,
                              add_self_loops=False),
            },
            aggr='mean',
        )

        self.conv2 = HeteroConv(
            {
                ('pixel', 'adjacent', 'pixel'):
                    GATv2Conv(hid, hid,
                              heads=heads,
                              concat=False,
                              add_self_loops=False),
                ('pixel', 'ray', 'pixel'):
                    GATv2Conv(hid, hid,
                              heads=heads,
                              concat=False,
                              add_self_loops=False),
            },
            aggr='mean',
        )

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict


class RFPredictor(nn.Module):
    def __init__(self, node_dim=64):
        super().__init__()
        self.enc = RFEncoder() # ! changed from out_dim=node_dim
        self.mlp = nn.Sequential(
            nn.Linear(node_dim * 2, node_dim),
            nn.ReLU(),
            nn.Linear(node_dim, 1),
        )

    def forward(self,
                data: HeteroData,
                tx_idx,
                rx_idx):
        # --- 1. get pixel embeddings ------------------------------------
        z = self.enc(data.x_dict, data.edge_index_dict)["pixel"]

        # --- 2. normalise the indices -----------------------------------
        tx_idx = torch.as_tensor(tx_idx, dtype=torch.long, device=z.device)
        rx_idx = torch.as_tensor(rx_idx, dtype=torch.long, device=z.device)

        if tx_idx.dim() == 0:
            tx_idx = tx_idx.unsqueeze(0)
            rx_idx = rx_idx.unsqueeze(0)

        # --- 3. gather & predict ----------------------------------------
        pair_emb = torch.cat([z[tx_idx], z[rx_idx]], dim=-1)
        pred = self.mlp(pair_emb).squeeze(-1)
        return pred

In [48]:
def load_walkable_nodes(mask_path: Path, cell_size: int) -> Tuple[np.ndarray, Dict[Tuple[int, int], int]]:
    img = Image.open(mask_path).convert("L")  # grayscale
    mask = np.array(img) > 0  # bool
    pooled = pool_mask(mask, cell_size)
    coords = np.argwhere(pooled)  # (row, col)
    id_map = {tuple(coord): idx for idx, coord in enumerate(coords)}
    return coords, id_map, pooled.shape[::-1]  # coords, map, (Wc, Hc)


# -----------------------------------------------------------
# Graph edge construction helpers
# -----------------------------------------------------------

def make_adjacent_edges(coords: np.ndarray, id_map: Dict[Tuple[int, int], int]) -> List[Tuple[int, int]]:
    dirs = np.array([[1, 0], [-1, 0], [0, 1], [0, -1]], dtype=int)
    edges = []
    for coord in coords:
        for d in dirs:
            nb = tuple(coord + d)
            if nb in id_map:
                edges.append((id_map[tuple(coord)], id_map[nb]))
    return edges


def bresenham(p0: Tuple[int, int], p1: Tuple[int, int]) -> List[Tuple[int, int]]:
    x0, y0 = p0
    x1, y1 = p1
    dx = abs(x1 - x0)
    dy = -abs(y1 - y0)
    sx = 1 if x0 < x1 else -1
    sy = 1 if y0 < y1 else -1
    err = dx + dy
    line = []
    while True:
        line.append((x0, y0))
        if x0 == x1 and y0 == y1:
            break
        e2 = 2 * err
        if e2 >= dy:
            err += dy
            x0 += sx
        if e2 <= dx:
            err += dx
            y0 += sy
    return line


def make_ray_edges(df: pd.DataFrame, id_map: Dict[Tuple[int, int], int], S: int) -> List[Tuple[int, int]]:
    edges = set()
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Trace rays"):
        tx = downsample_coord(int(row.tx_location_i), int(row.tx_location_j), S)
        rx = downsample_coord(int(row.i), int(row.j), S)
        if tx not in id_map or rx not in id_map:
            continue
        for p0, p1 in zip(bresenham(tx, rx)[:-1], bresenham(tx, rx)[1:]):
            if p0 in id_map and p1 in id_map:
                u, v = id_map[p0], id_map[p1]
                edges.add((u, v))
                edges.add((v, u))
    return list(edges)

In [49]:
def pool_mask(mask: np.ndarray, cell_size: int = 4):
    h, w = mask.shape
    h_crop = (h // cell_size) * cell_size         # largest multiple ≤ h
    w_crop = (w // cell_size) * cell_size
    mask = mask[:h_crop, :w_crop]                 # throw away the ragged fringe

    # now safe to reshape/pool
    pooled = mask.reshape(h_crop // cell_size, cell_size,
                          w_crop // cell_size, cell_size
                         ).max(axis=(1, 3))       # OR .mean(...)
    return pooled


def downsample_coord(row: int, col: int, S: int) -> Tuple[int, int]:
    """Map original‑resolution (row, col) to coarse grid indices."""
    return row // S, col // S

## Testing Script

In [50]:
## loading model .pt weights
state_dict = torch.load('best_model.pt', 
                        map_location=torch.device("cpu"),
                        weights_only=True
                       )
hidden_dim = 64
model = RFPredictor(hidden_dim)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [51]:
cell_size = 4

coords, id_map, (Wc, Hc) = load_walkable_nodes('train_data/walkable_mask.png', cell_size)
adj = make_adjacent_edges(coords, id_map)
df = pd.read_csv('train_data/training_walks.csv', delimiter=',')

CACHE = f"train_data/ray_edges_cs{cell_size}.pt"
if not os.path.exists(CACHE):
    ray = make_ray_edges(df, id_map, cell_size)
    tmp = CACHE + ".tmp"
    torch.save({"cell_size": cell_size,
                "edge_index": ray}, tmp)
    os.replace(tmp, CACHE)      # atomic move
else:
    blob = torch.load(CACHE)
    assert blob["cell_size"] == cell_size
    ray = blob["edge_index"]

data = HeteroData()
# node features: normalised coarse‑grid coords (x=j, y=i)
xy = coords[:, [1, 0]].astype(np.float32)
xy[:, 0] /= Wc; xy[:, 1] /= Hc
data["pixel"].x = torch.from_numpy(xy)
def to_idx(e):
    return to_undirected(torch.tensor(e, dtype=torch.long).t())
data["pixel", "adjacent", "pixel"].edge_index = to_idx(adj)
data["pixel", "ray", "pixel"].edge_index = to_idx(ray)

Trace rays: 100%|██████████████████| 1600000/1600000 [01:17<00:00, 20560.99it/s]
