## F1 state representation quality using multiple states  

In [1]:
from bbvi_infer_2l import * 
from f1 import *
%matplotlib inline
import pickle
from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import silhouette_score

  from .autonotebook import tqdm as notebook_tqdm


### Read in data observations

In [2]:
year= 2022
gp = 'Japan'
event = 'Race'

session_event = ff1.get_session(year, gp, event)
session_event.load()
circuit_info = session_event.get_circuit_info()

core           INFO 	Loading data for Japanese Grand Prix - Race [v3.6.1]
req            INFO 	Using cached data for session_info
req            INFO 	Using cached data for driver_info
req            INFO 	Using cached data for session_status_data
req            INFO 	Using cached data for lap_count
req            INFO 	Using cached data for track_status_data
req            INFO 	Using cached data for _extended_timing_data
req            INFO 	Using cached data for timing_app_data
core           INFO 	Processing timing data...
req            INFO 	Using cached data for car_data
req            INFO 	Using cached data for position_data
req            INFO 	Using cached data for weather_data
req            INFO 	Using cached data for race_control_messages
core           INFO 	Finished loading data for 20 drivers: ['1', '11', '16', '31', '44', '5', '14', '63', '6', '4', '3', '18', '22', '20', '77', '24', '47', '10', '55', '23']


In [3]:
# Train data
pos_winner= generate_full_laps(session_event.laps.pick_driver('VER'), 'VER')
## States 
states_train, zs_train = one_hot_states_multilap(pos_winner[['X','Y']].to_numpy(), circuit_info.corners[['X','Y']].to_numpy(), 
                                    pos_winner['LapNumber'].to_numpy())



In [4]:
## Standardize train data
import numpy.random as npr
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
track_norm = scaler.fit_transform(pos_winner[['X','Y']].to_numpy())
C = npr.randn(10, 2)  ## emission matrix
ys_train = (C @ track_norm.T).T + npr.randn(track_norm.shape[0], 10)*0.01

In [5]:
## Test data
pos_alt= generate_full_laps(session_event.laps.pick_driver('HAM'), 'HAM')
### States
states_test, zs_test = one_hot_states_multilap(pos_alt[['X','Y']].to_numpy(), circuit_info.corners[['X','Y']].to_numpy(), 
                                               pos_alt['LapNumber'].to_numpy())
### Observations
track_norm = scaler.transform(pos_alt[['X','Y']].to_numpy())
ys_test = (C @ track_norm.T).T + npr.randn(track_norm.shape[0], 10)*0.01



## GSD

In [6]:
gsd = []
for seed in range(3):  ## repeat tests
    torch.manual_seed(seed)  # Set PyTorch seed
    torch.cuda.manual_seed_all(seed)  # Set CUDA seed 
    accs = []
    for K in range(2, 19): ## twiddle states
        model = GenerativeSLDS(N=10, K=K, D=2, emission_model="gaussian")
        elbos, variational_z = fit_bbvi_schedule(model.to('cuda'), ys = torch.tensor(ys_train).to('cuda').float(), 
                                             num_iters=4000, learning=True, n_samples=10, base_lr=1e-2, warmup_iters=200, tau_max=0.99)
        zs = variational_z.sample_q_z(torch.tensor(ys_train).to('cuda').float().unsqueeze(0).expand(20, -1, -1), 0.99)
        zs2 = variational_z.sample_q_z(torch.tensor(ys_test).to('cuda').float().unsqueeze(0).expand(20, -1, -1), 0.99)
        pred_ys2=model.smooth(torch.tensor(ys_test).to('cuda').float(), zs2).mean(dim=0).detach().cpu().numpy() # [T, N]
        print(f"Test accuracy for {K} states:", train_metrics(ys_test, pred_ys2, None, k_max=0))
        Z_train = torch.mean(zs, dim=0).detach().cpu().numpy()
        Y_train = np.array(zs_train)
        Z_test = torch.mean(zs2, dim=0).detach().cpu().numpy()
        Y_test = np.array(zs_test)
        # Train KNN classifier
        knn = KNeighborsClassifier(n_neighbors=10, metric='euclidean')
        knn.fit(Z_train, Y_train)
        acc = knn.score(Z_test, Y_test)
        print(f"State quality for {K} states:", acc)
        accs.append(acc)  
    gsd.append(np.array(accs))
