## F1 state representation quality using multiple states  

In [1]:
from bbvi_infer_2l import * 
from f1 import *
%matplotlib inline
import pickle
from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import silhouette_score

  from .autonotebook import tqdm as notebook_tqdm


### Read in data observations

In [2]:
year= 2024
gp = 'China'
event = 'Race'

session_event = ff1.get_session(year, gp, event)
session_event.load()
circuit_info = session_event.get_circuit_info()

core           INFO 	Loading data for Chinese Grand Prix - Race [v3.6.1]
req            INFO 	Using cached data for session_info
req            INFO 	Using cached data for driver_info
req            INFO 	Using cached data for session_status_data
req            INFO 	Using cached data for lap_count
req            INFO 	Using cached data for track_status_data
req            INFO 	Using cached data for _extended_timing_data
req            INFO 	Using cached data for timing_app_data
core           INFO 	Processing timing data...
req            INFO 	Using cached data for car_data
req            INFO 	Using cached data for position_data
req            INFO 	Using cached data for weather_data
req            INFO 	Using cached data for race_control_messages
core           INFO 	Finished loading data for 20 drivers: ['1', '4', '11', '16', '55', '63', '14', '81', '44', '27', '31', '23', '10', '24', '18', '20', '2', '3', '22', '77']


In [3]:
# Train data
pos_winner= generate_full_laps(session_event.laps.pick_driver('VER'), 'VER')
## States 
states_train, zs_train = one_hot_states_multilap(pos_winner[['X','Y']].to_numpy(), circuit_info.corners[['X','Y']].to_numpy(), 
                                    pos_winner['LapNumber'].to_numpy())



In [4]:
## Standardize train data
import numpy.random as npr
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
track_norm = scaler.fit_transform(pos_winner[['X','Y']].to_numpy())
C = npr.randn(10, 2)  ## emission matrix
ys_train = (C @ track_norm.T).T + npr.randn(track_norm.shape[0], 10)*0.01

In [5]:
## Test data
pos_alt= generate_full_laps(session_event.laps.pick_driver('SAI'), 'SAI')
### States
states_test, zs_test = one_hot_states_multilap(pos_alt[['X','Y']].to_numpy(), circuit_info.corners[['X','Y']].to_numpy(), 
                                               pos_alt['LapNumber'].to_numpy())
### Observations
track_norm = scaler.transform(pos_alt[['X','Y']].to_numpy())
ys_test = (C @ track_norm.T).T + npr.randn(track_norm.shape[0], 10)*0.01



## GSD

In [6]:
gsd = []
for seed in range(3):  ## repeat tests
    torch.manual_seed(seed)  # Set PyTorch seed
    torch.cuda.manual_seed_all(seed)  # Set CUDA seed 
    accs = []
    for K in range(2, 11):
        model = GenerativeSLDS(N=10, K=K, D=2, emission_model="gaussian")
        elbos, variational_z = fit_bbvi_schedule(model.to('cuda'), ys = torch.tensor(ys_train).to('cuda').float(), 
                                             num_iters=4000, learning=True, n_samples=10, base_lr=1e-2, warmup_iters=200, tau_max=0.99)
        zs = variational_z.sample_q_z(torch.tensor(ys_train).to('cuda').float().unsqueeze(0).expand(20, -1, -1), 0.99)
        zs2 = variational_z.sample_q_z(torch.tensor(ys_test).to('cuda').float().unsqueeze(0).expand(20, -1, -1), 0.99)
        pred_ys2=model.smooth(torch.tensor(ys_test).to('cuda').float(), zs2).mean(dim=0).detach().cpu().numpy() # [T, N]
        print(f"Test accuracy for {K} states:", train_metrics(ys_test, pred_ys2, None, k_max=0))
        Z_train = torch.mean(zs, dim=0).detach().cpu().numpy()
        Y_train = np.array(zs_train)
        Z_test = torch.mean(zs2, dim=0).detach().cpu().numpy()
        Y_test = np.array(zs_test)
        # Train KNN classifier
        knn = KNeighborsClassifier(n_neighbors=10, metric='euclidean')
        knn.fit(Z_train, Y_train)
        acc = knn.score(Z_test, Y_test)
        print(f"State quality for {K} states:", acc)
        accs.append(acc)  
    print("Finish one round,", accs)
    gsd.append(np.array(accs))
    np.save('./saved2024/f12024GSD.npy', np.array(gsd))

