# Обучение модели AIImageDetector

## Настройка окружения

In [None]:
# @title  **Проверка что cuda работает**
import torch

print("Cuda доступен" if torch.cuda.is_available() else "Cuda НЕ ДОСТУПЕН")

In [None]:
# @title  **Установка проекта**
!pip install python-dotenv
!git clone https://github.com/Ycalk/AIImageDetector.git
%cd AIImageDetector/
!git checkout research
%cd additional/research/src/research

In [None]:
# @title Установка переменных окружения
import os

kaggle_username = ""  # @param {type:"string"}
kaggle_key = ""  # @param {type:"string"}

os.environ["KAGGLE_USERNAME"] = kaggle_username
os.environ["KAGGLE_KEY"] = kaggle_key

## Обучение модели

In [None]:
# @title Обучение

import utils
from matplotlib import pyplot as plt
import seaborn as sns

# @markdown Параметры модели
conv_blocks_number = 4  # @param {type:"integer", default: 4}
base_channels = 64  # @param {type:"integer", default: 64}
dropout_probability = (
    0.3  # @param {type:"number", default: 0.3, min:0, max:1, step:0.001}
)
classifier_layers_count = 2  # @param {type:"integer", default: 2}

# @markdown Параметры датасета
size = 200000  # @param {type:"integer", default: 100000}
images_ratio = 0.5  # @param {type:"number", default: 0.5, min:0, max:1, step:0.001}
ratio = 0.7  # @param {type:"number", default: 0.7, min:0, max:1, step:0.01}

# @markdown Параметры обучения
learning_rate = (
    0.001  # @param {type:"number", default: 0.001, min:0, max:1, step:0.0001}
)
epochs = 20  # @param {type:"integer", default: 10}
batch_size = 16  # @param {type:"integer", default: 32}
weight_decay = (
    0.001  # @param {type:"number", default: 0.001, min:0, max:1, step:0.0001}
)

model = utils.Model(
    conv_blocks_number=conv_blocks_number,
    base_channels=base_channels,
    dropout_probability=dropout_probability,
    classifier_layers_count=classifier_layers_count,
)

dataset = utils.ArtiFactDataset.get_merged_dataset(
    size=size, ratio=ratio, images_ratio=images_ratio
)

trainer = utils.Trainer(
    model=model,
    dataset=dataset,
    epochs=epochs,
    batch_size=batch_size,
    lr=learning_rate,
    device="cuda",
    weight_decay=weight_decay,
    checkpoint_dir="checkpoints",
)
result = trainer()
model.eval()


## Получение результатов

In [None]:
# @title Числовые результаты

import time

print(f"Точность на тренировочной выборке: {result[-1].train_accuracy:.2f}")
print(f"Точность на валидационной выборке: {result[-1].val_accuracy:.2f}")

in_tensor = torch.rand(1, 3, 256, 256).to("cuda")
start_time = time.time()
model(in_tensor)
print(f"Время инференса: {(time.time() - start_time):.3f}")

In [None]:
# @title Графики Loss и Accuracy

plot_title = ""  # @param {type:"string"}

_, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot([r.train_accuracy for r in result], label="Train Accuracy")
axes[0].plot([r.val_accuracy for r in result], label="Validation Accuracy")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Accuracy")
axes[0].set_title("Accuracy")
axes[0].legend()

axes[1].plot([r.train_loss for r in result], label="Train Loss")
axes[1].plot([r.val_loss for r in result], label="Validation Loss")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Loss")
axes[1].set_title("Loss")
axes[1].legend()

plt.suptitle(plot_title)
plt.tight_layout()
plt.show()

In [None]:
# @title Confusion Matrix Train и Validate

plot_title = ""  # @param {type:"string"}

_, axes = plt.subplots(1, 2, figsize=(14, 6))
sns.heatmap(result[-1].train_cm, annot=True, fmt="d", cmap="Blues", ax=axes[0])
axes[0].set_title("Train Confusion Matrix")
axes[0].set_xlabel("Predicted")
axes[0].set_ylabel("True")

sns.heatmap(result[-1].val_cm, annot=True, fmt="d", cmap="Greens", ax=axes[1])
axes[1].set_title("Validation Confusion Matrix")
axes[1].set_xlabel("Predicted")
axes[1].set_ylabel("True")

plt.suptitle(plot_title)
plt.tight_layout()
plt.show()

In [None]:
# @title Grad-Cam generated и real изображения

plot_title = ""  # @param {type:"string"}

generated_dataset = utils.ArtiFactDataset(images_count=1, ratio=0.0)
real_dataset = utils.ArtiFactDataset(images_count=1, ratio=1)

generated_image, _ = generated_dataset[0]
real_image, _ = real_dataset[0]

generated_result, generated_score = utils.GradCam(
    model.to("cuda"), model.features[-1][0], generated_image
)()

real_result, real_score = utils.GradCam(
    model.to("cuda"), model.features[-1][0], real_image
)()

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].imshow(generated_result)
axes[0].axis("off")
axes[0].set_title(f"Generated\nPred = {generated_score:.3f} (Label: 1)")

axes[1].imshow(real_result)
axes[1].axis("off")
axes[1].set_title(f"Real\nPred = {real_score:.3f} (Label: 0)")

plt.suptitle(plot_title)
plt.tight_layout()
plt.show()
