## 0. Подключаем необходимые библиотеки (в частности PyTorch)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
import os
import cv2
import numpy as np
from PIL import Image

## 1. Формируем датасет

In [2]:
transform = transforms.Compose([
    transforms.Resize((64,64)),    
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225] )
    ])
    
train_path = "./train/"
train_data = tv.datasets.ImageFolder(root=train_path,transform=transform)

test_path = "./test/"
test_data = tv.datasets.ImageFolder(root=test_path,transform=transform)

## 2. Создаём даталоадер

In [3]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True) 

## 3. Прописываем архитектуру

In [4]:
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.linear1 = nn.Linear(12288, 500)
        self.linear2 = nn.Linear(500, 300)
        self.linear3 = nn.Linear(300, 100)
        self.linear4 = nn.Linear(100, 50)
        self.linear5 = nn.Linear(50,2)
    
    def forward(self, x):
        x = x.view(-1, 12288)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x = F.relu(self.linear4(x))
        x = self.linear5(x)
        return x

network = NeuralNet()


## 4. Определяем функцию потерь и оптимизатор

In [5]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(network.parameters(), lr=0.001)

if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

network.to(device)

NeuralNet(
  (linear1): Linear(in_features=12288, out_features=500, bias=True)
  (linear2): Linear(in_features=500, out_features=300, bias=True)
  (linear3): Linear(in_features=300, out_features=100, bias=True)
  (linear4): Linear(in_features=100, out_features=50, bias=True)
  (linear5): Linear(in_features=50, out_features=2, bias=True)
)

## 5. Тренируем сеть

In [6]:
n = 1.0
epoch = 1
while n > 0.2:
    t_loss = 0.0
    for batch in train_loader:
        optimizer.zero_grad()
        dt, answers = batch
        dt = dt.to(device)
        answers = answers.to(device)
        pred = network(dt)
        loss = loss_fn(pred, answers)
        loss.backward()
        optimizer.step()
        t_loss += loss.data.item() * dt.size(0)
    t_loss /= len(train_loader.dataset)
    print('Epoch: {}, Training Loss: {:.2f}'.format(epoch, t_loss))
    if epoch == 12: break
    n = t_loss
    epoch += 1

Epoch: 1, Training Loss: 0.72
Epoch: 2, Training Loss: 0.67
Epoch: 3, Training Loss: 0.61
Epoch: 4, Training Loss: 0.54
Epoch: 5, Training Loss: 0.43
Epoch: 6, Training Loss: 0.37
Epoch: 7, Training Loss: 0.26
Epoch: 8, Training Loss: 0.23
Epoch: 9, Training Loss: 0.28
Epoch: 10, Training Loss: 0.49
Epoch: 11, Training Loss: 0.24
Epoch: 12, Training Loss: 0.22


## 6. Тестируем (максимальное значение ошибки - 7)

In [7]:
test_loss = 0.0
for batch in test_loader:
    dt, answers = batch
    dt = dt.to(device)
    answers = answers.to(device)
    pred = network(dt)
    loss = loss_fn(pred, answers)
    test_loss += loss.data.item() * dt.size(0)
test_loss /= len(test_loader.dataset)
print('Training Loss: {:.2f}'.format(test_loss))

Training Loss: 1.46


In [8]:
options = ['fruit','vegetable']
n = 1
for i in range(12):
    img = Image.open(f"./using/{n}.jpg") 
    img = transform(img).to(device)
    predict = F.softmax(network(img))
    predict = predict.argmax()
    print('Номер: {}, Ответ: {}'.format(n, options[predict])) 
    n += 1

Номер: 1, Ответ: fruit
Номер: 2, Ответ: vegetable
Номер: 3, Ответ: fruit
Номер: 4, Ответ: vegetable
Номер: 5, Ответ: fruit
Номер: 6, Ответ: fruit
Номер: 7, Ответ: fruit
Номер: 8, Ответ: fruit
Номер: 9, Ответ: vegetable
Номер: 10, Ответ: vegetable
Номер: 11, Ответ: vegetable
Номер: 12, Ответ: vegetable


  predict = F.softmax(network(img))


## 7. Сохраняем

In [18]:
torch.save(network, "network") 
network = torch.load("network")  