ELBO: 171215.2, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [00:50<00:00, 79.85it/s]


Test accuracy for 2 states: [0.9849207501261057]
State quality for 2 states: 0.3610425006387872


ELBO: 494214.8, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [00:55<00:00, 72.24it/s]


Test accuracy for 3 states: [0.9991589784671543]
State quality for 3 states: 0.6589302444425518


ELBO: 470370.8, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:03<00:00, 63.48it/s]


Test accuracy for 4 states: [0.9989725057349246]
State quality for 4 states: 0.7194446810322801


ELBO: 5536.8, LR: 0.00100, Tau: 0.990:  25%|██▌       | 1005/4000 [00:17<00:51, 58.39it/s]  


Early stopping at iteration 1005.
Test accuracy for 5 states: [0.9773706019118789]
State quality for 5 states: 0.6426624648667064


ELBO: 247024.3, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:13<00:00, 54.37it/s]


Test accuracy for 6 states: [0.9932593822875779]
State quality for 6 states: 0.6292053487777872


ELBO: 200985.7, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:19<00:00, 50.52it/s]


Test accuracy for 7 states: [0.9862095506159146]
State quality for 7 states: 0.6330806575249127


ELBO: 477772.0, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:29<00:00, 44.71it/s]


Test accuracy for 8 states: [0.9990380417329424]
State quality for 8 states: 0.7010050251256281


ELBO: 400501.2, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:38<00:00, 40.80it/s]


Test accuracy for 9 states: [0.996178151090454]
State quality for 9 states: 0.737756579507708


ELBO: 464532.5, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:42<00:00, 38.84it/s]


Test accuracy for 10 states: [0.9989036909872235]
State quality for 10 states: 0.6739630355165659
Finish one round, [0.3610425006387872, 0.6589302444425518, 0.7194446810322801, 0.6426624648667064, 0.6292053487777872, 0.6330806575249127, 0.7010050251256281, 0.737756579507708, 0.6739630355165659]


ELBO: 112918.0, LR: 0.00100, Tau: 0.990:  53%|█████▎    | 2113/4000 [00:26<00:23, 79.81it/s]


Early stopping at iteration 2113.
Test accuracy for 2 states: [0.9791681248777339]
State quality for 2 states: 0.4901200919853505


ELBO: 478287.1, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [00:55<00:00, 72.34it/s]


Test accuracy for 3 states: [0.9990123334104014]
State quality for 3 states: 0.5782301337194446


ELBO: 4208.2, LR: 0.00100, Tau: 0.990:  70%|███████   | 2813/4000 [00:44<00:18, 63.37it/s]  


Early stopping at iteration 2813.
Test accuracy for 4 states: [0.9709867444757606]
State quality for 4 states: 0.42619879056298443


ELBO: 424680.3, LR: 0.00100, Tau: 0.990: 100%|██████████| 4000/4000 [01:10<00:00, 56.75it/s]


Test accuracy for 5 states: [0.9984357895503919]
State quality for 5 states: 0.4857763393237373


ELBO: 184472.9, LR: 0.00100, Tau: 0.990:  55%|█████▌    | 2208/4000 [00:40<00:32, 54.48it/s]

## rSLDS

In [6]:
import ssm
from ssm.extensions.mp_srslds.transitions_ext import StickyRecurrentOnlyTransitions, StickyRecurrentTransitions