np.save('f1GSD.npy', np.array(gsd))

ELBO: 87695.8, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [00:41<00:00, 95.49it/s]


Test accuracy for 2 states: [0.974468207750306]
State quality for 2 states: 0.3969631236442516


ELBO: 250150.7, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [00:42<00:00, 94.97it/s]


Test accuracy for 3 states: [0.9973776521624671]
State quality for 3 states: 0.36969321351100093


ELBO: 272086.2, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [00:43<00:00, 91.23it/s]


Test accuracy for 4 states: [0.9978422869173654]
State quality for 4 states: 0.47737837000309885


ELBO: 180660.9, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [00:46<00:00, 85.39it/s]


Test accuracy for 5 states: [0.9828250911813476]
State quality for 5 states: 0.551131081499845


ELBO: 258014.4, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [00:50<00:00, 79.86it/s]


Test accuracy for 6 states: [0.997236976672523]
State quality for 6 states: 0.44941121784939575


ELBO: 262046.2, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [00:53<00:00, 75.33it/s]


Test accuracy for 7 states: [0.9973067487237909]
State quality for 7 states: 0.6353424233033778


ELBO: 277370.7, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [00:57<00:00, 70.01it/s]


Test accuracy for 8 states: [0.9979116112039905]
State quality for 8 states: 0.6678803842578246


ELBO: 263127.0, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:01<00:00, 64.68it/s]


Test accuracy for 9 states: [0.9977755949539366]
State quality for 9 states: 0.648590021691974


ELBO: 117090.2, LR: 0.00100, Tau: 0.990:  78%|███████▊  | 3108/4000 [00:50<00:14, 61.68it/s]


Early stopping at iteration 3108.
Test accuracy for 10 states: [0.9832527951248415]
State quality for 10 states: 0.3617136659436009


ELBO: 276953.2, LR: 0.00100, Tau: 0.990:  95%|█████████▍| 3792/4000 [01:04<00:03, 59.02it/s]


Early stopping at iteration 3792.
Test accuracy for 11 states: [0.99786369137434]
State quality for 11 states: 0.6935233963433529


ELBO: 268941.6, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:09<00:00, 57.16it/s]


Test accuracy for 12 states: [0.9978129098651488]
State quality for 12 states: 0.4804772234273319


ELBO: 204407.8, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:13<00:00, 54.22it/s]


Test accuracy for 13 states: [0.9941428029586625]
State quality for 13 states: 0.6220948249147815


ELBO: 259850.8, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:16<00:00, 52.15it/s]


Test accuracy for 14 states: [0.9975821100678921]
State quality for 14 states: 0.5668577626278277


ELBO: 272246.8, LR: 0.00100, Tau: 0.990:  89%|████████▉ | 3577/4000 [01:11<00:08, 50.05it/s]


Early stopping at iteration 3577.
Test accuracy for 15 states: [0.9977406713042987]
State quality for 15 states: 0.6645491168267741


ELBO: 220697.4, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:29<00:00, 44.91it/s]


Test accuracy for 16 states: [0.9959675368093786]
State quality for 16 states: 0.5037960954446855


ELBO: 275280.5, LR: 0.00100, Tau: 0.990:  78%|███████▊  | 3110/4000 [01:14<00:21, 41.70it/s]

## rSLDS

In [6]:
import ssm
from ssm.extensions.mp_srslds.transitions_ext import StickyRecurrentOnlyTransitions, StickyRecurrentTransitions

