<a href="https://colab.research.google.com/github/tcapelle/classification-losses/blob/main/Classification_Losses.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Classification Losses
> This notebook accompains this [report](https://wandb.ai/capecape/classification-techniques/reports/Classifiction-Losses-SoftMax-and-Cross-Entropy-what-s-the-deal---VmlldzoxODEwNTM5)

<!--- @wandbcode{classification-losses} -->

In [None]:
!pip install -Uqqq wandb

In [None]:
# log to Weights and Biases
import wandb
wandb.login()

## PyTorch

Let's get some data first, we will use the same code as in our [example](https://wandb.me/intro). 

In [None]:
#@title
import math
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as T
from tqdm.notebook import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def get_dataloader(is_train, batch_size, slice=5):
    "Get a training dataloader"
    full_dataset = torchvision.datasets.MNIST(root=".", train=is_train, transform=T.ToTensor(), download=True)
    sub_dataset = torch.utils.data.Subset(full_dataset, indices=range(0, len(full_dataset), slice))
    loader = torch.utils.data.DataLoader(dataset=sub_dataset, 
                                         batch_size=batch_size, 
                                         shuffle=True if is_train else False, 
                                         pin_memory=True, num_workers=2)
    return loader

def get_model(last_layer=None):
    "A simple model"
    
    layers = [nn.Flatten(),
              nn.Linear(28*28, 256),
              nn.BatchNorm1d(256),
              nn.ReLU(),
              nn.Linear(256,10)]
    
    layers += [last_layer] if last_layer else []
    
    model = nn.Sequential(*layers).to(device)
    return model

def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for i, (images, labels) in tqdm(enumerate(valid_dl), leave=False):
            images, labels = images.to(device), labels.to(device)

            # Forward pass ➡
            outputs = model(images)
            val_loss += loss_func(outputs, labels)*labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

    return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)

In [None]:
def train(model, loss_func, config):
    
    # Get the data
    train_dl = get_dataloader(is_train=True, batch_size=config.batch_size)
    valid_dl = get_dataloader(is_train=False, batch_size=2*config.batch_size)
    n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    # Training
    example_ct = 0
    step_ct = 0
    for epoch in tqdm(range(config.epochs)):
        model.train()
        for step, (images, labels) in enumerate(tqdm(train_dl, leave=False)):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            train_loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            example_ct += len(images)
            metrics = {"train/train_loss": train_loss, 
                       "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch, 
                       "train/example_ct": example_ct}

            if step + 1 < n_steps_per_epoch:
                # 🐝 Log train metrics to wandb 
                wandb.log(metrics)

            step_ct += 1

        val_loss, accuracy = validate_model(model, valid_dl, loss_func, log_images=(epoch==(config.epochs-1)))

        # 🐝 Log train and validation metrics to wandb
        val_metrics = {"val/val_loss": val_loss, 
                       "val/val_accuracy": accuracy}
        wandb.log({**metrics, **val_metrics})

        print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")

In [None]:
PROJECT = "classification-techniques"

In [None]:
def run5(loss_func, last_layer):
    for _ in range(5):
        # 🐝 initialise a wandb run
        wandb.init(
            project=PROJECT,
            config={
                "epochs": 10,
                "batch_size": 128,
                "lr": 1e-3,
                "loss_func": str(loss_func),
                "last_layer": str(last_layer),
                "framework": "PyTorch",
                })

        # Copy your config 
        config = wandb.config
        
        # A simple MLP model
        model = get_model(last_layer=last_layer)
        
        # train the model with loss func
        train(model, loss_func, config)

        # 🐝 Close your wandb run 
        wandb.finish()

### Baseline
> Run the model without any `last_layer` using `CrossEntropyLoss`

In [None]:
last_layer = None

# Select a loss function
loss_func = nn.CrossEntropyLoss()

run5(loss_func, last_layer)

### Common error
> Adding Softmax

In [None]:
last_layer = nn.Softmax(dim=-1)
loss_func = nn.CrossEntropyLoss()

run5(loss_func, last_layer)

### NLL
> You need to pass nn.LogSoftmax

In [None]:
last_layer = nn.LogSoftmax(dim=-1)
loss_func = nn.NLLLoss()

run5(loss_func, last_layer)

### Focal Loss
> Check this [excellent article](https://amaarora.github.io/2020/06/29/FocalLoss.html) from Aman

In [None]:
# from fastai
import torch.nn.functional as F

class FocalLoss(nn.Module):
    "Focal loss implemented using F.cross_entropy"
    def __init__(self, gamma: float = 2.0, weight=None, reduction: str = 'mean') -> None:
        super().__init__()
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction

    def forward(self, inp: torch.Tensor, targ: torch.Tensor):
        ce_loss = F.cross_entropy(inp, targ, weight=self.weight, reduction="none")
        p_t = torch.exp(-ce_loss)
        loss = (1 - p_t)**self.gamma * ce_loss
        if self.reduction == "mean":
            loss = loss.mean()
        elif self.reduction == "sum":
            loss = loss.sum()
        return loss

In [None]:
last_layer = None
loss_func = FocalLoss()

run5(loss_func, last_layer)

## Keras
> How to do it in Keras

In [None]:
import random

import numpy as np
import tensorflow as tf
from wandb.keras import WandbCallback

# Simple Keras Model

def train_keras(loss_func, last_layer, config):
    # Get the data
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    x_train, y_train = x_train[::5], y_train[::5]  # Subset data for a faster demo
    x_test, y_test = x_test[::20], y_test[::20]
    labels = [str(digit) for digit in range(np.max(y_train) + 1)]

    # Build a model
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(256, activation="relu"),
        tf.keras.layers.Dense(10, activation=last_layer)
        ])

    model.compile(optimizer="adam",
                  loss=loss_func,
                  metrics=["accuracy"]
                )

    # WandbCallback auto-saves all metrics from model.fit(), plus predictions on validation_data
    logging_callback = WandbCallback()

    history = model.fit(x=x_train, y=y_train,
                        epochs=config.epoch,
                        batch_size=config.batch_size,
                        validation_data=(x_test, y_test),
                        callbacks=[logging_callback]
                        )

In [None]:
def run5_keras(loss_func, last_layer):
    for _ in range(5):
        wandb.init(
            project=PROJECT,
            # Set entity to specify your username or team name
            # ex: entity="wandb",
            config={
                "epochs": 10,
                "batch_size": 128,
                "last_layer": str(last_layer),
                "loss_func": str(loss_func),
                "metric": "accuracy",
                "epoch": 10,
                "framework": "Keras",
            })
        config = wandb.config

        train_keras(loss_func, last_layer, config)

        wandb.finish()

### SoftMax activation in last Dense layer

In [None]:
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
last_layer = "softmax"

run5_keras(loss_func, last_layer)

### From Logits (without SoftMax layer)

In [None]:
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
last_layer = None

run5_keras(loss_func, last_layer)