In [1]:
from active_embedder import *
from prob_cover import *
from data_labeler import *
from resnet_verification import *

Global seed set to 42


## Active Learning for Classification Toolkit

Why we made this: Active Learning algorithms such as ProbCover have achieved state of the art results, and even made improvements over Self-Supervised and Semi-Supervised techniques. However, most of these results are in simualted environments: the datasets were actually labelled, but the labels were hidden from the model until the Active Learning "oracle" unhid them. We wanted to make Active Learning work for you -- dear Reader with a truly unlabelled dataset, by providing a full-service workflow. You start with an image folder of unlabelled data, and use this Toolkit to generate embeddings, select the examples to label, and save the labels. 

When to use this: This toolkit focuses on the Cold Start problem. Many Active Learning frameworks, such as weakly supervised or semi-supervised learning rely on an "initial set" of labelled examples and work to propogate those labels to unlabelled examples. The harder problem is when ALL your data is unlabelled. Where do you even know where to start labelling? That's where this toolkit comes in. We'll help you label as many examples as possible to get a working classifier, and provide guidance on when you can stop labelling

How to use the toolkit: Unfortunately, the interactivity of this notebook tends to slow way down on Colab. We recommend cloning this repo and running a jupyter notebook locally. 

The steps are as follows:

### [Part 1: Prep work](#part_1)

Enter the root directory where your images are stored in the cells below. The cell after that will find all images in the folder, so don't worry about file naming and folder structure

### [Part 2: Create Embeddings](#part_2)

Specify or create Embeddings. All our Active Learning algorithms require a good embedding space as a prerequisite. If you already have self-supervised embeddings from your data, simply enter where the npy or pth file is stored. If you don't have embeddings, you have three options -- using the forward pass on a pre-trained VGG model to generate embeddings, using the forward pass on a self-supervised ResNet to generate embeddings, or fine-tuning a self-supervised model on your dataset. The first two will run quickly, the second will likely take about a day to run. 

### [Part 3: Label your examples with Active Learning](#part_3)

Use your embeddings to select examples to label. Instantiate a class of our Active Learners as specified below, and it will generate a list of examples to label. The step will save a data manifest for your future use

### [Part_4: Test a model](#part_4)

Use those labels to build a DataLoader. You're now ready to train a classifier that should have maximum performance per labelled example!

A final note before we begin: To keep this notebook organized, most length functions (e.g., our Active Learner classes, our visualization classes) are imported. However, the whole reason we made this a notebook instead of a GUI is so that you can see them, inspect them, and create your own classes that may well improve on anything we did. At your leisure, browse the rest of the code in this repo to understand how these algorithms maximize performance, and modify to your heart's delight. 

Sounds good? Let's get started

## <a id="part_1"> Part 1 <a>

Enter the root directory within which all images are stored. The format of this folder doesn't matter, but note that ALL images within the folder you specify will be added to your dataset

In [2]:
image_dir = "/Users/TestAccount/Downloads/images" # ENTER IMAGE FOLDER HERE
image_size = 256  # ENTER IMAGE HEIGHT OR WIDTH HERE (ASSUMED TO BE THE SAME)
image_format = ".png"  # SPECIFY FILE TYPE, WE SUPPORT .jpg, .jpeg, OR .png

In [3]:
image_list = []
for root, dirs, files in os.walk(image_dir, topdown=True):
   for name in files:
      if name[-len(image_format):]==image_format:
        image_list.append(os.path.join(root,name))

In [5]:
len(image_list)

1000

If you already have embeddings for your images, please enter the location of the .npy or .pth below

In [6]:
embeddings_loc = None # REPLACE NONE WITH A PATH TO YOUR EMBEDDINGS IF YOU HAVE IT, DO NOTHING IF YOU WANT US TO CREATE EMBEDDINGS

## <a id ="part_2">Part 2: Create Embeddings<a>

