# Programme clean resnet18 + cifar 10



In [1]:
import torch
import numpy as np
import scipy.stats as st
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import os
import math
import time
import argparse

print("Using pytorch version: " + torch.__version__)

Using pytorch version: 1.9.0


In [2]:
import sys
def is_interactive():
    import __main__ as main
    return not hasattr(main, '__file__')
    
parser = argparse.ArgumentParser(description='Cifar quick code')
parser.add_argument("--batch-size", default=64, help="batch_size")
parser.add_argument("--device", default="cuda:0", help="device to use")
parser.add_argument("--feature-maps", default=64, help="number of feature maps")
parser.add_argument("--quiet", action="store_true", help="prevent too much display of info")
parser.add_argument("--dataset-path", default=os.environ.get("DATASETS"), help="dataset path")
if is_interactive:
    args = parser.parse_args(args = [])
else:
    args = parser.parse_args()

In [3]:
class Dataset():
    def __init__(self, data, targets, shuffle = True, transforms = [], batch_size = args.batch_size, device = args.device):
        self.data = data.to(device)
        self.targets = targets.to(device)
        assert(data.shape[0] == targets.shape[0])
        self.shuffle = shuffle
        self.length = data.shape[0]
        self.batch_size = batch_size
        self.transforms = transforms
    def __iter__(self):
        if self.shuffle:
            perm = np.random.permutation(np.arange(self.length))
            self.data = self.data[perm]
            self.targets = self.targets[perm]
        for i in range(self.length // self.batch_size):
            data, targets = self.data[i * self.batch_size : (i+1) * self.batch_size], self.targets[i * self.batch_size : (i+1) * self.batch_size]
            data = self.transforms(data)
            yield data, targets
        if self.length % self.batch_size != 0:
            data, targets = self.data[self.length - (self.length % self.batch_size):], self.targets[self.length - (self.length % self.batch_size):]
            data = self.transforms(data)
            yield data, targets
    def __len__(self):
        return self.length
    def batch_length(self):
        return self.length // self.batch_size + (0 if self.length % self.batch_size == 0 else 1)
    
def mnist(batch_size):

    transform = transforms.Compose([
           transforms.ToTensor(),
           transforms.Normalize((0.1307,), (0.3081,))
        ])

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(path, train=True, download=True, transform=transform),
        batch_size = batch_size, shuffle=True
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(path, train=False, download=True, transform=transform),
        batch_size = batch_size, shuffle=False
    )
    
    return (train_loader, train_loader, test_loader), [1, 28, 28], 10

def cifar10(data_augmentation = True):
    train_loader = datasets.CIFAR10(args.dataset_path, train = True, download = True)
    train_data = torch.stack(list(map(transforms.ToTensor(), train_loader.data)))
    train_targets = torch.LongTensor(train_loader.targets)
    test_loader = datasets.CIFAR10(args.dataset_path, train = False, download = True)
    test_data = torch.stack(list(map(transforms.ToTensor(), test_loader.data)))
    test_targets = torch.LongTensor(test_loader.targets)
    if data_augmentation:
        list_trans_train = [        
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip()
        ]
    else:
        list_trans_train = []
    list_trans = [
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ]
    train_loader = Dataset(train_data, train_targets, transforms = transforms.Compose(list_trans_train + list_trans))
    val_loader = Dataset(train_data, train_targets, transforms = transforms.Compose(list_trans))
    test_loader = Dataset(test_data, test_targets, transforms = transforms.Compose(list_trans))
    return (train_loader, val_loader, test_loader), [3, 32, 32], 10

## Définition du modèle