In [7]:
slds_acc = []
slds_sil = []
for seed in range(3):  ## repeat tests
    torch.manual_seed(seed)  # Set PyTorch seed
    torch.cuda.manual_seed_all(seed)  # Set CUDA seed 
    accs2 = []
    sils = []
    for K in range(2, 11):
        sro_trans=StickyRecurrentOnlyTransitions(K=K,D=2) 
        rslds = ssm.SLDS(N=10, K=K, D=2,
                 transitions=sro_trans,
                 dynamics="diagonal_gaussian",
                 emissions="gaussian",
                 single_subspace=True)
        rslds.initialize(ys_train)
        q_elbos_lem, q_lem = rslds.fit(ys_train, method="laplace_em", 
                                   variational_posterior="structured_meanfield",
                                   initialize=False, num_iters=100, alpha=0.0)
        q_lem_x = q_lem.mean_continuous_states[0]
        q_lem_z = rslds.most_likely_states(q_lem_x, ys_train)
        # plot_trajectory(q_lem_z, q_lem_x)
        elbos, q_lem2 = rslds.approximate_posterior(ys_test,
                                        method="laplace_em",
                                        variational_posterior="structured_meanfield",
                                        num_iters=20)
        q_lem_x2 = q_lem2.mean_continuous_states[0]
        q_lem_z2 = rslds.most_likely_states(q_lem_x2, ys_test)
        sil = silhouette_score(q_lem_x2, q_lem_z2)
        pred_ys2=rslds.smooth(q_lem_x2, ys_test)
        print(f"Test accuracy for {K} states:", train_metrics(ys_test, pred_ys2, None, k_max=0))
        print(f"Silhouette score at {K} states:", sil)
        Z_train = np.eye(K)[q_lem_z]
        Y_train = np.array(zs_train)
        Z_test = np.eye(K)[q_lem_z2]
        Y_test = np.array(zs_test)
        # Train KNN classifier
        knn = KNeighborsClassifier(n_neighbors=10, metric='euclidean')
        knn.fit(Z_train, Y_train)
        acc = knn.score(Z_test, Y_test)
        print(f"State quality for {K} states:", acc)
        accs2.append(acc)
        sils.append(sil)
    print("Finish one round,", accs2)
    slds_acc.append(np.array(accs2))
    slds_sil.append(np.array(sils))
    np.save('./saved2024/f12024SLDS.npy', np.array(slds_acc))    
    np.save('./saved2024/f12024SLDS_sil.npy', np.array(slds_sil))   

ELBO: 647499.2: 100%|██████████| 100/100 [02:43<00:00,  1.63s/it]
ELBO: 651184.6: 100%|██████████| 20/20 [00:23<00:00,  1.19s/it]


Test accuracy for 2 states: [0.9997806517054715]
Silhouette score at 2 states: 0.043377017763064435
State quality for 2 states: 0.03560173750106464


ELBO: 656269.4: 100%|██████████| 100/100 [03:00<00:00,  1.80s/it]
ELBO: 660121.1: 100%|██████████| 20/20 [00:26<00:00,  1.34s/it]


Test accuracy for 3 states: [0.9997805796289772]
Silhouette score at 3 states: -0.042072581759947034
State quality for 3 states: 0.15910058768418364


ELBO: 661709.9: 100%|██████████| 100/100 [03:42<00:00,  2.23s/it]
ELBO: 665660.7: 100%|██████████| 20/20 [00:32<00:00,  1.62s/it]


Test accuracy for 4 states: [0.9997799054998697]
Silhouette score at 4 states: -0.048837278239596806
State quality for 4 states: 0.17779575845328335


ELBO: 661434.3: 100%|██████████| 100/100 [04:11<00:00,  2.52s/it]
ELBO: 665584.6: 100%|██████████| 20/20 [00:35<00:00,  1.78s/it]


Test accuracy for 5 states: [0.9997799219806183]
Silhouette score at 5 states: -0.07971538243774795
State quality for 5 states: 0.1091474320756324


ELBO: 663317.7: 100%|██████████| 100/100 [05:03<00:00,  3.04s/it]
ELBO: 667019.0: 100%|██████████| 20/20 [00:41<00:00,  2.09s/it]


Test accuracy for 6 states: [0.9997796383854226]
Silhouette score at 6 states: -0.057608712615416946
State quality for 6 states: 0.24929733412826846


ELBO: 665400.3: 100%|██████████| 100/100 [05:50<00:00,  3.51s/it]
ELBO: 669453.1: 100%|██████████| 20/20 [00:47<00:00,  2.40s/it]


Test accuracy for 7 states: [0.999779189845017]
Silhouette score at 7 states: -0.05099060521540595
State quality for 7 states: 0.23362575589813475


ELBO: 667492.9: 100%|██████████| 100/100 [06:48<00:00,  4.08s/it]
ELBO: 671727.2: 100%|██████████| 20/20 [00:55<00:00,  2.75s/it]


