# TKAN example and comparison with benchmarks

All test have been run on a RTX 4070 with an Core™ i7-6700K on vast.ai using this [jax docker image](https://hub.docker.com/r/bitnami/jax/)

tkan version: 0.4.1

In [1]:
!pip install pandas numpy matplotlib pyarrow scikit-learn tkan "jax[cuda12]"

[33mDEPRECATION: Loading egg at /opt/bitnami/python/lib/python3.11/site-packages/pip-23.3.2-py3.11.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m

In [2]:
import os
BACKEND = 'jax' # You can use any backend here 
os.environ['KERAS_BACKEND'] = BACKEND

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense, Input, Flatten, GRU

from sklearn.metrics import r2_score
from sklearn.metrics import root_mean_squared_error

from tkan import TKAN

import time

keras.utils.set_random_seed(1) 

N_MAX_EPOCHS = 1000
BATCH_SIZE = 128
early_stopping_callback = lambda : keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.00001,
    patience=10,
    mode="min",
    restore_best_weights=True,
    start_from_epoch=6,
)
lr_callback = lambda : keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.25,
    patience=5,
    mode="min",
    min_delta=0.00001,
    min_lr=0.000025,
    verbose=0,
)
callbacks = lambda : [early_stopping_callback(), lr_callback(), keras.callbacks.TerminateOnNaN()]


# Data

In [3]:
df = pd.read_parquet('/workspace/data.parquet')
df = df[(df.index >= pd.Timestamp('2020-01-01')) & (df.index < pd.Timestamp('2023-01-01'))]
assets = ['BTC', 'ETH', 'ADA', 'XMR', 'EOS', 'MATIC', 'TRX', 'FTM', 'BNB', 'XLM', 'ENJ', 'CHZ', 'BUSD', 'ATOM', 'LINK', 'ETC', 'XRP', 'BCH', 'LTC']
df = df[[c for c in df.columns if 'quote asset volume' in c and any(asset in c for asset in assets)]]
df.columns = [c.replace(' quote asset volume', '') for c in df.columns]
display(df)

Unnamed: 0_level_0,BTC,ADA,XMR,EOS,CHZ,MATIC,TRX,ENJ,FTM,BNB,XLM,BUSD,ATOM,LTC,LINK,ETC,ETH,XRP,BCH
group,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
2020-01-01 00:00:00,3.675857e+06,38189.176211,4.539598e+04,94778.577031,817.146319,31003.791035,481993.354990,15241.945783,1165.788613,8.498617e+05,9460.819556,1.352376e+04,31986.972694,1.165827e+05,24281.170262,56488.402352,1.000930e+06,2.579254e+05,178258.749391
2020-01-01 01:00:00,6.365953e+06,51357.010954,3.348395e+04,593292.135445,886.460339,84465.335718,533668.554562,11896.843688,413.844612,7.405759e+05,37141.909518,2.531605e+04,81777.666046,2.830715e+05,51190.975142,182102.074213,1.474278e+06,4.520609e+05,615321.025242
2020-01-01 02:00:00,4.736719e+06,36164.263914,1.573255e+04,266732.556000,1819.795050,113379.718506,387049.986770,30109.770521,3559.965968,1.039091e+06,16878.822627,1.390886e+04,195731.175551,2.402871e+05,28721.756184,134063.422732,9.940256e+05,4.414948e+05,221535.645771
2020-01-01 03:00:00,5.667367e+06,24449.953815,2.575105e+04,124516.579473,2979.655803,41771.707995,450772.139235,6732.833578,4076.415482,4.975018e+05,9049.223394,2.251969e+04,120113.343316,1.613043e+05,29596.222534,131094.172168,6.473610e+05,1.886061e+05,397185.950571
2020-01-01 04:00:00,3.379094e+06,44502.669843,6.295563e+04,421819.671410,1023.388675,22254.756114,284788.973752,846.938455,633.367505,4.751285e+05,7254.260203,1.122460e+04,19989.169106,2.214516e+05,54514.370016,134937.122201,4.430067e+05,2.279373e+05,316499.137509
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022-12-31 19:00:00,6.704605e+07,581680.400510,3.873989e+05,48359.865300,199491.822800,890911.573610,225136.420055,40281.859330,159553.944500,9.889098e+05,39230.588600,6.560756e+06,180809.784710,9.964355e+05,190664.976300,181340.756100,7.738029e+06,1.413563e+06,35409.149500
2022-12-31 20:00:00,4.344849e+07,323561.762270,1.379392e+05,37858.704700,173057.240300,333511.762200,157069.026827,42228.830930,270251.374500,6.032059e+05,52964.531800,7.255324e+06,276013.421720,1.173164e+06,265727.950340,90513.087600,4.278879e+06,1.113527e+06,42674.516600
2022-12-31 21:00:00,5.992803e+07,455185.698060,2.445869e+05,79538.050600,107544.609700,525037.759990,180404.744820,27446.620810,198885.610000,1.386864e+06,44485.594800,8.712142e+06,476151.071190,6.820723e+05,265687.852060,85399.066100,4.643401e+06,1.373231e+06,38027.858800
2022-12-31 22:00:00,1.106669e+08,763590.393960,1.486058e+06,119816.048000,227360.873900,940456.693720,378767.904610,37877.840280,179819.382700,1.387985e+06,43947.636100,5.517835e+06,558172.963150,8.838422e+05,678706.662170,319377.553800,1.143952e+07,3.006036e+06,50730.511300


