# Example sample-wise adaptation using OTDA and BOTDA

In [8]:
import numpy as np
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
import ot
import scipy.io
import mne          
from mne.decoding import CSP
mne.set_log_level(verbose='warning') #to avoid info at terminal
import matplotlib.pyplot as pl
np.random.seed(100)
from MIOTDAfunctions import*

# get the functions from RPA package
import rpa.transfer_learning as TL
# pyriemann import
from pyriemann.classification import MDM
from pyriemann.estimation import Covariances
from pyriemann.utils.base import invsqrtm
import timeit

#ignore warning 
# import warnings filter
from warnings import simplefilter
# ignore all future warnings
simplefilter(action='ignore', category=FutureWarning)

## load data and filter it

In [9]:
fName = 'Data/DataSession1_S9.mat'
s = scipy.io.loadmat(fName)

Data_S1=s["X"]
Labels_S1=s["y"]
Labels_S1=np.squeeze(Labels_S1)

#filterting with mne
[nt, nc, ns]=np.shape(Data_S1)
Data_S1=np.reshape(Data_S1, [nt, nc*ns])
Data_S1=mne.filter.filter_data(Data_S1, 128, 8, 30)
Data_S1=np.reshape(Data_S1, [nt,nc,ns])

fName = 'Data/DataSession2_S9.mat'
s2 = scipy.io.loadmat(fName)

Data_S2=s2["X"]
Labels_S2=s2["y"]
Labels_S2=np.squeeze(Labels_S2)

#filterting with mne
[nt, nc, ns]=np.shape(Data_S2)
Data_S2=np.reshape(Data_S2, [nt, nc*ns])
Data_S2=mne.filter.filter_data(Data_S2, 128, 8, 30)
Data_S2=np.reshape(Data_S2, [nt,nc,ns])

### learn CSP+LDA from source data (Data_S1)

In [10]:
Xtr=Data_S1
Ytr=Labels_S1
csp = CSP(n_components=6, reg='empirical', log=True, norm_trace=False, cov_est='epoch')
#learn csp filters
Gtr=csp.fit_transform(Xtr,Ytr)
#learn lda
lda = LinearDiscriminantAnalysis()
lda.fit(Gtr,Ytr)

LinearDiscriminantAnalysis(n_components=None, priors=None, shrinkage=None,
              solver='svd', store_covariance=False, tol=0.0001)

### first 20 trials of the new session used as transportation set

In [11]:
##
Labels_te=Labels_S2[20:]
##
Xval=Data_S2[0:20]
Yval=Labels_S2[0:20]
##
Gval=csp.transform(Xval)

#### for saving outcomes

In [16]:
yt_predict_sc=[]
yt_predict_sr=[]
yt_predict_1=[]
yt_predict_2=[]
yt_predict_3=[]
yt_predict_4=[]
yt_predict_rpa=[]
yt_predict_eu=[]

### set OTDA params

In [17]:
rango_cl=[0.1, 1, 10]
rango_e=[0.1, 1, 10] 
metrica = 'sqeuclidean'
outerkfold = 20
innerkfold = None
M=20
clf=LinearDiscriminantAnalysis()

### select subset Gtr

In [18]:
# Subset selection re-training path
G_FOTDAs_, Y_FOTDAs_, regu_FOTDAs_=\
SelectSubsetTraining_OTDAs(Gtr, Ytr, Gval, Yval, rango_e, clf, metrica, outerkfold, innerkfold, M)
G_FOTDAl1l2_, Y_FOTDAl1l2_, regu_FOTDAl1l2_=\
    SelectSubsetTraining_OTDAl1l2(Gtr, Ytr, Gval, Yval, rango_e, rango_cl, clf, metrica, outerkfold, innerkfold, M)
G_BOTDAs_, Y_BOTDAs_, regu_BOTDAs_=\
SelectSubsetTraining_BOTDAs(Gtr, Ytr, Gval, Yval, rango_e, lda, metrica, outerkfold, innerkfold, M)
G_BOTDAl1l2_, Y_BOTDAl1l2_, regu_BOTDAl1l2_=\
SelectSubsetTraining_BOTDAl1l2(Gtr, Ytr, Gval, Yval, rango_e, rango_cl, lda, metrica, outerkfold, innerkfold, M)

## for each trial 

