In [42]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import pandas as pd
import random

In [71]:

def split_to_inputs_and_labels(data):
  return data[:,1:], data[:,0]

def labels_to_array(data):
  return jnp.array(list(map(lambda x: [x==1.0, x==2.0, x==3.0], data)), dtype=float)

df = pd.read_csv(
    "../data/wine/wine.data",
    header=None,
    index_col=None,
    names=[
        "Label",
        "Alcohol",
        "Malic acid",
        "Ash",
        "Alcalinity of ash",
        "Magnesium",
        "Total phenols",
        "Flavanoids",
        "Nonflavanoid phenols",
        "Proanthocyanins",
        "Color intensity",
        "Hue",
        "OD280/OD315 of diluted wines",
        "Proline",
    ],
)

data = df.to_numpy()
shuffled = jax.random.permutation(jax.random.PRNGKey(42), data)

testdata_size = int(0.15 * len(shuffled))
inputs_train, labels_train_raw = split_to_inputs_and_labels(shuffled[testdata_size:])
labels_train = labels_to_array(labels_train_raw)
inputs_test, labels_test_raw = split_to_inputs_and_labels(shuffled[:testdata_size])
labels_test = labels_to_array(labels_test_raw)


In [72]:
class Classifier(eqx.Module):
    linear: eqx.Module

    def __init__(self, input_size, output_size, key):
        self.linear = eqx.nn.Linear(input_size, output_size, key=key)

    def __call__(self, x):
        return jax.nn.softmax(self.linear(x))