# Iterative Kalman filter using the EM algorithm

This code can be used in both situation (using a PCA or not)
by giving a value to a variable name opt:
- When opt=0 that mean we don't use a PCA
- Any integer leads to to the application of a PCA on the data

# Imports of libraries

In [None]:
%load_ext autoreload
%autoreload 2

from Functions import *

import netCDF4 as nc
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import inv
from numpy.linalg import pinv
from numpy.linalg import det
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import pandas as pd
from scipy import stats
from scipy.stats import norm,kurtosis,skew
import imageio as iio
from IPython import display
from sklearn.decomposition import PCA


# Read data

In [None]:
fn='C:/M2_CSM/Stage/Stage/Codes/Stage_Algerie_2016.nc'
ds=nc.Dataset(fn)

# Variables associated with each dimension
time=ds.variables['time']
lat=ds.variables['latitude']
long=ds.variables['longitude']

# Other variables
u_ais=ds.variables['u_ais'] # Eastward Water Velocity from AIS (m.s^-1)
v_ais=ds.variables['v_ais'] # Northward Water Velocity form AIS (m.s^-1)
var_u_ais=ds.variables['var_u_ais'] # Estimated Eastward Water Velocity variance error from AIS (m.s^-1)
var_v_ais=ds.variables['var_v_ais'] # Estimated Northward Water Velocity variance error from AIS (m.s^-1)
u_gos=ds.variables['u_gos'] # Eastward component from DUACS (m.s^-1)
v_gos=ds.variables['v_gos'] # Northward_component_DUACS (m.s^-1)
var_u_gos=ds.variables['var_u_gos'] # Eastward variance error component from DUACS
var_v_gos=ds.variables['var_v_gos'] # Northward variance component error from DUACS

In [None]:
T=time.size # Number of time steps
La=lat.size # Number of latitude 
Lo=long.size # Number of longitude
r=2*La*Lo # size of the state vector

Data needed to display observation with cartopy

In [None]:
LONG,LAT=np.meshgrid(long[:].data,lat[:].data)

extent=[np.min(LONG),np.max(LONG),np.min(LAT),np.max(LAT)]
central_lat=np.mean(LAT)
central_long=np.mean(LONG)
pcar=ccrs.PlateCarree()

# Modification of data 

Construction of vectors u,v and their variances for ais observation and satellite observation

In [None]:
# Vectors u and v for ais observations  
u_ais=u_ais[:].data
v_ais=v_ais[:].data
u=u_ais.flatten()
v=v_ais.flatten()

# Vectors of variances of u and v 
var_u_ais=var_u_ais[:].data
var_v_ais=var_v_ais[:].data
var_u=var_u_ais.flatten()
var_v=var_v_ais.flatten()

# Vectors u and v and their variances for satellite observations
u_sat=u_gos[:].data
v_sat=v_gos[:].data

var_u_sat=var_u_gos[:].data
var_v_sat=var_v_gos[:].data

Check if u and var_u have the same number of NaN and at the same index. Same things for v and var_v. If not the index concerned is modified by a NaN

In [None]:
for i in range(len(u)):
            if np.isnan(u[i])==True:
                var_u[i]=np.nan
            if np.isnan(v[i])==True:
                var_v[i]=np.nan
            if np.isnan(var_u[i])==True:
                u[i]=np.nan
            if np.isnan(var_v[i])==True:
                v[i]=np.nan

Modifications of u,v,var_u and var_v: u: we took only values in [-1.5,1.5] v: we took only values in [-1,1] var_u: we took only values < Quartile 95% var_v: we took only values < Quartile 95% All the values which don't verify this are transformed in NaN


In [None]:
# Transformation of data to take only the value in the interval 
#[-1.5,1.5] for u and [-1,1] for v
for i in range(len(u)):
    if (u[i]<-1.5) or (u[i]>1.5):
        u[i]=np.nan
        v[i]=np.nan
        var_u[i]=np.nan
        var_v[i]=np.nan
    if (v[i]<-1) or (v[i]>1):
        v[i]=np.nan
        u[i]=np.nan
        var_u[i]=np.nan
        var_v[i]=np.nan

# Buffer variables to calculate the quantiles of variances  

var_utp=var_u
var_vtp=var_v

lvu=[]
lvv=[]

