In [1]:
import io
import json
import torch
import evaluate
import pandas as pd
from pathlib import Path
from datasets import Dataset
from PIL import Image as PILImage
from api.feedback.models import Feedback
from transformers import TrainingArguments, Trainer
from transformers import ViTImageProcessor, ViTForImageClassification

### Training Data (mock images and feedback list)

In [2]:
IMG_MOCK_DIR = Path.cwd().parent / "api" / "mock"
audi = IMG_MOCK_DIR / "audi.jpg"
bmw = IMG_MOCK_DIR / "bmw.jpg"
vw = IMG_MOCK_DIR / "vw.jpg"
feedback = [
    Feedback(
        image=audi.read_bytes(),
        label="Audi",
        correct=True,
        correct_label="Audi",
    ),
    Feedback(
        image=bmw.read_bytes(),
        label="Audi",
        correct=False,
        correct_label="BMW",
    ),
    Feedback(
        image=vw.read_bytes(),
        label="BMW",
        correct=False,
        correct_label="Volkswagen",
    ),
]

In [3]:
image_bytes = [feedback_item.image for feedback_item in feedback]
images = [PILImage.open(io.BytesIO(image_bytes)) for image_bytes in image_bytes]
labels = [feedback_item.correct_label for feedback_item in feedback]
df_train = pd.DataFrame({"image": image_bytes, "label": labels})

### Validation Data

In [4]:
def image_to_bytes(image_path: "str") -> bytes:
    with open(image_path, "rb") as img_file:
        return img_file.read()

In [5]:
df_val = pd.read_csv("validation.csv")
df_val["image"] = df_val["image_path"].apply(image_to_bytes)
df_val = df_val.drop(columns=["image_path"])

### Helper Functions

In [12]:
def preprocess(example) -> dict:
    """
    preprocessing necessary as FeedbackAPI provides image as bytes and not as files
    :param example: entry in a Dataset object
    :return: image (preprocessed) and label (preserved)
    """
    image = PILImage.open(io.BytesIO(example["image"])).convert("RGB")
    inputs = processor(images=image, return_tensors="np")
    label_id = label2id[example["label"]]
    return {
        "pixel_values": inputs["pixel_values"].squeeze(),
        "label": label_id
    }


# Evaluation metric
def compute_metrics(eval_pred) -> dict:
    predictions = eval_pred.predictions
    label_ids = eval_pred.label_ids
    predicted_labels = predictions.argmax(axis=1)
    acc_score = accuracy.compute(predictions=predicted_labels, references=label_ids)['accuracy']
    return {"accuracy": acc_score}


# Data collator
def collate_fn(examples):
    pixel_values = torch.stack([torch.tensor(example["pixel_values"]) for example in examples])
    labels = torch.tensor([example['label'] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

### Model

In [13]:
with open("config.json", "r") as file:
    config = json.load(file)
id2label = config.get("id2label")
label2id = config.get("label2id")

# load model
model_id = "dima806/car_brands_image_detection"
processor = ViTImageProcessor.from_pretrained(model_id)
model = ViTForImageClassification.from_pretrained(
    model_id,
    use_safetensors=True,
    trust_remote_code=True
)

In [14]:
dataset_train = Dataset.from_pandas(df_train)
dataset_train = dataset_train.map(preprocess)

dataset_val = Dataset.from_pandas(df_val)
dataset_val = dataset_val.map(preprocess)

# Evaluation metric
accuracy = evaluate.load("accuracy")

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Map:   0%|          | 0/350 [00:00<?, ? examples/s]

### Training Setup

In [17]:
args = TrainingArguments(
    output_dir="car_brands_image_detection",
    logging_dir='./logs',
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=10,
    learning_rate=5e-6,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=8,
    num_train_epochs=4,
    weight_decay=0.02,
    warmup_steps=50,
    remove_unused_columns=False,
    save_strategy='epoch',
    load_best_model_at_end=True,
    save_total_limit=1,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset_train,
    eval_dataset=dataset_val,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    processing_class=processor,
)

### Evaluate Validation Set before training

In [18]:
trainer.evaluate()



{'eval_loss': 2.006784200668335,
 'eval_model_preparation_time': 0.0031,
 'eval_accuracy': 0.5228571428571429,
 'eval_runtime': 126.6022,
 'eval_samples_per_second': 2.765,
 'eval_steps_per_second': 0.348}

### Train Model

In [19]:
trainer.train()



Epoch,Training Loss,Validation Loss,Model Preparation Time,Accuracy
1,No log,2.006784,0.0031,0.522857
2,No log,2.00611,0.0031,0.522857
3,No log,2.004786,0.0031,0.52
4,No log,2.002869,0.0031,0.52




TrainOutput(global_step=4, training_loss=3.2835543155670166, metrics={'train_runtime': 508.032, 'train_samples_per_second': 0.024, 'train_steps_per_second': 0.008, 'total_flos': 930345600393216.0, 'train_loss': 3.2835543155670166, 'epoch': 4.0})

In [20]:
trainer.evaluate()



{'eval_loss': 2.0028693675994873,
 'eval_model_preparation_time': 0.0031,
 'eval_accuracy': 0.52,
 'eval_runtime': 123.3039,
 'eval_samples_per_second': 2.839,
 'eval_steps_per_second': 0.357,
 'epoch': 4.0}