# utils

> General-purpose utility functions used throughout the project

In [None]:
#| default_exp common.utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
from fastcore.test import *

In [None]:
#| export
from typing import Callable, Dict, Generic, Iterable, Sequence, Tuple, TypeVar

In [None]:
#| export
import torch

In [None]:
# | export
T = TypeVar('T') # Generic type that will be used in many places

In [None]:
# | export
def aggregate_by_string_key(
    items: Iterable[T], key: Callable[[T], str]
) -> Dict[str, T]:
    """Aggregates an iterable of items into a dictionary, where the key is the result of
    applying the key function to the item. If multiple items have the same key, the
    last item is used."""
    return {key(item): item for item in items}

In [None]:
# Tests for aggregate_by_string_key
items = [('a', 1), ('b', 2), ('a', 3)]
test_eq(aggregate_by_string_key(items, lambda x: x[0]), {'a': ('a', 3), 'b': ('b', 2)})

In [None]:
# | export
class DataWrapper(Generic[T]):
    def __init__(
        self,
        data: Sequence[T],
        format_item_fn: Callable[[T], str] = repr,
    ):
        self.data = data
        self.format_item_fn = format_item_fn

    def __repr__(self):
        return f"DataWrapper({repr(self.data)})"

    def __str__(self):
        return ', '.join([self.format_item_fn(d) for d in self.data])

    def __getitem__(self, i):
        return self.data[i]

    def print(self):
        for d in self.data:
            print(self.format_item_fn(d))

In [None]:
# Tests for DataWrapper
dw = DataWrapper([1, 2, 3], format_item_fn=lambda x: f"{x}!")
test_eq(str(dw), '1!, 2!, 3!')

In [None]:
# | export