for k in range(len(var_utp)):
    if (np.isnan(var_utp[k])==True) or (np.isinf(var_utp[k])==True):
        lvu.append(k)
    if (np.isnan(var_vtp[k])==True) or (np.isinf(var_vtp[k])==True):
        lvv.append(k)
        
var_utp=np.delete(var_utp,lvu)
var_vtp=np.delete(var_vtp,lvv)

# Calculate quantiles Q1 and Q3 for var_u and var_v
Q1u=np.quantile(var_utp,0.25)
Q3u=np.quantile(var_utp,0.75)

Q1v=np.quantile(var_vtp,0.25)
Q3v=np.quantile(var_vtp,0.75)

# D9 decile calculation for var_u and var_v
D9u=min(np.max(var_utp),Q3u+1.5*(Q3u-Q1u))
D9v=min(np.max(var_vtp),Q3v+1.5*(Q3v-Q1v))

for k in range(len(u)):
    if var_u[k]>D9u:
        u[k]=np.nan
        v[k]=np.nan
        var_u[k]=np.nan
        var_v[k]=np.nan
    if var_v[k]>D9v:
        u[k]=np.nan
        v[k]=np.nan
        var_u[k]=np.nan
        var_v[k]=np.nan

Creation of the observation vector and its associated variance

In [None]:
# Reshape u and v
u=np.reshape(u,(T,La,Lo))
v=np.reshape(v,(T,La,Lo))
var_u=np.reshape(var_u,(T,La,Lo))
var_v=np.reshape(var_v,(T,La,Lo))

# Création of y and var_y
y=np.zeros((T,r))
var_y=np.zeros((T,r))

for t in range(T):
    y[t,:]=np.concatenate((u[t,:,:].flatten(),v[t,:,:].flatten()))
    var_y[t,:]=np.concatenate((var_u[t,:,:].flatten(),var_v[t,:,:].flatten()))

# Observations

Observation ais 

In [None]:
# Cartopy with ais observations
"""
for t in range(T):
    plt.figure()
    ax=plt.axes(projection=ccrs.Orthographic(central_long,central_lat))
    ax.set_extent(extent)
    gl=ax.gridlines(draw_labels=True)
    gl.top_labels=False
    gl.right_labels=False
    ax.coastlines()
    ax.add_feature(cfeature.LAND,edgecolor='black')
    ax.add_feature(cfeature.OCEAN)
    ax.quiver(LONG.flatten(),LAT.flatten(),u[t,:,:].flatten(),v[t,:,:].flatten(),transform=pcar) 
    plt.title('Image n° %i' %t )
    plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/plot_observation/observation'+str(t)+'.png')
    plt.close()

frames=np.stack([iio.imread('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/plot_observation/observation'+str(t)+'.png')for t in range(time[:].data.size)],axis=0)
iio.mimwrite('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/plot_observation/observations.gif',frames,duration=0.6)
"""
display.Image('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/plot_observation/observations.gif')


Observation satellite 

In [None]:
# Cartopy with observations by satellite 
"""
for t in range(T):
    plt.figure()
    ax=plt.axes(projection=ccrs.Orthographic(central_long,central_lat))
    ax.set_extent(extent)
    gl=ax.gridlines(draw_labels=True)
    gl.top_labels=False
    gl.right_labels=False
    ax.coastlines()
    ax.add_feature(cfeature.LAND,edgecolor='black')
    ax.add_feature(cfeature.OCEAN)
    ax.quiver(LONG,LAT,u_sat[t,:,:],v_sat[t,:,:],transform=pcar)
    plt.title('Image n° %i' %t)
    plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/obssat/obssat'+str(t)+'.png')
    plt.close()

frames=np.stack([iio.imread('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/obssat/obssat'+str(t)+'.png')for t in range(time[:].data.size)],axis=0)
iio.mimwrite('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/obssat/obssat.gif',frames,duration=0.6)   
"""
display.Image('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/obssat/obssat.gif')


# Sample of data

We took only a part of data in order to apply the Kalman filter algorithm using the EM algorithm.
We set the variable opt here :
- If opt=0, that mean we use only a part of the data which can be allocated ( no PCA use here)
- If opt!=0, that mean we use a PCA here
Lists of line and column are not the same,according to the value of opt 


In [None]:
# Variable option use to know if we apply a PCA or not on data
opt=1

# Creation of vector u and v meaning it's a part of the data
uc=u
vc=v
var_uc=var_u
var_vc=var_v

