-
Notifications
You must be signed in to change notification settings - Fork 0
/
few_shot_learning_using_prototypical_network (1).py
366 lines (290 loc) · 14.2 KB
/
few_shot_learning_using_prototypical_network (1).py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
# -*- coding: utf-8 -*-
"""Few-Shot Learning using Prototypical Network.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1pkpOMH-JqmDntKv50hTmR4EWfTxMaG6g
## Discovering Prototypical Networks
"""
!pip install easyfsl
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm
from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average
"""Now, we need a dataset. I suggest we use [Omniglot](https://github.com/brendenlake/omniglot), a popular MNIST-like benchmark
for few-shot classification. It contains 1623 characters from 50 different alphabets. Each character has been written by
20 different people.
Bonus: it's part of the `torchivision` package, so it's very easy to download
and work with.
"""
image_size = 28
# NB: background=True selects the train set, background=False selects the test set
# It's the nomenclature from the original paper, we just have to deal with it
train_set = Omniglot(
root="./data",
background=True,
transform=transforms.Compose(
[
transforms.Grayscale(num_output_channels=3),
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
),
download=True,
)
test_set = Omniglot(
root="./data",
background=False,
transform=transforms.Compose(
[
# Omniglot images have 1 channel, but our model will expect 3-channel images
transforms.Grayscale(num_output_channels=3),
transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
]
),
download=True,
)
"""Simply put, in a few-shot classification task, we have a labeled support set (which kind of acts
like a catalog) and query set. For each image of the query set, we want to predict a label from the
labels present in the support set. A few-shot classification model has to use the information from the
support set in order to classify query images. We say *few-shot* when the support set contains very
few images for each label (typically less than 10). The figure below shows a 3-way 2-shots classification task. "3-way" means "3 different classes" and "2-shots" means "2 examples per class".
We expect a model that has never seen any Saint-Bernard, Pug or Labrador during its training to successfully
predict the query labels. The support set is the only information that the model has regarding what a Saint-Bernard,
a Pug or a Labrador can be.
![few-shot classification task](https://images.ctfassets.net/be04ylp8y0qc/bZhboqYXfYeW4I88xmMNv/7c5efdc368206feaad045c674b1ced95/1_AteD0yXLkQ1BbjQTB3Ytwg.png?fm=webp)
Most few-shot classification methods are *metric-based*. It works in two phases : 1) they use a CNN to project both
support and query images into a feature space, and 2) they classify query images by comparing them to support images.
If, in the feature space, an image is closer to pugs than it is to labradors and Saint-Bernards, we will guess that
it's a pug.
From there, we have two challenges :
1. Find the good feature space. This is what convolutional networks are for. A CNN is basically a function that takes an image as input and outputs a representation (or *embedding*) of this image in a given feature space. The challenge here is to have a CNN that will
project images of the same class into representations that are close to each other, even if it has not been trained
on objects of this class.
2. Find a good way to compare the representations in the feature space. This is the job of Prototypical Networks.
![Prototypical classification](https://images.ctfassets.net/be04ylp8y0qc/45M9UcUp6KnzwDaBHeGZb7/bb2dcda5942ee7320600125ac2310af6/0_M0GSRZri859fGo48.png?fm=webp)
From the support set, Prototypical Networks compute a prototype for each class, which is the mean of all embeddings
of support images from this class. Then, each query is simply classified as the nearest prototype in the feature space,
with respect to euclidean distance.
We simply define Prototypical Networks as a torch module, with a `forward()` method.
Two things to notice:
1. We initiate `PrototypicalNetworks` with a *backbone*. This is the feature extractor we were talking about.
Here, we use as backbone a ResNet18 pretrained on ImageNet, with its head chopped off and replaced by a `Flatten`
layer. The output of the backbone, for an input image, will be a 512-dimensional feature vector.
2. The forward method doesn't only take one input tensor, but 3: in order to predict the labels of query images,
we also need support images and labels as inputs of the model.
"""
class PrototypicalNetworks(nn.Module):
def __init__(self, backbone: nn.Module):
super(PrototypicalNetworks, self).__init__()
self.backbone = backbone
def forward(
self,
support_images: torch.Tensor,
support_labels: torch.Tensor,
query_images: torch.Tensor,
) -> torch.Tensor:
"""
Predict query labels using labeled support images.
"""
# Extract the features of support and query images
z_support = self.backbone.forward(support_images)
z_query = self.backbone.forward(query_images)
# Infer the number of different classes from the labels of the support set
n_way = len(torch.unique(support_labels))
# Prototype i is the mean of all instances of features corresponding to labels == i
z_proto = torch.cat(
[
z_support[torch.nonzero(support_labels == label)].mean(0)
for label in range(n_way)
]
)
# Compute the euclidean distance from queries to prototypes
dists = torch.cdist(z_query, z_proto)
# And here is the super complicated operation to transform those distances into classification scores!
scores = -dists
return scores
convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()
print(convolutional_network)
model = PrototypicalNetworks(convolutional_network).cuda()
"""We used a pretrained feature extractor,
so our model should already be up and running. Let's see that.
Here we create a dataloader that will feed few-shot classification tasks to our model.
But a regular PyTorch dataloader will feed batches of images, with no consideration for
their label or whether they are support or query. We need 2 specific features in our case.
1. We need images evenly distributed between a given number of classes.
2. We need them split between support and query sets.
For the first point, I wrote a custom sampler: it first samples `n_way` classes from the dataset,
then it samples `n_shot + n_query` images for each class (for a total of `n_way * (n_shot + n_query)`
images in each batch).
For the second point, I have a custom collate function to replace the built-in PyTorch `collate_fn`.
This model feed each batch as the combination of 5 items:
1. support images
2. support labels between 0 and `n_way`
3. query images
4. query labels between 0 and `n_way`
5. a mapping of each label in `range(n_way)` to its true class id in the dataset
(it's not used by the model but it's very useful for us to know what the true class is)
We can see that in PyTorch, a DataLoader is basically the combination of a sampler, a dataset and a collate function
(and some multiprocessing voodoo): sampler says which items to fetch, the dataset says how to fetch them, and
the collate function says how to present these items together.
"""
N_WAY = 5 # Number of classes in a task
N_SHOT = 5 # Number of images per class in the support set
N_QUERY = 10 # Number of images per class in the query set
N_EVALUATION_TASKS = 100
# The sampler needs a dataset with a "get_labels" method.
test_set.get_labels = lambda: [
instance[1] for instance in test_set._flat_character_images
]
test_sampler = TaskSampler(
test_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)
test_loader = DataLoader(
test_set,
batch_sampler=test_sampler,
num_workers=12,
pin_memory=True,
collate_fn=test_sampler.episodic_collate_fn,
)
"""We created a dataloader that will feed us with 5-way 5-shot tasks."""
(
example_support_images,
example_support_labels,
example_query_images,
example_query_labels,
example_class_ids,
) = next(iter(test_loader))
plot_images(example_support_images, "support images", images_per_row=N_SHOT)
plot_images(example_query_images, "query images", images_per_row=N_QUERY)
model.eval()
example_scores = model(
example_support_images.cuda(),
example_support_labels.cuda(),
example_query_images.cuda(),
).detach()
_, example_predicted_labels = torch.max(example_scores.data, 1)
print("Ground Truth / Predicted")
for i in range(len(example_query_labels)):
print(
f"{test_set._characters[example_class_ids[example_query_labels[i]]]} / {test_set._characters[example_class_ids[example_predicted_labels[i]]]}"
)
"""Model was trained on very different images, and has only seen 5 examples for each class!"""
def evaluate_on_one_task(
support_images: torch.Tensor,
support_labels: torch.Tensor,
query_images: torch.Tensor,
query_labels: torch.Tensor,
) -> [int, int]:
"""
Returns the number of correct predictions of query labels, and the total number of predictions.
"""
return (
torch.max(
model(support_images.cuda(), support_labels.cuda(), query_images.cuda())
.detach()
.data,
1,
)[1]
== query_labels.cuda()
).sum().item(), len(query_labels)
def evaluate(data_loader: DataLoader):
# We'll count everything and compute the ratio at the end
total_predictions = 0
correct_predictions = 0
# eval mode affects the behaviour of some layers (such as batch normalization or dropout)
# no_grad() tells torch not to keep in memory the whole computational graph (it's more lightweight this way)
model.eval()
with torch.no_grad():
for episode_index, (
support_images,
support_labels,
query_images,
query_labels,
class_ids,
) in tqdm(enumerate(data_loader), total=len(data_loader)):
correct, total = evaluate_on_one_task(
support_images, support_labels, query_images, query_labels
)
total_predictions += total
correct_predictions += correct
print(
f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%"
)
evaluate(test_loader)
"""With absolutely zero training on Omniglot images, and only 5 examples per class, we achieve around 86% accuracy!
## Training a meta-learning algorithm
Let's use the "background" images of Omniglot as training set. Here we prepare a data loader of 40 000 few-shot classification
tasks on which we will train our model. The alphabets used in the training set are entirely separated from those used in the testing set.
This guarantees that at test time, the model will have to classify characters that were not seen during training.
"""
N_TRAINING_EPISODES = 40000
N_VALIDATION_TASKS = 100
train_set.get_labels = lambda: [instance[1] for instance in train_set._flat_character_images]
train_sampler = TaskSampler(
train_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES
)
train_loader = DataLoader(
train_set,
batch_sampler=train_sampler,
num_workers=12,
pin_memory=True,
collate_fn=train_sampler.episodic_collate_fn,
)
"""We will keep the same model. So our weights will be pre-trained on ImageNet. If we want to start a training from scratch, we can set `pretrained=False` in the definition of the ResNet.
Here we define our loss and our optimizer, and a `fit` method.
This method takes a classification task as input (support set and query set). It predicts the labels of the query set
based on the information from the support set; then it compares the predicted labels to ground truth query labels,
and this gives us a loss value. Then it uses this loss to update the parameters of the model. This is a *meta-training loop*.
"""
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def fit(
support_images: torch.Tensor,
support_labels: torch.Tensor,
query_images: torch.Tensor,
query_labels: torch.Tensor,
) -> float:
optimizer.zero_grad()
classification_scores = model(
support_images.cuda(), support_labels.cuda(), query_images.cuda()
)
loss = criterion(classification_scores, query_labels.cuda())
loss.backward()
optimizer.step()
return loss.item()
"""To train the model, we are just going to iterate over a large number of randomly generated few-shot classification tasks,
and let the `fit` method update our model after each task. This is called **episodic training**.
"""
log_update_frequency = 10
all_loss = []
model.train()
with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
for episode_index, (
support_images,
support_labels,
query_images,
query_labels,
_,
) in tqdm_train:
loss_value = fit(support_images, support_labels, query_images, query_labels)
all_loss.append(loss_value)
if episode_index % log_update_frequency == 0:
tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))
# here is already trained model
!wget https://public-sicara.s3.eu-central-1.amazonaws.com/easy-fsl/resnet18_with_pretraining.tar
model.load_state_dict(torch.load("resnet18_with_pretraining.tar", map_location="cuda"))
evaluate(test_loader)
"""Around 98%!
It's not surprising that the model performs better after being further trained on Omniglot images than it was with its
ImageNet-based parameters. However, we have to keep in mind that the classes on which we just evaluated our model were still
**not seen during training**, so 99% (with a 12% improvement over the model trained on ImageNet) seems like a decent performance.
"""