## Import required libraries

In [None]:
import os
import sys
import torch
import torchvision
from tensorflow.keras.utils import get_file
from tempfile import TemporaryDirectory

ood_path = os.path.abspath('../')
if ood_path not in sys.path:
    sys.path.append(ood_path)
    
from ood_enabler.ood_enabler import OODEnabler
from ood_enabler.model_wrapper.pytorch import PytorchWrapper
from ood_enabler.storage.model_store import ModelStore
from ood_enabler.storage.local_storage import FileSystemStorage
from ood_enabler.data.pytorch_image_data_handler import PytorchImageDataHandler

# Download dataset for example

In [None]:
dataset_url = "https://public-test-rhods.s3.us-east.cloud-object-storage.appdomain.cloud/flower_photos_small.tar.gz"
archive = get_file(origin=dataset_url, extract=False)

# Define local storage connection and metadata for dataset¶

In [None]:
local_store = FileSystemStorage()
ds_metadata = {'img_height': 224, 'img_width': 224, 'batch_size': 32, 'normalize': ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])}

# Get image datahandler from downloaded dataset

In [None]:
data_handler = PytorchImageDataHandler()
data_handler.load_dataset(local_store, archive, '.', ds_metadata)

# Create Model Store connection to local filesystem

In [None]:
model_store = ModelStore.from_filesystem()

## Load pretrained ResNet50 model from tf and save locally.*
### *(demo purposes only)

## Then reload into memory using FileStorage

In [None]:
model = torchvision.models.resnet50(pretrained=True)
model_metadata = {'type': 'pytorch', 'arch': 'resnet50', 'ood_thresh_percentile': 20}

with TemporaryDirectory() as tmpdir:
    model_path = os.path.join(tmpdir, 'pytorch_resnet50')
    model_full = torch.jit.script(model)
    model_full.save(model_path)
    
    model = model_store.load(model_metadata, model_path)

# Review model architecure

In [None]:
model.model

# Enable model with OOD layer

In [None]:
OODEnabler.ood_enable(model, data_handler)

## Review new model architecture with embedded OOD Layer

In [None]:
model.model

## Run `predict` on OOD model to review change in outputs

In [None]:
from torchvision import transforms
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 32

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size)

In [None]:
# OOD model
model.model.eval()
for i, data in enumerate(trainloader, 0):
    # get the inputs; data is a list of [inputs, labels]
    inputs, labels = data
    outputs = model.model(inputs)
    break
    

In [None]:
outputs

# Save OOD enabled model

In [None]:
model_store.upload(model, '.')