In [4]:
class MinMaxScaler:
    def __init__(self, feature_axis=None, minmax_range=(0, 1)):
        """
        Initialize the MinMaxScaler.
        Args:
        feature_axis (int, optional): The axis that represents the feature dimension if applicable.
                                      Use only for 3D data to specify which axis is the feature axis.
                                      Default is None, automatically managed based on data dimensions.
        """
        self.feature_axis = feature_axis
        self.min_ = None
        self.max_ = None
        self.scale_ = None
        self.minmax_range = minmax_range # Default range for scaling (min, max)

    def fit(self, X):
        """
        Fit the scaler to the data based on its dimensionality.
        Args:
        X (np.array): The data to fit the scaler on.
        """
        if X.ndim == 3 and self.feature_axis is not None:  # 3D data
            axis = tuple(i for i in range(X.ndim) if i != self.feature_axis)
            self.min_ = np.min(X, axis=axis)
            self.max_ = np.max(X, axis=axis)
        elif X.ndim == 2:  # 2D data
            self.min_ = np.min(X, axis=0)
            self.max_ = np.max(X, axis=0)
        elif X.ndim == 1:  # 1D data
            self.min_ = np.min(X)
            self.max_ = np.max(X)
        else:
            raise ValueError("Data must be 1D, 2D, or 3D.")

        self.scale_ = self.max_ - self.min_
        return self

    def transform(self, X):
        """
        Transform the data using the fitted scaler.
        Args:
        X (np.array): The data to transform.
        Returns:
        np.array: The scaled data.
        """
        X_scaled = (X - self.min_) / self.scale_
        X_scaled = X_scaled * (self.minmax_range[1] - self.minmax_range[0]) + self.minmax_range[0]
        return X_scaled

    def fit_transform(self, X):
        """
        Fit to data, then transform it.
        Args:
        X (np.array): The data to fit and transform.
        Returns:
        np.array: The scaled data.
        """
        return self.fit(X).transform(X)

    def inverse_transform(self, X_scaled):
        """
        Inverse transform the scaled data to original data.
        Args:
        X_scaled (np.array): The scaled data to inverse transform.
        Returns:
        np.array: The original data scale.
        """
        X = (X_scaled - self.minmax_range[0]) / (self.minmax_range[1] - self.minmax_range[0])
        X = X * self.scale_ + self.min_
        return X

