# SGD vs RMSProp vs Adam for logistic regression on MNIST

## References

* SGD
* RMSProp
* Kingma et al. 2017, [Adam: A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980) -> Adam

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import typing as T
from collections import defaultdict
from enum import Enum

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchinfo
import tqdm
from einops.layers.torch import Rearrange
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset

import random_neural_net_models.convolution_lecun1990 as conv_lecun1990
import random_neural_net_models.telemetry as telemetry
import random_neural_net_models.utils as rnnm_utils

sns.set_theme()

In [None]:
rnnm_utils.make_deterministic(42)

In [None]:
device = rnnm_utils.get_device()
device

## Preparing the MNIST data

In [None]:
mnist = fetch_openml("Fashion-MNIST", version=1, cache=True, parser="auto")

In [None]:
X = mnist["data"]
y = mnist["target"]
X.shape, y.shape

In [None]:
X0, X1, y0, y1 = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)
ds_test = conv_lecun1990.DigitsDataset(X1, y1)

In [None]:
batch_size = 128
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(ds_test, batch_size=500, shuffle=False)

## `Model` & `do_epoch`

In [None]:
# logistic regression model


class LogisticRegression(nn.Module):
    def __init__(self, h: int, w: int, output_dim: int):
        super().__init__()
        self.rectangle2flat = Rearrange("b h w -> b (h w)", h=h, w=w)
        self.linear = nn.Linear(h * w, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(self.rectangle2flat(x))


def do_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    dataloader_test: DataLoader,
    opt: optim.Optimizer,
    _iter: int,
) -> int:
    # training part
    for X_batch, y_batch in dataloader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        opt.zero_grad()
        logits = model(X_batch)

        loss = F.cross_entropy(logits, y_batch)
        loss.backward()
        opt.step()
        model.loss_history_train(loss, _iter)
        model.parameter_history(_iter)

        _iter += 1

    # validation part
    with torch.no_grad():
        all_logits, all_targets = [], []
        for X_batch, y_batch in dataloader_test:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            logits = model(X_batch)
            all_logits.append(logits)
            all_targets.append(y_batch)

        all_logits = torch.cat(all_logits, dim=0)
        all_targets = torch.cat(all_targets, dim=0)
        loss = F.cross_entropy(all_logits, all_targets)
        model.loss_history_test(loss, _iter)

    return _iter

In [None]:
def do_training(
    model: nn.Module,
    n_epochs: int,
    dataloader: DataLoader,
    dataloader_test: DataLoader,
    opt: optim.Optimizer,
):
    _iter = 0
    for _ in tqdm.tqdm(range(n_epochs), total=n_epochs, desc="Epoch"):
        _iter = do_epoch(model, dataloader, dataloader_test, opt, _iter)

## Optimizers

In [None]:
class MySGD(optim.Optimizer):
    def __init__(self, params: T.Generator, lr: float = 0.01):
        super(MySGD, self).__init__(params, {"lr": lr})

    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                p.data -= group["lr"] * p.grad


class MySGDWithMomentum(optim.Optimizer):
    # https://paperswithcode.com/method/sgd-with-momentum
    # https://towardsdatascience.com/stochastic-gradient-descent-with-momentum-a84097641a5d
    def __init__(
        self, params: T.Generator, lr: float = 0.01, momentum: float = 0.9
    ):
        super(MySGDWithMomentum, self).__init__(params, {"lr": lr})
        self.momentum = momentum
        self.state = defaultdict(dict)
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["momentum"] = torch.zeros_like(p.data)

    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["momentum"] = (
                    self.momentum * self.state[p]["momentum"]
                    + (1 - self.momentum) * p.grad
                )
                p.data -= group["lr"] * self.state[p]["momentum"]


class MyAdam(optim.Optimizer):
    def __init__(
        self,
        params: T.Generator,
        alpha: float = 0.001,
        betas: T.Tuple[float] = (0.9, 0.999),
        eps: float = 1e-8,
    ):
        super(MyAdam, self).__init__(
            params, {"alpha": alpha, "eps": eps, "betas": betas}
        )

        self.state = defaultdict(dict)
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["momentum_m"] = torch.zeros_like(p.data)
                self.state[p]["momentum_v"] = torch.zeros_like(p.data)

    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["momentum_m"] = (
                    group["betas"][0] * self.state[p]["momentum_m"]
                    + (1 - group["betas"][0]) * p.grad
                )
                self.state[p]["momentum_v"] = group["betas"][1] * self.state[p][
                    "momentum_v"
                ] + (1 - group["betas"][1]) * p.grad.pow(2)
                m_hat = self.state[p]["momentum_m"] / (1 - group["betas"][0])
                v_hat = self.state[p]["momentum_v"] / (1 - group["betas"][1])
                p.data -= group["alpha"] * m_hat / (v_hat.sqrt() + group["eps"])


class MyRMSProp(optim.Optimizer):
    # https://optimization.cbe.cornell.edu/index.php?title=RMSProp
    def __init__(
        self,
        params: T.Generator,
        lr: float = 0.001,
        momentum: float = 0.9,
        eps: float = 1e-8,
    ):
        super(MyRMSProp, self).__init__(
            params, {"lr": lr, "eps": eps, "momentum": momentum}
        )

        self.state = defaultdict(dict)
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["momentum"] = torch.zeros_like(p.data)

    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["momentum"] = group["momentum"] * self.state[p][
                    "momentum"
                ] + (1 - group["momentum"]) * p.grad.pow(2)
                p.data -= (
                    group["lr"]
                    / (self.state[p]["momentum"].sqrt() + group["eps"])
                    * p.grad
                )

