In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black
%matplotlib inline
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
# default_exp data_loader

In [25]:
# export
from collections import Counter

from torch.utils.data import DataLoader as PytorchDataLoader, Dataset
from torch.utils.data import WeightedRandomSampler
from tqdm import tqdm


class DataLoader(PytorchDataLoader):
    pass


def create_weighted_random_sampler(weights, n_samples=None, replacement=True):
    if n_samples is None:
        n_samples = len(weights)
    return WeightedRandomSampler(
        weights, num_samples=n_samples, replacement=replacement
    )


def create_sample_weights_from_multilabels(multilabels, no_label_sample_factor=1):
    all_labels = []
    labels_for_samples = []
    sample_weights = []
    n_samples = len(multilabels)

    for labels in tqdm(multilabels):
        labels_for_samples.append(labels)
        all_labels.extend(labels)

    label_counts = Counter(all_labels)

    weight_lookup = {label: 1 / n_labels for label, n_labels in label_counts.items()}

    for labels_for_sample in labels_for_samples:
        if len(labels_for_sample) == 0:
            sample_weights.append(1 / n_samples ** no_label_sample_factor)
            continue
        sample_weights.append(
            sum([weight_lookup[l] for l in labels_for_sample]) / len(labels_for_sample)
        )
    return sample_weights


lookup = {
    "DataLoader": DataLoader,
    "create_weighted_random_sampler": create_weighted_random_sampler,
    "create_sample_weights_from_multilabels": create_sample_weights_from_multilabels,
}

##### Tests for create_sample_weights_from_multilabel

In [20]:
multilabels = [[1], [1, 2], [1, 2, 3]]
sample_weights_from_multilabels = create_sample_weights_from_multilabels(multilabels)
assert sample_weights_from_multilabels == [
    0.3333333333333333,
    0.41666666666666663,
    0.611111111111111,
]

100%|██████████| 3/3 [00:00<00:00, 35848.75it/s]


In [21]:
multilabels = [[], [1, 2], [1, 2, 3]]
sample_weights_from_multilabels_with_no_label_sample_factor_1 = (
    create_sample_weights_from_multilabels(multilabels, no_label_sample_factor=1)
)
assert sample_weights_from_multilabels_with_no_label_sample_factor_1 == [
    0.3333333333333333,
    0.5,
    0.6666666666666666,
]

100%|██████████| 3/3 [00:00<00:00, 14597.35it/s]


##### Tests for create_weighted_random_sampler

In [87]:
class TestDataset(Dataset):
    def __init__(self, samples):
        self.samples = samples

    def __len__(self):
        return len(samples)

    def __getitem__(self, i):
        return self.samples[i]


weights = [0.05, 0.35, 0.6]
weighted_random_sampler = create_weighted_random_sampler(weights=weights, n_samples=100)
data_loader = PytorchDataLoader(
    TestDataset(samples=[0, 1, 2]), sampler=weighted_random_sampler
)
sampled_results = [r.item() for r in data_loader]
assert len(sampled_results) == 100
Counter(sampled_results)

In [95]:
from nbdev.export import notebook2script

notebook2script()

Converted data_loader.ipynb.
