https://www.kaggle.com/datasets/andrewmvd/animal-faces

### Импорт всех необходимых библиотек

In [None]:
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchvision import transforms
from torchvision.models import vgg16
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tqdm import tqdm
import numpy as np
from sklearn.metrics import classification_report
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image

### Инициализация претренированной сети vgg 16 и изменение последнего слоя

In [None]:
torch.has_mps

In [None]:
torch.cuda.is_available()

In [None]:
class MyCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=3)
        self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=2, padding=1)
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv3 = torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=2, padding=1)
        self.pool3 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv4 = torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=0)
        self.linear1 = torch.nn.Linear(256, 64)
        self.linear2 = torch.nn.Linear(64, 16)
        self.linear3 = torch.nn.Linear(16, 3)
        # self.softmax = torch.nn.Softmax(dim=1)
        self.relu = torch.nn.ReLU()
    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.pool3(x)
        x = self.relu(x)
        x = self.conv4(x).squeeze(2).squeeze(2)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.linear3(x)
        # x = self.softmax(x)
        return x

In [None]:
model = MyCNN()

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

### Ресайз, нормализация и аугументация

In [None]:
my_transform_train = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(32),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
        transforms.RandomRotation((-10, 10)),
        transforms.RandomHorizontalFlip(p=0.5)
    ]
)
my_transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(32),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ]
)

### Чтение файлов в датасет и применения трансформаций

In [None]:
train_dataset = ImageFolder(root='train', transform=my_transform_train)
test_dataset = ImageFolder(root='val', transform=my_transform_test)

### Объявление даталоадеров

In [None]:
batch_size = 512
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=12, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=12, pin_memory=True)

In [None]:
model.forward(next(iter(train_dataloader))[0].to(device)).shape

In [None]:
num_epochs = 16
lr = 1e-3
optimizer = Adam(model.parameters(), lr=lr)
### Используем скедулер для уменьшения лернинг рейта
scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=2, threshold=1e-3)

### Объявление loss функции

In [None]:
def loss_function(preds, true):
    loss = nn.CrossEntropyLoss()
    return loss(preds, true)

In [None]:
all_train_losses = []
all_test_losses = []

### Обучение и визуализация процесса обучения

In [None]:
for epoch in range(num_epochs):
    train_loss = 0
    test_loss = 0
    model.train(True)
    print(f'epoch_number is {epoch}. Train')
    for (X, y) in tqdm(train_dataloader):
        model.zero_grad()
        X = X.to(device)
        y = y.to(device)
        preds = model.forward(X)
        loss = loss_function(preds, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.detach().item()/(len(train_dataloader))
    model.train(False)
    model.eval()
    print(f'epoch_number is {epoch}. Test')
    for (X, y) in tqdm(test_dataloader):
        X = X.to(device)
        y = y.to(device)
        preds = model.forward(X)
        loss = loss_function(preds, y)
        test_loss += loss.detach().item()/(len(test_dataloader))
    scheduler.step(test_loss)
    clear_output()
    all_train_losses.append(train_loss)
    all_test_losses.append(test_loss)
    print('loss train', train_loss)
    print('loss test', test_loss)
    plt.figure(figsize = (10, 6))
    plt.plot(all_train_losses, label = 'Train loss', color = 'blue')
    plt.plot(all_test_losses, label = 'Val loss', color = 'orange')
    plt.legend()
    plt.ylabel('Loss')
    plt.grid()
    plt.show()

### Вывод метрик модели

In [None]:
final_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

In [None]:
preds = []
real = []
for x, y in final_dataloader:
    preds.append(model.forward(x.to(device)).detach().cpu().numpy())
    real.append(y.numpy())

In [None]:
final_preds = np.concatenate(preds).argmax(axis=1)
real = np.concatenate(real)

In [None]:
print(classification_report(real, final_preds))

### Сохранение модели

In [None]:
torch.save(model.state_dict(), 'model.pth')

### Предсказание на тесте ( вообще функция подойдет для любой картинки, до нее нужно лишь указать путь)

In [None]:
dict_names = {
    0: 'Кошка',
    1: 'Собака',
    2: 'Дикая'
}

In [None]:
model.eval()

In [None]:
def predict_by_path(path):
    image = Image.open(path)
    image_np =  np.array(image)
    pred = model.forward(my_transform_train(image_np).to(device).unsqueeze(0)).cpu().detach().numpy()[0]
    plt.imshow(image)
    res = np.argmax(pred)
    print(dict_names[res])

In [None]:
predict_by_path('val/cat/flickr_cat_000008.jpg')

In [None]:
predict_by_path('val/dog/flickr_dog_000060.jpg')

In [None]:
predict_by_path('val/wild/flickr_wild_000470.jpg')