#Creation of integer lists 
ligne=[]
colonne=[]

if opt==0:
    for k in range(0,5):
        ligne.append(k)

    for k in range(16,u.shape[1]):
        ligne.append(k)

    for k in range(0,5):
        colonne.append(k)
        
    for k in range(40,u.shape[2]):
        colonne.append(k)

# A retirer (permet de verifier si avec acp pour même nombre on a les meme resultats)
#else:    
#    for k in range(int(u.shape[1]/3)+1):
#        ligne.append(k)
#    for k in range(15,u.shape[1]):
#        ligne.append(k)
#    for k in range(0,24):
#        colonne.append(k)
#    for k in range(40,u.shape[2]):
#        colonne.append(k)

else:
    for k in range(0,5):
        ligne.append(k)

    for k in range(14,u.shape[1]):
        ligne.append(k)
    
    for k in range(0,5):
        colonne.append(k)
        
    for k in range(24,u.shape[2]):
        colonne.append(k)

# We delete in uc and vc the integers contains in the lists 
#ligne and colonne
uc=np.delete(uc,ligne,1)
uc=np.delete(uc,colonne,2)

vc=np.delete(vc,ligne,1)
vc=np.delete(vc,colonne,2)

var_uc=np.delete(var_uc,ligne,1)
var_uc=np.delete(var_uc,colonne,2)

var_vc=np.delete(var_vc,ligne,1)
var_vc=np.delete(var_vc,colonne,2)

# Creation of the observation vector y and its associated variance var_y
# from uc and vc
nr=2*uc.shape[1]*uc.shape[2]
yc=np.zeros((T,nr))
var_yc=np.zeros((T,nr))

#vectors filling
for t in range(T):
    yc[t,:]=np.concatenate((uc[t,:,:].flatten(),vc[t,:,:].flatten()))
    var_yc[t,:]=np.concatenate((var_uc[t,:,:].flatten(),var_vc[t,:,:].flatten()))

We performed the same thing above for the satellite data
by using the same lists line and column.

In [None]:
# Creation of u_satc and v_satc and their associated variances
u_satc=np.delete(u_sat,ligne,1)
u_satc=np.delete(u_satc,colonne,2)

v_satc=np.delete(v_sat,ligne,1)
v_satc=np.delete(v_satc,colonne,2)

var_u_satc=np.delete(var_u_sat,ligne,1)
var_u_satc=np.delete(var_u_satc,colonne,2)

var_v_satc=np.delete(var_v_sat,ligne,1)
var_v_satc=np.delete(var_v_satc,colonne,2)

# Création of y_sat
y_sat=np.zeros((T,nr))
var_y_sat=np.zeros((T,nr))

# Vectors filling

for t in range(T):
    y_sat[t,:]=np.concatenate((u_satc[t,:].flatten(),v_satc[t,:].flatten()))
    var_y_sat[t,:]=np.concatenate((var_u_satc[t,:].flatten(),var_v_satc[t,:].flatten()))

# Pseudo satellite observations

We use satellite observation here to verify if the algorithm works
We took the satelitte observations and put a NaN in index where we have a NaN in the ais observations.

In [None]:
# Creation of y_art and its associated variance.
y_art=np.zeros((T,nr))
var_y_art=np.zeros((T,nr))

# Put NaN at each index where yc has a NaN
for t in range(T):
    y_art[t,:]=y_sat[t,:]
    var_y_art[t,:]=var_y_sat[t,:]
    
    iais=lisind(yc,t,nr)
    ivarais=lisind(var_yc,t,nr)
    for i in range(len(iais)):
        y_art[t,iais[i]]=np.nan
        var_y_art[t,ivarais[i]]=np.nan

In [None]:
"""
Utiliser pour connaitre le pourcentage de données d'apprentissage 
cn=0
for t in range(T):
    for i in range(nr):
        if np.isnan(y_art[t,i])==True:
            cn=cn+1
pr=T*nr-cn
pr=pr/(T*nr)*100

print(pr)

"""

# Validation data

This part we place all the data where we put a NaN in a new vector.
It is use at the end of the algorithm to verify the results.

In [None]:
y_val=np.nan*np.ones((T,nr))
for t in range(T):
    iais=lisind(yc,t,nr)
    for i in range(len(iais)):
        y_val[t,iais[i]]=y_sat[t,iais[i]]

# Observations in the delimited area

Observation AIS