In [None]:
for re in range(1,len(Labels_te)+1):
    if np.mod(re,10)==0 : print('Running testing trial={:1.0f}'.format(re))
    #testing trial
    Xte=Data_S2[20+(re-1):20+(re)]
    Xte=Xte.reshape(1, nc, ns)
    Yte=Labels_S2[20+(re-1):20+(re)]
    
    Xval=np.vstack((Xval, Xte))
    Yval=np.hstack((Yval, Yte))
    
    #csp estimation
    Gval=csp.transform(Xval)
    Gte=csp.transform(Xte)
    #feature computation
    Gte=csp.transform(Xte)
    
    #evaluate SC  
    yt_predict_sc.append(lda.predict(Gte))

    #evaluate SR
    # time
    start = timeit.default_timer()

    Xtr2=np.vstack((Xtr,Xval))
    Ytr2=np.hstack((Ytr, Yval))
    Xtr2=Xtr2[len(Yval):]
    Ytr2=Ytr2[len(Yval):]

    csp2 = CSP(n_components=6, reg='empirical', log=True, norm_trace=False, cov_est='epoch')
    #learn csp filters
    Gtr2=csp2.fit_transform(Xtr2,Ytr2)
    #learn lda
    lda2 = LinearDiscriminantAnalysis()
    lda2.fit(Gtr2,Ytr2)

    Gte2=csp2.transform(Xte)

    #ldatest
    yt_predict_sr.append(lda2.predict(Gte2))
    # time
    stop = timeit.default_timer()
    time_sr = stop - start
    
    #%% # Sinkhorn Transport
    # time
    start = timeit.default_timer()

    Gtr_daot=G_FOTDAs_
    Ytr_daot=Y_FOTDAs_ 
    ot_sinkhorn= ot.da.SinkhornTransport(metric=metrica,reg_e=regu_FOTDAs_)
    #learn the map
    ot_sinkhorn.fit(Xs=Gtr_daot, ys=Ytr_daot, Xt=Gval)
    #apply the mapping over source data
    transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Gtr)

    # retraining
    lda3 = LinearDiscriminantAnalysis()
    lda3.fit(transp_Xs_sinkhorn,Ytr)
    # Compute acc
    yt_predict_1.append(lda3.predict(Gte))
    # time
    stop = timeit.default_timer()
    time_fs = stop - start

    #%% # Group-Lasso Transport
    # time
    start = timeit.default_timer()
    Gtr_daot=G_FOTDAl1l2_
    Ytr_daot=Y_FOTDAl1l2_

    ot_l1l2 = ot.da.SinkhornL1l2Transport(metric=metrica,reg_e=regu_FOTDAl1l2_[0], reg_cl=regu_FOTDAl1l2_[1])

    ot_l1l2.fit(Xs=Gtr_daot, ys=Ytr_daot, Xt=Gval)

    #transport taget samples onto source samples
    transp_Xs_l1l2=ot_l1l2.transform(Xs=Gtr)

    # retraining
    lda3 = LinearDiscriminantAnalysis()
    lda3.fit(transp_Xs_l1l2,Ytr)

    # Compute acc
    yt_predict_2.append(lda3.predict(Gte))
    # time
    stop = timeit.default_timer()
    time_fg = stop - start

    #%% # Backward Sinkhorn Transport
    # time
    start = timeit.default_timer()   
    Gtr_botda=G_BOTDAs_
    Ytr_botda=Y_BOTDAs_

    bot_s = ot.da.SinkhornTransport(metric=metrica,reg_e=regu_BOTDAs_)


    bot_s.fit(Xs=Gval, ys=Yval, Xt=Gtr_botda)
    #transport testing samples
    transp_Xt_s_backward=bot_s.transform(Xs=Gte)
    # Compute accuracy one-training    
    yt_predict_3.append(lda.predict(transp_Xt_s_backward))
    # time
    stop = timeit.default_timer()
    time_bs = stop - start

    #%% # Backward Group-Lasso Transport
    # time
    start = timeit.default_timer()    
    Gtr_botda=G_BOTDAl1l2_
    Ytr_botda=Y_BOTDAl1l2_

    bot_l1l2 = ot.da.SinkhornL1l2Transport(metric=metrica,reg_e=regu_BOTDAl1l2_[0], reg_cl=regu_BOTDAl1l2_[1])

    bot_l1l2.fit(Xs=Gval, ys=Yval, Xt=Gtr_botda)
    #transport testing samples
    transp_Xt_l1l2_backward=bot_l1l2.transform(Xs=Gte)
    # Compute accuracy one-training    
    yt_predict_4.append(lda.predict(transp_Xt_l1l2_backward))
    # time
    stop = timeit.default_timer()
    time_bg = stop - start


    
    #%% # Riemann
    # time
    start_rpa = timeit.default_timer()
    # cov matrix estimation
    cov_tr = Covariances().transform(Xtr)
    cov_val= Covariances().transform(Xval)
    cov_te = Covariances().transform(Xte)
        
    clf = MDM()
    source={'covs':cov_tr, 'labels': Ytr}
    target_org_train={'covs':cov_val, 'labels': Yval}
    target_org_test={'covs':cov_te, 'labels': Yte}
    # get the score with the re-centered matrices
    source_rct, target_rct_train, target_rct_test = TL.RPA_recenter(source, target_org_train, target_org_test)   
    # rotate the re-centered-stretched matrices using information from classes
    source_rpa, target_rpa_train, target_rpa_test = TL.RPA_rotate(source_rct, target_rct_train, target_rct_test)
    # get score
    covs_source, y_source = source_rpa['covs'], source_rpa['labels']
    covs_target_train, y_target_train = target_rpa_train['covs'], target_rpa_train['labels']
    covs_target_test, y_target_test = target_rpa_test['covs'], target_rpa_test['labels']
    
    covs_train = np.concatenate([covs_source, covs_target_train])
    y_train = np.concatenate([y_source, y_target_train])
    clf.fit(covs_train, y_train)

    covs_test = covs_target_test
    y_test = y_target_test

    yt_predict_rpa.append(clf.predict(covs_test))
    # time
    stop_rpa = timeit.default_timer()
    time_rpa = stop_rpa - start_rpa
        
    #%% # Euclidean
    # get arithmetic mean
    # time
    start_eu = timeit.default_timer()
    # Estimate single trial covariance
    cov_tr = Covariances().transform(Xtr)
    cov_val= Covariances().transform(Xval)
    Ctr = cov_tr.mean(0)
    Cval = cov_val.mean(0)
    # aligment
    Xtr_eu = np.asarray([np.dot(invsqrtm(Ctr), epoch) for epoch in Xtr])
    Xval_eu = np.asarray([np.dot(invsqrtm(Cval), epoch) for epoch in Xval])
    Xte_eu = np.asarray([np.dot(invsqrtm(Cval), epoch) for epoch in Xte])


    x_train = np.concatenate([Xtr_eu, Xval_eu])
    y_train = np.concatenate([Ytr, Yval])

    # train new csp+lda

    csp2 = CSP(n_components=6, reg='empirical', log=True, norm_trace=False, cov_est='epoch')
    # learn csp filters
    Gtr2 = csp2.fit_transform(x_train,y_train)
    # learn lda
    lda2=LinearDiscriminantAnalysis()

    # lda2=LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')
    lda2.fit(Gtr2,y_train)

    Gte2=csp2.transform(Xte_eu)

    # ldatest
    yt_predict_eu.append(lda2.predict(Gte2))
    # time
    stop_eu= timeit.default_timer()
    time_eu = stop_eu - start_eu
    
    #save times
    times = [time_sr, time_rpa, time_eu,time_fs, time_fg, time_bs, time_bg]
        
    if re==1:
        times_se = times
    else:
        times_se = np.vstack((times_se, times))