**If you already have embeddings, skip ahead to [Part 3](#part_3)**

In [7]:
embeddings_path = "xr_embeddings_simclr.pth" # ENTER WHERE YOU WANT THE EMBEDDINGS TO BE SAVED ENDING IN ".pth"

If you don't have embeddings, we'll create them for you. Choose from the following options:

If you are working with normal images (e.g., images at ground-level) use one of the following

1: VGG pre-trained on ImageNet (fast)

2: ResNet pre-trained on SimCLR (medium, requires large download)

3: Fine-tuning SimCLR on your data (runtime depends on the number of images in your dataset, but budget up to day for a dataset of 50000+ images)

4: Specify a SimCLR checkpoint to use

TO BE IMPLEMENTED SOON

If you are working with remote sensing (e.g., satellite images) use one of the following:

5: VGG pre-trained on fMOW (fast)

6: Frozen SatMAE (medium, requires large download)

7: Fine-tune SatMAE (runtime is around a day usually)

If you are working with medical imagery (e.g., X-rays), use one of the following:

8: VGG pre-trained on SOMETHING

9: Frozen medical SimCLR

10: Fine-tune on SimCLR

In [8]:
embedding_option = 1 # Replace with the number of the option you want to use

In [9]:
contrast_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
                                          transforms.Resize(size=224),
                                          transforms.RandomResizedCrop(size=224),
                                          transforms.RandomApply([
                                              transforms.ColorJitter(brightness=0.5,
                                                                     contrast=0.5,
                                                                     saturation=0.5,
                                                                     hue=0.1)
                                          ], p=0.8),
                                          transforms.RandomGrayscale(p=0.2),
                                          transforms.GaussianBlur(kernel_size=9),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.5,), (0.5,))
                                         ])

If you chose option 2, 5, or 8, you must provide a path to where you want the model stored, and a 

In [10]:
embedder = Embedder(image_list,image_size,embedding_option,transform=contrast_transforms,embedding_ckpt='simclr/epoch=0-step=70.ckpt')

In [11]:
embeddings_transform = transforms.Compose([transforms.Resize(224),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.5,), (0.5,))
                                         ])

In [12]:
embeddings = embedder.get_embeddings(embeddings_transform)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

Now we will save the embeddings to disk just to make sure we have them saved 

In [13]:
torch.save(embeddings,embeddings_path)

In [14]:
embeddings_loc = embeddings_path

## <a id=part_3>Part 3: Active Learning<a>
    
#### If you already created manifests of your data, skip to [Part 4](#part_4)

Please provide a list of the classes in your data and how many examples you want to label

At this point we will create an "Active Learner" instance, which provides a reference to the list of "best" examples to label. First enter your set of mutually exclusive labels, and let's decide how many examples we want to label:

In [15]:
labels_list = ["malignant","benign"]

In [16]:
examples_to_label = 50

So far, we have implemented two kinds of Active Learners. The first, ProbCover, comes from <a href='https://arxiv.org/abs/2205.11320'>Active Learning Through a Covering Lens</a> and minimizes the 1-NN error for a given labeling budget. The second, CoverNN, is our own, and minimizes the likelihood that a knn-graph contains images of difference classes. Run ONE of the following two cells, based on which active learner you want to use. 

In [17]:
prob_labels = ProbCover(save_dir="",image_list=image_list,embeddings_loc=embeddings_loc,num_classes=len(labels_list),delta=None,input_size=224)

TypeError: ProbCover.__init__() got an unexpected keyword argument 'k'

In [18]:
prob_labels = CoverNN(save_dir="",image_list=image_list,embeddings_loc=embeddings_loc,num_classes=len(labels_list),k=30,input_size=224)

Finished loading data...
Start constructing graph using k=30
Finished constructing graph using k=30
Graph contains 27000 edges.


In [19]:
oracle_results = prob_labels.select_samples(examples_to_label)


Start selecting 50 samples.
Iteration is 0.	Min distance is 0.518.	Coverage is 0.000
Iteration is 1.	Min distance is 0.523.	Coverage is 0.033
Iteration is 2.	Min distance is 0.530.	Coverage is 0.066
Iteration is 3.	Min distance is 0.532.	Coverage is 0.086
Iteration is 4.	Min distance is 0.538.	Coverage is 0.107
Iteration is 5.	Min distance is 0.544.	Coverage is 0.126
Iteration is 6.	Min distance is 0.545.	Coverage is 0.141
Iteration is 7.	Min distance is 0.545.	Coverage is 0.150
Iteration is 8.	Min distance is 0.546.	Coverage is 0.162
Iteration is 9.	Min distance is 0.552.	Coverage is 0.176
Iteration is 10.	Min distance is 0.555.	Coverage is 0.184
Iteration is 11.	Min distance is 0.555.	Coverage is 0.187
Iteration is 12.	Min distance is 0.557.	Coverage is 0.197
Iteration is 13.	Min distance is 0.557.	Coverage is 0.203
Iteration is 14.	Min distance is 0.558.	Coverage is 0.213
Iteration is 15.	Min distance is 0.559.	Coverage is 0.220
Iteration is 16.	Min distance is 0.559.	Coverage is 0.

In [20]:
oracle_image_list = [[image_list[prob_labels.dict_indices["train"][i]] for i in sublist] for sublist in oracle_results]