In [None]:
# Plot of observation in the restricted zone 

LATdi=np.delete(LAT,ligne,0)
LATdi=np.delete(LATdi,colonne,1)

LONGdi=np.delete(LONG,ligne,0)
LONGdi=np.delete(LONGdi,colonne,1)

extentdi=[np.min(LONGdi),np.max(LONGdi),np.min(LATdi),np.max(LATdi)]
central_latdi=np.mean(LATdi)
central_longdi=np.mean(LONGdi)

Ladi=LATdi.shape[0]
Lodi=LATdi.shape[1]
"""
for t in range(T):
    plt.figure()
    ax2=plt.axes(projection=ccrs.Orthographic(central_longdi,central_latdi))
    ax2.set_extent(extentdi)
    gl=ax2.gridlines(draw_labels=True)
    gl.top_labels=False
    gl.right_labels=False
    ax2.coastlines()
    ax2.add_feature(cfeature.LAND,edgecolor='black')
    ax2.add_feature(cfeature.OCEAN)    
    ax2.quiver(LONGdi.flatten(),LATdi.flatten(),uc[t,:,:].flatten(),vc[t,:,:].flatten(),transform=pcar)
    plt.title('Image n° %i' %t)
    plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/restricted_obs/resobs'+str(t)+'.png')
    plt.close()

# Creation of the gif 
frames=np.stack([iio.imread('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/restricted_obs/resobs'+str(t)+'.png')for t in range(time[:].data.size)],axis=0)
iio.mimwrite('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/restricted_obs/resobs.gif',frames,duration=0.6)   
"""
display.Image('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/restricted_obs/resobs.gif')


Satellite observations

In [None]:
"""
for t in range(T):
    plt.figure()
    ax2=plt.axes(projection=ccrs.Orthographic(central_longdi,central_latdi))
    ax2.set_extent(extentdi)
    gl=ax2.gridlines(draw_labels=True)
    gl.top_labels=False
    gl.right_labels=False
    ax2.coastlines()
    ax2.add_feature(cfeature.LAND,edgecolor='black')
    ax2.add_feature(cfeature.OCEAN)    
    ax2.quiver(LONGdi.flatten(),LATdi.flatten(),u_satc[t,:,:].flatten(),v_satc[t,:,:].flatten(),transform=pcar)
    plt.title('Image n° %i' %t)
    plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/restricted_obssat/resobssat'+str(t)+'.png')
    plt.close()

# Creation of the gif
frames=np.stack([iio.imread('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/restricted_obssat/resobssat'+str(t)+'.png')for t in range(time[:].data.size)],axis=0)
iio.mimwrite('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/restricted_obssat/resobssat.gif',frames,duration=0.6)   
"""
display.Image('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/restricted_obssat/resobssat.gif')


# Initialisation 

In this part we are going to initialize M,Q,R,H,x0 and P0 for the first Kalman filter.
In the function below all the arguments are not used, that depend of the value of opt.

In [None]:
n=50 # Number of components (use for the PCA,but ask in the function)
M,Q,R,x0,P0=IC(opt,nr,n,LONGdi,LATdi,y_sat,var_y_art)
if opt==0:
    H=np.eye(nr)
else:
    pca=PCA(n_components=n)
    EOF=pca.fit(y_sat)
    H=np.transpose(EOF.components_)
    #y_acp=pca.fit_transform(y_sat)
    #y_invacp=pca.inverse_transform(y_acp)

