In [1]:
import torch
from torch import Tensor


def add_coord_chans(tens: Tensor) -> Tensor:
    """
    Add new coordinate channels (ex: row and col) to a tensor.
    Assume the first dimension is batch size.
    Assume the second dimension is number of channels (width).

    Args:
        tens: tensor to add channels to

    Returns: input tensor with coordinate channels
    """
    batch_size, _, *dims = tens.shape
    coords = [torch.arange(d) for d in dims]
    chans = torch.meshgrid(coords, indexing='ij')

    # repeat across batch
    tile_dims = (batch_size, 1, *[1 for _ in range(len(dims))])
    chans = [torch.tile(c, tile_dims) for c in chans]

    return torch.cat([tens, *chans], dim=1)