In [None]:
# classification
# https://grok.com/chat/94d35e69-f6b6-4402-9107-14616814fa5e

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
import os

# Параметры
num_epochs = 5
batch_size = 16
learning_rate = 2e-5
num_classes = 3  # Укажите количество классов в вашем датасете
data_dir = "dataset"  # Путь к датасету
model_name = "google/vit-base-patch16-224"  # Предобученная модель ViT
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Подготовка данных
# Трансформации для предобработки изображений
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),  # ViT ожидает изображения 224x224
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=feature_extractor.image_mean, std=feature_extractor.image_std
        ),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=feature_extractor.image_mean, std=feature_extractor.image_std
        ),
    ]
)

# Загрузка датасета
train_dataset = datasets.ImageFolder(
    os.path.join(data_dir, "train"), transform=train_transforms
)
val_dataset = datasets.ImageFolder(
    os.path.join(data_dir, "val"), transform=val_transforms
)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)

# 2. Загрузка модели
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=num_classes,  # Количество классов
    ignore_mismatched_sizes=True,  # Игнорировать несоответствие размеров классификатора
)

# Перенос модели на устройство
model = model.to(device)

# 3. Настройка оптимизатора и функции потерь
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()


# 4. Функция для обучения
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    best_acc = 0.0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        # Обучение
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_acc:.2f}%")

        # Валидация
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100 * correct / total
        print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%")

        # Сохранение лучшей модели
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "best_vit_model.pth")
            print("Saved best model with Val Accuracy: {:.2f}%".format(best_acc))


# 5. Запуск обучения
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)

# 6. Сохранение финальной модели
torch.save(model.state_dict(), "final_vit_model.pth")
print("Training completed!")