In [None]:
# Les duex lignes recuperant y_invacp et ce qui suit est à supprimer ultérieurement  
"""
pix=int(nr/2)

uacp=np.zeros((T,pix))
vacp=np.zeros((T,pix))

usat=np.zeros((T,pix))
vsat=np.zeros((T,pix))

for t in range(T):
    uacp[t,:]=y_invacp[t,0:pix]
    vacp[t,:]=y_invacp[t,pix:]
    
    usat[t,:]=y_sat[t,0:pix]
    vsat[t,:]=y_sat[t,pix:]

uacp=np.reshape(uacp,(T,Ladi,Lodi))
vacp=np.reshape(vacp,(T,Ladi,Lodi))

usat=np.reshape(usat,(T,Ladi,Lodi))
vsat=np.reshape(vsat,(T,Ladi,Lodi))

for t in range(T):
    plt.figure()
    ax2=plt.axes(projection=ccrs.Orthographic(central_longdi,central_latdi))
    ax2.set_extent(extentdi)
    gl=ax2.gridlines(draw_labels=True)
    gl.top_labels=False
    gl.right_labels=False
    ax2.coastlines()
    ax2.add_feature(cfeature.LAND,edgecolor='black')
    ax2.add_feature(cfeature.OCEAN)    
    ax2.quiver(LONGdi,LATdi,usat[t,:,:],vsat[t,:,:],color='red',label='pseudo satelitte obs',transform=pcar)
    ax2.quiver(LONGdi,LATdi,uacp[t,:,:],vacp[t,:,:],color='blue',label='inverse_transformacp',transform=pcar)
    plt.legend(bbox_to_anchor=(0,-0.4,1,0.2),loc="lower center")
    plt.title('Image n° %i' %t)
    plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/plotacp/obssatacp'+str(t)+'.png')
    plt.close()

# Creation of the gif
frames=np.stack([iio.imread('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/plotacp/obssatacp'+str(t)+'.png')for t in range(time[:].data.size)],axis=0)
iio.mimwrite('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/plotacp/obssatacp.gif',frames,duration=0.6)   
"""
display.Image('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/plotacp/obssatacp.gif')



# Apply the iterative Kalman filter using EM algotithm

In [None]:
N=10 # number of iteration 
slk,M,Q,R,xs,Ps=Kalman_EM(y_art,var_y_art,x0,P0,M,Q,R,H,N,opt)

In [None]:
print('index max of slk is:',np.argmax(slk)) # use to know the maximum and where we get the best result

# Results

Loglikelihood

In [None]:
plt.figure()
plt.plot(np.linspace(0,N-1,N),slk)
plt.title("sum of loglik")
plt.xlabel('iteration number')
plt.ylabel("sum of loglikelihood")
plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/Results/LoglikEMsat.png')


In [None]:
"""
plt.figure()
plt.plot(np.linspace(0,N-1,N),sRMSE)
plt.title("sum of RMSE")
plt.xlabel('iteration number')
plt.ylabel("sum of RMSE")
"""

Matrix M,Q and R

In [None]:
a=np.max((abs(np.min(M)),abs(np.max(M))))

plt.figure()
plt.pcolor(M,cmap='RdBu_r',vmin=-a,vmax=a)
plt.colorbar()
plt.title('M')
plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/Results/MEMpseudosat.png')

plt.figure()
plt.pcolor(Q,cmap='gist_yarg')
plt.colorbar()
plt.title('Q')
plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/Results/QEMpseudosat.png')

plt.figure()
plt.pcolor(R[0,:,:],cmap='gist_yarg')
plt.colorbar()
plt.title('R')
plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/Results/REMpseudosat.png')

### Passage into the canonical base

This part is use if opt!=0 , that mean a PCA on the data is used.

In [None]:
if opt!=0:
    nxs=np.zeros((T,nr))
    nPs=np.zeros((T,nr,nr))
    for t in range(T):
        nxs[t,:]=np.transpose(H@np.transpose(np.array([xs[t,:]])))
        nPs[t,:]=H@Ps[t,:,:]@np.transpose(H)
    xs=pca.inverse_transform(xs)
    print(nxs[0,:]-xs[0,:])
    xs=nxs
    Ps=nPs

In [None]:
if opt!=0:
    print('min=',np.min(nxs),'max=',np.max(nxs))
    print(np.shape(Ps),np.min(Ps),np.max(Ps))

In [None]:
print('min =',np.min(xs),'max =',np.max(xs))
print('Ps min =',np.min(Ps),'max =',np.max(Ps))

Graphics of xs 

In [None]:
mid=int(len(xs[0,:])/2)
quart=int(len(xs[0,:])/4)
tquart=int(3*len(xs[0,:])/4)
plt.figure()
plt.plot(np.linspace(0,yc.shape[0]-1,yc.shape[0]),xs[:,tquart],'r',linewidth=2,label='Estimated state $xs$')
plt.plot(np.linspace(0,yc.shape[0]-1,yc.shape[0]),y_art[:,tquart],'.k',label=' pseudo-sat Observations')
#plt.plot(np.linspace(0,yc.shape[0]-1,yc.shape[0]),y_sat[:,tquart],color='blue',label='satelitte observations')
plt.fill_between(np.linspace(0,yc.shape[0]-1,yc.shape[0]),xs[:,tquart]-1.96*np.sqrt(Ps[:,tquart,tquart]),xs[:,tquart]+1.96*np.sqrt(Ps[:,tquart,tquart]),alpha=0.25,color='red')
#plt.ylim(-0.2,0.4)
plt.legend(loc='best')
plt.xlabel('time [day]')
plt.ylabel('velocity [$m.s^-1$]')
plt.title('xs (corresponding to v at the place 3/4)')
plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/Results/xspseudo_sat.png')

