In [1]:
import torch
from torch import Tensor


def append_coord_chans(tens: Tensor) -> Tensor:
    """
    Add new coordinate channels (ex: row and col) to a tensor.
    Assume the first dimension is batch size.
    Assume the second dimension is number of channels (network width).

    Args:
        tens: tensor to add channels to

    Returns: input tensor with coordinate channels
    """
    batch_size, _, *dims = tens.shape
    coords = [torch.arange(d) for d in dims]
    chans = torch.meshgrid(coords, indexing='ij')

    # repeat across batch
    tile_dims = (batch_size, 1, *[1 for _ in range(len(dims))])
    chans = [torch.tile(c, tile_dims) for c in chans]

    return torch.cat([tens, *chans], dim=1)

In [2]:
from typing import List
from torch import Tensor


def flatten_cnn_featmaps(tensor: Tensor) -> List[Tensor]:
    """
    Convert CNN feature maps to feature table rows.

    Args:
        tensor: batch of feature maps to convert

    Returns: flattened feature maps w/spatial coordinates kept as features
    """
    # append spatial index channels (will be same across batch)
    tensor_aug = append_coord_chans(tensor)

    # flatten spatial dimension to get final features
    return [t.view(t.shape[0], -1).T for t in tensor_aug]