## Import required libraries

In [None]:
import tensorflow as tf
import os
import sys
from tempfile import TemporaryDirectory

ood_path = os.path.abspath('../')
if ood_path not in sys.path:
    sys.path.append(ood_path)
    
from ood_enabler.ood_enabler import OODEnabler
from ood_enabler.model_wrapper.tf import TFWrapper
from ood_enabler.storage.model_store import ModelStore
from ood_enabler.storage.local_storage import FileSystemStorage
from ood_enabler.data.tf_image_data_handler import TFImageDataHandler
from ood_enabler.util.archiver import archive

# Download dataset for example

In [None]:
dataset_url = "https://public-test-rhods.s3.us-east.cloud-object-storage.appdomain.cloud/flower_photos_small.tar.gz"
archive_path = tf.keras.utils.get_file(origin=dataset_url, extract=False)

# Define local storage connection and metadata for dataset¶

In [None]:
local_store = FileSystemStorage()
ds_metadata = {'img_height': 224, 'img_width': 224, 'batch_size': 32, 'normalize': 255}

# Get image datahandler from downloaded dataset and normalize

In [None]:
data_handler = TFImageDataHandler()
data_handler.load_dataset(local_store, archive_path, '.', ds_metadata)

# Create Model Store connection to local filesystem

In [None]:
model_store = ModelStore.from_filesystem()

## Load pretrained ResNet50 model from tf and save locally.*
### *(demo purposes only)

## Then reload into memory using FileStorage

In [None]:
model = tf.keras.applications.resnet50.ResNet50()
model_metadata = {'type': 'tf', 'arch': 'resnet50', 'ood_thresh_percentile': 20}

with TemporaryDirectory() as tmpdir:
    model_path = os.path.join(tmpdir, 'tf_resnet50')
    model.save(model_path)
    
    model = model_store.load(model_metadata, model_path)

# Review model architecure

In [None]:
model.model.summary()

# Enable model with OOD layer

In [None]:
OODEnabler.ood_enable(model, data_handler)

## Review new model architecture with embedded OOD Layer

In [None]:
model.model.summary()

## Run `predict` on OOD model to review change in outputs

In [None]:
(c10_x_1, c10_y_1), (c10_x_2, c10_y_2) = tf.keras.datasets.cifar10.load_data()
nn = tf.image.resize(
    [c10_x_1[0]],
    (224, 224),
    preserve_aspect_ratio=False,
    antialias=False,
    name=None)

In [None]:
# OOD model
o2 = model.model.predict(nn)
print(len(o2), o2[0], o2[1])

# Save OOD enabled model

In [None]:
model_store.upload(model, './ood_tf_model')