<a href="https://colab.research.google.com/github/s-a-malik/multi-few/blob/main/notebooks/FewShotImageClassificationDemo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Improving Few-Shot Learning using Task-Informed Meta-Initialisation

Authors: Matthew Jackson*, Shreshth Malik*, Michael Matthews, and Yousuf Mohamed-Ahmed.

Paper: (link)

This notebook is designed to explore the dataset iNat-Anim. 

Let us begin by cloning the repository.

In [1]:
!rm -rf multi-few
!git clone https://github.com/s-a-malik/multi-few.git
!pip install -r multi-few/requirements.txt

Cloning into 'multi-few'...
remote: Enumerating objects: 1430, done.[K
remote: Counting objects: 100% (176/176), done.[K
remote: Compressing objects: 100% (104/104), done.[K
remote: Total 1430 (delta 129), reused 110 (delta 72), pack-reused 1254[K
Receiving objects: 100% (1430/1430), 2.83 MiB | 21.13 MiB/s, done.
Resolving deltas: 100% (891/891), done.
Obtaining file:///content/multi-few
Collecting tqdm==4.60.0
[?25l  Downloading https://files.pythonhosted.org/packages/72/8a/34efae5cf9924328a8f34eeb2fdaae14c011462d9f0e3fcded48e1266d1c/tqdm-4.60.0-py2.py3-none-any.whl (75kB)
[K     |████████████████████████████████| 81kB 2.8MB/s 
Collecting numpy==1.20.2
[?25l  Downloading https://files.pythonhosted.org/packages/73/ef/8967d406f3f85018ceb5efab50431e901683188f1741ceb053efcab26c87/numpy-1.20.2-cp37-cp37m-manylinux2010_x86_64.whl (15.3MB)
[K     |████████████████████████████████| 15.3MB 409kB/s 
[?25hCollecting torchmeta==1.7.0
[?25l  Downloading https://files.pythonhosted.org/pac

Run the following cell to download the dataset and some pre-computed embeddings to use during inference.

The following should take a few minutes to complete.

In [12]:
# import gdown
# links = ["https://drive.google.com/uc?id=1cT6klPSkCY3tnhXtndmo-4Z9P85kg6j9", 
#          "https://drive.google.com/uc?id=1tzZqrGZSK_e8vJfK6yRe4TvOeJ8uyP0e", 
#          "https://drive.google.com/uc?id=1Ga68-VUt8wS8_P-xIM-csOYdasAmxo3n",
#          "https://drive.google.com/uc?id=1xPzSyyUoqtCpVAL8L7JD2tkhZhGf_PpG",
#          "https://drive.google.com/uc?id=1tPYYWJdz5rEEvcOJblni-JWng-XmFR7D"]
# outputs = ["images.hdf5", "train.json", "image-embedding-resnet-152.hdf5", "am3.pth.tar", "fumi.pth.tar"]
# for (l,o) in zip(links, outputs):
#     gdown.download(l, o, quiet=False)

# download the dataset zip from zenodo
!wget -O ./data/inat-anim url

Downloading...
From: https://drive.google.com/uc?id=1cT6klPSkCY3tnhXtndmo-4Z9P85kg6j9
To: /content/images.hdf5
9.61GB [01:20, 119MB/s]
Downloading...
From: https://drive.google.com/uc?id=1tzZqrGZSK_e8vJfK6yRe4TvOeJ8uyP0e
To: /content/train.json
77.8MB [00:00, 198MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Ga68-VUt8wS8_P-xIM-csOYdasAmxo3n
To: /content/image-embedding-resnet-152.hdf5
1.60GB [00:08, 197MB/s]
Downloading...
From: https://drive.google.com/uc?id=1xPzSyyUoqtCpVAL8L7JD2tkhZhGf_PpG
To: /content/am3.pth.tar
27.0MB [00:00, 189MB/s]
Downloading...
From: https://drive.google.com/uc?id=1tPYYWJdz5rEEvcOJblni-JWng-XmFR7D
To: /content/fumi.pth.tar
8.92MB [00:00, 95.4MB/s]


# Explore the Dataset

The following cell launches a UI to explore the dataset, the description is shown in the text area on the left.

Type into the green cell to display some images for a particular animal in the dataset (or click the 'random animal' button).

Flick through the gallery with the 'next' and 'back' buttons.

In [1]:
from __future__ import print_function
import ipywidgets as widgets
from IPython.display import display
import time
import random
import json
import h5py

import cv2
import numpy as np


class DemoDataParser():
    def __init__(self):
        json_path = "./train.json"
        with open(json_path) as annotations:
            annotations = json.load(annotations)
        N = len(annotations['categories'])
        M = len(annotations['images'])
        self.common_names = set(
            [annotations['categories'][i]['common_name'] for i in range(N)])

        cname_category_index_map = {}
        for i in range(N):
            cname_category_index_map[annotations['categories'][i]
                                     ['common_name']] = i

        self.cname_category_index_map = cname_category_index_map

        cname_image_index_map = {}
        for c in self.common_names:
            cname_image_index_map[c] = []
        for i in range(M):
            cname = annotations['categories'][annotations['annotations'][i]
                                              ['category_id']]['common_name']
            cname_image_index_map[cname].append(i)
        self.cname_image_index_map = cname_image_index_map

        cname_description_map = {}
        for i in range(N):
            cname = annotations['categories'][i]['common_name']
            cname_description_map[cname] = annotations['categories'][i][
                'description']
        self.cname_description_map = cname_description_map

        h5_file = h5py.File("/content/images.hdf5", 'r')
        self.images = h5_file['images']
        self.annotations = annotations


class DatasetExplorer():
    def __init__(self, data: DemoDataParser):
        # todo: add species name
        self.data = data
        self.cnames_list = list(self.data.common_names)
        default_cname = random.choice(self.cnames_list)
        self.common_name = widgets.Combobox(options=list(
            self.data.common_names),
                                            value=default_cname,
                                            font_size="20px",
                                            layout=widgets.Layout(width='50%'),
                                            ensure_option=True)
        self.common_name.add_class('data_input')
        self.random_species_button = widgets.Button(
            description="Random animal",
            icon='fa-dice',
            button_style='',
            layout=widgets.Layout(width='50%'))
        common_name_box = widgets.HBox(
            [self.common_name, self.random_species_button])
        self.description = widgets.Textarea(layout=widgets.Layout(
            width='100%', font_size="20px", height='100%'))

        data_input_style = '''<style>
        .mytext .fa, .far, .fas {
            font-style: italic;
            color: blue;
            font-size: 100px;
        }
        .data_input input { background-color:#bede68 !important; font-size: 5; }
        .data_input text { background-color:#bede68 !important; font-size: 5; }</style>'''
        self.image = widgets.Image(width="100%", height="100%", format='raw')

        ui = widgets.VBox([
            widgets.HTML(data_input_style), common_name_box, self.description
        ])
        self.back_button = widgets.Button(description="Back",
                                          layout=widgets.Layout(width='50%'),
                                          disabled=True)
        self.next_button = widgets.Button(description="Next",
                                          layout=widgets.Layout(width='50%'),
                                          disabled=False)

        self.random_species_button.on_click(self.on_random_click)
        self.next_button.on_click(self.on_next)
        self.back_button.on_click(self.on_back)
        buttons = widgets.HBox([self.back_button, self.next_button])
        ui = widgets.HBox([ui, widgets.VBox([self.image, buttons])])

        self.output = widgets.Output()
        self.row, self.col = 4, 6
        self.base = 0
        out = widgets.interactive_output(self.explore,
                                         {'common_name': self.common_name})
        display(ui, out, self.output)

    def on_random_click(self, b):
        self.base = 0
        self.back_button.disabled = True
        self.next_button.disabled = False
        with self.output:
            b.disabled = True
            self.common_name.value = random.choice(self.cnames_list)
            self.explore(self.common_name.value)
            b.disabled = False

    def on_next(self, b):
        self.base += self.row * self.col
        self.back_button.disabled = False
        self.update_gallery()

    def on_back(self, b):
        self.next_button.disabled = False
        self.base -= (self.row * self.col)
        self.update_gallery()

    def update_gallery(self):
        indxs = self.data.cname_image_index_map[self.common_name.value]
        frames = [
            np.hstack(
                self.data.images[indxs[self.base +
                                       start:min(self.base + start +
                                                 self.col, len(indxs))]])
            for start in range(0, self.row * self.col, self.col)
        ]
        frame = np.vstack(frames)
        if self.base == 0:
            self.back_button.disabled = True
        if self.base > len(indxs) - (2 * self.row * self.col):
            self.next_button.disabled = True

        frame = np.flip(frame, 2)
        _, im_buf_arr = cv2.imencode(".jpg", frame)
        byte_im = im_buf_arr.tobytes()
        self.image.value = byte_im

    def explore(self, common_name):
        self.base = 0
        indxs = self.data.cname_image_index_map[common_name]
        start = time.time()
        frames = [
            np.hstack(self.data.images[indxs[start:min(start +
                                                       self.col, len(indxs))]])
            for start in range(0, self.row * self.col, self.col)
        ]
        frame = np.vstack(frames)
        start = time.time()
        frame = np.flip(frame, 2)
        _, im_buf_arr = cv2.imencode(".jpg", frame)
        byte_im = im_buf_arr.tobytes()
        self.image.value = byte_im
        self.description.value = self.data.cname_description_map[common_name]

data = DemoDataParser()
DatasetExplorer(data)

HBox(children=(VBox(children=(HTML(value='<style>\n        .mytext .fa, .far, .fas {\n            font-style: …

Output()

Output()

<multifew.demo.dataset_explorer.DatasetExplorer at 0x7f9aba0e0a90>

Below you can also look up data examples given the image id e.g. for analysing results.

In [None]:

from __future__ import print_function
import ipywidgets as widgets
import IPython
import time
import random
import json
import h5py

import numpy as np
import cv2

json_path = "./train.json"
with open(json_path) as annotations:
    annotations = json.load(annotations)
N = len(annotations['categories'])
M = len(annotations['images'])
category = [annotations['annotations'][i]['category_id'] for i in range(M)]
common_names = set([annotations['categories'][i]['common_name'] for i in category])
file_names = [annotations['images'][i]['file_name'] for i in range(M)]
species = [annotations['categories'][i]['name'] for i in category]
descriptions = [annotations['categories'][i]['description'] for i in category]

cnames_list = list(common_names)
common_name=widgets.Combobox(options=[str(i) for i in range(M)], value="0", font_size="20px", layout=widgets.Layout(width='20%'), ensure_option=True)
common_name.add_class('data_input')
description = widgets.Textarea(layout=widgets.Layout(width='20%', font_size="20px", height='3000px'))
description.add_class("data_input")


ui = widgets.VBox([common_name, description])
output = widgets.Output()
def explore(index):
    index = int(index)
    description.value = "Species: %s \nDescription: %s \nFile path: %s" % (species[index], descriptions[index], file_names[index])
out = widgets.interactive_output(explore,{'index':common_name})
display(ui,out, output)