In [1]:
from fastai.vision.all import *
from datasets import load_dataset

learn = load_learner("initial_model.pkl")

learn

<fastai.learner.Learner at 0x75ea81c44230>

In [2]:
dataset = load_dataset("dnth/active-learning-imagenette")
dataset["unlabeled"]


Dataset({
    features: ['image', 'filepath', 'label_name'],
    num_rows: 9469
})

In [3]:
unlabeled_dataset = dataset["unlabeled"].remove_columns(["image"]).to_pandas()
unlabeled_dataset


Unnamed: 0,filepath,label_name
0,data/imagenette/2/00000.jpg,cassette player
1,data/imagenette/2/00001.jpg,cassette player
2,data/imagenette/2/00002.jpg,cassette player
3,data/imagenette/2/00003.jpg,cassette player
4,data/imagenette/2/00004.jpg,cassette player
...,...,...
9464,data/imagenette/5/09464.jpg,French horn
9465,data/imagenette/5/09465.jpg,French horn
9466,data/imagenette/5/09466.jpg,French horn
9467,data/imagenette/5/09467.jpg,French horn


In [4]:
list_of_images = unlabeled_dataset["filepath"].tolist()

list_of_images

['data/imagenette/2/00000.jpg',
 'data/imagenette/2/00001.jpg',
 'data/imagenette/2/00002.jpg',
 'data/imagenette/2/00003.jpg',
 'data/imagenette/2/00004.jpg',
 'data/imagenette/2/00005.jpg',
 'data/imagenette/2/00006.jpg',
 'data/imagenette/2/00007.jpg',
 'data/imagenette/2/00008.jpg',
 'data/imagenette/2/00009.jpg',
 'data/imagenette/2/00010.jpg',
 'data/imagenette/2/00011.jpg',
 'data/imagenette/2/00012.jpg',
 'data/imagenette/2/00013.jpg',
 'data/imagenette/2/00014.jpg',
 'data/imagenette/2/00015.jpg',
 'data/imagenette/2/00016.jpg',
 'data/imagenette/2/00017.jpg',
 'data/imagenette/2/00018.jpg',
 'data/imagenette/2/00019.jpg',
 'data/imagenette/2/00020.jpg',
 'data/imagenette/2/00021.jpg',
 'data/imagenette/2/00022.jpg',
 'data/imagenette/2/00023.jpg',
 'data/imagenette/2/00024.jpg',
 'data/imagenette/2/00025.jpg',
 'data/imagenette/2/00026.jpg',
 'data/imagenette/2/00027.jpg',
 'data/imagenette/2/00028.jpg',
 'data/imagenette/2/00029.jpg',
 'data/imagenette/2/00030.jpg',
 'data/i

In [5]:
test_dl = learn.dls.test_dl(list_of_images, bs=128)

In [6]:
preds, targets, decoded = learn.get_preds(dl=test_dl, with_decoded=True)

In [7]:
unlabeled_dataset['confidence'] = torch.max(F.softmax(preds, dim=1), dim=1)[0].numpy()
unlabeled_dataset['predicted_label'] = [learn.dls.vocab[i] for i in decoded.numpy()]

unlabeled_dataset

Unnamed: 0,filepath,label_name,confidence,predicted_label
0,data/imagenette/2/00000.jpg,cassette player,0.230600,cassette player
1,data/imagenette/2/00001.jpg,cassette player,0.226287,cassette player
2,data/imagenette/2/00002.jpg,cassette player,0.231944,cassette player
3,data/imagenette/2/00003.jpg,cassette player,0.215759,cassette player
4,data/imagenette/2/00004.jpg,cassette player,0.231724,cassette player
...,...,...,...,...
9464,data/imagenette/5/09464.jpg,French horn,0.218210,French horn
9465,data/imagenette/5/09465.jpg,French horn,0.231173,French horn
9466,data/imagenette/5/09466.jpg,French horn,0.231816,French horn
9467,data/imagenette/5/09467.jpg,French horn,0.230855,French horn


In [8]:
# Sort by confidence
unlabeled_dataset = unlabeled_dataset.sort_values(by="confidence", ascending=True)
unlabeled_dataset

Unnamed: 0,filepath,label_name,confidence,predicted_label
2489,data/imagenette/3/02489.jpg,chain saw,0.109699,tench
9279,data/imagenette/5/09279.jpg,French horn,0.110210,French horn
4602,data/imagenette/9/04602.jpg,parachute,0.110406,church
2701,data/imagenette/3/02701.jpg,chain saw,0.112579,garbage truck
2804,data/imagenette/3/02804.jpg,chain saw,0.112646,English springer
...,...,...,...,...
1794,data/imagenette/0/01794.jpg,tench,0.231969,tench
1151,data/imagenette/0/01151.jpg,tench,0.231969,tench
1566,data/imagenette/0/01566.jpg,tench,0.231969,tench
1787,data/imagenette/0/01787.jpg,tench,0.231969,tench


In [9]:
# Get top 100 predictions with low confidence
top_100_low_confidence = unlabeled_dataset.head(100)
top_100_low_confidence



Unnamed: 0,filepath,label_name,confidence,predicted_label
2489,data/imagenette/3/02489.jpg,chain saw,0.109699,tench
9279,data/imagenette/5/09279.jpg,French horn,0.110210,French horn
4602,data/imagenette/9/04602.jpg,parachute,0.110406,church
2701,data/imagenette/3/02701.jpg,chain saw,0.112579,garbage truck
2804,data/imagenette/3/02804.jpg,chain saw,0.112646,English springer
...,...,...,...,...
4543,data/imagenette/9/04543.jpg,parachute,0.121954,chain saw
7426,data/imagenette/8/07426.jpg,golf ball,0.122035,golf ball
2888,data/imagenette/4/02888.jpg,church,0.122096,French horn
7173,data/imagenette/8/07173.jpg,golf ball,0.122100,cassette player


In [14]:
# Let add this to the dataset

training_samples = pd.read_parquet("training_samples.parquet")
training_samples


Unnamed: 0,filepath,label_name
0,data/imagenette/2/00346.jpg,cassette player
1,data/imagenette/2/00845.jpg,cassette player
2,data/imagenette/2/00383.jpg,cassette player
3,data/imagenette/2/00503.jpg,cassette player
4,data/imagenette/2/00002.jpg,cassette player
...,...,...
95,data/imagenette/5/09030.jpg,French horn
96,data/imagenette/5/09299.jpg,French horn
97,data/imagenette/5/09018.jpg,French horn
98,data/imagenette/5/08529.jpg,French horn


In [15]:
top_100_low_confidence.drop(columns=["confidence", "predicted_label"])

training_samples = pd.concat([training_samples, top_100_low_confidence])
training_samples

training_samples.to_parquet("training_samples_iteration_1.parquet")