Instead of having a dense label to supervise the heatmap, we now just have a single action that was executed we can provide feedback for (poteially 1-4 depending on number of arms, etc.)

See discussion with chat: https://chatgpt.com/c/68b98e84-ba50-8328-a9e5-088a8882eda4

From the tossing bot paper:

"We pass gradients only through the single pixel i on which the grasping primitive was executed. All other pixels backpropagate with 0 loss."

##### Key ideas:
 - Instead of applying BCE loss over the entire (C, H, W) tensor, select logits at the executed (c, h, w) indices.
 - This ensures the computational graph only flows through those selected entries, so gradients elsewhere are zero.
 - To handle multiple executed actions per sample, gather all relevant indices, compute the loss for each, and average.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Example shapes
B, C, H, W = 4, 12, 64, 64  # batch, angles, image height, width
logits = torch.randn(B, C, H, W, requires_grad=True)  # model output
# Suppose ground truth labels (binary) for selected actions
labels = torch.randint(0, 2, (B,), dtype=torch.float32)

# Indices of the executed actions (batch-wise)
# each row: (c, h, w) where the action was executed
executed_inds = torch.tensor([
    [3, 20, 15],
    [7, 40, 10],
    [0, 12, 32],
    [5, 50, 50],
])

loss_fn = nn.BCEWithLogitsLoss(reduction='none')
loss_fn_mean = nn.BCEWithLogitsLoss(reduction='mean') # for some reason chat did 'none' but we see that its computing the same as if mean was used...

# Gather the logits and targets at the executed indices
b_indices = torch.arange(B)
c_indices = executed_inds[:, 0]
h_indices = executed_inds[:, 1]
w_indices = executed_inds[:, 2]

# Select only the executed action logits
selected_logits = logits[b_indices, c_indices, h_indices, w_indices]
selected_labels = labels  # shape [B]

print("Selected logits shape:", selected_logits.shape)
print("Selected labels shape:", selected_labels.shape)

# Compute BCE loss only on those selected logits
loss_per_sample = loss_fn(selected_logits, selected_labels)
loss_per_sample_mean = loss_fn_mean(selected_logits, selected_labels)
print(loss_per_sample_mean.shape)
print(loss_per_sample.shape)
loss = loss_per_sample.mean()

loss.backward()
print("Loss:", loss.item())
print("Loss:", loss_per_sample_mean.item())


Lets extend this idea to work for k executed actions per heatmap

In [None]:
import torch
import torch.nn as nn

# This assumes that the same number of actions are executed per batch.
# Hasn't been tested yet, but saving here from chats handy work.

def masked_bce_fixed_k(
    logits: torch.Tensor,              # [B, C, H, W]
    executed_inds: torch.Tensor,       # [B, K, 3]  (c, h, w) per action
    labels: torch.Tensor,              # [B, K]     (0/1) per action
    pos_weight: torch.Tensor | None = None,  # optional for BCEWithLogitsLoss
    reduce: str = "mean",              # "mean" | "sum" | "none"
):
    B, C, H, W = logits.shape
    B2, K, _ = executed_inds.shape
    assert B2 == B and executed_inds.shape[-1] == 3

    b = torch.arange(B, device=logits.device).unsqueeze(1).expand(B, K)  # [B,K]
    c = executed_inds[..., 0]
    h = executed_inds[..., 1]
    w = executed_inds[..., 2]

    # Select only the executed action logits -> [B, K]
    selected_logits = logits[b, c, h, w]

    # Compute BCE on those selected entries only
    loss_fn = nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight)
    per_action_loss = loss_fn(selected_logits, labels)

    if reduce == "mean":
        return per_action_loss.mean()
    elif reduce == "sum":
        return per_action_loss.sum()
    else:
        return per_action_loss  # [B, K]

