## Imports

In [73]:
import cv2 as cv
import pandas as pd
from torch.utils.data import Dataset
from transformers import (
    AutoModel,
    EvalPrediction, 
    ViTImageProcessor,
    ViTForImageClassification,
    TrainingArguments,
    Trainer
)
from torch.nn import (
    Linear,
    Module, 
    CrossEntropyLoss
)
from typing import Optional
from huggingface_hub import PyTorchModelHubMixin
from sklearn.metrics import classification_report
import numpy as np

## Determine some constants

In [64]:
import random
import numpy as np
import torch


def seed_all(seed_value: int = 42) -> None:
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

SEED = 42
seed_all(SEED)
MODEL_NAME = 'WinKawaks/vit-tiny-patch16-224'
DATA = 'papayas_links.csv'
DEVICE =  'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

cuda


## Objects

In [65]:
class DataSet(Dataset):
    def __init__(self,
                 df: pd.DataFrame,
                 processor: ViTImageProcessor
                ):
        self.len = len(df)
        self.df = df
        self.processor = processor

    def __getitem__(self, index):
        path = df.path[index]
        try:
            img = processor(cv.imread(path), return_tensors="pt")
        except: raise ValueError(
            f'this is problem {index}, {path}'
        )
        label = df.label[index]
        return {
            'labels': label,
            'pixel_values': img['pixel_values'][0]
        }

    def __len__(self):
        return self.len

class ClassificationMetrica:
    """
    Class for computing metrics (sklearn, classification_report())
    by EvalPrediction at training and evaluation stages.

    Args:
        id2label (dict): A dictionary where keys it's classes idx
        and values it's classes names.
    """
    def __init__(self, id2label: dict[int, str]):
        self._id2label = id2label

    def __get_target_names(self, preds: np.ndarray) -> list[str]:
        unique_classes = np.unique(preds)
        return [
            self._id2label[idx]
            for idx in unique_classes
        ]

    def _compute_metrics(self, labels: np.ndarray, preds: np.ndarray) -> dict:
        target_names = self.__get_target_names(preds)
        return classification_report(
            labels, preds,
            output_dict=True,
            target_names=target_names
        )

    def _get_logs(self, metrics: dict) -> dict:
        train_logs = {}
        for main_key, main_item in metrics.items():
            if main_key.isdigit():
                main_key = self._id2label[int(main_key)]
            if isinstance(main_item, dict):
                main_item.pop("support", None)
                for key, value in main_item.items():
                    train_logs[f"{main_key}/{key}"] = value
            else:
                train_logs[main_key] = main_item
        return train_logs

    def __call__(self, pred: EvalPrediction) -> dict[str, float]:
        global predx
        predx = pred
        labels = pred.label_ids
        logits = pred.predictions
        preds = logits.argmax(-1)
        metrics = self._compute_metrics(labels, preds)
        logs = self._get_logs(metrics)
        return logs

## Model class

In [74]:
class Papaya(
    Module,
    PyTorchModelHubMixin
):
    def __init__(self, 
                 num_labels: Optional[int] = len(id2label), 
                 ):
        super().__init__()
        self.num_labels = num_labels
        print(self.num_labels)
        self.backbone = AutoModel.from_pretrained(MODEL_NAME)
        
        # Classification head
        self.classification_head = Linear(self.backbone.config.hidden_size, num_labels)
        self.loss = CrossEntropyLoss()
        
    def forward(self, 
                pixel_values: torch.Tensor, 
                labels: Optional[torch.Tensor] = None
                ) -> dict[str, torch.Tensor]:
        outputs = self.backbone(pixel_values)
        cls_embedding = outputs['pooler_output']
    
        logits = self.classification_head(cls_embedding)
        
        if labels is not None:
            loss = self.loss(logits, labels)
            return {
                "loss": loss, 
                "logits": logits
            }
        return {
            "logits": logits
        }

## Load data and model

In [84]:
model = Papaya()
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

3


Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [85]:
df = pd.read_csv(DATA)
label2id = {a: b for b, a in enumerate(df.label.unique())}
id2label = {a: b for a, b in enumerate(df.label.unique())}
print(id2label, label2id)
df.label = df.label.map(lambda x: label2id[x])
df.head()

