In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from optax.schedules import  warmup_cosine_decay_schedule

from src.utils import from_timeseries_to_dataset, batch_dataset,unbatch_dataset
from src.kernel import VFTSGaussKernel, TSGaussGaussKernel, GaussKernel, ExpKernel,TSExpExpKernel,VFTSExpKernel,MyTSGaussKernel
from src.lddmm import varifold_registration, Shooting, Flowing, batch_varifold_time_initializer,batch_one_to_many_varifold_registration
from src.plotting import plot2Dfigure
from src.barycenter import batch_varifold_barycenter_time_registration, batch_varifolld_barycenter_initializer,batch_varifold_barycenter_registration, batch_barycenter_registration
from src.loss import VarifoldLoss,MyLoss

In [None]:
from loadmydata.load_uea_ucr import load_uea_ucr_data

dataset_name = "ECG200" #"ECGFiveDays"
data = load_uea_ucr_data(dataset_name)

print(data.description)
print(data.X_train.shape)
print(data.X_test.shape)
print(data.y_train.shape)
print(data.y_test.shape)

In [None]:
X = data.X_train
X,X_mask = from_timeseries_to_dataset(X)
bX,bX_mask = batch_dataset(X,X.shape[0],X_mask)

In [None]:
Kv = VFTSGaussKernel(1,0.1,35,1,1)
Kl = MyTSGaussKernel(3,1,1)
shoot = Shooting(Kv)
flow = Flowing(Kv)

loss = MyLoss(Kl)
init,init_mask = X[10],X_mask[10]
bp,qbar,qbar_mask = batch_barycenter_registration(bX,bX_mask,Kv,loss,init,init_mask,niter=800,optimizer=optax.adam(warmup_cosine_decay_schedule(0,0.01,80,800,0)))

In [None]:
plt.plot(*X[0].T)
plt.plot(*qbar.T)

In [None]:
for x in X: 
    plt.plot(*x.T,color="black",alpha=0.2)

plt.plot(*qbar.T,color = "red")

In [None]:

plot2Dfigure(qbar,bX[0,1],bp[0,1],shoot,flow,qbar_mask,bX_mask[0,1])

In [None]:
iKv = GaussKernel(20)
iKl = TSGaussGaussKernel(1,1,1,1)
Kv = VFTSGaussKernel(1,0.1,35,1,1)
Kl = TSGaussGaussKernel(1,1,1,1)
shoot = Shooting(Kv)
flow = Flowing(Kv)

idx0 = 8
idx1 = 11

#init = batch_varifold_time_initializer(iKv,Kl,0,400,optax.adam(warmup_cosine_decay_schedule(0,0.1,40,400,0)),verbose=True)
#bp,q,qm = batch_one_to_many_varifold_registration(X[idx0],X_mask[idx0],bX,bX_mask,Kv,Kl,batched_p0=init,niter=400,optimizer = optax.adam(warmup_cosine_decay_schedule(0,0.3,40,400,0)))
#plot2Dfigure(X[idx0],X[idx1],p,shoot,flow,X_mask[idx0],X_mask[idx1])

In [None]:

init,init_mask = X[10],X_mask[10]
#time_init = batch_varifolld_barycenter_initializer(iKv,iKl,800,optimizer=optax.adam(warmup_cosine_decay_schedule(0,0.01,80,800,0)))
bp,qbar,qbar_mask = batch_varifold_barycenter_registration(bX,bX_mask,Kv,Kl,init,init_mask,niter=800,optimizer=optax.adam(warmup_cosine_decay_schedule(0,0.01,80,800,0)))

In [None]:
plt.plot(*X[0].T)
plt.plot(*qbar.T)

In [None]:
for x in X: 
    plt.plot(*x.T,color="black",alpha=0.2)

plt.plot(*qbar.T,color = "red")

In [None]:

plot2Dfigure(qbar,bX[0,1],bp[0,1],shoot,flow,qbar_mask,bX_mask[0,1])

In [None]:
from src.statistic import MomentaLDA, MomentaPCA
y_train = np.array([(int(y)+1)/2 for y in data.y_train],dtype=int)



## PCA 

In [None]:
y_train = np.array([int(y) for y in data.y_train],dtype=int)
y_train

In [None]:
ps = unbatch_dataset(bp)

ps.shape

In [None]:
ncomp, ndisp = 2, 5
pca = MomentaPCA(ncomp)
ps = unbatch_dataset(bp)
pca.fit(Kv,ps,qbar,qbar_mask)
p0_bar = np.mean(ps,axis=0)
k = 1
fig,axs = plt.subplots(ncomp,ndisp,figsize = (ndisp*3,ncomp*3),sharex=True,sharey=True)
for pca_index in range(ncomp):
    sigma_pca = np.std(pca.p_score_[:,pca_index])
    for alpha in np.linspace(-3*sigma_pca , 3*sigma_pca, ndisp):
        p0_mode = p0_bar + alpha * pca.p_pc_[pca_index]
        p,q = shoot(p0_mode,qbar,qbar_mask)
        plt.subplot(ncomp,ndisp,k)        
        plt.plot(*q.T)
        k += 1
plt.show()

In [None]:
plt.scatter(*pca.p_score_[y_train==-1].T)
plt.scatter(*pca.p_score_[y_train==1].T)

## LDA

In [None]:
lda = MomentaLDA()
lda.fit(Kv,ps,qbar,qbar_mask,y_train)

plt.scatter(lda.p_score_[0,y_train==-1], np.ones(np.sum(y_train==-1)))
plt.scatter(lda.p_score_[0,y_train==1], np.ones(np.sum(y_train == 1)))

In [None]:
ndisp = 3
#fig,axs = plt.subplots(1,ndisp,figsize = (ndisp*3,3),sharex=True,sharey=True)
sigma = np.std(lda.p_score_)
k=1
for alpha in np.linspace(-10*sigma, 10*sigma, ndisp):
    s_p0_mode =   alpha*lda.p_lda_[0]
    _,q = shoot(s_p0_mode,qbar,qbar_mask)
    #plt.subplot(1,ndisp,k)        
    plt.plot(*q.T)
    k += 1
plt.show()

In [None]:
np.sum(y_train==1)