# MTS Benchmark Results

In [1]:
import argparse
import joblib
import json
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import shutil
import sys
import seaborn as sns
import tensorflow as tf

from itertools import cycle
from numpy import interp
from pathlib import Path
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve, auc
from tensorflow import keras

from astronet.constants import ASTRONET_WORKING_DIRECTORY as asnwd
from astronet.preprocess import one_hot_encode
from astronet.utils import astronet_logger, load_dataset

In [2]:
from astronet.viz.visualise_results import plot_acc_history, plot_confusion_matrix, plot_loss_history, plot_multiROC

In [3]:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.serif": ["Computer Modern Roman"]})

mpl.style.use("seaborn")

pd.set_option("display.precision", 1)
pd.set_option('display.float_format', '{:.2f}'.format)

In [4]:
architecture = "t2"

In [5]:
df = pd.read_csv(f"{os.environ['ASNWD']}/results/mts-{architecture}-results.csv", index_col='Unnamed: 0')
t2 = df[f"{architecture}"].multiply(100).to_frame()
t2

Unnamed: 0,t2
ArabicDigits,97.32
AUSLAN,92.91
CharacterTrajectories,94.57
CMUsubject16,100.0
ECG,84.0
JapaneseVowels,97.3
KickvsPunch,90.0
Libras,81.67
NetFlow,77.9
UWave,84.53


In [7]:
architecture = "atx"

In [8]:
df = pd.read_csv(f"{os.environ['ASNWD']}/results/mts-{architecture}-results.csv", index_col='Unnamed: 0')
snx = df[f"{architecture}"].multiply(100).to_frame()
snx

Unnamed: 0,atx
ArabicDigits,98.36
AUSLAN,83.86
CharacterTrajectories,96.44
CMUsubject16,72.41
ECG,33.0
JapaneseVowels,94.86
KickvsPunch,40.0
Libras,74.44
NetFlow,77.9
UWave,90.95


In [9]:
df_combined_arch = t2.join(snx)
df_combined_arch

Unnamed: 0,t2,atx
ArabicDigits,97.32,98.36
AUSLAN,92.91,83.86
CharacterTrajectories,94.57,96.44
CMUsubject16,100.0,72.41
ECG,84.0,33.0
JapaneseVowels,97.3,94.86
KickvsPunch,90.0,40.0
Libras,81.67,74.44
NetFlow,77.9,77.9
UWave,84.53,90.95


In [10]:
df_benchmark = pd.read_csv(f"{os.environ['ASNWD']}/results/mts-fawaz-results.csv", index_col='Unnamed: 0')
df_benchmark

Unnamed: 0,MLP,FCN,ResNet,Encoder,MCNN,t-LeNet,MCDCNN,Time-CNN,TWIESN
ArabicDigits,96.9(0.2),99.4(0.1),99.6(0.1),98.1(0.1),10.0(0.0),10.0(0.0),95.9(0.2),95.8(0.3),85.3(1.4)
AUSLAN,93.3(0.5),97.5(0.4),97.4(0.3),93.8(0.5),1.1(0.0),1.1(0.0),85.4(2.7),72.6(3.5),72.4(1.6)
CharacterTrajectories,96.9(0.2),99.0(0.1),99.0(0.2),97.1(0.2),5.4(0.8),6.7(0.0),93.8(1.7),96.0(0.8),92.0(1.3)
CMUsubject16,60.0(16.9),100.0(0.0),99.7(1.1),98.3(2.4),53.1(4.4),51.0(5.3),51.4(5.0),97.6(1.7),89.3(6.8)
ECG,74.8(16.2),87.2(1.2),86.7(1.3),87.2(0.8),67.0(0.0),67.0(0.0),50.0(17.9),84.1(1.7),73.7(2.3)
JapaneseVowels,97.6(0.2),99.3(0.2),99.2(0.3),97.6(0.6),9.2(2.5),23.8(0.0),94.4(1.4),95.6(1.0),96.5(0.7)
KickvsPunch,61.0(12.9),54.0(13.5),51.0(8.8),61.0(9.9),54.0(9.7),50.0(10.5),56.0(8.4),62.0(6.3),67.0(14.2)
Libras,78.0(1.0),96.4(0.7),95.4(1.1),78.3(0.9),6.7(0.0),6.7(0.0),65.1(3.9),63.7(3.3),79.4(1.3)
NetFlow,55.0(26.1),89.1(0.4),62.7(23.4),77.7(0.5),77.9(0.0),72.3(17.6),63.0(18.2),89.0(0.9),94.5(0.4)
UWave,90.1(0.3),93.4(0.3),92.6(0.4),90.8(0.4),12.5(0.0),12.5(0.0),84.5(1.6),85.9(0.7),75.4(6.3)


In [11]:
df_combined_both = df_combined_arch.join(df_benchmark)
df_combined_both

