# OOD enablement with finetuned resnet18 model demo

This demo will be using a pytorch resnet18 model finetuned on the hymenoptera dataset, a subset of ImageNet, as found in this official pytorch tutorial: 
[script](https://github.com/pytorch/tutorials/blob/main/beginner_source/transfer_learning_tutorial.py) 
and [documentation](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) from Pytorch. 

For this demo, the pretrained/finetuned model, as well as a subset of the indistrubtion dataset for OOD enablement can be downloaded from a public COS storage url: 

model: https://public-test-rhods.s3.us-east.cloud-object-storage.appdomain.cloud/model_ft_cpu_jit.pth

indistribution-data: https://public-test-rhods.s3.us-east.cloud-object-storage.appdomain.cloud/hymenoptera_data.zip

# Import required libraries

In [None]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import os
from PIL import Image
from tempfile import TemporaryDirectory
from torch.utils.data import DataLoader
from tensorflow.keras.utils import get_file

import json
import os
import sys
ood_path = os.path.abspath('../')
if ood_path not in sys.path:
    sys.path.append(ood_path)
    
from ood_enabler.ood_enabler import OODEnabler
from ood_enabler.storage.model_store import ModelStore
from ood_enabler.model_wrapper.pytorch import PytorchWrapper
from ood_enabler.data.pytorch_image_data_handler import PytorchImageDataHandler
from ood_enabler.storage.local_storage import FileSystemStorage
from ood_enabler.util.constants import SavedModelFormat


In [None]:
def show_images(inputs):
    w = 10
    h = 10
    fig = plt.figure(figsize=(20, 20))

    columns = 5
    rows = 1
    for i in range(1, columns*rows +1):
        img = inputs[i - 1].T
        fig.add_subplot(rows, columns, i)
        plt.imshow(img)
    plt.show()

# Define metadata for the model that will be used

In [None]:
model_metadata = {'type': 'pytorch', 'arch': 'resnet18'}

# Download pretrained/finetuned resnet18 model and load from local storage

In [None]:
model_url = "https://public-test-rhods.s3.us-east.cloud-object-storage.appdomain.cloud/model_ft_cpu_jit.pth"

with TemporaryDirectory() as tmpdir:
    model_path = get_file(origin=model_url, extract=False, fname=os.path.join(tmpdir, 'pretrained_model.pth')) 
    model_store = ModelStore.from_filesystem()
    model = model_store.load(model_metadata, model_path)
    

# Define metadata for the in-distribution dataset that will be used

In [None]:
ds_metadata = {'img_height': 224, 'img_width': 224, 'batch_size': 32, 'normalize': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])}

# Download in-distribution subset and load into data handler

In [None]:
data_url = 'https://public-test-rhods.s3.us-east.cloud-object-storage.appdomain.cloud/hymenoptera_data.zip'
data_path = get_file(origin=data_url, extract=False)

local_store = FileSystemStorage()
data_handler = PytorchImageDataHandler()
data_handler.load_dataset(local_store, data_path, '.', ds_metadata)

# Enable model with OOD layer

In [None]:
OODEnabler.ood_enable(model, data_handler)

# Download in-distribution samples to test OOD enablement and scoring

In [None]:
id_sample_url = 'https://public-test-rhods.s3.us-east.cloud-object-storage.appdomain.cloud/hymenoptera2.json'

with TemporaryDirectory() as tmpdir:
    id_sample = get_file(origin=id_sample_url, extract=False, fname=os.path.join(tmpdir, 'id_sample.json'))
    with open(id_sample, 'r') as f: 
        data = json.load(f)


# Review in-distribution samples to test OOD enablement and scoring

In [None]:
show_images(torch.Tensor(data['inputs'][0]['data']))

# Run inference through OOD-enabled model and review results
### OOD scores in second tensor 

In [None]:
model.model(torch.Tensor(data['inputs'][0]['data']))

# Review out-of-distribution samples to test OOD enablement and scoring
## 3 sample urls provided 

In [None]:
ood_sample_url = 'https://public-test-rhods.s3.us-east.cloud-object-storage.appdomain.cloud/flowers2.json'
#ood_sample_url = 'https://public-test-rhods.s3.us-east.cloud-object-storage.appdomain.cloud/food101.json'
#ood_sample_url = 'https://public-test-rhods.s3.us-east.cloud-object-storage.appdomain.cloud/cifar10.json'

with TemporaryDirectory() as tmpdir:
    ood_sample = get_file(origin=ood_sample_url, extract=False, fname=os.path.join(tmpdir, 'ood_sample.json'))
    with open(ood_sample, 'r') as f: 
        data = json.load(f)

# Review OOD samples to test OOD enablement and scoring

In [None]:
show_images(torch.Tensor(data['inputs'][0]['data']))

# Run inference through OOD-enabled model and review results
### OOD scores in second tensor 

In [None]:
model.model(torch.Tensor(data['inputs'][0]['data']))

# Save OOD-enabled model

In [None]:
model_store.upload(model, '.')