def generate_data(df, sequence_length, n_ahead = 1):
    #Case without known inputs
    scaler_df = df.copy().shift(n_ahead).rolling(24 * 14).median()
    tmp_df = df.copy() / scaler_df
    tmp_df = tmp_df.iloc[24 * 14 + n_ahead:].fillna(0.)
    scaler_df = scaler_df.iloc[24 * 14 + n_ahead:].fillna(0.)
    def prepare_sequences(df, scaler_df, n_history, n_future):
        X, y, y_scaler = [], [], []
        num_features = df.shape[1]
        
        # Iterate through the DataFrame to create sequences
        for i in range(n_history, len(df) - n_future + 1):
            # Extract the sequence of past observations
            X.append(df.iloc[i - n_history:i].values)
            # Extract the future values of the first column
            y.append(df.iloc[i:i + n_future,0:1].values)
            y_scaler.append(scaler_df.iloc[i:i + n_future,0:1].values)
        
        X, y, y_scaler = np.array(X), np.array(y), np.array(y_scaler)
        return X, y, y_scaler
    
    # Prepare sequences
    X, y, y_scaler = prepare_sequences(tmp_df, scaler_df, sequence_length, n_ahead)
    
    # Split the dataset into training and testing sets
    train_test_separation = int(len(X) * 0.8)
    X_train_unscaled, X_test_unscaled = X[:train_test_separation], X[train_test_separation:]
    y_train_unscaled, y_test_unscaled = y[:train_test_separation], y[train_test_separation:]
    y_scaler_train, y_scaler_test = y_scaler[:train_test_separation], y_scaler[train_test_separation:]
    
    # Generate the data
    X_scaler = MinMaxScaler(feature_axis=2)
    X_train = X_scaler.fit_transform(X_train_unscaled)
    X_test = X_scaler.transform(X_test_unscaled)
    
    y_scaler = MinMaxScaler(feature_axis=2)
    y_train = y_scaler.fit_transform(y_train_unscaled)
    y_test = y_scaler.transform(y_test_unscaled)
    
    y_train = y_train.reshape(y_train.shape[0], -1) 
    y_test = y_test.reshape(y_test.shape[0], -1)
    return X_scaler, X_train, X_test, X_train_unscaled, X_test_unscaled, y_scaler, y_train, y_test, y_train_unscaled, y_test_unscaled, y_scaler_train, y_scaler_test



In [5]:
n_aheads = [1, 3, 6, 9, 12, 15]
models = [
    "TKAN",
    "GRU",
    "LSTM",
 ]

results = {model: {n_ahead: [] for n_ahead in n_aheads} for model in models}
results_rmse = {model: {n_ahead: [] for n_ahead in n_aheads} for model in models}
time_results = {model: {n_ahead: [] for n_ahead in n_aheads} for model in models}
for n_ahead in n_aheads:
    sequence_length = max(45, 5 * n_ahead)
    X_scaler, X_train, X_test, X_train_unscaled, X_test_unscaled, y_scaler, y_train, y_test, y_train_unscaled, y_test_unscaled, y_scaler_train, y_scaler_test = generate_data(df, sequence_length, n_ahead)
    
    for model_id in models:
        
        for run in range(10):

            if model_id == 'TKAN':
                model = Sequential([
                    Input(shape=X_train.shape[1:]),
                    TKAN(100, return_sequences=True),
                    TKAN(100, sub_kan_output_dim = 20, sub_kan_input_dim = 20, return_sequences=False),
                    Dense(units=n_ahead, activation='linear')
                ], name = model_id)
            elif model_id == 'GRU':
                model = Sequential([
                    Input(shape=X_train.shape[1:]),
                    GRU(100, return_sequences=True),
                    GRU(100, return_sequences=False),
                    Dense(units=n_ahead, activation='linear')
                ], name = model_id)
            elif model_id == 'LSTM':
                model = Sequential([
                    Input(shape=X_train.shape[1:]),
                    LSTM(100, return_sequences=True),
                    LSTM(100, return_sequences=False),
                    Dense(units=n_ahead, activation='linear')
                ], name = model_id)
            else:
                raise ValueError
            
            optimizer = keras.optimizers.Adam(0.001)
            model.compile(optimizer=optimizer, loss='mean_squared_error', jit_compile=True)
            if run==0:
                model.summary()
                
            # Fit the model
            start_time = time.time()
            history = model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=N_MAX_EPOCHS, validation_split=0.2, callbacks=callbacks(), shuffle=True, verbose = False)
            end_time = time.time()
            time_results[model_id][n_ahead].append(end_time - start_time)
            # Evaluate the model on the test set
            preds = model.predict(X_test, verbose=False)
            r2 = r2_score(y_true=y_test, y_pred=preds)
            print(end_time - start_time, r2)
            rmse = root_mean_squared_error(y_true=y_test, y_pred=preds)
            results[model_id][n_ahead].append(r2)
            results_rmse[model_id][n_ahead].append(rmse)
    
            del model
            del optimizer
                