## Running a training using a single optimizer

In [None]:
model = LogisticRegression(28, 28, 10)
model = telemetry.ModelTelemetry(
    model,
    loss_names=("total",),
    loss_train_every_n=1,
    loss_test_every_n=1,
    parameters_name_patterns=("linear",),
)
model = model.to(device).double()

# define the optimizer
# opt = MySGD(
#     model.parameters(),
#     lr=0.01,
# )
# opt = MySGDWithMomentum(
#     model.parameters(),
#     lr=0.01,
#     momentum=0.9,
# )
# opt = MyAdam(
#     model.parameters(),
#     alpha=0.001,
#     betas=(0.9, 0.999),
# )
opt = MyRMSProp(
    model.parameters(),
    lr=0.01,
    momentum=0.9,
)

In [None]:
n_epochs = 2
do_training(model, n_epochs, dataloader, dataloader_test, opt)

In [None]:
model.draw_loss_history_train()
model.draw_loss_history_test()

## Running training for multiple optimizers

In [None]:
class OptimizerType(Enum):
    SGD: str = "SGD"
    Adam: str = "Adam"
    RMSProp: str = "RMSProp"
    MySGD: str = "MySGD"
    MySGDWithMomentum: str = "MySGDWithMomentum"
    MyAdam: str = "MyAdam"
    MyRMSProp: str = "MyRMSProp"


def get_optimizer(
    name: OptimizerType, model_params: dict, optimizer_params: dict
) -> optim.Optimizer:
    if name == OptimizerType.SGD:
        return optim.SGD(model_params, **optimizer_params)
    elif name == OptimizerType.Adam:
        return optim.Adam(model_params, **optimizer_params)
    elif name == OptimizerType.MySGD:
        return MySGD(model_params, **optimizer_params)
    elif name == OptimizerType.MySGDWithMomentum:
        return MySGDWithMomentum(model_params, **optimizer_params)
    elif name == OptimizerType.MyAdam:
        return MyAdam(model_params, **optimizer_params)
    elif name == OptimizerType.MyRMSProp:
        return MyRMSProp(model_params, **optimizer_params)
    elif name == OptimizerType.RMSProp:
        return optim.RMSprop(model_params, **optimizer_params)
    else:
        raise ValueError(f"Unknown optimizer {name}")


def train_with_multiple_optimizers(
    n_epochs: int,
    dataloader: DataLoader,
    dataloader_test: DataLoader,
    optimizer_params: T.Dict[str, dict],
) -> T.Dict[str, telemetry.ModelTelemetry]:
    models = {}
    for optimizer_name, optimizer_params in optimizer_params.items():
        # define the model
        model = LogisticRegression(28, 28, 10)
        model = telemetry.ModelTelemetry(
            model,
            loss_names=("total",),
            loss_train_every_n=1,
            loss_test_every_n=1,
            parameters_name_patterns=("linear",),
        )
        model = model.to(device).double()

        # define the optimizer
        opt = get_optimizer(
            OptimizerType(optimizer_name), model.parameters(), optimizer_params
        )

        do_training(model, n_epochs, dataloader, dataloader_test, opt)

        models[optimizer_name] = model

    return models


def get_rolling_loss_df(
    models: T.Dict[str, telemetry.ModelTelemetry], group: str = "train"
) -> pd.DataFrame:
    losses = [
        getattr(model, f"loss_history_{group}")
        .get_rolling_mean_df()
        .assign(optimizer=optimizer_name)
        for optimizer_name, model in models.items()
    ]
    losses = pd.concat(losses, ignore_index=True)
    return losses


def plot_losses_for_optimizers(
    models: T.Dict[str, telemetry.ModelTelemetry],
    alpha: float = 0.5,
    figsize: T.Tuple[int, int] = (10, 7),
):
    fig, axs = plt.subplots(nrows=2, figsize=figsize)

    # train
    ax = axs[0]
    losses = get_rolling_loss_df(models, "train")
    sns.lineplot(
        data=losses, x="iter", y="total", hue="optimizer", ax=ax, alpha=alpha
    )
    ax.set_title("Train loss")

    # test
    ax = axs[1]
    losses = get_rolling_loss_df(models, "test")
    sns.lineplot(
        data=losses, x="iter", y="total", hue="optimizer", ax=ax, alpha=alpha
    )
    ax.set_title("Test loss")

    plt.tight_layout()

In [None]:
optimizer_params = {
    "SGD": {"lr": 0.1, "momentum": 0.9, "nesterov": True},
    "Adam": {"lr": 0.001, "betas": (0.9, 0.999)},
    "RMSProp": {"lr": 0.001, "momentum": 0.9},
    "MySGD": {"lr": 0.1},
    "MySGDWithMomentum": {"lr": 0.1, "momentum": 0.9},
    "MyAdam": {"alpha": 0.001, "betas": (0.9, 0.999)},
    "MyRMSProp": {"lr": 0.001, "momentum": 0.9},
}

In [None]:
n_epochs = 1
models = train_with_multiple_optimizers(
    n_epochs, dataloader, dataloader_test, optimizer_params
)

In [None]:
plot_losses_for_optimizers(models, figsize=(10, 10))