## Библиотеки

In [2]:
!pip install evaluate -q

In [20]:
import os
import random
import datasets
import cv2
import json
import pandas as pd
from pathlib import Path
from collections import defaultdict
from PIL import Image

import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import top_k_accuracy_score

from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from transformers.trainer_callback import EarlyStoppingCallback
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from datasets import Dataset as HFDataset, DatasetDict, load_dataset
from tqdm import tqdm
import evaluate
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

In [4]:
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

data_dir = '/kaggle/input/stanford-dogs-dataset/images/Images'

## Препроцессинг

In [5]:
# Создание словарей для сопоставления ID и меток
ID2LABEL = {}
LABEL2ID = {}
for idx, image_filename in enumerate(os.listdir(data_dir)):
    if not image_filename.endswith('.xlsx'):
        label = image_filename.split('-')[1].lower()
        ID2LABEL[idx] = label
        LABEL2ID[label] = idx
NUM_LABELS = len(ID2LABEL)
print(f"NUM_LABELS: {NUM_LABELS}\n")

NUM_LABELS: 120



In [6]:
# Функция для загрузки изображений и меток
def load_images(data_dir):
    categories = os.listdir(data_dir)
    images = []
    labels = []

    for category in tqdm(categories):
        category_path = os.path.join(data_dir, category)
        all_images = os.listdir(category_path)

        for image_name in all_images:
            image_path = os.path.join(category_path, image_name)
            label = category.split('-')[1].lower()

            images.append(image_path)
            labels.append(label)

    return images, labels

In [7]:
# Загрузка путей к изображениям и меток
images, labels = load_images(data_dir)

100%|██████████| 120/120 [00:02<00:00, 42.45it/s]


In [8]:
train_paths, val_paths, train_labels, val_labels = train_test_split(
    images, 
    labels, 
    test_size=0.2, 
    random_state=SEED, 
    stratify=labels
)

In [9]:
# Создание объекта Dataset
train_dataset = datasets.Dataset.from_dict(
    mapping = {'image': train_paths, 'labels': train_labels,},
    features = datasets.Features({
        'image': datasets.Image(),
        'labels': datasets.features.ClassLabel(names=list(LABEL2ID.keys())),
    })
)
# Создание объекта Dataset
val_dataset = datasets.Dataset.from_dict(
    mapping = {'image': val_paths, 'labels': val_labels,},
    features = datasets.Features({
        'image': datasets.Image(),
        'labels': datasets.features.ClassLabel(names=list(LABEL2ID.keys())),
    })
)

In [10]:
# Инициализация процессора изображений
processor = ViTImageProcessor.from_pretrained("google/vit-large-patch32-384")
# Загрузка модели ViT для классификации изображений
model = ViTForImageClassification.from_pretrained(
    'google/vit-large-patch32-384',
    num_labels=len(LABEL2ID),
    id2label=ID2LABEL,
    label2id=LABEL2ID,
    ignore_mismatched_sizes=True
)

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.23G [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-large-patch32-384 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([120, 1024]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([120]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
def train_transforms(example):
    # Применение процессора ко всем изображениям в примере
    images = [image.convert('RGB') for image in example['image']]
    example.update(processor(images, return_tensors='pt'))
    return example

def eval_transforms(example):
    # Применение процессора ко всем изображениям в примере
    images = [image.convert('RGB') for image in example['image']]
    example.update(processor(images, return_tensors='pt'))
    return example

In [12]:
# Применение трансформаций к данным
train_dataset.set_transform(train_transforms)
val_dataset.set_transform(eval_transforms)

## Обучение

In [13]:
# Инициализация метрик для оценки
accuracy_metric = evaluate.load('accuracy')
f1_metric = evaluate.load('f1')

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.77k [00:00<?, ?B/s]

In [14]:
# Функция для вычисления метрик
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    accuracy = accuracy_metric.compute(predictions=preds, references=labels)
    top3_accuracy = top_k_accuracy_score(labels, logits, k=3)
    top5_accuracy = top_k_accuracy_score(labels, logits, k=5)
    f1 = f1_metric.compute(predictions=preds, references=labels, average='macro')
    
    return {
        **accuracy,
        'top3_accuracy': top3_accuracy,
        'top5_accuracy': top5_accuracy,
        **f1,
    }

In [15]:
# Функция для объединения данных в батчи
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [16]:
# Аргументы для тренировки модели
training_args = TrainingArguments(
    seed=SEED,
    output_dir='./results',
    optim='adamw_torch',
    num_train_epochs=30,
    gradient_accumulation_steps=4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    save_strategy='epoch',
    eval_strategy='epoch',
    load_best_model_at_end=True,
    save_total_limit=2,
    report_to='wandb',
    learning_rate=1e-5,
    remove_unused_columns=False,
    metric_for_best_model='accuracy',
    fp16=True,
    lr_scheduler_type='cosine',  # Косинусное уменьшение lr
    warmup_ratio=0.1,  # Warmup для lr
    dataloader_pin_memory=True,

)

In [17]:
# Инициализация объекта Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    tokenizer=processor,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

ValueError: Tried to use `fp16` but it is not supported on cpu

In [18]:
#trainer.train()

In [59]:
with open("/kaggle/input/training-state/trainer_state.json", "r") as file:
    trainer_state = json.load(file)

# Извлечение истории метрик
history = trainer_state['log_history']

# Инициализация списка для хранения данных
data = []

# Проход по всей истории и сбор метрик для каждой эпохи
for entry in history:
    row = {
        "Validation Loss": entry.get('eval_loss', None),
        "Validation Accuracy": entry.get('eval_accuracy', None),
        "Validation F1 Score": entry.get('eval_f1', None),
        "Validation Top-3 Accuracy": entry.get('eval_top3_accuracy', None),
        "Validation Top-5 Accuracy": entry.get('eval_top5_accuracy', None)
    }
    data.append(row)

# Создание DataFrame из списка
df_metrics = pd.DataFrame(data)

# Фильтрация столбцов для вывода только необходимых метрик
df_metrics_filtered = df_metrics[[
    "Validation Loss", 
    "Validation Accuracy", 
    "Validation Top-3 Accuracy", 
    "Validation Top-5 Accuracy", 
    "Validation F1 Score"
]]
df_metrics_filtered = df_metrics_filtered.dropna(subset=['Validation Loss'], axis=0).reset_index(drop=True)

In [58]:
df_metrics_filtered

Unnamed: 0,Validation Loss,Validation Accuracy,Validation Top-3 Accuracy,Validation Top-5 Accuracy,Validation F1 Score
0,4.544873,0.030612,0.08139,0.12585,0.022641
1,3.117323,0.661808,0.851069,0.906463,0.628339
2,1.23854,0.872935,0.972546,0.983965,0.866364
3,0.56788,0.908163,0.98275,0.98931,0.905838
4,0.406856,0.910107,0.982507,0.989796,0.907745
5,0.360546,0.910836,0.981293,0.989067,0.907947
6,0.34863,0.914966,0.98105,0.988338,0.91177