print('R2 scores')
print('Means:')
display(pd.DataFrame({model_id: {n_ahead: np.mean(results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results.keys()}))
display(pd.DataFrame({model_id: {n_ahead: np.mean(results_rmse[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results_rmse.keys()}))
print('Std:')
display(pd.DataFrame({model_id: {n_ahead: np.std(results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results.keys()}))
display(pd.DataFrame({model_id: {n_ahead: np.std(results_rmse[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results_rmse.keys()}))
print('Training Times')
display(pd.DataFrame({model_id: {n_ahead: np.mean(time_results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in time_results.keys()}))
display(pd.DataFrame({model_id: {n_ahead: np.std(time_results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in time_results.keys()}))

87.7924542427063 0.2883694212812108
58.81150460243225 0.2996397661396135
93.38526654243469 0.3341084410571994
69.46869540214539 0.31051983221133617
68.03620338439941 0.2921106524335578
69.86754989624023 0.31632189994316484
83.32379245758057 0.3110331089568017
87.04948496818542 0.32794331666897736
84.89363670349121 0.31305216100227273
65.98418259620667 0.30768323851050017


25.800300121307373 0.39845659417931767
33.090811014175415 0.3765754299721896
29.157642126083374 0.3996904390926539
20.830080032348633 0.39383111789434266
24.594220638275146 0.39854787088819454
25.829734086990356 0.3980608345819734
20.8519549369812 0.39818872394114213
29.55962371826172 0.39876597794912183
28.695900917053223 0.4014753210019817
31.473294496536255 0.385033271890177


17.663942098617554 0.3868841900601451
18.93087387084961 0.36429593525523807
17.406195163726807 0.3900695412878702
20.44808530807495 0.3786097961170115
16.755755186080933 0.392573888221295
20.213451862335205 0.3840845450430137
22.768882513046265 0.36630601832303555
18.173134088516235 0.3832138605836025
17.591335773468018 0.3918248620808171
19.007630586624146 0.37993474576853103


96.62368559837341 0.2039716435474497
74.19592547416687 0.1932276162314687
79.30283999443054 0.186078923268888
67.78045701980591 0.1888747778655048
90.85375475883484 0.20038166436812602
83.50374150276184 0.198828800729096
79.80795526504517 0.1848691783440326
59.527796268463135 0.18538958325425756
73.05465388298035 0.17813436933288326
77.17678189277649 0.185736715040992


32.23410701751709 0.24445524304420677
34.76943516731262 0.2201955661466759
25.166553735733032 0.23882125234574236
28.084814310073853 0.2344194906558783
27.79819416999817 0.235279604982632
31.362367868423462 0.23122380522590277
24.73177719116211 0.2426694875126042
26.578373432159424 0.23963082071966313
27.062561988830566 0.24032075216240234
28.228745222091675 0.24178121308010916


20.623400688171387 0.14482767971630928
21.055927991867065 0.07395821838488396
24.374985218048096 0.022805938949401933
21.713353157043457 0.0749404329796987
21.184755563735962 0.1522887184364525
18.015154600143433 0.16326070883053923
18.17797017097473 0.1804372350265144
17.993833541870117 0.18266789728362973
18.728025197982788 0.14342017896777437
20.97398853302002 -0.002893617498768264


65.0717043876648 0.12467616278770983
90.68930077552795 0.13929252527937125
76.80289697647095 0.12103943728777532
84.18350672721863 0.13730972361175564
59.469563484191895 0.10800808332018592
80.92852973937988 0.13700481460829453
85.58175778388977 0.1324956475675861
78.75041937828064 0.13073665118121205
80.78611373901367 0.12963646594644282
73.53622269630432 0.12723135400013533


26.26027750968933 0.1602740337153421
30.87123155593872 0.12435047691678952
30.425692319869995 0.10110813722797496
29.073087215423584 0.13987663037711742
27.583446502685547 0.12691294524178082
30.0588481426239 0.09291668323995494
30.896180629730225 0.13374339087226397
26.069262266159058 0.13703060900398636
26.07837224006653 0.14521618469459105
28.18198537826538 0.13697203267595567


25.56489086151123 0.07762226752057655
17.17194938659668 -0.030710740566574007
22.08793616294861 0.12124885788661292
18.8874089717865 -0.19992596769356452
18.460254669189453 -0.03568727060662086
19.47718334197998 -0.1732379052017987
18.813477754592896 -0.0515297041407121
18.96723985671997 -0.00603457861723435
17.978565216064453 0.06545331792210407
18.76132583618164 -0.21833925712011495


79.56117129325867 0.10413605110577702
95.19019293785095 0.10785571960353318
62.058950901031494 0.1022740221679086
57.313427448272705 0.09659038594513981
95.39828181266785 0.11798299575914914
109.50233817100525 0.11471786200051105
91.66243028640747 0.10141240448038549
83.4484133720398 0.11271609311366479
80.45564818382263 0.10729036960326731
89.72633719444275 0.10601871726343631


22.387739181518555 0.12872665220206386
30.9358811378479 0.07007529147404491
30.65935730934143 0.06094348103987125
22.681428909301758 0.11531462257235203
30.2576744556427 -0.009184511914249627
23.401977062225342 0.1145808233763385
29.867743015289307 -0.0020916084892752293
21.44379758834839 0.13675731715603687
32.31744432449341 0.08318358609980472
36.52803301811218 -0.02434265870166709


20.542985439300537 -0.058979852588216475
19.61976170539856 -0.07973171479104875
19.607908487319946 -0.15736860736964103
17.53559136390686 -0.09773126198658748
19.434728860855103 -0.39964713298492527
20.465200901031494 -0.2428380152707189
18.985889673233032 -0.5012525957302113
20.644855260849 -0.2738693890416318
19.245380401611328 -0.05146921829097126
19.828939199447632 -0.0008553086817381464


102.03934359550476 0.0950019998069378
121.27882838249207 0.10264306639669311
108.56248140335083 0.0990784027938416
77.42332911491394 0.0915531316015012
108.50404167175293 0.09352464587821978
95.75039339065552 0.092634008105403
102.65606546401978 0.09820662504181471
130.8202440738678 0.09765221932195063
113.90040755271912 0.09989627016065132
118.27212023735046 0.1041245428645089


35.47221755981445 0.04088546661722716
36.82418465614319 0.03298667710271531
29.090290307998657 0.08667586160493074
27.488890886306763 0.09733002632430683
30.277069330215454 0.10445165105456156
26.764411687850952 0.11400616888042837
35.36870241165161 0.02046360232907878
26.94648003578186 0.11067681945648022
35.716086864471436 0.02628936001538192
32.919156312942505 0.05842798641178421


26.069539546966553 -0.36695473176783655
21.807262659072876 -0.25014631415290206
20.902458667755127 -0.11645941888847307
22.4891197681427 -0.052407595508665694
22.018208742141724 -0.1754222997379895
23.925899028778076 -0.444045746274295
23.6626615524292 -0.2362643247514915
20.98034143447876 -0.16545529402468354
22.008980751037598 -0.23959377820038377
21.912462949752808 -0.5002557724294028


108.7913281917572 0.09994023646490877
156.19587421417236 0.08586536769033742
151.76872277259827 0.0953234353911566
109.25939583778381 0.08350316914250061
120.67966675758362 0.09383066359946372
118.93469595909119 0.09027677327433754
117.88046956062317 0.09486514143615367
106.22660422325134 0.09703190398065814
154.6233696937561 0.09741707472392
124.99523186683655 0.09227563879811171


41.711098432540894 0.04633579312704695
37.61102104187012 0.041959092296909146
33.6794171333313 0.08124711803653464
41.39418888092041 0.062273110060791455
33.090798139572144 0.09644091868374302
34.905259132385254 0.0688485348099919
31.54578924179077 0.09070480771844083
31.649243354797363 0.08222696375881307
34.72465991973877 0.07408738048462013
34.364349365234375 0.06764429477726015


27.019399642944336 -0.02014558059192651
24.994324684143066 -0.034652205897714984
29.875847578048706 -0.2883331388024311
27.68077278137207 -0.19307423134624493
32.10626721382141 -0.1967002645401843
26.022549629211426 -0.10374175022709843
28.117710828781128 -0.0705739546737177
30.498241662979126 -0.25929543857022325
24.907368898391724 -0.09258147039272031
26.723052978515625 -0.15915120445580425
R2 scores
Means:


Unnamed: 0,TKAN,GRU,LSTM
1,0.310078,0.394863,0.38178
3,0.190549,0.23688,0.113571
6,0.128743,0.12984,-0.045114
9,0.107099,0.067396,-0.186374
12,0.097431,0.069219,-0.254701
15,0.093033,0.071177,-0.141825


Unnamed: 0,TKAN,GRU,LSTM
1,0.058833,0.055101,0.055693
3,0.063659,0.06172,0.066402
6,0.066264,0.066109,0.072302
9,0.066893,0.068206,0.076763
12,0.067267,0.068214,0.079032
15,0.067711,0.068456,0.075833


Std:


Unnamed: 0,TKAN,GRU,LSTM
1,0.013617,0.007498,0.009371
3,0.00782,0.006761,0.063202
6,0.008831,0.018997,0.112792
9,0.006202,0.057011,0.156331
12,0.004002,0.035374,0.135507
15,0.004925,0.016791,0.087433


Unnamed: 0,TKAN,GRU,LSTM
1,0.000581,0.00034,0.000421
3,0.000316,0.000273,0.002282
6,0.000347,0.000707,0.003831
9,0.000237,0.002033,0.004908
12,0.000157,0.001278,0.004174
15,0.000187,0.000609,0.002844


Training Times


Unnamed: 0,TKAN,GRU,LSTM
1,76.861277,26.988356,18.895929
3,78.182759,28.601693,20.284139
6,77.580002,28.549838,19.617023
9,84.431719,28.048108,19.591124
12,107.920725,31.686749,22.577694
15,126.935536,35.467582,27.794554


Unnamed: 0,TKAN,GRU,LSTM
1,11.082262,3.94529,1.723379
3,10.159973,3.052013,1.950043
6,8.965218,1.885116,2.32025
9,14.904037,4.8942,0.870263
12,14.096107,3.808997,1.490844
15,18.705339,3.458216,2.267511
