### Обучение классификатора картинок на примере CIFAR-100 (датасет можно изменить) сверточной сетью (самописной)

In [251]:
import torch
import numpy as np
from torch import nn
from torchvision import models
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR100
from torchvision import transforms
from tqdm import tqdm_notebook
import PIL
from matplotlib import pyplot as plt

In [52]:
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    ]
)
train = CIFAR100('./data', download=True, train=True, transform=trans)
val = CIFAR100('./data', download=True, train=False, transform=trans)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
train_loader = DataLoader(train, batch_size=128, shuffle=True)
val_loader = DataLoader(val, batch_size=128, shuffle=True)

In [4]:
train[0][0].shape

torch.Size([3, 32, 32])

In [5]:
256 * 2

512

In [6]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(3)
        self.Conv1 = nn.Conv2d(in_channels=3, out_channels=30, kernel_size=5) # tensor shape 28x28x30
        self.maxpool1 = nn.MaxPool2d(kernel_size=2) # tensor shape 14x14x30
        
        self.bn2 = nn.BatchNorm2d(30)
        self.Conv2 = nn.Conv2d(in_channels=30, out_channels=60, kernel_size=2) # tensor shape 12x12x60
        self.maxpool2 = nn.MaxPool2d(kernel_size=2) # tensor shape 6x6x60
        self.dp = nn.Dropout(0.2)
        self.fc1 = nn.Linear(6*6*60, 512)
        self.fc2 = nn.Linear(512, 256)
        self.out = nn.Linear(256, 100)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.bn1(x)
        x = self.Conv1(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        
        x = self.bn2(x)
        x = self.Conv2(x)
        x = self.relu(x)
        x = self.maxpool2(x)
        x = x.view(x.shape[0], -1)
        x = self.dp(x)
        
        x = self.fc1(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.relu(x)
        return self.out(x)

In [234]:
def train_model(model, optimizer, train_dataset, val_dataset, n_epochs=5):
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=True)
    for epoch in range(n_epochs):
        # тренировка
        for x_train, y_train in tqdm_notebook(train_loader):
            y_pred = model(x_train)
            loss = F.cross_entropy(y_pred, y_train)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        # валидация
        if epoch % 2 == 0:
            val_loss = []
            val_accuracy = []
            with torch.no_grad():
                for x_val, y_val in tqdm_notebook(val_loader):
                    y_pred = model(x_val)
                    loss = F.cross_entropy(y_pred, y_val)
                    val_loss.append(loss.numpy())
                    val_accuracy.extend((torch.argmax(y_pred, dim=-1) == y_val).numpy().tolist())

            # печатаем метрики
            print(f"Epoch: {epoch}, loss: {np.mean(val_loss)}, accuracy: {np.mean(val_accuracy)}")

In [8]:
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [9]:
train_model(model, optimizer, train, val)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for x_train, y_train in tqdm_notebook(train_loader):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)





Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for x_val, y_val in tqdm_notebook(val_loader):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch: 0, loss: 4.522462844848633, accuracy: 0.0306


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch: 2, loss: 3.9512994289398193, accuracy: 0.1003


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch: 4, loss: 3.5429115295410156, accuracy: 0.1709


### Обучение классификатора картинок на примере CIFAR-100 (датасет можно изменить) через дообучение ImageNet Resnet-50

In [96]:
resnet50 = models.resnet50(pretrained=True)

In [97]:
class ResNetModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, 100)
        self.l_relu = nn.LeakyReLU()
        self.maxpool = nn.MaxPool2d(2)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.l_relu(x)
        return self.fc2(x)

In [98]:
resnet_out = ResNetModule()

In [99]:
def set_requires_grad(model, requires_grad=False):
    for param in model.parameters():
        param.requires_grad = requires_grad

In [100]:
set_requires_grad(resnet50)
resnet50.fc = resnet_out

In [101]:
optimizer = torch.optim.Adam(resnet50.fc.parameters())

In [102]:
train_model(resnet50, optimizer, train, val)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for x_train, y_train in tqdm_notebook(train_loader):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for x_val, y_val in tqdm_notebook(val_loader):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch: 0, loss: 3.0774929523468018, accuracy: 0.2618


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch: 2, loss: 2.9268343448638916, accuracy: 0.2823


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch: 4, loss: 2.886695623397827, accuracy: 0.292


In [241]:
class DatasetWithAug(Dataset):
    def __init__(self, dataset, transform=True, **kwargs):
        self.dataset = dataset
        if transform:
            self.transform = transforms.Compose(
                                [
                                    transforms.ToPILImage(),
                                    transforms.ColorJitter(hue=.05, saturation=.05),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomRotation(20, resample=PIL.Image.BILINEAR),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                ]
                            )

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img = self.dataset[idx]
        if self.transform:
            img_transformed = self.transform(img[0])
            img = (img_transformed, img[1])
        return img

In [243]:
train_aug = DatasetWithAug(train)
val_aug = DatasetWithAug(val)

In [249]:
resnet50_aug = models.resnet50(pretrained=True)
resnet_out = ResNetModule()
set_requires_grad(resnet50_aug)
resnet50_aug.fc = resnet_out

In [250]:
train_model(resnet50_aug, optimizer, train_aug, val_aug)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for x_train, y_train in tqdm_notebook(train_loader):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for x_val, y_val in tqdm_notebook(val_loader):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch: 0, loss: 4.623915672302246, accuracy: 0.0096


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch: 2, loss: 4.619439601898193, accuracy: 0.0096


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=391.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch: 4, loss: 4.620582103729248, accuracy: 0.0107
