# **ЗАДАНИЕ 4** - ПРИМЕНЕНИЕ КАСТОМНЫХ АУГМЕНТАЦИЙ

## 1. Импорты

In [18]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.transforms import functional as F
from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import random
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

## 2. Дублируем классы

In [19]:
class BaseTransform:
    def __init__(self, p: float = 0.5):
        self.p = p

    def __call__(self, img: Image.Image) -> Image.Image:
        if img.mode != 'L':
            img = img.convert('L')
        # вызывает трансформацию с вероятностью p
        if random.random() < self.p:
            return self.apply(img)
        else:
            return img

    def apply(self, img: Image.Image) -> Image.Image:
        raise NotImplementedError

In [20]:
class RandomCrop(BaseTransform):
    def __init__(self, p: float = 0.5, crop_size=(50, 50)):
        super().__init__(p)
        self.crop_size = crop_size

    def apply(self, img: Image.Image) -> Image.Image:
        w, h = img.size
        cw, ch = self.crop_size
        if cw > w or ch > h:
            return img  # если crop больше исходного, то возвращаем оригинал

        left = random.randint(0, w - cw)
        top = random.randint(0, h - ch)
        right = left + cw
        bottom = top + ch
        return img.crop((left, top, right, bottom))

In [21]:
class RandomRotate(BaseTransform):
    def __init__(self, p: float = 0.5, max_angle: float = 30):
        super().__init__(p)
        self.max_angle = max_angle

    def apply(self, img: Image.Image) -> Image.Image:
        angle = random.uniform(-self.max_angle, self.max_angle)
        return img.rotate(angle)

In [22]:
class Resize(BaseTransform):
    def __init__(self, size=(28, 28)):
        self.size = size

    def __call__(self, img):
        return F.resize(img, self.size)

In [23]:
class RandomZoom(BaseTransform):
    def __init__(self, p: float = 0.5, zoom_range=(0.8, 1.2)):
        super().__init__(p)
        self.zoom_range = zoom_range

    def apply(self, img: Image.Image) -> Image.Image:
        zoom = random.uniform(*self.zoom_range)
        w, h = img.size
        new_w, new_h = int(w * zoom), int(h * zoom)
        img_zoomed = img.resize((new_w, new_h), Image.BICUBIC)

        # увеличили - обрезаем
        # уменьшили - добавляем поля
        if zoom > 1:
            left = (new_w - w) // 2
            top = (new_h - h) // 2
            img_zoomed = img_zoomed.crop((left, top, left + w, top + h))
        else:
            pad_w = (w - new_w) // 2
            pad_h = (h - new_h) // 2
            new_img = Image.new("RGB", (w, h))
            new_img.paste(img_zoomed, (pad_w, pad_h))
            img_zoomed = new_img

        return img_zoomed

In [24]:
class ToTensor:
    def __call__(self, img: Image.Image) -> torch.Tensor:
        arr = np.array(img, dtype=np.float32) / 255.0
        if arr.ndim == 2:
            arr = np.expand_dims(arr, axis=-1)
        # HWC -> CHW
        arr = np.transpose(arr, (2, 0, 1))
        return torch.tensor(arr, dtype=torch.float32)

In [25]:
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        if img.mode != 'L':
            img = img.convert('L')

        for t in self.transforms:
            img = t(img)
            if isinstance(img, Image.Image) and img.mode != 'L':
                img = img.convert('L')
        return img

## 3. Настрйока TB

In [26]:
logdir_base = "runs/fashion_augment_" + datetime.now().strftime("%Y%m%d-%H%M%S")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

device: cpu


## 4. Определяем разные трансформации

In [27]:
no_aug = Compose([
    ToTensor()
])

mild_aug = Compose([
    RandomCrop(p=0.3, crop_size=(24, 24)),
    RandomRotate(p=0.3, max_angle=15),
    RandomZoom(p=0.3, zoom_range=(0.9, 1.1)),
    Resize((28, 28)),
    ToTensor()
])