In [10]:
slds_acc = []
slds_sil = []
for seed in range(3):  ## repeat tests
    torch.manual_seed(seed)  # Set PyTorch seed
    torch.cuda.manual_seed_all(seed)  # Set CUDA seed 
    accs2 = []
    sils = []
    for K in range(2, 19):
        sro_trans=StickyRecurrentOnlyTransitions(K=K,D=2) 
        rslds = ssm.SLDS(N=10, K=K, D=2,
                 transitions=sro_trans,
                 dynamics="diagonal_gaussian",
                 emissions="gaussian",
                 single_subspace=True)
        rslds.initialize(ys_train)
        q_elbos_lem, q_lem = rslds.fit(ys_train, method="laplace_em", 
                                   variational_posterior="structured_meanfield",
                                   initialize=False, num_iters=100, alpha=0.0)
        q_lem_x = q_lem.mean_continuous_states[0]
        q_lem_z = rslds.most_likely_states(q_lem_x, ys_train)
        # plot_trajectory(q_lem_z, q_lem_x)
        elbos, q_lem2 = rslds.approximate_posterior(ys_test,
                                        method="laplace_em",
                                        variational_posterior="structured_meanfield",
                                        num_iters=20)
        q_lem_x2 = q_lem2.mean_continuous_states[0]
        q_lem_z2 = rslds.most_likely_states(q_lem_x2, ys_test)
        sil = silhouette_score(q_lem_x2, q_lem_z2)
        pred_ys2=rslds.smooth(q_lem_x2, ys_test)
        print(f"Test accuracy for {K} states:", train_metrics(ys_test, pred_ys2, None, k_max=0))
        print(f"Silhouette score at {K} states:", sil)
        Z_train = np.eye(K)[q_lem_z]
        Y_train = np.array(zs_train)
        Z_test = np.eye(K)[q_lem_z2]
        Y_test = np.array(zs_test)
        # Train KNN classifier
        knn = KNeighborsClassifier(n_neighbors=10, metric='euclidean')
        knn.fit(Z_train, Y_train)
        acc = knn.score(Z_test, Y_test)
        print(f"State quality for {K} states:", acc)
        accs2.append(acc)
        sils.append(sil)
    slds_acc.append(np.array(accs2))
    slds_sil.append(np.array(sils))
np.save('f1SLDS.npy', np.array(slds_acc))    
np.save('f1SLDS_sil.npy', np.array(slds_sil))   

ELBO: 362534.6: 100%|██████████| 100/100 [01:29<00:00,  1.11it/s]
ELBO: 353592.9: 100%|██████████| 20/20 [00:13<00:00,  1.48it/s]


Test accuracy for 2 states: [0.9993817874650348]
Silhouette score at 2 states: 0.04706478178738587
State quality for 2 states: 0.17260613572977998


ELBO: 368538.0: 100%|██████████| 100/100 [01:44<00:00,  1.04s/it]
ELBO: 358742.3: 100%|██████████| 20/20 [00:15<00:00,  1.32it/s]


Test accuracy for 3 states: [0.9993799412354623]
Silhouette score at 3 states: 0.024683241748336693
State quality for 3 states: 0.1789587852494577


ELBO: 372255.2: 100%|██████████| 100/100 [01:57<00:00,  1.18s/it]
ELBO: 362553.0: 100%|██████████| 20/20 [00:16<00:00,  1.18it/s]


Test accuracy for 4 states: [0.9993791102001859]
Silhouette score at 4 states: -0.01676337406826936
State quality for 4 states: 0.24387976448713977


ELBO: 371867.8: 100%|██████████| 100/100 [02:17<00:00,  1.37s/it]
ELBO: 362008.5: 100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


Test accuracy for 5 states: [0.9993785067615193]
Silhouette score at 5 states: -0.05986335262567291
State quality for 5 states: 0.21444065695692593


ELBO: 374539.1: 100%|██████████| 100/100 [02:38<00:00,  1.59s/it]
ELBO: 364614.5: 100%|██████████| 20/20 [00:21<00:00,  1.08s/it]


Test accuracy for 6 states: [0.9993769634728478]
Silhouette score at 6 states: -0.034838222557713974
State quality for 6 states: 0.19460799504183451