Running testing trial=10
Running testing trial=20
Running testing trial=30
Running testing trial=40
Running testing trial=50
Running testing trial=60


In [None]:
# compute accuracy 
yt_predict_4=np.squeeze(np.asarray(yt_predict_4))
yt_predict_3=np.squeeze(np.asarray(yt_predict_3))
yt_predict_2=np.squeeze(np.asarray(yt_predict_2))
yt_predict_1=np.squeeze(np.asarray(yt_predict_1))
yt_predict_sc=np.squeeze(np.asarray(yt_predict_sc))
yt_predict_sr=np.squeeze(np.asarray(yt_predict_sr))

acc_botdal1l2=accuracy_score(Labels_te, yt_predict_4)
acc_botdas=accuracy_score(Labels_te, yt_predict_3)
acc_fotdal1l2=accuracy_score(Labels_te, yt_predict_2)
acc_fotdas=accuracy_score(Labels_te, yt_predict_1)
acc_sc=accuracy_score(Labels_te, yt_predict_sc)
acc_sr=accuracy_score(Labels_te, yt_predict_sr)

#print accuracy
acc={}
acc["sc"]=acc_sc
acc["sr"]=acc_sr
acc["fotda_s"]=acc_fotdas
acc["fotda_l1l2"]=acc_fotdal1l2
acc["botda_s"]=acc_botdas
acc["botda_l1l2"]=acc_botdal1l2
    
print(acc)  

#print computing time
mean_time = times.mean(axis=0)
time = {}
time["sr"] = round(mean_time[0],3)
time["rpa"] = round(mean_time[1],3)
time["eu"] = round(mean_time[2],3)
time["fotda_s"] = round(mean_time[3],3)
time["fotda_l1l2"] = round(mean_time[4],3)
time["botda_s"] = round(mean_time[5],3)
time["botda_l1l2"] = round(mean_time[6],3)
    
print(time)  