strong_aug = Compose([
    RandomCrop(p=0.8, crop_size=(24, 24)),
    RandomRotate(p=0.8, max_angle=30),
    RandomZoom(p=0.8, zoom_range=(0.8, 1.2)),
    Resize((28, 28)),
    ToTensor()
])

## 5. FashionMNIST

In [28]:
def get_dataloaders(transform, batch_size=64):
    train_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=transform
    )
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )
    return (
        DataLoader(train_data, batch_size=batch_size, shuffle=True),
        DataLoader(test_data, batch_size=batch_size, shuffle=False)
    )

## 6. Простая CNN модель

In [29]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

## 7. Функции обучения и валидации

In [30]:
def train_model(train_loader, test_loader, writer, epochs=10):
    model = SimpleCNN().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)

            optimizer.zero_grad()
            preds = model(X)
            loss = criterion(preds, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * X.size(0)
            correct += (preds.argmax(1) == y).sum().item()
            total += y.size(0)

        train_loss = total_loss / total
        train_acc = correct / total

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for X, y in test_loader:
                X, y = X.to(device), y.to(device)
                preds = model(X)
                loss = criterion(preds, y)
                val_loss += loss.item() * X.size(0)
                val_correct += (preds.argmax(1) == y).sum().item()
                val_total += y.size(0)

        val_loss /= val_total
        val_acc = val_correct / val_total

        print(f"Epoch {epoch+1}/{epochs}: "
              f"Train loss={train_loss:.4f}, acc={train_acc:.3f}, "
              f"Val loss={val_loss:.4f}, acc={val_acc:.3f}")

        writer.add_scalars("Loss", {"train": train_loss, "val": val_loss}, epoch)
        writer.add_scalars("Accuracy", {"train": train_acc, "val": val_acc}, epoch)

    return model

## 8. Эксперименты

In [31]:
experiments = {
    "no_aug": no_aug,
    "mild_aug": mild_aug,
    "strong_aug": strong_aug,
}

for name, transform in experiments.items():
    print(f"\n=== Эксперимент: {name} ===")
    writer = SummaryWriter(log_dir=f"{logdir_base}/{name}")

    train_loader, test_loader = get_dataloaders(transform)
    model = train_model(train_loader, test_loader, writer, epochs=10)

    writer.close()

print("Обучение всех экспериментов завершено. Запустите TensorBoard для просмотра:")
print("tensorboard --logdir=runs")


=== Эксперимент: no_aug ===
Epoch 1/10: Train loss=0.4998, acc=0.819, Val loss=0.3457, acc=0.872
Epoch 2/10: Train loss=0.3246, acc=0.882, Val loss=0.2921, acc=0.895
Epoch 3/10: Train loss=0.2815, acc=0.897, Val loss=0.2718, acc=0.900
Epoch 4/10: Train loss=0.2529, acc=0.908, Val loss=0.2528, acc=0.908
Epoch 5/10: Train loss=0.2267, acc=0.917, Val loss=0.2493, acc=0.906
Epoch 6/10: Train loss=0.2074, acc=0.924, Val loss=0.2302, acc=0.914
Epoch 7/10: Train loss=0.1913, acc=0.930, Val loss=0.2314, acc=0.917
Epoch 8/10: Train loss=0.1732, acc=0.936, Val loss=0.2217, acc=0.921
Epoch 9/10: Train loss=0.1608, acc=0.939, Val loss=0.2330, acc=0.918
Epoch 10/10: Train loss=0.1493, acc=0.944, Val loss=0.2283, acc=0.921

=== Эксперимент: mild_aug ===
Epoch 1/10: Train loss=0.6292, acc=0.763, Val loss=0.3796, acc=0.859
Epoch 2/10: Train loss=0.4526, acc=0.832, Val loss=0.3356, acc=0.876
Epoch 3/10: Train loss=0.4050, acc=0.848, Val loss=0.3216, acc=0.882
Epoch 4/10: Train loss=0.3804, acc=0.856, 