In [4]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, feature_maps, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = feature_maps
        self.length = len(num_blocks)
        self.conv1 = nn.Conv2d(3, feature_maps, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(feature_maps)
        layers = []
        for i, nb in enumerate(num_blocks):
            layers.append(self._make_layer(block, (2 ** i) * feature_maps, nb, stride = 1 if i == 0 else 2))            
        self.layers = nn.Sequential(*layers)
        self.linear = nn.Linear((2 ** (len(num_blocks) - 1)) * feature_maps * block.expansion, num_classes)        
        self.depth = len(num_blocks)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for i in range(len(strides)):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            if i < len(strides) - 1:
                layers.append(nn.ReLU())
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        for i in range(len(self.layers)):
            out = self.layers[i](out)
        out = F.avg_pool2d(out, out.shape[2])
        features = out.view(out.size(0), -1)
        out = self.linear(features)
        return out, features


## Routines d'entraînement et de test

In [5]:
import random
last_update = 0

def train(model, train_loader, optimizer, mixup = "None"):
    model.train()
    global last_update
    accuracy, total_loss, total_elts = 0., 0., 0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        
        #if mixup == "standard":
        #    index_mixup = torch.randperm(data.shape[0])
        #if mixup == "graph":
        #    ins = data.reshape(data.shape[0], -1)
        #    dists = torch.norm(ins.reshape(data.shape[0], 1, -1) - ins.reshape(1, data.shape[0], -1), p = 2, dim = 2)
        #    indices = torch.sort(dists, dim = 1)[1]
        #    index_mixup = torch.zeros(dists.shape[0], dtype = torch.long).to(device)
        #    for i in range(data.shape[0]):
        #        index_mixup[i] = indices[i, random.randint(0, min(k, data.shape[0] - 1))]
        if mixup == "standard":
            index_mixup = torch.randperm(data.shape[0])
        if mixup == "local":
            data_v = data.reshape(data.shape[0], -1)
            distances = torch.norm(data.reshape(data.shape[0], 1, -1) - data.reshape(1, data.shape[0], -1), p = 2, dim = 2)
            sim = torch.exp(-1 * alpha * distances)
            sim = sim / torch.sum(sim, dim = 1, keepdim = True)
            index_mixup = torch.LongTensor(data.shape[0])
            for index in range(data.shape[0]):
                probas = sim[index].numpy()
                probas = probas / np.sum(probas)
                index_mixup[index] = np.random.choice(np.arange(data.shape[0]), p=probas)
        if mixup == "standard" or mixup == "local":
            lam = random.random() / 2
            data_mixed = lam * data + (1 - lam) * data[index_mixup]
            output, _ = model(data_mixed)
            loss = lam * criterion(output, target) + (1 - lam) * criterion(output, target[index_mixup])            
        else:
            output, _ = model(data)
            loss = criterion(output, target)
            
        loss.backward()
        total_loss += loss.item() * data.shape[0]
        total_elts += data.shape[0]
        optimizer.step()
        if time.time() - last_update > 0.1 and not args.quiet:
            print("\r{:5d}/{:5d} loss:{:.5f} time:{:9d}.{:02d}".format(batch_idx + 1, train_loader.batch_length(), total_loss / total_elts, int(time.time() - start_time), int(100*(time.time() - start_time)) % 100), end = "")
            last_update = time.time()
        
    return { "train_loss" : total_loss / len(train_loader) }

def test(model, test_loader):
    model.eval()
    test_loss, accuracy = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            output, _ = model(data)
            test_loss += criterion(output, target).item() * data.shape[0]
            pred = output.argmax(dim=1, keepdim=True)
            accuracy += pred.eq(target.view_as(pred)).sum().item()
    return { "test_loss" : test_loss / len(test_loader), "test_acc" : accuracy / len(test_loader) }


In [6]:
def save_features(loader, filename, pre = False):
    all_features = []
    all_targets = []
    elements_per_class = torch.zeros(10, dtype=torch.long)
    for batch_idx, (data, target) in enumerate(loader):
        with torch.no_grad():
            output, features = model(data)
            dim = features.shape[1]
            all_features.append(features.cpu())
            all_targets.append(target.cpu())
            for i in range(target.shape[0]):
                elements_per_class[target[i]] += 1

    max_elements_per_class = torch.max(elements_per_class).item()
    features = torch.zeros((10, max_elements_per_class, dim))
    elements_per_class[:] = 0
    for i in range(len(all_features)):
        for j in range(all_features[i].shape[0]):
            features[all_targets[i][j], elements_per_class[all_targets[i][j]]] = all_features[i][j]
            elements_per_class[all_targets[i][j]] += 1
    torch.save(features, filename)

## Routine de test

In [7]:
def train_era(model, epochs, lr, loaders, mixup = False, verbose = False):
    if lr < 0:
        optimizer = torch.optim.Adam(model.parameters())
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum = 0.9, weight_decay = 5e-4)        
    train_loader, val_loader, test_loader = loaders
    for epoch in range(epochs):
        train_stats = train(model, train_loader, optimizer, mixup = mixup)
        test_stats = test(model, test_loader)
        print("\rEpoch: {:3d}, test_acc: {:.2f}%, train_loss: {:.5f}".format(epoch, 100 * test_stats["test_acc"], train_stats["train_loss"]), end = "")
        if verbose:
            print()
    print()
    return test_stats["test_acc"]

def train_complete(model, training, loaders, mixup = False):
    global start_time
    start_time = time.time()
    for (epochs, lr) in training:
        test_acc = train_era(model, epochs, lr, loaders, mixup = mixup)
    return test_acc

In [8]:
training = [(100, 0.1), (100, 0.01), (100, 0.001), (50, 0.0001)]
criterion = torch.nn.CrossEntropyLoss()
loaders, _, _ = cifar10(data_augmentation = True)

alpha = 1e-3

scores = []

for i in range(100):
    model = ResNet(BasicBlock, [2, 2, 2, 2], args.feature_maps).to(args.device)
    scores.append(train_complete(model, training, loaders, mixup = "standard"))
    print(np.mean(scores), st.norm.interval(0.95, loc = np.mean(scores), scale = st.sem(scores)))


Files already downloaded and verified
Files already downloaded and verified
    1/  782 loss:2.34462 time:        1.05
    4/  782 loss:4.64492 time:        1.30
    7/  782 loss:5.81336 time:        1.55
   10/  782 loss:5.37597 time:        1.79
   13/  782 loss:5.03286 time:        2.05
   16/  782 loss:4.75405 time:        2.30
   19/  782 loss:4.41376 time:        2.55
   22/  782 loss:4.19395 time:        2.80
   25/  782 loss:3.99099 time:        3.06
   28/  782 loss:3.81918 time:        3.32
   31/  782 loss:3.68021 time:        3.57
   34/  782 loss:3.56787 time:        3.82
   37/  782 loss:3.47329 time:        4.06
   40/  782 loss:3.38728 time:        4.34
   43/  782 loss:3.31535 time:        4.60
   46/  782 loss:3.25021 time:        4.87
   49/  782 loss:3.19172 time:        5.13
   52/  782 loss:3.13998 time:        5.39
   55/  782 loss:3.09449 time:        5.64
   58/  782 loss:3.05255 time:        5.89
   61/  782 loss:3.01369 time:        6.13
   64/  782 loss:2.97

KeyboardInterrupt: 