In [None]:
#Résulats pour le centre des u et le centre des v par rapport au temps 

us=np.zeros((T,mid))
vs=np.zeros((T,mid))

usat=np.zeros((T,mid))
vsat=np.zeros((T,mid))

uart=np.zeros((T,mid))
vart=np.zeros((T,mid))

uval=np.zeros((T,mid))
vval=np.zeros((T,mid))

for t in range(T):
    us[t,:]=xs[t,0:mid]
    vs[t,:]=xs[t,mid:]
    
    usat[t,:]=y_sat[t,0:mid]
    vsat[t,:]=y_sat[t,mid:]
    
    uart[t,:]=y_art[t,0:mid]
    vart[t,:]=y_art[t,mid:]
    
    uval[t,:]=y_val[t,0:mid]
    vval[t,:]=y_val[t,mid:]

middir=int(len(us[0,:])/2)

plt.figure()
plt.plot(np.linspace(0,yc.shape[0]-1,yc.shape[0]),us[:,middir],'r',linewidth=2,label='reconstructed currents')
plt.fill_between(np.linspace(0,yc.shape[0]-1,yc.shape[0]),us[:,middir]-1.96*np.sqrt(Ps[:,middir,middir]),us[:,middir]+1.96*np.sqrt(Ps[:,middir,middir]),alpha=0.25,color='red')
plt.plot(np.linspace(0,yc.shape[0]-1,yc.shape[0]),uart[:,middir],color='blue',marker='*',label='pseudo_satelitte observations')
#plt.ylim(-0.3,0.3)
plt.title('Composante central des us')
plt.legend(loc='best')
plt.ylabel('velocity [$m.s^-1$]')
plt.xlabel('time [day]')
plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/Results/uspseudo_sat.png')

plt.figure()
plt.plot(np.linspace(0,yc.shape[0]-1,yc.shape[0]),vs[:,middir],'r',linewidth=2,label='reconstructed currents')
plt.fill_between(np.linspace(0,yc.shape[0]-1,yc.shape[0]),vs[:,middir]-1.96*np.sqrt(Ps[:,mid+middir,mid+middir]),vs[:,middir]+1.96*np.sqrt(Ps[:,mid+middir,mid+middir]),alpha=0.25,color='red')
plt.plot(np.linspace(0,yc.shape[0]-1,yc.shape[0]),vart[:,middir],color='blue',marker='*',label='pseudo_satelitte observations')
#plt.ylim(-0.2,0.4)
plt.legend(loc='best')
plt.title('Composante centrale des vs')
plt.xlabel('time [day]')
plt.ylabel('velocity [$m.s^-1$]')
plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/Results/vspseudo_sat.png')

# Map of reconstructed currents which we add validation data

In [None]:
if opt==0:
    var_us=np.zeros((T,mid))
    var_vs=np.zeros((T,mid))
    for t in range(T):
        var_us[t,:]=np.diag(Ps[t,:,:])[0:mid]
        var_vs[t,:]=np.diag(Ps[t,:,:])[mid:]

#Reshape of all the data in array(time,lat,long)
us=np.reshape(us,(T,Ladi,Lodi))
vs=np.reshape(vs,(T,Ladi,Lodi))
if opt==0:
    var_us=np.reshape(var_us,(T,Ladi,Lodi))
    var_vs=np.reshape(var_vs,(T,Ladi,Lodi))


#Reshape
uart=np.reshape(uart,(T,Ladi,Lodi))
vart=np.reshape(vart,(T,Ladi,Lodi))

uval=np.reshape(uval,(T,Ladi,Lodi))
vval=np.reshape(vval,(T,Ladi,Lodi))

# Superposition des résultats de Kalman avec ceux des observations satellites 