ELBO: 375643.7: 100%|██████████| 100/100 [03:03<00:00,  1.83s/it]
ELBO: 366154.2: 100%|██████████| 20/20 [00:25<00:00,  1.25s/it]


Test accuracy for 7 states: [0.9993767664782279]
Silhouette score at 7 states: -0.05391351620067358
State quality for 7 states: 0.17128912302448093


ELBO: 375536.8: 100%|██████████| 100/100 [03:29<00:00,  2.09s/it]
ELBO: 365698.6: 100%|██████████| 20/20 [00:28<00:00,  1.41s/it]


Test accuracy for 8 states: [0.9993747534960642]
Silhouette score at 8 states: -0.09029590792546859
State quality for 8 states: 0.26510691044313606


ELBO: 376291.1: 100%|██████████| 100/100 [03:55<00:00,  2.36s/it]
ELBO: 366348.0: 100%|██████████| 20/20 [00:31<00:00,  1.57s/it]


Test accuracy for 9 states: [0.9993764338045571]
Silhouette score at 9 states: -0.07892296361838727
State quality for 9 states: 0.23055469476293772


ELBO: 376573.5: 100%|██████████| 100/100 [04:35<00:00,  2.76s/it]
ELBO: 366773.5: 100%|██████████| 20/20 [00:35<00:00,  1.75s/it]


Test accuracy for 10 states: [0.999374167377943]
Silhouette score at 10 states: -0.0834192150715985
State quality for 10 states: 0.28370003098853425


ELBO: 378381.5: 100%|██████████| 100/100 [05:17<00:00,  3.17s/it]
ELBO: 368112.1: 100%|██████████| 20/20 [00:40<00:00,  2.02s/it]


Test accuracy for 11 states: [0.9993746398959393]
Silhouette score at 11 states: -0.14598712759597846
State quality for 11 states: 0.23566780291292222


ELBO: 379264.9: 100%|██████████| 100/100 [06:21<00:00,  3.81s/it]
ELBO: 369241.3: 100%|██████████| 20/20 [00:44<00:00,  2.23s/it]


Test accuracy for 12 states: [0.9993736944989837]
Silhouette score at 12 states: -0.11610846421702982
State quality for 12 states: 0.4828013634955067


ELBO: 380541.7: 100%|██████████| 100/100 [07:43<00:00,  4.63s/it]
ELBO: 370076.1: 100%|██████████| 20/20 [00:53<00:00,  2.69s/it]


Test accuracy for 13 states: [0.9993730270245894]
Silhouette score at 13 states: -0.0909277210949511
State quality for 13 states: 0.27091726061357296


ELBO: 380631.9: 100%|██████████| 100/100 [08:18<00:00,  4.99s/it]
ELBO: 370413.4: 100%|██████████| 20/20 [01:00<00:00,  3.02s/it]


Test accuracy for 14 states: [0.9993731330249803]
Silhouette score at 14 states: -0.10189277749842464
State quality for 14 states: 0.3804617291602107


ELBO: 380440.8: 100%|██████████| 100/100 [09:15<00:00,  5.55s/it]
ELBO: 370048.6: 100%|██████████| 20/20 [01:06<00:00,  3.31s/it]


Test accuracy for 15 states: [0.9993709412897083]
Silhouette score at 15 states: -0.15744508956310513
State quality for 15 states: 0.2993492407809111


ELBO: 381308.5: 100%|██████████| 100/100 [11:16<00:00,  6.76s/it]
ELBO: 370936.6: 100%|██████████| 20/20 [01:14<00:00,  3.74s/it]


Test accuracy for 16 states: [0.9993725251060143]
Silhouette score at 16 states: -0.122482317888844
State quality for 16 states: 0.5464828013634955


ELBO: 380791.9: 100%|██████████| 100/100 [11:37<00:00,  6.98s/it]
ELBO: 370234.7: 100%|██████████| 20/20 [01:19<00:00,  3.99s/it]


