### 1. Загрузка данных

In [None]:
from datasets import load_dataset
import os

In [None]:
DATASET_PATH = '../dataset'
TRAIN_SUBSET_PATH = os.path.join(DATASET_PATH, 'train')
VALIDATION_SUBSET_PATH = os.path.join(DATASET_PATH, 'validation')
TEST_SUBSET_PATH = os.path.join(DATASET_PATH, 'test')

In [None]:
os.makedirs(TRAIN_SUBSET_PATH, exist_ok=True)
os.makedirs(VALIDATION_SUBSET_PATH, exist_ok=True)
os.makedirs(TEST_SUBSET_PATH, exist_ok=True)

In [None]:
dataset = load_dataset('Bingsu/Cat_and_Dog', keep_in_memory=True)

In [None]:
for i, sample in enumerate(dataset['train']):
    classname = 'cat' if sample['labels'] == 0 else 'dog'
    filename = f'train_{i}.{classname}.jpeg'
    sample['image'].save(os.path.join(TRAIN_SUBSET_PATH, filename))

In [None]:
for i, sample in enumerate(sorted(dataset['test'], key=lambda x: x['labels'])):
    classname = 'cat' if sample['labels'] == 0 else 'dog'
    if i % 2 == 0:
        filename = f'validation_{i}.{classname}.jpeg'
        sample['image'].save(os.path.join(VALIDATION_SUBSET_PATH, filename))
    else:
        filename = f'test_{i}.{classname}.jpeg'
        sample['image'].save(os.path.join(TEST_SUBSET_PATH, filename))

In [None]:
print(f'Train subset size: {len(os.listdir(TRAIN_SUBSET_PATH))}')
print(f'Validation subset size: {len(os.listdir(VALIDATION_SUBSET_PATH))}')
print(f'Test subset size: {len(os.listdir(TEST_SUBSET_PATH))}')

In [None]:
del dataset

### 2. Подготовка датасетов

In [None]:
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

In [None]:
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [None]:
class DatasetWrapper(Dataset):
    def __init__(self, data_path: str, preprocess: transforms.Compose):
        super().__init__()
        self._images = sorted(os.listdir(data_path))
        self._images_path = data_path
        self._preprocess = preprocess

    def __len__(self):
        return len(self._images)
    
    def __getitem__(self, index: int):
        image_path = os.path.join(self._images_path, self._images[index])
        image = Image.open(image_path).convert('RGB')
        image = self._preprocess(image)
        label = 0 if 'cat' in image_path else 1
        return (image, label)

In [None]:
train_dataset = DatasetWrapper(TRAIN_SUBSET_PATH, preprocess)
validation_dataset = DatasetWrapper(VALIDATION_SUBSET_PATH, preprocess)
test_dataset = DatasetWrapper(TEST_SUBSET_PATH, preprocess)

### 3. Подготовка даталоудеров

In [None]:
from torch.utils.data import DataLoader 

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=128)
test_dataloader = DataLoader(test_dataset, batch_size=128)

### 4. Архитектура нейросети

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size=(5, 5), stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size=(5, 5), stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size=(3, 3), padding=1)

        self.fc1 = nn.Linear(in_features= 64 * 6 * 6, out_features=500)
        self.fc2 = nn.Linear(in_features=500, out_features=50)
        self.fc3 = nn.Linear(in_features=50, out_features=2)


    def forward(self, X):
        X = F.relu(self.conv1(X))
        X = F.max_pool2d(X, 2)

        X = F.relu(self.conv2(X))
        X = F.max_pool2d(X, 2)

        X = F.relu(self.conv3(X))
        X = F.max_pool2d(X, 2)

        X = X.view(X.shape[0], -1)
        X = F.relu(self.fc1(X))
        X = F.relu(self.fc2(X))
        X = self.fc3(X)

        return X


### 5. Обучение модели

In [None]:
import torch
from torch import optim
import matplotlib.pyplot as plt

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device

In [None]:
model = NeuralNetwork()
model.to(device)

In [None]:
losses = []
accuracies = []
epoches = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

In [None]:
for epoch in range(epoches):
    batch_accuracies = []
    batch_losses = []

    for X, y in train_dataloader:
        _X = X.to(device)
        _y = y.to(device)

        probs = model(_X)
        batch_loss = loss_fn(probs, _y)
        batch_losses.append(batch_loss)

        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

        batch_accuracy = (probs.argmax(dim=1) == _y).float().mean()
        batch_accuracies.append(batch_accuracy)
        
        print('.', end='', flush=True)

    loss = sum(batch_losses) / len(train_dataloader)
    losses.append(loss.detach().numpy().item())
    accuracy = sum(batch_accuracies) / len(train_dataloader)
    accuracies.append(accuracy.detach().numpy().item())

    print("\nEpoch: {}, train loss: {:.3f}, train accuracy: {:.3f}".format(epoch, loss, accuracy))

    with torch.no_grad():
        batch_val_accuracies = []
        batch_val_losses = []

        for val_X, val_y in validation_dataloader:
            _val_X = val_X.to(device)
            _val_y = val_y.to(device)

            val_probs = model(_val_X)
            batch_val_loss = loss_fn(val_probs, _val_y)
            batch_val_losses.append(batch_val_loss)
          
            batch_val_accuracy = (val_probs.argmax(dim=1) == _val_y).float().mean()
            batch_val_accuracies.append(batch_val_accuracy)

        val_loss = sum(batch_val_losses) / len(validation_dataloader)
        val_accuracy = sum(batch_val_accuracies) / len(validation_dataloader)

        print("Epoch: {}, validation loss: {:.3f}, validation accuracy: {:.3f}\n".format(epoch, val_loss, val_accuracy))

    scheduler.step()

In [None]:
fig = plt.figure(figsize=(15, 4))
ax1 = fig.add_subplot(121)
ax1.set_title('Loss')
ax1.plot(list(range(epoches)), losses)
ax2 = fig.add_subplot(122)
ax2.set_title('Accuracy')
ax2.plot(list(range(epoches)), accuracies)
plt.show()

### 6. Тестирование модели

In [None]:

with torch.no_grad():
    batch_test_accuracies = []

    for test_X, test_y in test_dataloader:
        _test_X = test_X.to(device)
        _test_y = test_y.to(device)
        test_probs = model(_test_X)

        batch_test_accuracy = (test_probs.argmax(dim=1) == _test_y).float().mean()
        batch_test_accuracies.append(batch_test_accuracy)

    test_accuracy = sum(batch_test_accuracies) / len(test_dataloader)
    print('Test accuracy: {:.3f}%'.format(test_accuracy * 100))

### 7. Сохранение весов модели

In [None]:
MODEL_PATH = '../model'

In [None]:
torch.save(model.state_dict(), os.path.join(MODEL_PATH, 'model_state_dict.pth'))

In [None]:
model = NeuralNetwork()
model.load_state_dict(torch.load(os.path.join(MODEL_PATH, 'model_state_dict.pth')))
model.eval()

### 8. Сохранение модели в ONNX

In [None]:
dummy_input = torch.rand(1, 3, 224, 224, device=device)
torch.onnx.export(
    model,
    dummy_input,
    os.path.join(MODEL_PATH, 'model.onnx'),
    input_names=['input'],
    output_names=['output'],
)