for t in range(T):
    speed=np.sqrt(us[t,:,:]**2 +vs[t,:,:]**2)
    lw=5*speed/speed.max()
    plt.figure()
    ax=plt.axes(projection=ccrs.Orthographic(central_longdi,central_latdi))
    ax.set_extent(extentdi)
    gl=ax.gridlines(draw_labels=True)
    gl.top_labels=False
    gl.right_labels=False
    ax.coastlines()
    ax.add_feature(cfeature.LAND,edgecolor='black')
    ax.add_feature(cfeature.OCEAN)
    ax.quiver(LONGdi,LATdi,uart[t,:,:],vart[t,:,:],color='red',label='pseudo satelitte obs',transform=pcar)
    ax.quiver(LONGdi,LATdi,uval[t,:,:],vval[t,:,:],color='blue',label='validation obs',transform=pcar)
    #ax.quiver(LONGdi,LATdi,us[t,:,:],vs[t,:,:],color='black',label='reconstructed currents',transform=pcar)
    #ax.quiver(LONGdi,LATdi,uc[t,:,:],vc[t,:,:],color='green',label='obs ais',transform=pcar)
    plt.legend(bbox_to_anchor=(0,-0.4,1,0.2),loc="lower center")
    ax.streamplot(LONGdi,LATdi,us[t,:,:],vs[t,:,:],density=0.6,color='k',linewidth=lw,transform=pcar)
    plt.title('Image n° %i' %t)
    plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/supresEM/kalsatEM'+str(t)+'.png')
    plt.close()

frames=np.stack([iio.imread('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/supresEM/kalsatEM'+str(t)+'.png')for t in range(time[:].data.size)],axis=0)
iio.mimwrite('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/supresEM/kalsat.gif',frames,duration=0.9)   

display.Image('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/supresEM/kalsat.gif')

# RMSE

In [None]:
RMSE=np.zeros(nr)
pix=int((nr/2))
for i in range(nr):
    RMSE[i]=np.sqrt(np.sum((y_sat[:,i]-xs[:,i])**2)/T)
    
RMSEzonal=RMSE[0:pix]
RMSEmeri=RMSE[pix:nr]

plt.figure()
ax=plt.axes(projection=ccrs.Orthographic(central_longdi,central_latdi))
ax.set_extent(extentdi)
gl=ax.gridlines(draw_labels=True)
gl.top_labels=False
gl.right_labels=False
ax.coastlines()
ax.add_feature(cfeature.LAND,edgecolor='black')
ax.add_feature(cfeature.OCEAN)
cbzo=ax.scatter(LONGdi.flatten(),LATdi.flatten(),s=50,c=RMSEzonal,cmap='jet',transform=pcar)
plt.colorbar(cbzo,cmap='jet',orientation='vertical',ticklocation='auto')
plt.title('RMSE of the zonal component of oceanic currents')
#plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/Results/RMSEzonsat.png')

In [None]:
plt.figure()
ax=plt.axes(projection=ccrs.Orthographic(central_longdi,central_latdi))
ax.set_extent(extentdi)
gl=ax.gridlines(draw_labels=True)
gl.top_labels=False
gl.right_labels=False
ax.coastlines()
ax.add_feature(cfeature.LAND,edgecolor='black')
ax.add_feature(cfeature.OCEAN)
cbme=ax.scatter(LONGdi.flatten(),LATdi.flatten(),s=50,c=RMSEmeri,cmap='jet',transform=pcar)
plt.colorbar(cbme,cmap='jet',orientation='vertical',ticklocation='auto')
plt.title('RMSE of the southern component of oceanic currents')
#plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/Results/RMSEsousat.png')

In [None]:
# Map of correlations for a line
l=int(pix/4) # number of the line we choose
Cor=R[0,l,0:pix]

plt.figure()
ax=plt.axes(projection=ccrs.Orthographic(central_longdi,central_latdi))
ax.set_extent(extentdi)
gl=ax.gridlines(draw_labels=True)
gl.top_labels=False
gl.right_labels=False
ax.coastlines()
ax.add_feature(cfeature.LAND,edgecolor='black')
ax.add_feature(cfeature.OCEAN)
cor=ax.scatter(LONGdi.flatten(),LATdi.flatten(),s=50,c=Cor,cmap='jet',transform=pcar)
plt.colorbar(cor,cmap='jet',orientation='vertical',ticklocation='auto')
plt.title('Correlation of a line of the matrix R')
plt.savefig('C:/M2_CSM/Stage/Stage/Codes/Kalman/codes/Results/Rcorrelation.png')