Test accuracy for 17 states: [0.9993711535107475]
Silhouette score at 17 states: -0.1261452243782836
State quality for 17 states: 0.4404245429191199


ELBO: 381812.4: 100%|██████████| 100/100 [15:13<00:00,  9.13s/it]
ELBO: 371066.3: 100%|██████████| 20/20 [01:26<00:00,  4.34s/it]


Test accuracy for 18 states: [0.9993720563386097]
Silhouette score at 18 states: -0.10210011434565246
State quality for 18 states: 0.4786953827083979


ELBO: 365333.8: 100%|██████████| 100/100 [01:30<00:00,  1.10it/s]
ELBO: 355877.1: 100%|██████████| 20/20 [00:13<00:00,  1.49it/s]


Test accuracy for 2 states: [0.9993813834864043]
Silhouette score at 2 states: 0.07020163151144407
State quality for 2 states: 0.06523086458010537


ELBO: 370128.5: 100%|██████████| 100/100 [01:43<00:00,  1.03s/it]
ELBO: 360856.3: 100%|██████████| 20/20 [00:14<00:00,  1.33it/s]


Test accuracy for 3 states: [0.9993798698938525]
Silhouette score at 3 states: 0.010108391554201858
State quality for 3 states: 0.1433219708707778


ELBO: 371910.8: 100%|██████████| 100/100 [01:57<00:00,  1.18s/it]
ELBO: 362439.7: 100%|██████████| 20/20 [00:17<00:00,  1.16it/s]


Test accuracy for 4 states: [0.9993789178208417]
Silhouette score at 4 states: -0.009159440426635435
State quality for 4 states: 0.1720638363805392


ELBO: 371779.7: 100%|██████████| 100/100 [02:15<00:00,  1.36s/it]
ELBO: 361808.5: 100%|██████████| 20/20 [00:19<00:00,  1.04it/s]


Test accuracy for 5 states: [0.999378463896929]
Silhouette score at 5 states: -0.060462803656817214
State quality for 5 states: 0.2193988224356988


ELBO: 371378.4: 100%|██████████| 100/100 [02:34<00:00,  1.54s/it]
ELBO: 361578.8: 100%|██████████| 20/20 [00:21<00:00,  1.08s/it]


Test accuracy for 6 states: [0.9993776692088783]
Silhouette score at 6 states: -0.09590691040576556
State quality for 6 states: 0.21033467616981716


ELBO: 374731.2: 100%|██████████| 100/100 [03:02<00:00,  1.82s/it]
ELBO: 364586.9: 100%|██████████| 20/20 [00:25<00:00,  1.26s/it]


Test accuracy for 7 states: [0.9993754275414581]
Silhouette score at 7 states: -0.07911060802119362
State quality for 7 states: 0.26069104431360396


ELBO: 376718.9: 100%|██████████| 100/100 [03:34<00:00,  2.14s/it]
ELBO: 366577.4: 100%|██████████| 20/20 [00:28<00:00,  1.42s/it]


Test accuracy for 8 states: [0.9993750109117542]
Silhouette score at 8 states: -0.14794362108083228
State quality for 8 states: 0.21358847226526184


ELBO: 376585.5: 100%|██████████| 100/100 [03:58<00:00,  2.39s/it]
ELBO: 366497.6: 100%|██████████| 20/20 [00:33<00:00,  1.65s/it]


Test accuracy for 9 states: [0.9993754959689438]
Silhouette score at 9 states: -0.09584790016588293
State quality for 9 states: 0.41439417415556246


ELBO: 376391.5: 100%|██████████| 100/100 [04:46<00:00,  2.86s/it]
ELBO: 366717.0: 100%|██████████| 20/20 [00:35<00:00,  1.80s/it]


Test accuracy for 10 states: [0.9993752361262912]
Silhouette score at 10 states: -0.11454486235643967
State quality for 10 states: 0.399907034397273


ELBO: 379816.8:  37%|███▋      | 37/100 [02:35<03:07,  2.98s/it]