Test accuracy for 8 states: [0.9997790773144178]
Silhouette score at 8 states: -0.04792630228215438
State quality for 8 states: 0.2748488203730517


ELBO: 668575.3: 100%|██████████| 100/100 [07:38<00:00,  4.59s/it]
ELBO: 672335.2: 100%|██████████| 20/20 [01:01<00:00,  3.07s/it]


Test accuracy for 9 states: [0.9997787768276843]
Silhouette score at 9 states: -0.08617376045944702
State quality for 9 states: 0.4407205519121029


ELBO: 669322.1: 100%|██████████| 100/100 [09:06<00:00,  5.46s/it]
ELBO: 672884.2: 100%|██████████| 20/20 [01:08<00:00,  3.44s/it]


Test accuracy for 10 states: [0.9997785921090852]
Silhouette score at 10 states: -0.11593516642066365
State quality for 10 states: 0.3611276722596031
Finish one round, [0.03560173750106464, 0.15910058768418364, 0.17779575845328335, 0.1091474320756324, 0.24929733412826846, 0.23362575589813475, 0.2748488203730517, 0.4407205519121029, 0.3611276722596031]


ELBO: 646632.7: 100%|██████████| 100/100 [02:35<00:00,  1.55s/it]
ELBO: 650263.0: 100%|██████████| 20/20 [00:23<00:00,  1.19s/it]


Test accuracy for 2 states: [0.9997808857287808]
Silhouette score at 2 states: 0.04168725105060632
State quality for 2 states: 0.121965761008432


ELBO: 654732.7: 100%|██████████| 100/100 [02:55<00:00,  1.76s/it]
ELBO: 658451.1: 100%|██████████| 20/20 [00:26<00:00,  1.33s/it]


Test accuracy for 3 states: [0.9997806342896209]
Silhouette score at 3 states: -0.0009589945782676562
State quality for 3 states: 0.1932969934417852


ELBO: 660588.5: 100%|██████████| 100/100 [03:21<00:00,  2.01s/it]
ELBO: 664391.2: 100%|██████████| 20/20 [00:30<00:00,  1.50s/it]


Test accuracy for 4 states: [0.9997802505447642]
Silhouette score at 4 states: -0.039245425577938135
State quality for 4 states: 0.22983561877182523


ELBO: 661323.1: 100%|██████████| 100/100 [03:51<00:00,  2.32s/it]
ELBO: 665293.6: 100%|██████████| 20/20 [00:34<00:00,  1.71s/it]


Test accuracy for 5 states: [0.9997798946093635]
Silhouette score at 5 states: -0.04824309260674932
State quality for 5 states: 0.2047525764415297


ELBO: 665655.8: 100%|██████████| 100/100 [04:50<00:00,  2.90s/it]
ELBO: 669427.0: 100%|██████████| 20/20 [00:41<00:00,  2.06s/it]


Test accuracy for 6 states: [0.9997793792495896]
Silhouette score at 6 states: -0.0806727820078585
State quality for 6 states: 0.21275870879822842


ELBO: 664646.5: 100%|██████████| 100/100 [05:52<00:00,  3.53s/it]
ELBO: 668593.1: 100%|██████████| 20/20 [00:47<00:00,  2.35s/it]


Test accuracy for 7 states: [0.9997793167615996]
Silhouette score at 7 states: -0.1362184565552807
State quality for 7 states: 0.22966527553019334


ELBO: 671391.5: 100%|██████████| 100/100 [06:32<00:00,  3.93s/it]
ELBO: 675092.8: 100%|██████████| 20/20 [00:54<00:00,  2.71s/it]


Test accuracy for 8 states: [0.9997780914100598]
Silhouette score at 8 states: -0.07664817259933777
State quality for 8 states: 0.3526105101780087


ELBO: 671695.1: 100%|██████████| 100/100 [07:31<00:00,  4.51s/it]
ELBO: 675799.1: 100%|██████████| 20/20 [01:01<00:00,  3.06s/it]


Test accuracy for 9 states: [0.9997781769182719]
Silhouette score at 9 states: -0.027845416813297157
State quality for 9 states: 0.16872498083638532


