# Use Case #1

## Finding related images without pretraining, on toy data

The notebook below builds an MVP for this simple use case.

### User Story

The user provides an image to the system, and the number of similar images to be found in the Imagenette data set.

The system will return the requested number of images.

## TODO

0. Download the data using `!python '../src/data/get_imagenette.py'`, which will be stored at `../data/raw/imagenette-160`
1. create databunch (with `shuffle=False` for easy indexing later on)
2. create a `fastai` learner, based on pre-trained ResNet-18 (to be able to run on laptop).
3. create hook callback, pass model modules to be collected, create partial for learner creation
4. register callback with the learner by hand
5. use `.get_preds()` to collect activations for dataset
6. use `.pred()` on target image to collect activations for the target image
7. calculate distances from dataset activations to target image
8. get indices of closest activations
9. use the indices to retreive the closest images
10. plot the images
11. PROFIT!!!

## Imports

In [None]:
from fastai.vision import *
from fastai.metrics import accuracy
from fastai.callbacks.hooks import *

DATA_PATH = '../data/raw/imagenette-160'
GET_DATA_PATH = '../src/data/get_imagenette.py'

%load_ext autoreload
%autoreload 2

In [None]:
# !python '../src/usecases/mvp.py'

## Getting the Data

In [None]:
!ls ../src/data/

In [None]:
# # $ allows to pass python variable to jupyter magic command
# !python $GET_DATA_PATH

In [None]:
!ls $DATA_PATH

## Creating the DataBunch Step by Step

In [None]:
def make_data(data_path:PathOrStr, bs:int=16, img_size:int=160, 
              pct_partial:float=1.0, num_workers:int=0)->ImageDataBunch:
    return (ImageList.from_folder(data_path)         # -> ImageList
            .use_partial_data(pct_partial, seed=42)  # -> ImageList
            .split_none()                            # -> ItemLists: train and valid ItemList
            .label_from_folder()                     # -> LabelLists: train and valid LabelList
            .transform(size=img_size)                # -> LabelLists: train and valid LabelList
            .databunch(bs=bs, num_workers=0)         # -> ImageDataBunch
            .normalize(imagenet_stats))              # -> ImageDataBunch

data = make_data(DATA_PATH)

In [None]:
bs = 16
size = 160

data = (ImageList.from_folder(DATA_PATH)  # -> ImageList
        .use_partial_data(0.01, seed=42)  # -> ImageList
        .split_none()                     # -> ItemLists: train and valid ItemList
        .label_from_folder()              # -> LabelLists: train and valid LabelList
        .transform(size=size)             # -> LabelLists: train and valid LabelList
        .databunch(bs=bs, num_workers=0)  # -> ImageDataBunch
        .normalize(imagenet_stats))       # -> ImageDataBunch

# note that we need to make sure that data in the dataloader is not shuffled
# solution for turning off shuffle in data block API is here:
# https://forums.fast.ai/t/how-can-i-turn-off-shuffle-in-the-data-block/33942/2?u=maxim.pechyonkin
data.train_dl = data.train_dl.new(shuffle=False)

data.show_batch(rows=4, figsize=(8,8))

In [None]:
print(data.train_ds)
print('-'*42)
print(data.valid_ds)

# Creating the Learner and Collecting Activations
Callback hook will be registered with the learner and will allow to access activations of a given layer.

## Creating the Learner
## Registering Callback by Hand

In [None]:
activations = []
        
def printer(module, i, o):
    activations.append(o)

learner = cnn_learner(data, models.resnet18)

print("training:", learner.model.training)

last_layer = flatten_model(learner.model)[-2]
last_layer.register_forward_hook(printer)

# learner.fit_one_cycle(1)
preds = learner.get_preds(data.train_ds)

print("training:", learner.model.training)

data_activations = torch.cat(activations)

In [None]:
query_img = data.train_ds[0][0]
query_img

In [None]:
activations = []
learner.predict(query_img)

In [None]:
query_act = activations[0]

In [None]:
query_act.shape

In [None]:
data_activations.shape

In [None]:
query_act.shape

In [None]:
# finding distances from query to all 
closest_idxs = (data_activations - query_act).pow(2).sum(dim=1).argsort()

In [None]:
sorted_ds = data.train_ds[closest_idxs]

In [None]:
sorted_ds[0][0]

In [None]:
query_img

In [None]:
sorted_ds[1][0]

In [None]:
sorted_ds[2][0]

In [None]:
sorted_ds[3][0]

In [None]:
sorted_ds[4][0]