In [1]:
from tropy.svm import TropicalSVC
from tropy.metrics import veronese_feature_names, print_features_per_class
from tropy.veronese import hypersurface_polymake_code
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

np.set_printoptions(precision=3, suppress=True)

In [2]:
base_df = pd.read_csv('./notebooks/data/IRIS.csv')
df = base_df.loc[:, 'sepal_length':'petal_width']
features = df.columns.to_list() + ["intercept"]
classes = ["Iris-setosa", "Iris-virginica", "Iris-versicolor"]

def class_df(class_name):
  df_class = df[base_df["species"].str.contains(class_name)]
  df_train, df_test = train_test_split(df_class, test_size=0.5, random_state=43)
  Ctrain, Ctest = df_train.to_numpy(dtype=float).T, df_test.to_numpy(dtype=float).T
  return Ctrain, Ctest

In [3]:
degree = 2

In [4]:
Clist_train, Clist_test = [], []
for class_name in classes:
  train, test = class_df(class_name)
  Clist_train.append(train)
  Clist_test.append(test)

In [5]:
model = TropicalSVC()
model.fit(Clist_train, degree)
print("Apex:", model._apex)

100%|██████████| 15/15 [00:00<00:00, 3785.02it/s]

Apex: [  8.254  10.543   7.23   -9.214   7.051   3.989 -11.546   5.566 -11.496
 -13.696  12.301   5.497   8.86    2.404 -25.745]





In [6]:
accuracy = model.accuracy(Clist_test)
print(f"Accuracy: {round(accuracy, 3)}")

Accuracy: 0.947


In [7]:
print_features_per_class(classes, veronese_feature_names(features, model._veronese_coefficients), model._sector_indicator)

Dominant features for each class:
- Iris-setosa: sepal_width + intercept, 2*intercept
- Iris-virginica: petal_length + petal_width, 2*petal_length, 2*petal_width
- Iris-versicolor: sepal_length + sepal_width, petal_length + intercept, petal_width + intercept, 2*sepal_width


In [8]:
hypersurface_polymake_code(model._monomials, model._coeffs)

'$C = new Hypersurface<Max>(MONOMIALS=>[[1, 1, 0, 0, 0], [1, 0, 1, 0, 0], [1, 0, 0, 1, 0], [1, 0, 0, 0, 1], [0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 1, 0, 1], [0, 0, 0, 1, 1], [2, 0, 0, 0, 0], [0, 2, 0, 0, 0], [0, 0, 2, 0, 0], [0, 0, 0, 2, 0], [0, 0, 0, 0, 2]], COEFFICIENTS=>[-8.254298095703122, -10.542545776367184, -7.229642944335936, 9.21404602050781, -7.050910644531248, -3.989143066406248, 11.545610351562503, -5.566462402343748, 11.495701904296874, 13.695610351562502, -12.301026611328124, -5.4969555664062515, -8.860285644531249, -2.4042980957031226, 25.744600219726557]);'