In [1]:
import os
import io
import json
import torch
import mlflow
import pandas as pd
from pathlib import Path
from datasets import Dataset
from datetime import datetime
from PIL import Image as PILImage
from api.feedback.models import Feedback
from transformers import TrainingArguments, Trainer
from transformers import ViTImageProcessor, ViTForImageClassification
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

### Training Data (mock images and feedback list)

In [2]:
IMG_MOCK_DIR = Path.cwd().parent / "api" / "mock"
audi = IMG_MOCK_DIR / "audi.jpg"
bmw = IMG_MOCK_DIR / "bmw.jpg"
vw = IMG_MOCK_DIR / "vw.jpg"
feedback = [
    Feedback(
        image=audi.read_bytes(),
        label="Audi",
        correct=True,
        correct_label="Audi",
    ),
    Feedback(
        image=bmw.read_bytes(),
        label="Audi",
        correct=False,
        correct_label="BMW",
    ),
    Feedback(
        image=vw.read_bytes(),
        label="BMW",
        correct=False,
        correct_label="Volkswagen",
    ),
]

In [3]:
image_bytes = [feedback_item.image for feedback_item in feedback]
images = [PILImage.open(io.BytesIO(image_bytes)) for image_bytes in image_bytes]
labels = [feedback_item.correct_label for feedback_item in feedback]
df_train = pd.DataFrame({"image": image_bytes, "label": labels})

### Validation Data

In [4]:
def image_to_bytes(image_path: "str") -> bytes:
    with open(image_path, "rb") as img_file:
        return img_file.read()

In [5]:
df_val = pd.read_csv("validation.csv")
df_val["image"] = df_val["image_path"].apply(image_to_bytes)
df_val = df_val.drop(columns=["image_path"])

### Helper Functions

In [11]:
def preprocess(example) -> dict:
    """
    preprocessing necessary as FeedbackAPI provides image as bytes and not as files
    :param example: entry in a Dataset object
    :return: image (preprocessed) and label (preserved)
    """
    image = PILImage.open(io.BytesIO(example["image"])).convert("RGB")
    inputs = processor(images=image, return_tensors="np")
    label_id = label2id[example["label"]]
    return {
        "pixel_values": inputs["pixel_values"].squeeze(),
        "label": label_id
    }


# Evaluation metric
def compute_metrics(eval_pred) -> dict:
    predictions = eval_pred.predictions
    label_ids = eval_pred.label_ids
    predicted_labels = predictions.argmax(axis=1)
    return {"accuracy": accuracy_score(label_ids, predicted_labels)}


# Data collator
def collate_fn(examples):
    pixel_values = torch.stack([torch.tensor(example["pixel_values"]) for example in examples])
    labels = torch.tensor([example['label'] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

### Model

In [12]:
with open("config.json", "r") as file:
    config = json.load(file)
id2label = config.get("id2label")
label2id = config.get("label2id")

# load model
model_id = "dima806/car_brands_image_detection"
processor = ViTImageProcessor.from_pretrained(model_id)
model = ViTForImageClassification.from_pretrained(
    model_id,
    use_safetensors=True,
    trust_remote_code=True
)

In [13]:
dataset_train = Dataset.from_pandas(df_train)
dataset_train = dataset_train.map(preprocess)

dataset_val = Dataset.from_pandas(df_val)
dataset_val = dataset_val.map(preprocess)

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Map:   0%|          | 0/350 [00:00<?, ? examples/s]

### Training Setup with ML Flow tracking

make sure you started the ml flow server beforehand.

run `mlflow server --host 127.0.0.1 --port 8080` in a terminal before continuing

In [14]:
args = TrainingArguments(
    output_dir="car_brands_image_detection",
    logging_dir="./logs",
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=1,
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    weight_decay=0.02,
    warmup_steps=50,
    remove_unused_columns=False,
    save_strategy="epoch",
    load_best_model_at_end=True,
    save_total_limit=1,
    report_to="mlflow"
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset_train,
    eval_dataset=dataset_val,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    processing_class=processor,
)

In [15]:
mlflow.set_tracking_uri("http://127.0.0.1:8080")
mlflow.set_experiment("car_brand_classification")

# train
trainer.train()

# log metrics
eval_results = trainer.evaluate()
for key, value in eval_results.items():
    mlflow.log_metric(key, float(value))

# save and log model
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = os.path.join("sandbox", "checkpoints", f"model_{timestamp}")
os.makedirs(save_path, exist_ok=True)

trainer.save_model(save_path)
mlflow.log_param("model_save_path", save_path)
mlflow.log_artifacts(save_path, artifact_path="model")



Epoch,Training Loss,Validation Loss,Accuracy
1,3.5265,2.017773,0.525714


