#### Import des données pour les continents

In [192]:
#On utilise un echantilloneur de Gibbs avec balayage aléatoire
import pandas as pd 
import numpy as np

T=pd.read_csv("/Users/y.boukhateb/Desktop/Statapp Optimal Transport/Stat-App/DATA/dataset1.csv")
T.drop(T.columns[0], axis=1, inplace=True)
N=len(T.columns)-1
T.head()

Unnamed: 0,/,North America,Central America,South America,North Africa,Sub-Saharan Africa,Northern Europe,Western Europe,Southern Europe,Eastern Europe,Central Asia,Western Asia,South Asia,East Asia,South-East Asia,Oceania
0,North America,96102,208668,68240,4253,58827,248379,279267,551959,162959,85259,131288,158,39145,41758,52303
1,Central America,3233173,215008,36975,390,2414,42572,81356,190931,4962,3145,6236,14,1754,510,7541
2,South America,394674,26456,600759,417,19431,68645,157063,1200694,18644,5090,32516,351,110426,2035,28490
3,North Africa,62193,93,692,57855,12512,22807,404765,442162,5437,1382,367718,1,312,340,6858
4,Sub-Saharan Africa,478694,793,10738,44230,3717881,483736,430954,312329,6442,2724,305291,243,2918,576,149130


In [193]:
T_matrix = T.iloc[:, 1:].values

In [224]:
def histogramme(x):
    # Normaliser la liste pour que la somme soit égale à 1
    somme = np.sum(x)
    histogramme = x / somme
    return histogramme

### Création du code 

In [225]:
import ot
sigma=0.04

def proba_rejet(u,v,W,x,y,Z,T):
    u,v,W=histogramme(u),histogramme(v),histogramme(W)
    x,y,Z=histogramme(x),histogramme(y),histogramme(Z)
    return 1/(2*sigma**2)*(np.linalg.norm(T-ot.emd(u,v,W))**2-np.linalg.norm(T-ot.emd(x,y,Z))**2)

def in_hypercube(u,v,W):
    
    if any(x < 0 or x > 1 for x in u):
        return False
    
    if any(x < 0 or x > 1 for x in v):
        return False
    
    if any(any(x < 0 or x > 1 for x in row) for row in W):
        return False
    return True

In [238]:
# Graine aléatoire
np.random.seed(123)

# Taille de l'échantillon
n = 10**5
def MH_gibbs(T):
    N=len(T)
    # Initialisation
    U, V, C = np.zeros((n, N)), np.zeros((n, N)), np.zeros((n, N, N))
    U[0] = np.random.uniform(size=N)
    V[0] = np.random.uniform(size=N)
    C[0, :, :] = np.random.uniform(size=(N, N))

    # Vecteurs aléatoires
    e_u = np.random.normal(0, 0.02, (N, n))
    e_v = np.random.normal(0, 0.02, (N, n))
    e_C = np.random.normal(0, 0.04, (N, N, n))
    u_u = np.log(np.random.uniform(size=n))
    u_v = np.log(np.random.uniform(size=n))
    u_c = np.log(np.random.uniform(size=n))

    for k in range(n-1):
        # Pas pour U
        x, y, Z = U[k] + e_u[:, k], V[k], C[k]
        if not in_hypercube(x, y, Z):
            U[k+1], V[k+1], C[k+1] = U[k], V[k], C[k]
        else:
            if u_u[k] < proba_rejet(U[k], V[k], C[k], x, y, Z,T):
                U[k+1], V[k+1], C[k+1] = x, y, Z
            else:
                U[k+1], V[k+1], C[k+1] = U[k], V[k], C[k]
        UK=U[k+1]
        # Pas pour V
        x, y, Z = UK, V[k] + e_v[:, k], C[k]
        if not in_hypercube(x, y, Z):
            U[k+1], V[k+1], C[k+1] = UK, V[k], C[k]
        else:
            if u_v[k] < proba_rejet(UK, V[k], C[k], x, y, Z,T):
                U[k+1], V[k+1], C[k+1] = x, y, Z
            else:
                U[k+1], V[k+1], C[k+1] = UK, V[k], C[k]
        VK=V[k+1]
        # Pas pour C
        x, y, Z = UK, VK, C[k] + e_C[:, :, k]
        if not in_hypercube(x, y, Z):
            U[k+1], V[k+1], C[k+1] = UK, VK, C[k]
        else:
            if u_c[k] < proba_rejet(U[k], V[k], C[k], x, y, Z,T):
                U[k+1], V[k+1], C[k+1] = x, y, Z
            else:
                U[k+1], V[k+1], C[k+1] = UK, VK, C[k]
    return U,V,C

### Données simulées

In [239]:
u=np.array([0.1,0.4,0.5])
v=np.array([0.6,0.35,0.05])
C=[[10,3,5],[6,10,2],[8,8,10]]
T=ot.emd(u,v,C)
U,V,W=MH_gibbs(T)

In [240]:
a=np.array([0.2,0,0.8])
b=np.array([0.3,0.5,0.2])
D=[[8,3,5],[6,1,3],[1,0,3]]
T=ot.emd(a,b,D)
P,O,I=MH_gibbs(T)

### Comparaison et affichage des résultats

In [241]:
print(histogramme(W[-1]))
print(histogramme(C))

[[0.12780777 0.01355973 0.0694587 ]
 [0.0437215  0.12902292 0.15214773]
 [0.20547175 0.17401019 0.08479973]]
[[0.16129032 0.0483871  0.08064516]
 [0.09677419 0.16129032 0.03225806]
 [0.12903226 0.12903226 0.16129032]]


In [242]:
print(histogramme(I[-1]))
print(histogramme(D))

[[0.15706884 0.14517821 0.04542256]
 [0.1519962  0.15531388 0.1212589 ]
 [0.02864007 0.14412996 0.05099139]]
[[0.26666667 0.1        0.16666667]
 [0.2        0.03333333 0.1       ]
 [0.03333333 0.         0.1       ]]
