In [None]:
from libcrap import traverse_files

import numpy as np
import re
import matplotlib.pyplot as plt
import random

from sklearn.preprocessing import LabelEncoder

In [None]:
class DictWithCounter(dict):
    """
    >>> d = DictWithCounter()
    >>> print(d.get_maybe_add("aaa"))
    0
    >>> print(d.get_maybe_add("bbb"))
    1
    >>> print(d.get_maybe_add("aaa"))
    0
    >>> print(sorted(d.items()))
    [('aaa', 0), ('bbb', 1)]
    """
    def __init__(self):
        super(DictWithCounter, self).__init__()
        self._next_value = 0
    
    def get_maybe_add(self, item):
        if item not in self:
            self[item] = self._next_value
            self._next_value += 1
        return self[item]

In [None]:
NUM_CLASSES = 8
NUM_OBJECTS_PER_CLASS = 10
NUM_OBJECTS = 8*10
NUM_ANGLES = 41
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
IMAGE_CHANNELS = 3

In [None]:
if __name__ == "__main__":
    import doctest
    doctest.run_docstring_examples(DictWithCounter, globals())
    print("ran a test")

In [None]:
def load_eth80(path, use_torch):
    all_filenames = list(traverse_files(path))
    dataset = np.zeros(
        (NUM_OBJECTS, NUM_ANGLES, IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH),
        dtype=np.float32, order="C"
    )
    object_classes = [None] * NUM_OBJECTS
    object_id_to_index = DictWithCounter()
    angles_to_index = DictWithCounter()
    loaded_num = 0
    for path in all_filenames:
        match = re.search(r"([a-z]+)(\d\d?)-(\d\d\d-\d\d\d).png", path)
        if match:
            object_id = match.group(1) + match.group(2)    
            angles = match.group(3)

            object_index = object_id_to_index.get_maybe_add(object_id)
            angles_index = angles_to_index.get_maybe_add(angles)
            object_classes[object_index] = match.group(1)

            image = plt.imread(path)
            dataset[object_index, angles_index] = np.moveaxis(image, 2, 0)
            loaded_num += 1
    if loaded_num != NUM_OBJECTS * NUM_ANGLES:
        raise ValueError(
            f"Loaded {loaded_num} objects, but should've loaded {NUM_OBJECTS * NUM_ANGLES}"
        )

    label_encoder = LabelEncoder()
    labels = label_encoder.fit_transform(object_classes)
    
    if use_torch:
        import torch
        dataset = torch.tensor(dataset)
        labels = torch.tensor(labels)
    
    return dataset, labels, label_encoder, object_classes

In [None]:
def fix_img_axes_for_show(image):
    return np.moveaxis(image, 0, 2)

In [None]:
def choose_random_image():
    return random.randint(0, NUM_OBJECTS-1), random.randint(0, NUM_ANGLES-1)

In [None]:
def show_image(dataset, obj_id, angle_id, ax=None):
    if ax is None:
        _, ax = plt.subplots()
    ax.imshow(fix_img_axes_for_show(dataset[obj_id, angle_id]))

In [None]:
def stratified_split_torch(dataset, labels, num_test_per_class):
    assert 1 <= num_test_per_class < NUM_OBJECTS_PER_CLASS
    import torch
    obj_indices_sorted_by_class = torch.argsort(labels)
    test_objects = set()
    for label in range(NUM_CLASSES):
        obj_indices_in_class = random.choices(
            range(NUM_OBJECTS_PER_CLASS), k=num_test_per_class
        )
        new_test_objects = obj_indices_sorted_by_class[[
            label*10 + ind_in_class for ind_in_class in obj_indices_in_class
        ]]
        test_objects.update(x.item() for x in new_test_objects)
    train_objects = sorted(frozenset(range(NUM_OBJECTS)) - test_objects)
    test_objects = sorted(test_objects)
    X_train = dataset[train_objects]
    y_train = labels[train_objects]
    X_test = dataset[test_objects]
    y_test = labels[test_objects]
    return X_train, y_train, X_test, y_test