# Fossilnet inference

We trained **fossilnet** on Google Colab's CPU-accelerated notebook, and saved the model weights in `./geofignet.pt`.

You will need to install [PyTorch](https://pytorch.org/get-started/locally/) version >= 1.4.

In [None]:
import torch
import torchvision

torch.__version__, torchvision.__version__

In [None]:
import torch.nn as nn
from torchvision import datasets, models, transforms

## Instantiate the model

In [None]:
class_names = ['ammonites',
               'bivalves',
               'corals',
               'dinosaurs',
               'echinoderms',
               'fishes',
               'forams',
               'gastropods',
               'plants',
               'trilobites',
              ]

You will need to download the weights from your Google Drive, or you can [use this one](https://drive.google.com/open?id=1Uf7TCzje8sfSrC_mVYjKAjpnHUUuDKdf)

In [None]:
# Instantiate a vanilla ResNet and adjust its shape.
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, len(class_names))

# Load the geofignet weights.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load('../data/fossilnet.pt', map_location=device), strict=False)

# Set the mode to 'evaluate' before inference, e.g. to disable dropout layers.
_ = model.eval()

## Inference on one image

In [None]:
from IPython.display import Image as Img

Img("../data/random_ammonite.jpeg")

In [None]:
from PIL import Image

data_transforms = transforms.Compose([
        transforms.Resize(156),
        transforms.CenterCrop(128),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def image_loader(image_name):
    """load image, returns cuda tensor"""
    image = Image.open(image_name).convert('RGB')
    image = data_transforms(image).unsqueeze(0)
    return image.to(device)

image = image_loader("../data/random_ammonite.jpeg")

sm = torch.nn.Softmax(dim=1)
probs = sm(model(image))
prob, clas = torch.max(probs, 1)

In [None]:
class_names[clas]

We also get the probability of the class selection:

In [None]:
prob.item()

This came from the model output, which is passed through a softmax function:

In [None]:
torch.nn.Softmax(dim=1)(model(image))

This is a torch tensor, which was can convert to a NumPy object for easier manipulation:

In [None]:
 probs.detach().numpy().squeeze()

For example, we could make a plot of the log probability of each class:

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

In [None]:
probs_ = probs.detach().numpy().squeeze()
y = np.arange(len(probs_))
y_min, y_max = y[0]-0.75, y[-1]+0.75

fig, ax = plt.subplots(figsize=(6, 10))
bars = ax.barh(y, probs_, color='orange', align='center', lw=2)
ax.set_yticks(y)
ax.set_yticklabels(class_names, size=14)
ax.set_xscale('log')
ax.set_ylim(y_max, y_min)  # Label top-down.
ax.grid(c='black', alpha=0.1, which='both')

for i, p in enumerate(probs_):
    ax.text(0.55*min(probs_), i, f"{p:0.2e}", va='center')

bars[np.argmax(probs_)].set_color('red')

## What next?

If we think our model is doing what we want, we could deploy it in a web app for example. 

See an implementation of `geofignet` here:

> https://geofignet.geosci.ai/

But before we feel too pleased with ourselves:

In [None]:
Img("../data/cinnamon.jpg", width=512)

In [None]:
image = image_loader("../data/cinnamon.jpg")

sm = torch.nn.Softmax(dim=1)
probs = sm(model(image))
prob, clas = torch.max(probs, 1)

In [None]:
class_names[clas], prob.item()