Unnamed: 0,t2,atx,MLP,FCN,ResNet,Encoder,MCNN,t-LeNet,MCDCNN,Time-CNN,TWIESN
ArabicDigits,97.32,98.36,96.9(0.2),99.4(0.1),99.6(0.1),98.1(0.1),10.0(0.0),10.0(0.0),95.9(0.2),95.8(0.3),85.3(1.4)
AUSLAN,92.91,83.86,93.3(0.5),97.5(0.4),97.4(0.3),93.8(0.5),1.1(0.0),1.1(0.0),85.4(2.7),72.6(3.5),72.4(1.6)
CharacterTrajectories,94.57,96.44,96.9(0.2),99.0(0.1),99.0(0.2),97.1(0.2),5.4(0.8),6.7(0.0),93.8(1.7),96.0(0.8),92.0(1.3)
CMUsubject16,100.0,72.41,60.0(16.9),100.0(0.0),99.7(1.1),98.3(2.4),53.1(4.4),51.0(5.3),51.4(5.0),97.6(1.7),89.3(6.8)
ECG,84.0,33.0,74.8(16.2),87.2(1.2),86.7(1.3),87.2(0.8),67.0(0.0),67.0(0.0),50.0(17.9),84.1(1.7),73.7(2.3)
JapaneseVowels,97.3,94.86,97.6(0.2),99.3(0.2),99.2(0.3),97.6(0.6),9.2(2.5),23.8(0.0),94.4(1.4),95.6(1.0),96.5(0.7)
KickvsPunch,90.0,40.0,61.0(12.9),54.0(13.5),51.0(8.8),61.0(9.9),54.0(9.7),50.0(10.5),56.0(8.4),62.0(6.3),67.0(14.2)
Libras,81.67,74.44,78.0(1.0),96.4(0.7),95.4(1.1),78.3(0.9),6.7(0.0),6.7(0.0),65.1(3.9),63.7(3.3),79.4(1.3)
NetFlow,77.9,77.9,55.0(26.1),89.1(0.4),62.7(23.4),77.7(0.5),77.9(0.0),72.3(17.6),63.0(18.2),89.0(0.9),94.5(0.4)
UWave,84.53,90.95,90.1(0.3),93.4(0.3),92.6(0.4),90.8(0.4),12.5(0.0),12.5(0.0),84.5(1.6),85.9(0.7),75.4(6.3)


In [12]:
filename = f"{os.environ['ASNWD']}/results/mts-combined-results.md"

In [13]:
results = df_combined_both.to_markdown()
print(results,  file=open(filename, 'w'))

In [14]:
filename = f"{os.environ['ASNWD']}/results/mts-t2-combined-results.md"

In [15]:
t2_results_combined = df_combined_both.drop(columns=['atx']).to_markdown()

In [16]:
print(t2_results_combined,  file=open(filename, 'w'))

In [28]:
df_combined_both.drop(columns=['snX', 'Encoder', 'MCNN', 't-LeNet', 'MCDCNN', 'Time-CNN', 'MLP'])

Unnamed: 0,t2,FCN,ResNet,TWIESN
ArabicDigits,97.32,99.4(0.1),99.6(0.1),85.3(1.4)
AUSLAN,92.91,97.5(0.4),97.4(0.3),72.4(1.6)
CharacterTrajectories,94.57,99.0(0.1),99.0(0.2),92.0(1.3)
CMUsubject16,100.0,100.0(0.0),99.7(1.1),89.3(6.8)
ECG,84.0,87.2(1.2),86.7(1.3),73.7(2.3)
JapaneseVowels,97.3,99.3(0.2),99.2(0.3),96.5(0.7)
KickvsPunch,90.0,54.0(13.5),51.0(8.8),67.0(14.2)
Libras,81.67,96.4(0.7),95.4(1.1),79.4(1.3)
NetFlow,77.9,89.1(0.4),62.7(23.4),94.5(0.4)
UWave,84.53,93.4(0.3),92.6(0.4),75.4(6.3)


In [30]:
df_combined_both.drop(columns=['snX'])

Unnamed: 0,t2,MLP,FCN,ResNet,Encoder,MCNN,t-LeNet,MCDCNN,Time-CNN,TWIESN
ArabicDigits,97.32,96.9(0.2),99.4(0.1),99.6(0.1),98.1(0.1),10.0(0.0),10.0(0.0),95.9(0.2),95.8(0.3),85.3(1.4)
AUSLAN,92.91,93.3(0.5),97.5(0.4),97.4(0.3),93.8(0.5),1.1(0.0),1.1(0.0),85.4(2.7),72.6(3.5),72.4(1.6)
CharacterTrajectories,94.57,96.9(0.2),99.0(0.1),99.0(0.2),97.1(0.2),5.4(0.8),6.7(0.0),93.8(1.7),96.0(0.8),92.0(1.3)
CMUsubject16,100.0,60.0(16.9),100.0(0.0),99.7(1.1),98.3(2.4),53.1(4.4),51.0(5.3),51.4(5.0),97.6(1.7),89.3(6.8)
ECG,84.0,74.8(16.2),87.2(1.2),86.7(1.3),87.2(0.8),67.0(0.0),67.0(0.0),50.0(17.9),84.1(1.7),73.7(2.3)
JapaneseVowels,97.3,97.6(0.2),99.3(0.2),99.2(0.3),97.6(0.6),9.2(2.5),23.8(0.0),94.4(1.4),95.6(1.0),96.5(0.7)
KickvsPunch,90.0,61.0(12.9),54.0(13.5),51.0(8.8),61.0(9.9),54.0(9.7),50.0(10.5),56.0(8.4),62.0(6.3),67.0(14.2)
Libras,81.67,78.0(1.0),96.4(0.7),95.4(1.1),78.3(0.9),6.7(0.0),6.7(0.0),65.1(3.9),63.7(3.3),79.4(1.3)
NetFlow,77.9,55.0(26.1),89.1(0.4),62.7(23.4),77.7(0.5),77.9(0.0),72.3(17.6),63.0(18.2),89.0(0.9),94.5(0.4)
UWave,84.53,90.1(0.3),93.4(0.3),92.6(0.4),90.8(0.4),12.5(0.0),12.5(0.0),84.5(1.6),85.9(0.7),75.4(6.3)
