# Using the `fastai` library for Inference (aka "Production") 

This notebook is a final step of a small segmentation tutorial started [here](https://www.kaggle.com/code/purplejester/fast-ai-01-basic-segmentation-model-training). We load the trained segmentation model and use it to forecast newly encountered images.

As medical images is a rather specific domain, we take a small shortcut here and instead of using brand-new images, just go with the same data that was used for training. Also, we do our "production" run on Kaggle which is quite different from doing it on a dedicated service running somewhere on cloud. (So the title of this notebook should be taken with a grain of salt...) However, the logic shouldn't be too different. It is all about loading pretrained weights and running it on images loaded from a persistent storage, posted via HTTP request, or anything similar.

> **Note:** if you refer to the [fast.ai course's lectures](course.fast.ai), you'll find a more "real" example of model's deployment, like running it via Gradle, Streamlit, or a dedicated website.

# Import

In [None]:
import logging
from pathlib import Path
from fastai.vision.all import *

logging.captureWarnings(True)

In [None]:
DATA_DIR = Path("/kaggle/input/uw-madison-gi-tract-image-segmentation/")

# Loading a Pretrained Model

## Training Code
The model loaded here was trained on a different machine and uploaded to Kaggle. The following snippet shows how the training code looks. (Nothing fancy!)

```python
import logging
from pathlib import Path
from fastai.vision.all import *
from nbs.fast_ai_utils import get_dataset_size, COMBINED_MASK_CODES, OUTPUT_DIR

DATASET_SIZE = get_dataset_size(debug=False)

logging.captureWarnings(True)


def get_items(source_dir: Path):
    return get_image_files(source_dir.joinpath("images"))[:DATASET_SIZE]


def get_y(fn: Path):
    return fn.parent.parent.joinpath("masks").joinpath(f"{fn.stem}.png")


def train():
    seg = DataBlock(blocks=(ImageBlock, MaskBlock(COMBINED_MASK_CODES)),
                    get_items=get_items,
                    get_y=get_y,
                    splitter=RandomSplitter(),
                    item_tfms=[Resize(192, method="squash")])

    dls = seg.dataloaders(OUTPUT_DIR, bs=50)
    learn = unet_learner(dls, resnet18, metrics=DiceMulti)
    learn.fine_tune(20)
    learn.save("unet_resnet18_e20")


if __name__ == '__main__':
    train()
```

## Restoring the Exported Model

Note that the `fastai` library saves both the model and data transformations. Therefore, it is important to include into inference code the same classes and functions that were using during training/validation. Otherwise, reading of a pretrained model (deserialization) fails. 

Also, I had to patch `PIL.Image` as the model complained about missing `PIL.Image.Resampling` attribute. I'm not sure why it happened, maybe some versions conflict. This is something to be aware of when you decide to deploy a model: different library versions can lead to crashes or deteriorated performance. I remember a case when I used different versions of the same vision models' library during training and inference, and got different results even though the model's checkpoints were same.

In [None]:
import PIL.Image
from enum import IntEnum

class Resampling(IntEnum):
    NEAREST = 0
    BOX = 4
    BILINEAR = 2
    HAMMING = 5
    BICUBIC = 3
    LANCZOS = 1
    
PIL.Image.Resampling = Resampling

def get_items(source_dir: Path): return get_image_files(source_dir.joinpath("images"))

def get_y(fn: Path): return fn.parent.parent.joinpath("masks").joinpath(f"{fn.stem}.png")

learn = load_learner("/kaggle/input/uwm-models/unet_resnet18_e20.pth")

# Inference

Ok, the model was successfully loaded, let's try to run it on some images.

In [None]:
png_files = get_image_files(DATA_DIR)

some_files = np.random.choice(png_files, size=20, replace=False)

with learn.no_bar():    
    
    predicted = []

    for fn in some_files:
        mask, *_ = learn.predict(fn)
        predicted.append(mask.numpy())

The model [returns multiple tensors](https://docs.fast.ai/learner.html#Learner.predict) but we need only the first one which contains predicted segmentation mask. The following code renders images overlayed with the forecasted masks to see what was predicted.

In [None]:
from skimage import color
from skimage import transform
from skimage import exposure

def mask_overlay(image_file: str, mask_file: str, mask_alpha: float = 0.5) -> PIL.Image:
    """Overlay a mask on top of an image."""

    img, seg = [
        np.asarray(PIL.Image.open(fn)) 
        for fn in (image_file, mask_file)
    ]  # type: ignore

    img = np.amax(img) - img

    return color.label2rgb(seg, img, kind="overlay", alpha=mask_alpha)

def equalize(data: np.array, adaptive: bool) -> np.ndarray:
    """Histogram equalization to normalize images before previewing."""
    
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    method = (
        exposure.equalize_adapthist
        if adaptive
        else exposure.equalize_hist
    )
    return method(data)

figsize_mult = 4

n_rows, n_cols = 5, 4

f, axes = plt.subplots(
    n_rows, n_cols,
    figsize=(n_rows * figsize_mult, n_cols * figsize_mult)
)

for fn, mask, ax in zip(some_files, predicted, axes.flat):
    img = np.asarray(PIL.Image.open(fn))
    img = equalize(img, True)
    img = transform.resize(img, mask.shape)
    img = np.amax(img) - img
    overlay = color.label2rgb(mask, img, kind="overlay", alpha=0.5)
    overlay = transform.resize(overlay, (128, 128))
    ax.imshow(overlay)
    ax.set_axis_off()

We see that many images don't have any predictions but the ones that have it look reasonable: they show segmentation masks within body scans and also show shapes that resemble the original masks. Of course, it is far away from the perfect model, but could be a good start!

# Conclusion

I hope this and previous notebooks helped you to get a short overview of the `fastai` library and its capabilities. I still have a quite basic knowledge about its functionality but this little series of experiments gave me a good impression about this tool. Compared to some other frameworks, it helps to quickly jump into modelling and get great out-of-the-box results.

Good luck with your data projects, and see you on the [forums](forums.fast.ai)!