# A simple pytorch based MNL lib

> Fit your Multinomial Logistic Regression with Pytorch

## Install

`pip install pytorch_mnl`

## How to use

import the lib

In [None]:
import pandas as pd
from pytorch_mnl.core import *

load data

In [None]:
data = pd.read_csv("./data/Iris.csv").drop("Id", axis=1)

choose x, y cols:

In [None]:
x_cols=['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
target_col = 'Species'

the number of classes to predict:

In [None]:
n_targets = len(data[target_col].unique())
n_targets

3

In [None]:
X, y, av = prepare_data(data, x_cols=x_cols, target_col=target_col)

{'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}


we get pytorch tensors ready to use!

In [None]:
type(X), type(y)

(torch.Tensor, torch.Tensor)

let's split in train/valid choosing a percenage as holdout, and choose a batch size to fit our model

In [None]:
dls = DataLoaders.from_Xy(X, y, av, pct=0.2, batch_size=8)

as our model has 4 variables, we will fit a 4 MNL, with 3 targets.

In [None]:
model = LinearMNL(len(x_cols), n_targets)

In [None]:
learn = Learner(dls, model)

In [None]:
learn.fit(25)

epoch =   0, train_loss = 42.186, val_loss = 13.372, accuracy = 0.19
epoch =   1, train_loss = 20.678, val_loss = 8.456, accuracy = 0.14
epoch =   2, train_loss = 15.091, val_loss = 7.411, accuracy = 0.43
epoch =   3, train_loss = 13.755, val_loss = 6.893, accuracy = 0.68
epoch =   4, train_loss = 12.857, val_loss = 6.482, accuracy = 0.68
epoch =   5, train_loss = 12.131, val_loss = 6.147, accuracy = 0.68
epoch =   6, train_loss = 11.534, val_loss = 5.869, accuracy = 0.68
epoch =   7, train_loss = 11.036, val_loss = 5.635, accuracy = 0.68
epoch =   8, train_loss = 10.614, val_loss = 5.435, accuracy = 0.68
epoch =   9, train_loss = 10.253, val_loss = 5.263, accuracy = 0.68
epoch =  10, train_loss = 9.940, val_loss = 5.113, accuracy = 0.68
epoch =  11, train_loss = 9.665, val_loss = 4.980, accuracy = 0.69
epoch =  12, train_loss = 9.422, val_loss = 4.862, accuracy = 0.69
epoch =  13, train_loss = 9.205, val_loss = 4.755, accuracy = 0.70
epoch =  14, train_loss = 9.009, val_loss = 4.659, 