In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import pandas as pd

In [None]:
def split_to_inputs_and_labels(data):
    return data[:, 1:], data[:, 0]


def labels_to_array(data):
    return jnp.array(
        list(map(lambda x: [x == 1.0, x == 2.0, x == 3.0], data)), dtype=float
    )


df = pd.read_csv(
    "../data/wine/wine.data",
    header=None,
    index_col=None,
    names=[
        "Label",
        "Alcohol",
        "Malic acid",
        "Ash",
        "Alcalinity of ash",
        "Magnesium",
        "Total phenols",
        "Flavanoids",
        "Nonflavanoid phenols",
        "Proanthocyanins",
        "Color intensity",
        "Hue",
        "OD280/OD315 of diluted wines",
        "Proline",
    ],
)

data = df.to_numpy()
shuffled = jax.random.permutation(jax.random.PRNGKey(42), data)

testdata_size = int(0.15 * len(shuffled))
inputs_train, labels_train_raw = split_to_inputs_and_labels(shuffled[testdata_size:])
labels_train = labels_to_array(labels_train_raw)
inputs_test, labels_test_raw = split_to_inputs_and_labels(shuffled[:testdata_size])
labels_test = labels_to_array(labels_test_raw)

In [None]:
class Classifier(eqx.Module):
    linear: eqx.Module

    def __init__(self, input_size, output_size, key):
        self.linear = eqx.nn.Linear(input_size, output_size, key=key)

    def __call__(self, x):
        return self.linear(x)

    def probs(self, x):
        return jax.nn.softmax(self(x))

In [None]:
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return jnp.mean(optax.losses.softmax_cross_entropy(pred, y))


def accuracy(model: Classifier, x, y):
    pred = jax.vmap(model.probs)(x)
    return jnp.mean(jnp.argmax(y, axis=1) == jnp.argmax(pred, axis=1))


def train(model, optim: optax.GradientTransformation, loss_fn, features, labels, steps):
    _model = model
    _state = optim.init(_model)

    @eqx.filter_jit
    def make_step(model, state, features, labels):
        grads = jax.grad(loss_fn)(model, features, labels)
        updates, state = optim.update(grads, state, model)
        model = eqx.apply_updates(model, updates)
        return model, state

    for step in range(steps):
        _model, _state = make_step(_model, _state, features, labels)
        if step % 200 == 0:
            acc = accuracy(_model, features, labels)
            print(f"step {step} -> accuracy: {acc}")
    return _model

In [None]:
model = Classifier(13, 3, jax.random.PRNGKey(43))
optim = optax.adamw(0.01)

model = train(model, optim, loss_fn, inputs_train, labels_train, 10000)

In [None]:
accuracy(model, inputs_test, labels_test)