### Function for getting the Food101 data

In [7]:
from torchvision.datasets import Food101
from torch.utils.data import Subset

def get_data(data_path, transform):
    food101_train_data = Food101(
        root = data_path / "train",
        split = 'train',
        transform = transform,
        target_transform = None,
        download = True
    )

    classes = food101_train_data.classes

    class_count = {}
    indices = []
    
    for i, (_, label) in enumerate(food101_train_data):
        if label not in class_count.keys():
            class_count[label] = 0

        if class_count[label] < 160:
            class_count[label] += 1
            indices.append(i)


    train_data = Subset(
        dataset = food101_train_data, 
        indices = indices
    )


    food101_test_data = Food101(
        root = data_path / "test",
        split = 'test',
        transform = transform,
        target_transform = None,
        download = True
    )

    class_count = {}
    indices = []

    for i, (_, label) in enumerate(food101_test_data):
        if label not in class_count.keys():
            class_count[label] = 0

        if class_count[label] < 40:
            class_count[label] += 1
            indices.append(i)

    test_data = Subset(
        dataset = food101_test_data,
        indices = indices
    )
        
    return train_data, test_data, classes

### Function to create Dataloaders

In [8]:
from torch.utils.data import DataLoader

def get_dataloaders(train_data, test_data):
    train_dataloader = DataLoader(
        dataset = train_data,
        batch_size = 32,
        num_workers = 1,
        shuffle = True
    )

    test_dataloader = DataLoader(
        dataset = test_data,
        batch_size = 32,
        num_workers = 1,
        shuffle = False
    )

    return train_dataloader, test_dataloader

### Function to create VisionTransformer (ViT) feature extractor

In [9]:
from torchvision.models import ViT_B_16_Weights, vit_b_16
from torch import nn

def get_model(out_features):
    model = vit_b_16(
        weights = ViT_B_16_Weights.DEFAULT
    )

    for param in model.parameters():
        param.requires_grad = False

    model.heads.head = nn.Linear(
        in_features = 768, 
        out_features = out_features, 
        bias = True
    )
    
    return model
    

### Function to create image transformer

In [10]:
from torchvision.models import ViT_B_16_Weights

def get_transform():
    return ViT_B_16_Weights.DEFAULT.transforms()

In [11]:
from pathlib import Path

data_path = Path('../data')

transform = get_transform()
train_data, test_data, classes = get_data(data_path, transform)


### Function for train engine

In [26]:
from sklearn.metrics import accuracy_score
import torch

def train_set(model, dataloader, loss_fn, optimizer, device):

    train_loss = 0
    accuracy = 0

    model.train()

    for X, y in dataloader:

        X.to(device)
        y.to(device)
        
        y_logits = model(X)
        loss = loss_fn(y_logits, y)
        y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)

        train_loss += loss.item()
        accuracy += accuracy_score(y.detach().numpy(), y_pred.detach().numpy())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss = train_loss / len(dataloader)
    accuracy = accuracy / len(dataloader) * 100

    return train_loss, accuracy


def test_set(model, dataloader, loss_fn, device):

    test_loss = 0
    accuracy = 0

    model.eval()

    with torch.inference_mode():
        for X, y in dataloader:

            X.to(device)
            y.to(device)

            y_logits = model(X)
            loss = loss_fn(y_logits, y)
            y_pred = torch.softmax(y_logits, dim=1).argamx(dim=1)

            test_loss += loss.item()
            accuracy += accuracy_score(y.detach().numpy(), y_pred.detach().numpy())

    test_loss = test_loss / len(dataloader)
    accuracy = accuracy / len(dataloader) * 100

    return test_loss, accuracy

        

### Function for engine

In [27]:
def engine(model, train_dataloader, test_dataloader, loss_fn, optimizer, device, epochs):

    result = {
        'epoch' : [],
        'train_loss' : [],
        'train_acc' : [],
        'test_loss' : [],
        'test_acc' : []
    }

    for i in range(epochs):
        
        train_loss, train_acc = train_set(
            model = model, 
            dataloader = train_dataloader, 
            loss_fn = loss_fn, 
            optimizer = optimizer, 
            device = device
        )

        test_loss, test_acc = test_set(
            model = model,
            dataloader = test_dataloader,
            loss_fn = loss_fn,
            device = device
        )

        result['epoch'].append(i+1)
        result['train_loss'].append(train_loss)
        result['train_acc'].append(train_acc)
        result['test_loss'].append(test_loss)
        result['test_acc'].append(test_acc)

        print(f"epoch = {epoch} | train loss = {train_loss}  | train acc = {train_acc} | test loss = {test_loss} | test acc = {test_acc}")


    return result

### Function for save model

In [29]:
def save_model(path, model):
    file_name = path / "vit_model_on_food101.pth"
    torch.save(model.state_dict(), file_name)
    print(f"The model has been saved successfully in {file_name}")