Please run the follow cell TWICE in a row for it to save your results properly. Some users have reported that running the cell once does not save properly in some instances. 

In [21]:
train_labeler = DataLabeler(oracle_image_list,labels_list)

100%|███████████████████████████████████████████████████████████████████████████████████| 50/50 [00:03<00:00, 16.30it/s]


In [22]:
train_labeler.display_pictures_button(oracle_image_list,0,train_labeler.df)

Loading images in cluster 2...


VBox(children=(Dropdown(options=('malignant', 'benign'), value='malignant'), HBox(children=(VBox(children=(Ima…

Let's take a look at your manifest!

In [23]:
train_labeler.df

Unnamed: 0,path,label
0,/Users/TestAccount/Downloads/images/00001230_0...,malignant
1,/Users/TestAccount/Downloads/images/00001018_0...,malignant
2,/Users/TestAccount/Downloads/images/00000857_0...,malignant
3,/Users/TestAccount/Downloads/images/00000996_0...,malignant
4,/Users/TestAccount/Downloads/images/00001171_0...,malignant
5,/Users/TestAccount/Downloads/images/00001230_0...,malignant
6,/Users/TestAccount/Downloads/images/00001217_0...,malignant
7,/Users/TestAccount/Downloads/images/00000908_0...,malignant
8,/Users/TestAccount/Downloads/images/00000149_0...,malignant
9,/Users/TestAccount/Downloads/images/00000081_0...,malignant


Running the cell above, you can see your labels have been saved as a csv. Let's save this csv right away so nothing happens to it. Feel free to edit the name

In [24]:
train_labeler.df.to_csv("train_df.csv")

The next few cells assume you will want to see how your model performs by creating a separate validation set. Here, you can label as many validation examples as you want at random, then below train a model to see how well you're doing.

In [25]:
val_to_label = 50
val_image_list = [[image_list[prob_labels.dict_indices["val"][i]]] for i in range(0,val_to_label)]

Please run the cell below TWICE in a row so that it saves properly

In [26]:
val_labeler = DataLabeler(val_image_list,labels_list)

100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 157.01it/s]


In [27]:
val_labeler.display_pictures_button(val_image_list,0,val_labeler.df)

Loading images in cluster 6...


VBox(children=(Dropdown(options=('malignant', 'benign'), value='malignant'), HBox(children=(VBox(children=(Ima…

In [28]:
val_labeler.df

Unnamed: 0,path,label
0,/Users/TestAccount/Downloads/images/00000248_0...,malignant
1,/Users/TestAccount/Downloads/images/00000272_0...,malignant
2,/Users/TestAccount/Downloads/images/00000061_0...,malignant
3,/Users/TestAccount/Downloads/images/00000021_0...,malignant
4,/Users/TestAccount/Downloads/images/00000766_0...,malignant
5,/Users/TestAccount/Downloads/images/00000013_0...,benign


And let's save the validation manifest for safe-keeping

In [29]:
val_labeler.df.to_csv("val_df.csv")

## <a id="part_4"> Part 4: Testing your model <a>

In [30]:
labels_map={label: i for label, i in zip(labels_list,range(0,len(labels_list)))}

In [31]:
import pandas as pd
train_data = ManifestData("train_df.csv",label_map=labels_map,transform=embeddings_transform)
val_data = ManifestData("val_df.csv",label_map=labels_map,transform=embeddings_transform)
train_loader = DataLoader(train_data,batch_size=256,num_workers=4,shuffle=True)
val_loader = DataLoader(val_data,batch_size=1,num_workers=4,shuffle=False)
verification_model = VerificationModel(0.001,weight_decay=.001,num_classes=len(labels_list))
trainer = pl.Trainer(max_epochs=50)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [32]:
trainer.fit(verification_model,train_loader,val_loader)


  | Name    | Type             | Params
---------------------------------------------
0 | convnet | ResNet           | 11.2 M
1 | loss    | CrossEntropyLoss | 0     
2 | val_acc | Accuracy         | 0     
---------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.710    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa948355ab0>
Traceback (most recent call last):
  File "/opt/anaconda3/envs/torch_simple/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/opt/anaconda3/envs/torch_simple/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1430, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/anaconda3/envs/torch_simple/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/opt/anaconda3/envs/torch_simple/lib/python3.10/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/opt/anaconda3/envs/torch_simple/lib/python3.10/multiprocessing/connection.py", line 936, in wait
    ready = selector.select(timeout)
  File "/opt/anaconda3/envs/torch_simple/lib/python3.10/selectors.py", line 416, in select
    fd_event_

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


How did you do? Note the validation accuracy on the progress bar