In [1]:
from tropy.learn import fit_classifier, fit_classifier_onevsall, _inrad_eigenpair, predict_onevsall
from tropy.metrics import accuracy_multiple, veronese_feature_names, print_features_per_class
from tropy.ops import veronese
from tropy.utils import simplex_lattice_points
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

np.set_printoptions(precision=3, suppress=True)

In [2]:
base_df = pd.read_csv('./data/IRIS.csv')
df = base_df.loc[:, 'sepal_length':'petal_width']
features = df.columns.to_list()
classes = ["Iris-setosa", "Iris-virginica", "Iris-versicolor"]

lattice_points = None
def class_df(class_name, size=None):
  global lattice_points
  df_class = df[base_df["species"].str.contains(class_name)]
  df_train, df_test = train_test_split(df_class, test_size=0.2, random_state=43)
  Ctrain, Ctest = df_train.to_numpy(dtype=float).T, df_test.to_numpy(dtype=float).T

  if size is not None:
    d = Ctrain.shape[0]
    lattice_points = list(simplex_lattice_points(d, size))
    Ctrain, Ctest = veronese(lattice_points, Ctrain), veronese(lattice_points, Ctest)

  print(Ctrain.shape, Ctest.shape)
  return Ctrain, Ctest

In [3]:
size = 2

In [4]:
Clist_train = []
Clist_test = []
for class_name in classes:
  train, test = class_df(class_name, size)
  Clist_train.append(train)
  Clist_test.append(test)

(16, 40) (16, 10)
(16, 40) (16, 10)
(16, 40) (16, 10)


In [5]:
x, l = _inrad_eigenpair(Clist_train, N=50)
print("Apex:", x)
print("Eigval:", l)

 74%|███████▍  | 37/50 [00:00<00:00, 9252.33it/s]

Apex: [ 1.144  3.044 -0.056 -0.156 -3.156 -1.356  1.144  3.044 -0.056 -0.156
 -3.156 -1.356  4.644 -1.456  2.144 -4.256]
Eigval: -3.852619556710124e-08





In [6]:
predictor, sector_indicator = fit_classifier(Clist_train, x)
indicators, apices = fit_classifier_onevsall(Clist_train)

 46%|████▌     | 23/50 [00:00<00:00, 8556.77it/s]
 82%|████████▏ | 41/50 [00:00<00:00, 9259.45it/s]
 74%|███████▍  | 37/50 [00:00<00:00, 11188.04it/s]


In [7]:
accuracy = accuracy_multiple(predictor, Clist_test)
accuracy_one_vs_all = accuracy_multiple(predict_onevsall(indicators, apices, Clist_train), Clist_test)
print(f"Accuracy: {round(accuracy, 3)}")
print(f"Accuracy (one-vs-all): {round(accuracy_one_vs_all, 3)}")

Accuracy: 0.933
Accuracy (one-vs-all): 0.9


In [8]:
print_features_per_class(classes, veronese_feature_names(features, lattice_points), sector_indicator)

Dominant features for each class:
- Iris-setosa: 2*sepal_width
- Iris-virginica: sepal_length + petal_length, sepal_length + petal_width, petal_length + petal_width, 2*petal_length
- Iris-versicolor: sepal_length + sepal_width, sepal_width + petal_length, sepal_width + petal_width, 2*sepal_length, 2*petal_width
