In [None]:
#
# Project:
#      PyTorch Dojo (https://github.com/wo3kie/ml-dojo)
#
# Author:
#      Lukasz Czerwinski (https://www.lukaszczerwinski.pl/)
#

In [None]:
from torch import float32, tensor
from torch.nn import BCELoss, Linear, Sequential, Sigmoid
from torch.optim import SGD


import import_ipynb
from common import assert_eq, assert_ge, assert_lt, Patient, T # type: ignore


def logistic_regression_sgd_autograd(X, y, epochs=2000, lr=0.1) -> tuple[float, callable]:
    """
    Perform logistic regression using Stochastic Gradient Descent (SGD) with autograd.

    Parameters:
        X: Input features of shape (Samples, Features)
        y: Target values of shape (Samples, 1)
        epochs: Number of training epochs
        lr: Learning rate

    Returns:
        A tuple containing the final loss and a prediction function that takes new input data and returns predicted probabilities.
    """
    (s, f) = X.shape

    model = Sequential(Linear(f, 1), Sigmoid())
    optimizer = SGD(model.parameters(), lr=lr)

    for _ in range(epochs):
        optimizer.zero_grad()

        predicted = model(X)
        assert_eq(predicted.shape, (s, 1))

        loss = BCELoss(reduction='mean')(predicted, y)
        assert_eq(loss.shape, ())
        
        loss.backward()
        optimizer.step()

    return (loss.item(), model)


def _test_logistic_regression_sgd_autograd(epochs=2000, lr=0.1) -> None:
    training_data = T([Patient(0.5).data for _ in range(80)])

    X = training_data[:, :-1]
    X[:, 0] /= 100 # Data scaling to make training numerically stable
    y = training_data[:, -1].unsqueeze(1)

    (_, model) = logistic_regression_sgd_autograd(X, y, epochs, lr)

    for d in T([Patient(1.0).data for _ in range(10)]):
        d[0] /= 100 # The same data scaling as during training.
        assert_ge(model(d[:-1]), T(0.5))
        
    for d in T([Patient(0.0).data for _ in range(10)]):
        d[0] /= 100 # The same data scaling as during training.
        assert_lt(model(d[:-1]), T(0.5))
    

if __name__ == "__main__":
    _test_logistic_regression_sgd_autograd()