# Loading a Pretrained Model

In [1]:
import torch


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.hub.load("rasbt/ord-torchhub", model="resnet34_corn_afad", source='github', pretrained=True)
model.to(device);

Using cache found in /Users/sebastianraschka/.cache/torch/hub/rasbt_ord-torchhub_main


# Classifying Images (from AFAD)

Note that the model has been pretrained on the [AFAD dataset](https://afad-dataset.github.io) and may not perform well on other datasets.

In [2]:
from PIL import Image
from torchvision import transforms

img1 = Image.open("example1.jpg")
img2 = Image.open("example2.jpg")

resize = transforms.Compose([transforms.Resize((120, 120)),
                             transforms.ToTensor()])
img1 = resize(img1)
img2 = resize(img2)
batch = torch.stack((img1, img2))

In [3]:
with torch.no_grad():
    logits = model(batch)

In [4]:
def label_from_logits(logits):
    """ Converts logits to class labels.
    This is function is specific to CORN.
    """
    probas = torch.sigmoid(logits)
    probas = torch.cumprod(probas, dim=1)
    predict_levels = probas > 0.5
    predicted_labels = torch.sum(predict_levels, dim=1)
    return predicted_labels

In [5]:
predictions = label_from_logits(logits)

print(f'Rank indices [0-13]: {predictions}')
print(f'Real ages [18-30]: {predictions+18}')

Rank indices [0-13]: tensor([9, 3])
Real ages [18-30]: tensor([27, 21])


# 

# Transfer Learning

Use the pretrained model and adjust it to the number of classes in your dataset:

```python
model = torch.hub.load("rasbt/ord-torchhub", model="resnet34_corn_afad", source='github', pretrained=True)

NUM_CLASSES = # number of classes in your dataset
model.output_layer = torch.nn.Linear(512, out_features=NUM_CLASSES-1)
```

Then, take a look at the [../_train/resnet34_corn_afad.py](../_train/resnet34_corn_afad.py) file for defining the training loop.

Essentially, you need to replace the 


    ##########################
    # MODEL
    ##########################


section with the model above. And adjust the 

    ############################
    # Dataset
    ############################
    
section based on your dataset.