In [None]:
import numpy as np
import torch.optim
import torchmetrics
from torchvision import datasets
from torchvision import transforms as T
import transformers

import armory.evaluation
import armory.model.image_classification
import armory.trainer

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
sysconfig = armory.evaluation.SysConfig()
sysconfig

In [None]:
tv_dataset = datasets.Food101(
    root=str(sysconfig.dataset_cache),
    split="test",
    download=True,
    transform=T.Compose(
        [
            T.Resize(size=(224, 224)),
            T.ToTensor(),  # HWC->CHW and scales to 0-1
            T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            # T.Lambda(np.asarray),
        ]
    ),
)

tv_dataset

In [None]:
labels = tv_dataset.classes
labels

In [None]:
armory_dataset = armory.dataset.TupleDataset(tv_dataset, ("image", "label"))

In [None]:
normalized_scale = armory.data.Scale(
    dtype=armory.data.DataType.FLOAT,
    max=1.0,
    mean=(0.5, 0.5, 0.5),
    std=(0.5, 0.5, 0.5),
)

batch_size = 1
shuffle = True
seed = None

dataloader = armory.dataset.ImageClassificationDataLoader(
    armory_dataset,
    dim=armory.data.ImageDimensions.CHW,
    scale=normalized_scale,
    image_key="image",
    label_key="label",
    batch_size=batch_size,
    shuffle=shuffle,
    seed=seed,
)

In [None]:
dataloader

In [None]:
evaluation_dataset = armory.evaluation.Dataset(
    name="food-101",
    dataloader=dataloader,
)

In [None]:
# Download configuration from huggingface.co and cache.
hf_config = transformers.AutoConfig.from_pretrained("nateraw/food")
hf_model = transformers.AutoModelForImageClassification.from_config(hf_config)

In [None]:
armory_model = armory.model.image_classification.ImageClassifier(
    name="ViT-finetuned-food101",
    model=hf_model,
    inputs_spec=armory.data.TorchImageSpec(
        dim=armory.data.ImageDimensions.CHW,
        scale=normalized_scale
    ),
)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(hf_model.parameters())
metric = armory.metric.PredictionMetric(
    torchmetrics.classification.MulticlassAccuracy(num_classes=len(labels)),
)

trainer = armory.trainer.Trainer(
    model=armory_model,
    criterion=criterion,
    optimizer=optimizer,
    metric=metric,
)

In [None]:
trainer.fit(
    train_dataloader=dataloader,
    max_epochs=1,
    accelerator="cpu"
)