In [1]:
from tropy.svm import TropicalSVC
from tropy.metrics import veronese_feature_names, print_features_per_class
from tropy.veronese import hypersurface_polymake_code
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

np.set_printoptions(precision=3, suppress=True)

In [2]:
base_df = pd.read_csv('./data/iris.csv')
df = base_df.loc[:, 'sepal_length':'petal_width']
features = df.columns.to_list() + ["intercept"]
classes = ["Iris-setosa", "Iris-virginica", "Iris-versicolor"]

def class_df(class_name):
  df_class = df[base_df["species"].str.contains(class_name)]
  df_train, df_test = train_test_split(df_class, test_size=0.5, random_state=43)
  Ctrain, Ctest = df_train.to_numpy(dtype=float).T, df_test.to_numpy(dtype=float).T
  return Ctrain, Ctest

In [3]:
degree = 2

In [4]:
Clist_train, Clist_test = [], []
for class_name in classes:
  train, test = class_df(class_name)
  Clist_train.append(train)
  Clist_test.append(test)

In [5]:
model = TropicalSVC()
model.fit(Clist_train, degree)
print("Apex:", model._apex)

KM converged in 31 iterations
Apex: [ -9.98 -12.18 -11.18 -14.28   8.52  10.92   7.52   7.22   4.02   5.92
 -26.48  12.72   5.42   9.22   2.62]


In [6]:
accuracy = model.accuracy(Clist_test)
print(f"Accuracy: {round(accuracy, 3)}")

Accuracy: 0.947


In [7]:
print_features_per_class(classes, veronese_feature_names(features, model._monomials), model._sector_indicator)

Dominant features for each class:
- Iris-setosa: sepal_length + petal_length, 2*sepal_length
- Iris-virginica: 2*petal_width, 2*intercept
- Iris-versicolor: sepal_length + sepal_width, sepal_length + petal_width, sepal_length + intercept, sepal_width + petal_length, 2*petal_length


In [8]:
hypersurface_polymake_code(model._monomials, model._coeffs)

'$C = new Hypersurface<Max>(MONOMIALS=>[[1, 1, 0, 0, 0], [1, 0, 1, 0, 0], [1, 0, 0, 1, 0], [1, 0, 0, 0, 1], [0, 1, 1, 0, 0], [2, 0, 0, 0, 0], [0, 0, 2, 0, 0], [0, 0, 0, 2, 0], [0, 0, 0, 0, 2]], COEFFICIENTS=>[9.979943178429581, 12.179564206091067, 11.180091728152087, 14.280053564415619, -8.519909612114231, 26.479537080715716, -5.419915604988735, -9.219907414658618, -2.619909612114233]);'