# --- Example usage ---
B, C, H, W, K = 4, 12, 64, 64, 3
logits = torch.randn(B, C, H, W, requires_grad=True)
executed_inds = torch.tensor([
    [[3,20,15],[3,21,15],[7,40,10]],
    [[7,40,10],[7,41,10],[7,42,10]],
    [[0,12,32],[0,13,32],[1,12,31]],
    [[5,50,50],[5,51,50],[5,52,50]],
])
labels = torch.tensor([[1,0,1],[0,1,1],[1,1,0],[0,0,1]], dtype=torch.float32)

loss = masked_bce_fixed_k(logits, executed_inds, labels, reduce="mean")
loss.backward()


In [None]:
import torch
import torch.nn as nn

# this supports doing different number of actions per batch. 
def masked_BCEWithLogitsLoss(
    logits: torch.Tensor,                         # [B, C, H, W]
    executed_list: list[torch.Tensor],            # length B, each [Ki, 3] (c,h,w) (where Ki is the number of actions executed in batch i)
    labels_list: list[torch.Tensor],              # length B, each [Ki] (0/1)
) -> tuple[torch.Tensor, dict]:

    """ 
        From the tossing bot paper: "We pass gradients only through the single pixel i on which the grasping primitive was executed. All other pixels backpropagate with 0 loss."

        ##### Key ideas to implement:
        - Instead of applying BCE loss over the entire (C, H, W) tensor, select logits at the executed (c, h, w) indices.
        - This ensures the computational graph only flows through those selected entries, so gradients elsewhere are zero.
        - To handle multiple executed actions per sample, gather all relevant indices, compute the loss for each, and average.


        ##### Notes


        Currently lists look something like:

        executed_list = [
            [[3, 20, 15], [7, 40, 10]],    # B0 x 2
            [[7, 40, 10]],                 # B1 x 1
            [],                            # B2 x 0 (no action)
            [[5, 50, 50], [5, 51, 50], [5, 52, 50]],  # B3 x 3
        ]
        labels_list = [
            [1, 0],
            [1],
            [],
            [0, 0, 1],
        ]

        We want to convert into a single list to easily index the logics directly with a single look up.

        b_all = [0, 0, 1, 3, 3, 3] # Notice how batch 2 disappears..
        c_all = [3, 7, 7, 5, 5, 5]
        h_all = [20, 40, 40, 50, 51, 52]
        w_all = [15, 10, 10, 50, 50, 50]
        y_all = [1, 0, 1, 0, 0, 1]

    """

    B, C, H, W = logits.shape
    device = logits.device



    b_all, c_all, h_all, w_all, y_all, img_id = [], [], [], [], [], []
    for i, (inds_i, y_i) in enumerate(zip(executed_list, labels_list)):
        if inds_i is None or len(inds_i) == 0: # if no actions executed, the skip
            continue
        inds_i = inds_i.to(device)        # [Ki, 3]
        y_i = y_i.to(device).float()      # [Ki]
        Ki = inds_i.shape[0]

        b_all.append(torch.full((Ki,), i, device=device, dtype=torch.long))
        c_all.append(inds_i[:, 0].long())
        h_all.append(inds_i[:, 1].long())
        w_all.append(inds_i[:, 2].long())
        y_all.append(y_i)

    if len(b_all) == 0: # no actions executed this step
        return logits.new_tensor(0.0, requires_grad=True)

    b_all = torch.cat(b_all)  # [N]
    c_all = torch.cat(c_all)  # [N]
    h_all = torch.cat(h_all)  # [N]
    w_all = torch.cat(w_all)  # [N]
    y_all = torch.cat(y_all)  # [N]

    selected_logits = logits[b_all, c_all, h_all, w_all]  # [N]

    loss_fn = nn.BCEWithLogitsLoss(reduction="mean")
    loss = loss_fn(selected_logits, y_all)  # [N]

    info = {
        'b_all': b_all,
        'c_all': c_all,
        'h_all': h_all,
        'w_all': w_all,
        'y_all': y_all,
    }

    return loss, info

    #  Some ideas here if you wanted to do action weighting...

    # loss_fn = nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight)
    # per_action = loss_fn(selected_logits, y_all)  # [N]

    # if equal_weight_per_image:
    #     # Average actions inside each image, then mean across images that had â‰¥1 action
    #     loss_per_img = []
    #     for i in range(B):
    #         mask_i = (img_id == i)
    #         if mask_i.any():
    #             loss_per_img.append(per_action[mask_i].mean())
    #     return torch.stack(loss_per_img).mean()
    # else:
    #     # All actions across batch equally weighted
    #     return per_action.mean()