{0: 'Anthracnose', 1: 'fruit_fly', 2: 'healthy_guava'} {'Anthracnose': 0, 'fruit_fly': 1, 'healthy_guava': 2}


Unnamed: 0,file_name,label,cluster,path
0,101_unsharp_clahe_augmented_5.png,0,test,data\test\Anthracnose\101_unsharp_clahe_augmen...
1,103_unsharp_clahe_augmented_7.png,0,test,data\test\Anthracnose\103_unsharp_clahe_augmen...
2,107_unsharp_clahe_augmented_3.png,0,test,data\test\Anthracnose\107_unsharp_clahe_augmen...
3,107_unsharp_clahe_augmented_6.png,0,test,data\test\Anthracnose\107_unsharp_clahe_augmen...
4,108_unsharp_clahe_augmented_5.png,0,test,data\test\Anthracnose\108_unsharp_clahe_augmen...


In [86]:
train_df = DataSet(df.query('cluster == "train"').reset_index(), processor)

val_df = DataSet(df.query('cluster == "val"').reset_index(), processor)

test_df = DataSet(df.query('cluster == "test"').reset_index(), processor)

In [87]:
metrica = ClassificationMetrica(id2label)

## Training

In [88]:
lr = 2e-5

training_args = TrainingArguments(
    output_dir=f"models/{MODEL_NAME}",
    num_train_epochs=3, 
    per_device_train_batch_size=128, 
    per_device_eval_batch_size=128,
    weight_decay=0.015,
    logging_dir=None,
    learning_rate=lr,
    eval_strategy="steps", 
    logging_strategy="steps",
    eval_steps=20,
    logging_steps=20,
    lr_scheduler_type="cosine",
    save_strategy="steps",
    save_total_limit=30,
    seed=SEED,
    report_to=None, 
    warmup_ratio=0.1
)

In [89]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_df,
    eval_dataset=val_df,
    compute_metrics=metrica
)

trainer.train()

Step,Training Loss,Validation Loss,Anthracnose/precision,Anthracnose/recall,Anthracnose/f1-score,Fruit Fly/precision,Fruit Fly/recall,Fruit Fly/f1-score,Healthy Guava/precision,Healthy Guava/recall,Healthy Guava/f1-score,Accuracy,Macro avg/precision,Macro avg/recall,Macro avg/f1-score,Weighted avg/precision,Weighted avg/recall,Weighted avg/f1-score
20,0.772,0.391958,0.930728,0.990548,0.959707,0.76875,0.931818,0.842466,0.9375,0.319149,0.47619,0.896689,0.878993,0.747172,0.759454,0.903252,0.896689,0.87901
40,0.2792,0.169984,0.983271,1.0,0.991565,0.86755,0.992424,0.925795,0.984848,0.691489,0.8125,0.960265,0.945223,0.894638,0.909953,0.963236,0.960265,0.957772
60,0.1642,0.132089,0.990637,1.0,0.995296,0.929577,1.0,0.963504,1.0,0.840426,0.913295,0.980132,0.973405,0.946809,0.957365,0.981127,0.980132,0.979528


TrainOutput(global_step=63, training_loss=0.39280575040786986, metrics={'train_runtime': 87.6114, 'train_samples_per_second': 90.639, 'train_steps_per_second': 0.719, 'total_flos': 0.0, 'train_loss': 0.39280575040786986, 'epoch': 3.0})

## Summary

We got high **precision** for healthy guava and high **recall** for else classes. Lets check the likage (It\`s best to do at the begining, but i don\`t get paid so didn`t bother)

In [92]:
df.duplicated('file_name').sum()

1624

In [94]:
model.backbone

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTSdpaAttention(
          (attention): ViTSdpaSelfAttention(
            (query): Linear(in_features=192, out_features=192, bias=True)
            (key): Linear(in_features=192, out_features=192, bias=True)
            (value): Linear(in_features=192, out_features=192, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=192, out_features=192, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=192, out_features=768, bias=True)
          (intermediate_act_fn): GELUActi