In [1]:
from datasets import load_dataset

dataset = load_dataset("dnth/active-learning-imagenette")

dataset

DatasetDict({
    evaluation: Dataset({
        features: ['image', 'filepath', 'label_name'],
        num_rows: 3925
    })
    unlabeled: Dataset({
        features: ['image', 'filepath', 'label_name'],
        num_rows: 9469
    })
})

## Sample an initial training set

To get started we sample 10 images for each class to "seed" the training set.

In [2]:
import numpy as np

seed = 316
np.random.seed(seed)

unique_labels = dataset["unlabeled"].unique("label_name")
samples = []
n_samples_per_class = 10

print("Sampling process:")
for label in unique_labels:
    label_indices = np.where(np.array(dataset["unlabeled"]["label_name"]) == label)[0]
    # Sample 10 random indices without replacement
    random_indices = np.random.choice(
        label_indices, size=n_samples_per_class, replace=False
    )
    samples.extend(random_indices)

initial_samples = dataset["unlabeled"].select(samples)

# Verify the result (should show 100 rows total, 10 per class)
print("\n=== Final Results ===")
print(f"Total samples: {len(initial_samples)}")
print("\nSamples per class:")
print(initial_samples.select_columns(["label_name"]).to_pandas().value_counts())

Sampling process:

=== Final Results ===
Total samples: 100

Samples per class:
label_name      
English springer    10
French horn         10
cassette player     10
chain saw           10
church              10
garbage truck       10
gas pump            10
golf ball           10
parachute           10
tench               10
Name: count, dtype: int64


In [3]:
initial_samples

Dataset({
    features: ['image', 'filepath', 'label_name'],
    num_rows: 100
})

In [4]:
initial_samples[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=90x90>,
 'filepath': 'data/imagenette/2/00710.jpg',
 'label_name': 'cassette player'}

In [5]:
initial_samples = initial_samples.remove_columns("image")
initial_samples = initial_samples.rename_column("label_name", "label")

df = initial_samples.to_pandas()

df


Unnamed: 0,filepath,label
0,data/imagenette/2/00710.jpg,cassette player
1,data/imagenette/2/00063.jpg,cassette player
2,data/imagenette/2/00506.jpg,cassette player
3,data/imagenette/2/00575.jpg,cassette player
4,data/imagenette/2/00136.jpg,cassette player
...,...,...
95,data/imagenette/5/08926.jpg,French horn
96,data/imagenette/5/09359.jpg,French horn
97,data/imagenette/5/08999.jpg,French horn
98,data/imagenette/5/08777.jpg,French horn


In [6]:
df.to_parquet("initial_samples.parquet")

## Get evaluation set

In [7]:
eval_samples = dataset["evaluation"].remove_columns("image")
eval_samples = eval_samples.rename_column("label_name", "label")

eval_samples

Dataset({
    features: ['filepath', 'label'],
    num_rows: 3925
})

In [8]:
df = eval_samples.to_pandas()
df

Unnamed: 0,filepath,label
0,data/imagenette/2/00000.jpg,cassette player
1,data/imagenette/2/00001.jpg,cassette player
2,data/imagenette/2/00002.jpg,cassette player
3,data/imagenette/2/00003.jpg,cassette player
4,data/imagenette/2/00004.jpg,cassette player
...,...,...
3920,data/imagenette/5/03920.jpg,French horn
3921,data/imagenette/5/03921.jpg,French horn
3922,data/imagenette/5/03922.jpg,French horn
3923,data/imagenette/5/03923.jpg,French horn


In [9]:
df.to_parquet("evaluation_samples.parquet")

## Get the unlabeled dataset

In [10]:
unlabeled_samples = dataset["unlabeled"].remove_columns("image")
unlabeled_samples = unlabeled_samples.rename_column("label_name", "label")

df = unlabeled_samples.to_pandas()
df
# df.to_parquet("unlabeled_samples.parquet")


Unnamed: 0,filepath,label
0,data/imagenette/2/00000.jpg,cassette player
1,data/imagenette/2/00001.jpg,cassette player
2,data/imagenette/2/00002.jpg,cassette player
3,data/imagenette/2/00003.jpg,cassette player
4,data/imagenette/2/00004.jpg,cassette player
...,...,...
9464,data/imagenette/5/09464.jpg,French horn
9465,data/imagenette/5/09465.jpg,French horn
9466,data/imagenette/5/09466.jpg,French horn
9467,data/imagenette/5/09467.jpg,French horn


In [11]:
initial_filepaths = set(initial_samples['filepath'])
initial_filepaths

{'data/imagenette/0/01061.jpg',
 'data/imagenette/0/01098.jpg',
 'data/imagenette/0/01207.jpg',
 'data/imagenette/0/01353.jpg',
 'data/imagenette/0/01371.jpg',
 'data/imagenette/0/01520.jpg',
 'data/imagenette/0/01629.jpg',
 'data/imagenette/0/01645.jpg',
 'data/imagenette/0/01716.jpg',
 'data/imagenette/0/01767.jpg',
 'data/imagenette/1/05684.jpg',
 'data/imagenette/1/05841.jpg',
 'data/imagenette/1/06043.jpg',
 'data/imagenette/1/06189.jpg',
 'data/imagenette/1/06196.jpg',
 'data/imagenette/1/06197.jpg',
 'data/imagenette/1/06454.jpg',
 'data/imagenette/1/06563.jpg',
 'data/imagenette/1/06574.jpg',
 'data/imagenette/1/06582.jpg',
 'data/imagenette/2/00003.jpg',
 'data/imagenette/2/00063.jpg',
 'data/imagenette/2/00136.jpg',
 'data/imagenette/2/00176.jpg',
 'data/imagenette/2/00195.jpg',
 'data/imagenette/2/00420.jpg',
 'data/imagenette/2/00506.jpg',
 'data/imagenette/2/00575.jpg',
 'data/imagenette/2/00710.jpg',
 'data/imagenette/2/00974.jpg',
 'data/imagenette/3/01974.jpg',
 'data/i

In [12]:
len(initial_filepaths)

100

In [13]:
# Filter out rows that are in initial_samples
unlabeled_samples = unlabeled_samples.filter(
    lambda x: x['filepath'] not in initial_filepaths
)

In [14]:
unlabeled_samples

Dataset({
    features: ['filepath', 'label'],
    num_rows: 9369
})

In [15]:
df = unlabeled_samples.to_pandas()
df


Unnamed: 0,filepath,label
0,data/imagenette/2/00000.jpg,cassette player
1,data/imagenette/2/00001.jpg,cassette player
2,data/imagenette/2/00002.jpg,cassette player
3,data/imagenette/2/00004.jpg,cassette player
4,data/imagenette/2/00005.jpg,cassette player
...,...,...
9364,data/imagenette/5/09464.jpg,French horn
9365,data/imagenette/5/09465.jpg,French horn
9366,data/imagenette/5/09466.jpg,French horn
9367,data/imagenette/5/09467.jpg,French horn


In [16]:
df.to_parquet("unlabeled_samples.parquet")