# --- Example usage ---
B, C, H, W = 4, 12, 64, 64
logits = torch.randn(B, C, H, W, requires_grad=True)

executed_list = [
    torch.tensor([[3,20,15],[7,40,10]]),   # K0=2
    torch.tensor([[7,40,10]]),             # K1=1
    torch.tensor([], dtype=torch.long).reshape(0,3),  # K2=0 (no action)
    torch.tensor([[5,50,50],[5,51,50],[5,52,50]]),    # K3=3
]
labels_list = [
    torch.tensor([1, 0], dtype=torch.float32),
    torch.tensor([1], dtype=torch.float32),
    torch.tensor([], dtype=torch.float32),
    torch.tensor([0, 0, 1], dtype=torch.float32),
]

loss, info = masked_BCEWithLogitsLoss(logits, executed_list, labels_list)

loss.backward()
print(info)

expected_values = dict()
expected_values['b_all'] = torch.tensor([0, 0, 1, 3, 3, 3])
expected_values['c_all'] = torch.tensor([3, 7, 7, 5, 5, 5])
expected_values['h_all'] = torch.tensor([20, 40, 40, 50, 51, 52])
expected_values['w_all'] = torch.tensor([15, 10, 10, 50, 50, 50])
expected_values['y_all'] = torch.tensor([1, 0, 1, 0, 0, 1])

for key, value in info.items():
    print(key, value.numpy())
    assert torch.all(value == expected_values[key]), f"Expected {key} to be {expected_values[key]}, but got {value}"



I will turn into a class now so that it can be used in more of the traditional flow of define and then use, instead of having to create the new loss funciton each time.

In [None]:
import torch
import torch.nn as nn
from typing import List, Dict

class MaskedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
    """
    BCEWithLogits computed ONLY at executed (c,h,w) indices per batch element.

    Forward args:
        logits:        [B, C, H, W]
        executed_list: list length B, each a LongTensor [Ki, 3] with (c,h,w)
        labels_list:   list length B, each a Float/Bool Tensor [Ki] with {0,1}

    Returns:
        loss: Tensor (scalar if reduction != "none")

    ---

    From the tossing bot paper: "We pass gradients only through the single pixel i on which the grasping primitive was executed. All other pixels backpropagate with 0 loss."

    ##### Key ideas to implement:
    - Instead of applying BCE loss over the entire (C, H, W) tensor, select logits at the executed (c, h, w) indices.
    - This ensures the computational graph only flows through those selected entries, so gradients elsewhere are zero.
    - To handle multiple executed actions per sample, gather all relevant indices, compute the loss for each, and average.

    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.info: Dict[str, torch.Tensor] = {}

    def forward(
        self,
        logits: torch.Tensor,
        executed_list: List[torch.Tensor],
        labels_list: List[torch.Tensor],
    ) -> torch.Tensor:
        B, C, H, W = logits.shape
        device = logits.device

        """ 
            Currently lists look something like:

            executed_list = [
                [[3, 20, 15], [7, 40, 10]],    # B0 x 2
                [[7, 40, 10]],                 # B1 x 1
                [],                            # B2 x 0 (no action)
                [[5, 50, 50], [5, 51, 50], [5, 52, 50]],  # B3 x 3
            ]
            labels_list = [
                [1, 0],
                [1],
                [],
                [0, 0, 1],
            ]

            We want to convert into a single list to easily index the logics directly with a single look up.

            b_all = [0, 0, 1, 3, 3, 3] # Notice how batch 2 disappears..
            c_all = [3, 7, 7, 5, 5, 5]
            h_all = [20, 40, 40, 50, 51, 52]
            w_all = [15, 10, 10, 50, 50, 50]
            y_all = [1, 0, 1, 0, 0, 1]

        """
        # Flatten variable-K actions across the batch
        b_all, c_all, h_all, w_all, y_all = [], [], [], [], []
        for i, (inds_i, y_i) in enumerate(zip(executed_list, labels_list)):
            if inds_i is None or len(inds_i) == 0:
                continue
            inds_i = inds_i.to(device)       # [Ki, 3] (c,h,w)
            y_i = y_i.to(device).float()     # [Ki]    (0/1)
            Ki = inds_i.shape[0]

            b_all.append(torch.full((Ki,), i, device=device, dtype=torch.long))
            c_all.append(inds_i[:, 0].long())
            h_all.append(inds_i[:, 1].long())
            w_all.append(inds_i[:, 2].long())
            y_all.append(y_i)

        # If no actions executed in the whole batch, return zero (still connected)
        if len(b_all) == 0:
            zero = logits.new_tensor(0.0, requires_grad=True)
            self.info = {"empty": torch.tensor(True, device=device)}
            return zero

        b_all = torch.cat(b_all)  # [N]
        c_all = torch.cat(c_all)  # [N]
        h_all = torch.cat(h_all)  # [N]
        w_all = torch.cat(w_all)  # [N]
        y_all = torch.cat(y_all)  # [N]

        # Index only the executed logits -> gradients flow only here
        selected_logits = logits[b_all, c_all, h_all, w_all]  # [N]

        # Standard BCEWithLogitsLoss reduction
        loss = super().forward(selected_logits, y_all)

        # Save debug info for later inspection
        self.info = {
            "b_all": b_all.detach(),
            "c_all": c_all.detach(),
            "h_all": h_all.detach(),
            "w_all": w_all.detach(),
            "y_all": y_all.detach(),
        }

        return loss

if __name__ == "__main__":
    B, C, H, W = 4, 12, 64, 64
    logits = torch.randn(B, C, H, W, requires_grad=True)

    executed_list = [
        torch.tensor([[3,20,15],[7,40,10]]),   # K0=2
        torch.tensor([[7,40,10]]),             # K1=1
        torch.tensor([], dtype=torch.long).reshape(0,3),  # K2=0
        torch.tensor([[5,50,50],[5,51,50],[5,52,50]]),    # K3=3
    ]
    labels_list = [
        torch.tensor([1, 0], dtype=torch.float32),
        torch.tensor([1], dtype=torch.float32),
        torch.tensor([], dtype=torch.float32),
        torch.tensor([0, 0, 1], dtype=torch.float32),
    ]

    loss_fn = MaskedBCEWithLogitsLoss(reduction="mean")
    loss = loss_fn(logits, executed_list, labels_list)
    loss.backward()

    print("loss:", float(loss))
    print("info dict:")

    expected_values = dict()
    expected_values['b_all'] = torch.tensor([0, 0, 1, 3, 3, 3])
    expected_values['c_all'] = torch.tensor([3, 7, 7, 5, 5, 5])
    expected_values['h_all'] = torch.tensor([20, 40, 40, 50, 51, 52])
    expected_values['w_all'] = torch.tensor([15, 10, 10, 50, 50, 50])
    expected_values['y_all'] = torch.tensor([1, 0, 1, 0, 0, 1])

    for k, v in loss_fn.info.items():
        print(k, v)
        assert torch.all(value == expected_values[key]), f"Expected {key} to be {expected_values[key]}, but got {value}"