In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms.functional as TF
from torchvision import datasets, transforms

from collections import defaultdict, deque
import datetime
import time
import wandb

Some utilities

In [None]:
class SmoothedValue:
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
        )


class MetricLogger:
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(f"{name}: {str(meter)}")
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ""
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt="{avg:.4f}")
        data_time = SmoothedValue(fmt="{avg:.4f}")
        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
        if torch.cuda.is_available():
            log_msg = self.delimiter.join(
                [
                    header,
                    "[{0" + space_fmt + "}/{1}]",
                    "eta: {eta}",
                    "{meters}",
                    "time: {time}",
                    "data: {data}",
                    "max mem: {memory:.0f}",
                ]
            )
        else:
            log_msg = self.delimiter.join(
                [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
            )
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(
                        log_msg.format(
                            i,
                            len(iterable),
                            eta=eta_string,
                            meters=str(self),
                            time=str(iter_time),
                            data=str(data_time),
                            memory=torch.cuda.max_memory_allocated() / MB,
                        )
                    )
                else:
                    print(
                        log_msg.format(
                            i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
                        )
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print(f"{header} Total time: {total_time_str}")
        
def inverse_normalize(img, mean=[0.1307,], std=[0.3081,]):
    img = img * torch.tensor(std)
    img = img + torch.tensor(mean)
    return img

## Classification
Classification is the most applied task in computer vision and is performed with convolutional neural networks (CNN). The basic idea for classification is to first extract features of an input image an then try to classify those features in order to predict the class of that image.  
In this tutorial, we will apply image classification on the MNIST dataset (handwritten numbers).  
The first task is to implement the `ConvNet` represented in the figure below.
![ConvNet.png](Images/convnet.svg)

In [None]:
class ClassNet(nn.Module):
    def __init__(self, output_dim):
        super(ClassNet, self).__init__()
        
        self.feature_extractor = nn.Sequential(

        )
        
        self.classifier = nn.Sequential(
        )
        
    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

### Definition of the data loader
We can retrieve the dataset from `PyTorch` and then define our `DataLoader`.

In [None]:
train_mnist = datasets.MNIST("/scratch/users/rvandeghen/mnist/", train=True, download=True,
                             transform=transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize((0.1307,), (0.3081,))
                             ]))
test_mnist = datasets.MNIST("/scratch/users/rvandeghen/mnist/", train=False, download=True,
                            transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                            ]))

train_loader = DataLoader(train_mnist, batch_size=32, num_workers=1, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_mnist, batch_size=256, num_workers=1, shuffle=False, pin_memory=True)

Content of the data

In [None]:
image, label = train_mnist[1]

In [None]:
image.shape

In [None]:
TF.to_pil_image(TF.resize(inverse_normalize(image), (256, 256)))

In [None]:
label

### Definition of the setup
Here we have to define the `ConvNet` model that we will use. We also need to define our loss function and an optimizer.

In [None]:
def compute_accuracy(y_pred, y_true):
    pred = y_pred.argmax(1, keepdim=True)
    correct = pred.eq(y_true.view_as(pred)).sum()
    accuracy = correct.float()/pred.shape[0]
    return accuracy

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = ClassNet(10).to(device)

optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
criterion = nn.CrossEntropyLoss()

Get the number of parameters of the model

In [None]:
# whole model
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

# feature extractor
print(sum(p.numel() for p in model.feature_extractor.parameters() if p.requires_grad))

# classifier
print(sum(p.numel() for p in model.classifier.parameters() if p.requires_grad))

# manually


In [None]:
wandb.init(project="cv_tuto",
          config={"batch_size": 256,
                  "optimizer": "sgd",
                  "lr": 1e-3,
                  "dataset": "mnist"
                 })

### Definition of a training loop
We can define a training loop where we pass sequentially the training data and we process the forward and backward passes.

In [None]:
def train_one_epoch(model, optimizer, train_loader, epoch, device, criterion):
    model.train()
    
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", SmoothedValue(window_size=10, fmt="{value}"))
    
    header = f"Epoch: [{epoch}]"
    
    for i, (inputs, targets) in enumerate(metric_logger.log_every(train_loader, 200, header)):
        start_time = time.time()
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        
        loss = criterion(outputs, targets.to(device))
        loss.backward()
        optimizer.step()
        
        batch_size = inputs.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
        
        if (i % 20) == 0:
            wandb.log({"train_loss": loss.item()})
    
def test(model, test_loader, device, critetion):
    model.eval()
    
    metric_logger = MetricLogger(delimiter="  ")
    
    header = f"Test: "
    
    with torch.no_grad():
        for inputs, targets in metric_logger.log_every(test_loader, 10, header):
            outputs = model(inputs.to(device))

            loss = criterion(outputs, targets.to(device))

            acc = compute_accuracy(outputs, targets.to(device))
        
            batch_size = inputs.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters["acc1"].update(acc.item(), n=batch_size)
            
    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f}")
    wandb.log({"test_acc1": 100*metric_logger.acc1.global_avg})
    
def train(num_epoch=10):
    for epoch in range(num_epoch):
        train_one_epoch(model, optimizer, train_loader, epoch, device, criterion)
        test(model, test_loader, device, criterion)

In [None]:
train()

In [None]:
wandb.finish()