ELBO: 673336.9: 100%|██████████| 100/100 [09:09<00:00,  5.50s/it]
ELBO: 676951.1: 100%|██████████| 20/20 [01:09<00:00,  3.47s/it]


Test accuracy for 10 states: [0.9997777061238657]
Silhouette score at 10 states: -0.004791495609389184
State quality for 10 states: 0.37105016608466057
Finish one round, [0.121965761008432, 0.1932969934417852, 0.22983561877182523, 0.2047525764415297, 0.21275870879822842, 0.22966527553019334, 0.3526105101780087, 0.16872498083638532, 0.37105016608466057]


ELBO: 647058.3: 100%|██████████| 100/100 [02:51<00:00,  1.71s/it]
ELBO: 650412.8: 100%|██████████| 20/20 [00:26<00:00,  1.35s/it]


Test accuracy for 2 states: [0.9997808628761373]
Silhouette score at 2 states: 0.03973423078557887
State quality for 2 states: 0.121965761008432


ELBO: 652099.0: 100%|██████████| 100/100 [02:55<00:00,  1.76s/it]
ELBO: 655457.6: 100%|██████████| 20/20 [00:26<00:00,  1.33s/it]


Test accuracy for 3 states: [0.9997808323230084]
Silhouette score at 3 states: -0.008301927765826684
State quality for 3 states: 0.1544587343497147


ELBO: 661506.2: 100%|██████████| 100/100 [03:25<00:00,  2.06s/it]
ELBO: 665236.7: 100%|██████████| 20/20 [00:29<00:00,  1.50s/it]


Test accuracy for 4 states: [0.9997799495615087]
Silhouette score at 4 states: -0.04388174721383831
State quality for 4 states: 0.21948726684268802


ELBO: 665517.7: 100%|██████████| 100/100 [03:55<00:00,  2.36s/it]
ELBO: 669329.5: 100%|██████████| 20/20 [00:34<00:00,  1.71s/it]


Test accuracy for 5 states: [0.9997791061872903]
Silhouette score at 5 states: -0.05303508362983916
State quality for 5 states: 0.1399369730005962


ELBO: 663300.0: 100%|██████████| 100/100 [04:57<00:00,  2.97s/it]
ELBO: 667501.0: 100%|██████████| 20/20 [00:41<00:00,  2.06s/it]


Test accuracy for 6 states: [0.9997794806673614]
Silhouette score at 6 states: -0.099859297165606
State quality for 6 states: 0.2188910654969764


ELBO: 664443.7: 100%|██████████| 100/100 [05:39<00:00,  3.40s/it]
ELBO: 668552.4: 100%|██████████| 20/20 [00:47<00:00,  2.36s/it]


Test accuracy for 7 states: [0.9997794315277201]
Silhouette score at 7 states: -0.07412027959459139
State quality for 7 states: 0.2511711097862192


ELBO: 666133.2: 100%|██████████| 100/100 [06:34<00:00,  3.95s/it]
ELBO: 670007.9: 100%|██████████| 20/20 [00:56<00:00,  2.84s/it]


Test accuracy for 8 states: [0.999779133741623]
Silhouette score at 8 states: -0.08528136025250942
State quality for 8 states: 0.23294438293160719


ELBO: 668852.2: 100%|██████████| 100/100 [07:49<00:00,  4.70s/it]
ELBO: 673013.0: 100%|██████████| 20/20 [01:00<00:00,  3.04s/it]


Test accuracy for 9 states: [0.9997784174535924]
Silhouette score at 9 states: -0.15391312369341537
State quality for 9 states: 0.2769781108934503


ELBO: 669391.3: 100%|██████████| 100/100 [08:55<00:00,  5.35s/it]
ELBO: 673192.6: 100%|██████████| 20/20 [01:09<00:00,  3.49s/it]


Test accuracy for 10 states: [0.9997786277510159]
Silhouette score at 10 states: -0.13996371548986072
State quality for 10 states: 0.24614598415807853
Finish one round, [0.121965761008432, 0.1544587343497147, 0.21948726684268802, 0.1399369730005962, 0.2188910654969764, 0.2511711097862192, 0.23294438293160719, 0.2769781108934503, 0.24614598415807853]