def topk_across_batches(
    n_batches: int,
    k: int,
    largest: bool,
    load_batch: Callable[[int], torch.Tensor],
    process_batch: Callable[[torch.Tensor], torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Like torch.topk, but works across multiple batches of data. Always
    works over the batch dimension, which is assumed to be the first dimension
    of each batch.

    Parameters:
    -----------
    n_batches:
        The number of batches to process.
    k:
        The number of top values to return.
    largest:
        Whether to return the largest or smallest values.
    load_batch:
        A function that takes a batch index and returns a batch of data.
    process_batch:
        A function that takes a batch of data and returns a tensor of
        values, with the same first dimension size as the batch. The function
        will return the top k of these values along the batch dimension.

    Returns:
    --------
        A tuple of (values, indices) where indices is a list of indices into
        the overall dataset i.e. across all batches.
    """
    all_topk_values = []
    all_topk_indices = []

    batch_sizes = []
    # Go through each batch and find the top k closest items
    # within that batch.
    for batch_idx in range(n_batches):
        batch = load_batch(batch_idx)

        results = process_batch(batch)

        assert (
            results.shape[0] == batch.shape[0]
        ), f"Batch had {batch.shape[0]} items, but results had {results.shape[0]} items."
        assert batch.shape[0] >= k, f"Batch had {batch.shape[0]} items, but k was {k}."

        batch_sizes.append(batch.shape[0])

        topk_values, topk_indices = torch.topk(results, k=k, largest=largest, dim=0)
        all_topk_values.append(topk_values)
        all_topk_indices.append(topk_indices)

    # Combine the results from all batches.
    all_topk_values_tensor = torch.cat(all_topk_values)
    all_topk_indices_tensor = torch.cat(all_topk_indices)

    # Find the topk items across all batches.
    topk = torch.topk(all_topk_values_tensor, k=k, largest=largest, dim=0)
    topk_overall_values: torch.Tensor = topk.values

    # Now we have to do math to translate the indices into all_topk_distances
    # into indices across all data items across all batches.

    # First, calculate the cumulative sum of the batch sizes.
    # Stick a zero at the front so that we can index into this
    # with batch_idx and know how many items were in all the
    # previous batches.
    prev_batch_sums = torch.cat([
        torch.tensor([0]),
        torch.cumsum(torch.tensor(batch_sizes), dim=0)
    ])

    topk_overall_indices = []
    for i in topk.indices:
        # i is the index into all_topk_distances. First, let's figure
        # out which batch it came from.
        batch_idx = i // k

        # Now we need to figure out which index into that batch it was.
        # all_topk_indices has the indices from the topk operation on
        # each batch.
        index_within_batch = torch.gather(
            all_topk_indices_tensor, dim=0, index=i.unsqueeze(0)
        ).squeeze(0)

        # The overall index is the sum of the number of items in all
        # previous batches, plus the index within the current batch.
        topk_overall_indices.append(prev_batch_sums[batch_idx] + index_within_batch)

    return topk_overall_values, torch.stack(topk_overall_indices)

In [None]:
# Tests for topk_across_batches()

batches = [
    [100, 98, 96, 94, 92],
    [99, 97, 95, 93, 91],
]

# Test with largest=True
values, indices = topk_across_batches(
    n_batches=len(batches),
    k=4,
    largest=True,
    load_batch=lambda i: torch.tensor(batches[i]),
    process_batch=lambda batch: batch,
)
test_eq(values, torch.tensor([100, 99, 98, 97]))
test_eq(indices, torch.tensor([0, 5, 1, 6]))

# Test with largest=False
values, indices = topk_across_batches(
    n_batches=len(batches),
    k=4,
    largest=False,
    load_batch=lambda i: torch.tensor(batches[i]),
    process_batch=lambda batch: batch,
)
test_eq(values, torch.tensor([91, 92, 93, 94]))
test_eq(indices, torch.tensor([9, 4, 8, 3]))

# Test where the last batch is smaller than the others
batches = [
    [100, 98, 96, 94, 92],
    [99, 97, 95, 93, 91],
    [90, 88, 86, 84],
]
values, indices = topk_across_batches(
    n_batches=len(batches),
    k=4,
    largest=False,
    load_batch=lambda i: torch.tensor(batches[i]),
    process_batch=lambda batch: batch,
)
test_eq(values, torch.tensor([84, 86, 88, 90]))
test_eq(indices, torch.tensor([13, 12, 11, 10]))

# Test where the all the batches are different sizes.
batches = [
    [100, 98, 96, 120, 160],
    [94, 92, 130, 140],
    [99, 97, 95, 93, 91, 109, 110],
    [90, 101, 104, 108],
]
values, indices = topk_across_batches(
    n_batches=len(batches),
    k=4,
    largest=False,
    load_batch=lambda i: torch.tensor(batches[i]),
    process_batch=lambda batch: batch,
)
test_eq(values, torch.tensor([90, 91, 92, 93]))
test_eq(indices, torch.tensor([16, 13, 6, 12]))

# Test where the results are 2-D tensors
batches = [
    [
        [100, 98, 96, 94, 92],
        [99, 97, 95, 93, 91],
        [90, 88, 86, 84, 82],
    ],
    [
        [200, 198, 196, 194, 192],
        [199, 197, 195, 193, 191],
        [190, 188, 186, 184, 182],
    ],
    [
        [300, 298, 296, 294, 292],
        [299, 297, 295, 293, 291],
    ],
]
values, indices = topk_across_batches(
    n_batches=len(batches),
    k=2,
    largest=False,
    load_batch=lambda i: torch.tensor(batches[i]),
    process_batch=lambda batch: batch,
)
# fmt: off
test_eq(values, torch.tensor([
    [90, 88, 86, 84, 82],
    [99, 97, 95, 93, 91],
]))
test_eq(indices, torch.tensor([
    [2, 2, 2, 2, 2],
    [1, 1, 1, 1, 1]
]))
# fmt: on

# A more interesting 2-D example
# fmt: off
batches = [
    [                          # Overall index:
        [14,  8,  1, 13, 13],  # 0
        [18, 13, 11,  1, 10],  # 1
        [ 8, 16, 15, 10, 14]   # 2
    ],
    [
        [12, 14, 19,  3,  1],  # 3
        [ 1, 15, 16,  3,  0],  # 4
        [ 7, 18,  5,  0,  6]   # 5
    ],
    [
        [10,  7, 16,  9,  0],  # 6
        [ 0, 12,  1,  9,  3],  # 7
        [14, 18, 14,  1,  8]   # 8
    ]
]
# fmt: on
values, indices = topk_across_batches(
    n_batches=len(batches),
    k=2,
    largest=False,
    load_batch=lambda i: torch.tensor(batches[i]),
    process_batch=lambda batch: batch,
)
# fmt: off
test_eq(values, torch.tensor([
    [0, 7, 1, 0, 0],
    [1, 8, 1, 1, 0],
]))
# fmt: on

# The indices picked for duplicates are apparently not stable
# across platforms. So we can't just test for equality with a
# known indices result. Instead let's check:
# - that the shape of the values and indices tensors are the same
# - that the values at the given indices are the same as what was
#   returned in the values array.
test_eq(indices.shape, values.shape)

# Cat the batches into one big tensor so we can index into it.
all_data = torch.cat([torch.tensor(b) for b in batches])

# Test that the values at the indices are the same as the returned values
test_eq(torch.gather(all_data, dim=0, index=indices), values)

# A 3-D example: 2 batches of shape (3, 3, 2)
# fmt: off
batches = [
    [
        [
            [15, 19],
            [19,  8],
            [10, 12]
        ],
        [
            [17,  0],
            [ 3, 15],
            [ 5, 15]
        ],
        [
            [19, 10],
            [ 7, 17],
            [ 8,  0]
        ]
    ],
    [
        [
            [ 9, 15],
            [13, 11],
            [ 8, 15]
        ],
        [
            [ 5,  0],
            [ 3,  6],
            [10, 15]
        ],
        [
            [ 6, 19],
            [11, 15],
            [ 2,  5]
        ]
    ]
]
# fmt: on
values, indices = topk_across_batches(
    n_batches=len(batches),
    k=2,
    largest=False,
    load_batch=lambda i: torch.tensor(batches[i]),
    process_batch=lambda batch: batch,
)
# fmt: off
test_eq(values, torch.tensor([
    [
        [ 5,  0],
        [ 3,  6],
        [ 2,  0]
    ],
    [
        [ 6,  0],
        [ 3,  8],
        [ 5,  5]
    ]
]))
# fmt: on

# Same issue as above re: duplicates
test_eq(indices.shape, values.shape)

# Cat the batches into one big tensor so we can index into it.
all_data = torch.cat([torch.tensor(b) for b in batches])

# Test that the values at the indices are the same as the returned values
test_eq(torch.gather(all_data, dim=0, index=indices), values)

# Test with processing function
batches = [
    [100, 98, 96, 94, 92],
    [99, 97, 95, 93, 91],
    [90, 88, 86, 84],
]
values, indices = topk_across_batches(
    n_batches=len(batches),
    k=4,
    largest=False,
    load_batch=lambda i: torch.tensor(batches[i]),
    process_batch=lambda batch: 2 * batch,
)
test_eq(values, torch.tensor([168, 172, 176, 180]))
test_eq(indices, torch.tensor([13, 12, 11, 10]))

# Test that processing function can't change size of batch
batches = [
    [100, 98, 96, 94, 92],
    [99, 97, 95, 93, 91],
    [90, 88, 86, 84],
]
with ExceptionExpected(ex=AssertionError):
    values, indices = topk_across_batches(
        n_batches=len(batches),
        k=4,
        largest=False,
        load_batch=lambda i: torch.tensor(batches[i]),
        process_batch=lambda batch: batch[:3],
    )

# Test that batch can't be smaller than k
batches = [
    [100, 98],
    [99, 97, 95, 93, 91],
    [90, 88, 86, 84],
]
with ExceptionExpected(ex=AssertionError):
    values, indices = topk_across_batches(
        n_batches=len(batches),
        k=4,
        largest=False,
        load_batch=lambda i: torch.tensor(batches[i]),
        process_batch=lambda batch: batch[:3],
    )

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()