# A simple pytorch based MNL lib

> Fit your Multinomial Logistic Regression with Pytorch

## Install

`pip install pytorch_mnl`

## How to use

import the lib

In [22]:
import pandas as pd
from pytorch_mnl.core import *

load data

In [23]:
data = pd.read_csv("./data/Iris.csv").drop("Id", axis=1)

In [24]:
data.head()

Unnamed: 0,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species
0,5.1,3.5,1.4,0.2,Iris-setosa
1,4.9,3.0,1.4,0.2,Iris-setosa
2,4.7,3.2,1.3,0.2,Iris-setosa
3,4.6,3.1,1.5,0.2,Iris-setosa
4,5.0,3.6,1.4,0.2,Iris-setosa


choose x, y cols:

In [25]:
x_cols=['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
target_col = 'Species'

In [26]:
n_targets = len(data[target_col].unique())
n_targets

3

In [27]:
X, y = prepare_data(data, x_cols=x_cols, target_col=target_col)

x_cols=['SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm'],
target_col='Species'


we get pytorch tensors ready to use!

In [28]:
type(X), type(y)

(torch.Tensor, torch.Tensor)

let's split in train/valid choosing a percenage as holdout, and choose a batch size to fit our model

In [29]:
dls = DataLoaders.from_Xy(X, y, pct=0.2, batch_size=8)

as our model has 4 variables, we will fit a 4 MNL, with 3 targets.

In [30]:
model = LinearMNL(len(x_cols), n_targets)

In [31]:
learn = Learner(dls, model)

In [32]:
learn.fit(25)

epoch =   0, val_loss = 2.080, accuracy = 0.37
epoch =   1, val_loss = 1.909, accuracy = 0.37
epoch =   2, val_loss = 1.770, accuracy = 0.53
epoch =   3, val_loss = 1.655, accuracy = 0.70
epoch =   4, val_loss = 1.561, accuracy = 0.70
epoch =   5, val_loss = 1.482, accuracy = 0.73
epoch =   6, val_loss = 1.416, accuracy = 0.77
epoch =   7, val_loss = 1.360, accuracy = 0.77
epoch =   8, val_loss = 1.311, accuracy = 0.80
epoch =   9, val_loss = 1.269, accuracy = 0.83
epoch =  10, val_loss = 1.233, accuracy = 0.87
epoch =  11, val_loss = 1.200, accuracy = 0.87
epoch =  12, val_loss = 1.171, accuracy = 0.87
epoch =  13, val_loss = 1.145, accuracy = 0.87
epoch =  14, val_loss = 1.121, accuracy = 0.87
epoch =  15, val_loss = 1.099, accuracy = 0.90
epoch =  16, val_loss = 1.079, accuracy = 0.90
epoch =  17, val_loss = 1.061, accuracy = 0.90
epoch =  18, val_loss = 1.044, accuracy = 0.90
epoch =  19, val_loss = 1.028, accuracy = 0.93
epoch =  20, val_loss = 1.013, accuracy = 0.93
epoch =  21, 