# Загрузка данных

In [1]:
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
                                         shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100.0%


Extracting ./data\cifar-10-python.tar.gz to ./data


NameError: name 'torch' is not defined

# Обработка данных

In [None]:
import torch
from sklearn.model_selection import train_test_split

train_data, val_data = train_test_split(trainset, test_size=0.2, random_state=42)
trainloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(val_data, batch_size=32, shuffle=False, num_workers=2)

# Определение модели

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(2304, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Обучение модели

In [None]:
import torch.optim as optim

def train_model(model, trainloader, valloader, epochs=10):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in trainloader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(trainloader)}')
        validate_model(model, valloader)

def validate_model(model, valloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in valloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Validation Accuracy: {100 * correct / total}%')

# Оценка модели

In [None]:
def test_model(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Test Accuracy: {100 * correct / total}%')

# Сбор артефактов

In [None]:
import os

model_path = './best_model.pth'
torch.save(model.state_dict(), model_path)
print(f'Model saved to {model_path}')

# Логирование и настройка

In [None]:
import logging

logging.basicConfig(level=logging.INFO, filename='training.log', filemode='w',
                    format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger()

logger.info('This is an info message')

# Управление зависимостями

In [None]:
[tool.poetry]
name = "ml-training-pipeline"
version = "0.1.0"
description = "A basic ML training pipeline for CIFAR-10"
authors = ["Your Name <your.email@example.com>"]

[tool.poetry.dependencies]
python = "^3.9"
torch = "^1.9.0"
torchvision = "^0.10.0"
scikit-learn = "^0.24.2"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"