In [7]:
from torchvision import datasets, transforms
import torch
import time
import torch.nn.init as init
import torchvision
from torchvision import datasets, transforms, models
from copy import deepcopy
from torch.utils.data import random_split
import torch.nn.functional as F

from torch.utils.data import ConcatDataset
from torch import nn

In [8]:
class MixApp(nn.Module):
    def __init__(self, alpha=1.0, prob=0.5):
        super().__init__()
        self.alpha = alpha
        self.prob = prob

    def forward(self, img):
        if torch.rand(1) < self.prob:
            # Генерируем lambda из Beta-распределения
            lam = torch.distributions.beta.Beta(self.alpha, self.alpha).sample()

            # Горизонтальное отражение (работает с тензорами [C,H,W])
            augmented_img = torch.flip(img, dims=[2])  # dims=[2] - горизонтальный flip

            # Линейная интерполяция между изображениями
            mixed_img = lam * img + (1 - lam) * augmented_img

            return mixed_img
        return img

In [9]:
base_transform = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_train_1 = transforms.Compose([
   transforms.RandomCrop(32, padding=4),
   transforms.RandomHorizontalFlip(),
   transforms.ToTensor(),
   transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_train_2 = transforms.Compose([
   transforms.ToTensor(),
   MixApp(alpha=0.4),
   transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [10]:
original_train = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=base_transform
)

100%|██████████| 170M/170M [00:03<00:00, 47.6MB/s]


In [11]:
augmented_train_1 = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=False,
    transform=transform_train_1
)

half_size = len(augmented_train_1) // 2
augmented_train_1, _ = random_split(augmented_train_1, [half_size, len(augmented_train_1) - half_size])

In [12]:
augmented_train_2 = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=False,
    transform=transform_train_2
)

half_size = len(augmented_train_2) // 2
augmented_train_2, _ = random_split(augmented_train_2, [half_size, len(augmented_train_2) - half_size])

In [13]:
augmented_train = ConcatDataset([augmented_train_1, augmented_train_2])

In [14]:
test_set = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=base_transform
)

In [15]:
train_set = ConcatDataset([original_train, augmented_train])

In [16]:
batch_size = 128
num_epochs = 15

In [18]:
trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=2)

In [19]:
class BasicBlock(torch.nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                                     padding=1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(planes)
        self.relu = torch.nn.ReLU(inplace=True)
        self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                                     padding=1, bias=False)
        self.bn2 = torch.nn.BatchNorm2d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

In [21]:
class ResNet18(torch.nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # [b, 3, 64, 64]
        self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        # [b, 64, 32, 32]
        self.bn1 = torch.nn.BatchNorm2d(64)
        # [b, 64, 32, 32]
        self.relu = torch.nn.ReLU(inplace=True)
        # [b, 64, 32, 32]
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # [b, 64, 16, 16]
        self.layer1 = torch.nn.Sequential(
            BasicBlock(64, 64),
            BasicBlock(64, 64),
        )
        # [b, 64, 16, 16]
        self.layer2 = torch.nn.Sequential(
            BasicBlock(64, 128, stride=2, downsample=torch.nn.Sequential(
                torch.nn.Conv2d(64, 128, kernel_size=1, stride=2),
                torch.nn.BatchNorm2d(128),
            )),
            BasicBlock(128, 128),
        )
        # [b, 128, 8, 8]
        self.layer3 = torch.nn.Sequential(
            BasicBlock(128, 256, stride=2, downsample=torch.nn.Sequential(
                torch.nn.Conv2d(128, 256, kernel_size=1, stride=2),
                torch.nn.BatchNorm2d(256),
            )),
            BasicBlock(256, 256),
        )
        # [b, 256, 4, 4]
        self.layer4 = torch.nn.Sequential(
            BasicBlock(256, 512, stride=2, downsample=torch.nn.Sequential(
                torch.nn.Conv2d(256, 512, kernel_size=1, stride=2),
                torch.nn.BatchNorm2d(512),
            )),
            BasicBlock(512, 512),
        )
        # [b, 512, 2, 2]

        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        # [b, 512, 1, 1]
        self.fc = torch.nn.Linear(512, num_classes)
        # [b, 10]

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [22]:
device = torch.device('cuda')
model = ResNet18().to(device)

In [23]:
def compute_loss(X_batch, y_batch):
    logits = model(X_batch)
    return F.cross_entropy(logits, y_batch)

In [24]:
opt = torch.optim.Adam(model.parameters(), lr=0.001)

train_loss_ = []
val_accuracy_ = []

In [25]:
for epoch in range(num_epochs):
    train_loss = []
    val_accuracy = []

    start_time = time.time()
    model.train()
    for (X_batch, y_batch) in trainloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        loss = compute_loss(X_batch, y_batch)
        loss.backward()
        opt.step()
        opt.zero_grad()
        train_loss.append(loss.item())

    model.eval()
    with torch.no_grad():
      for (X_batch, y_batch) in valloader:
          X_batch, y_batch = X_batch.to(device), y_batch.to(device)
          logits = model(torch.as_tensor(X_batch, dtype=torch.float32))
          y_pred = logits.argmax(dim=1)

          accuracy = (y_batch == y_pred).float().mean()
          val_accuracy.append(accuracy)

    train_loss_tensor = torch.tensor(train_loss)
    training_loss = train_loss_tensor.mean().item()

    val_accuracy_tensor = torch.tensor(val_accuracy)
    validation_acc = val_accuracy_tensor.mean().item() * 100

    train_loss_.append(training_loss)
    val_accuracy_.append(validation_acc)
    # Then we print the results for this epoch:
    print("Epoch {} of {} took {:.3f}s".format(
        epoch + 1, num_epochs, time.time() - start_time))
    print("  training loss (in-iteration): \t{:.6f}".format(
          training_loss))
    print("  validation accuracy: \t\t\t{:.2f} %".format(
        validation_acc))

KeyboardInterrupt: 