In [24]:
import torch
from typing import TypeVar, Dict


Tensor = TypeVar('torch.tensor')

#matplotlib.use('Agg')

#import click
from argparse import Namespace
import ast
import os

import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn.functional as F
from typing import TypeVar, Tuple
import copy
from numpy import random
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [25]:
no_batches=torch.tensor(2,device=device)
batch_size=1
par_runs=20000
X=torch.tensor([[-1, 1], [0.5, 2.0]],device=device) #X[0,:] denotes mean, X[1, :] denotes standard deviation
target_mean=(X[0,:]*(X[1,:].pow(-2))).sum()/(X[1,:].pow(-2).sum())
target_sd=(X[1,:].pow(-2)).sum().pow(-0.5)
l2regconst=torch.tensor(1,device=device).detach()
gam=torch.sqrt(l2regconst)

In [26]:
X.shape

torch.Size([2, 2])

In [46]:
from dataclasses import dataclass
@dataclass
class hclass:
    h: Tensor
    eta: Tensor
    etam1g: Tensor
    c11: Tensor
    c21: Tensor
    c22: Tensor

@dataclass
class BAOABhclass:
    h: Tensor
    eta: Tensor
    xc1: Tensor
    xc2: Tensor
    xc3: Tensor
    vc1: Tensor
    vc2: Tensor
    vc3: Tensor


def hper2const(h,gam):
    gh=gam.double()*h.double()
    s=torch.sqrt(4*torch.expm1(-gh/2)-torch.expm1(-gh)+gh)
    eta=(torch.exp(-gh/2)).float()
    etam1g=((-torch.expm1(-gh/2))/gam.double()).float()
    c11=(s/gam).float()
    c21=(torch.exp(-gh)*(torch.expm1(gh/2.0))**2/s).float()
    c22=(torch.sqrt(8*torch.expm1(-gh/2)-4*torch.expm1(-gh)-gh*torch.expm1(-gh))/s).float()
    hc=hclass(h=h,eta=eta,etam1g=etam1g,c11=c11,c21=c21,c22=c22)
    return(hc)

def BAOAB_hconst(h,gam):
    with torch.no_grad():
        hh=copy.deepcopy(h).detach().double()
        gamm=copy.deepcopy(gam).detach().double()
        gh=gamm*hh
        eta=(torch.exp(-gh/2))
        xc1=hh/2*(1+eta)
        xc2=(hh*hh/4)*(1+eta)
        xc3=hh/2*torch.sqrt(-torch.expm1(-gh))
        vc1=eta*(hh/2)
        vc2=(hh/2)
        vc3=torch.sqrt(-torch.expm1(-gh))

        hc=BAOABhclass(h=hh.float(),eta=eta.float(),xc1=xc1.float(),xc2=xc2.float(),xc3=xc3.float(),vc1=vc1.float(),vc2=vc2.float(),vc3=vc3.float())
        return(hc)

def U(x,v,hc,xi1,xi2):
    xn=x+hc.etam1g*v+hc.c11*xi1
    vn=v*hc.eta+hc.c21*xi1+hc.c22*xi2
    return([xn, vn])

# def func(x):
#     return ((x**2)/((1+x**2).sqrt())/2)
# def funcder(x):
#     return (x*(x**2+2)/(2*((1+x**2)**(3/2))))

def func(x):
    #return ((x**2)/((1+x**2)**(0.25))/2)
    return ((x**2)/2)
def funcder(x):
    #return (x*(3*x**2+4)/(4*((1+x**2)**(1.25))))
    return (x)

def V(x,batch_it):
    #return((x-X[0,batch_it])**2/(2*X[1,batch_it]))
    return(func((x-X[0,batch_it])/X[1,batch_it]))

def grad(x,batch_it):
    # res=V(x,batch_it)
    # return (x-X[0,batch_it])/(X[1,batch_it]), res
    return funcder((x-X[0,batch_it])/(X[1,batch_it]))/X[1,batch_it]



def UBU_step(p,hper2c,batch_it):   
    with torch.no_grad():
        xi1=torch.randn_like(p.data,device=device)
        xi2=torch.randn_like(p.data,device=device)
        p.data,p.v=U(p.data,p.v,hper2c,xi1,xi2)

    grads=grad(p,batch_it)*no_batches 
    p.v-=hper2c.h*grads

    with torch.no_grad():
        xi1=torch.randn_like(p.data,device=device)
        xi2=torch.randn_like(p.data,device=device)
        p.data,p.v=U(p.data,p.v,hper2c,xi1,xi2)

def BAOAB_step(p,hc,batch_it,last_grad):   
    
    with torch.no_grad():
        xi1=torch.randn_like(p.data,device=device)
        p.data=p.data+hc.xc1*p.v-hc.xc2*last_grad+hc.xc3*xi1

    grads=grad(p,batch_it)*no_batches 

    with torch.no_grad():
        p.v=hc.eta*p.v-hc.vc1*last_grad-hc.vc2*grads+hc.vc3*xi1

    return(grads)         


def EM_step(p,h,batch_it):   
    with torch.no_grad():
        xi1=torch.randn_like(p.data,device=device)
        
    grads=grad(p,batch_it)*no_batches 
    p.data+=p.v*h
    p.v-=h*grads+gam*(h)*p.v-torch.sqrt(2*gam*h)*xi1

# def UBU_step2(p, q, hper4c,batch_it_list):   
#     with torch.no_grad():
#         xi1=torch.randn_like(p.data,device=device)
#         xi2=torch.randn_like(p.data,device=device)
#         [p.data,p.v]=U(p.data,p.v,hper4c,xi1,xi2)
#         [q.data,q.v]=U(q.data,q.v,hper4c,xi1,xi2)
        
#     grads2=grad(q,batch_it_list[0])  
    
#     with torch.no_grad():

#         q.v-=hper4c.h*grads2
#         xi1=torch.randn_like(p.data,device=device)
#         xi2=torch.randn_like(p.data,device=device)
#         [p.data,p.v]=U(p.data,p.v,hper4c,xi1,xi2)
#         [q.data,q.v]=U(q.data,q.v,hper4c,xi1,xi2)
        
#         grads=grad(p, batch_it_list[2])

#         p.v-=2*hper4c.h*grads

#         xi1=torch.randn_like(p.data,device=device)
#         xi2=torch.randn_like(p.data,device=device)
#         [p.data,p.v]=U(p.data,p.v,hper4c,xi1,xi2)
#         [q.data,q.v]=U(q.data,q.v,hper4c,xi1,xi2)

#         grads2=grad(q, batch_it_list[1])


#         #for q,grad in zip(net2.parameters(), grads2):              
#         q.v-=hper4c.h*grads2

#         #for p,q in zip(net.parameters(),net2.parameters()):
#         xi1=torch.randn_like(p.data,device=device)
#         xi2=torch.randn_like(p.data,device=device)
#         [p.data,p.v]=U(p.data,p.v,hper4c,xi1,xi2)
#         [q.data,q.v]=U(q.data,q.v,hper4c,xi1,xi2)
    


# def EM_step2(net, net2, h, gam, batch_it_list):   
#     grads2,_=grad(net2, batch_it_list[0])
#     grads,loss_likelihood_data=grad(net, batch_it_list[2])   
#     sqrt2=torch.tensor(2).sqrt().detach()

#     with torch.no_grad():
#         for p,gradp in zip(net2.parameters(), grads2):              
#             p.xi=torch.randn_like(p.data,device=device)
#             p.data+=p.v*h/2
#             p.v-=(h/2)*gradp+gam*(h/2)*p.v-torch.sqrt(2*gam*h/2)*p.xi

#     grads2,_=grad(net2, batch_it_list[1])

#     with torch.no_grad():
#         for p,gradp in zip(net2.parameters(), grads2):              
#             p.xi2=torch.randn_like(p.data,device=device)
#             p.data+=p.v*h/2
#             p.v-=(h/2)*gradp+gam*(h/2)*p.v-torch.sqrt(2*gam*h/2)*p.xi2

#         for p,q,gradp in zip(net.parameters(),net2.parameters(), grads):              
#             p.data+=p.v*h
#             p.v-=(h)*gradp+gam*(h)*p.v-torch.sqrt(2*gam*h)*(q.xi+q.xi2)/sqrt2


#     return(loss_likelihood_data)






# def BAOAB_step2(net, net2, hc, hper2c, batch_it_list,last_grad,last_grad2):   

#     with torch.no_grad():
#         for p,grad in zip(net2.parameters(), last_grad2):
#             p.xi=torch.randn_like(p.data,device=device)
#             p.data=p.data+hper2c.xc1*p.v-hper2c.xc2*grad+hper2c.xc3*p.xi

#     grads2,_=grad(net2, batch_it_list[0])

#     with torch.no_grad():
#         for p,grad,gradn in zip(net2.parameters(), last_grad2,grads2):              
#             p.v=hper2c.eta*p.v-hper2c.vc1*grad-hper2c.vc2*gradn+hper2c.vc3*p.xi

#     with torch.no_grad():
#         for p,grad in zip(net2.parameters(), grads2):
#             p.xi2=torch.randn_like(p.data,device=device)
#             p.data=p.data+hper2c.xc1*p.v-hper2c.xc2*grad+hper2c.xc3*p.xi2

#     grads2n,_=grad(net2, batch_it_list[1])

#     with torch.no_grad():
#         for p,grad,gradn in zip(net2.parameters(), grads2,grads2n):              
#             p.v=hper2c.eta*p.v-hper2c.vc1*grad-hper2c.vc2*gradn+hper2c.vc3*p.xi2

#     sqrt2=torch.tensor(2).sqrt().detach()
#     with torch.no_grad():
#         for p,q,grad in zip(net.parameters(),net2.parameters(), last_grad):
#             p.data=p.data+hc.xc1*p.v-hc.xc2*grad+hc.xc3*(q.xi+q.xi2)/sqrt2

#     grads,loss_likelihood_data=grad(net, batch_it_list[2])

#     with torch.no_grad():
#         for p,q,grad,gradn in zip(net.parameters(),net2.parameters(), last_grad,grads):              
#             p.v=hc.eta*p.v-hc.vc1*grad-hc.vc2*gradn+hc.vc3*(q.xi+q.xi2)/sqrt2

#     return(loss_likelihood_data,grads,grads2n)

    # eta=to_data_type(exp(-h*gam/2));
    # h2=to_data_type(h/2);
    # eta2=to_data_type(exp(-h2*gam/2));

    # grad=gradp;
    # xip=R;

    # #xn=x+(h/2*(1+eta))*v-((h^2/4)*(1+eta))*grad+((h/2)*realsqrt(1-eta^2))*xip;
    # #xn=x+xc1*v-xc2*grad+xc3*xip

    # gradpn=grad_lpost(xn);
    # vn=eta*(v-(h/2)*grad)+(realsqrt(1-eta^2))*xip-(h/2)*gradpn;    
    # #vn=eta*v-vc1*grad-vc2*gradpn+vc3*xip

def ind_create(batch_it):
    modit=batch_it %(2*no_batches)
    ind=(modit<=(no_batches-1))*modit+(modit>=no_batches)*(2*no_batches-modit-1)
    return ind


In [61]:
def SMS_UBU(num_epochs,h,gam):

  p=torch.zeros(par_runs,device=device)
  rng = np.random.default_rng()

  with torch.no_grad():
    V_arr=torch.zeros([par_runs,num_epochs],device=device).detach()
    hper2c=hper2const(h,gam)
  #Initialise velocities
    p.v = torch.randn_like(p,device=device).detach()

  for epoch in range(num_epochs):

    if(epoch%2==0):
      rperm=rng.permuted(np.tile(np.arange(no_batches,dtype=int),(par_runs,1)),axis=1)

    for i in range(no_batches):
      b=i        
      it=epoch*no_batches+b
      ind=ind_create(it)
      UBU_step(p,hper2c,rperm[:,ind])

    with torch.no_grad():
      V_arr[:,epoch]=p.data
    
  return(V_arr)
  



def SG_UBU(num_epochs,h,gam):

  p=torch.zeros(par_runs,device=device)
  rng = np.random.default_rng()

  with torch.no_grad():
    V_arr=torch.zeros([par_runs,num_epochs],device=device).detach()
    hper2c=hper2const(h,gam)
  #Initialise velocities
    p.v = torch.randn_like(p,device=device).detach()

  for epoch in range(num_epochs):
    for i in range(no_batches):

      ind=torch.randint(high=no_batches,size=(par_runs,)).int()

      UBU_step(p,hper2c,ind)

    with torch.no_grad():
      V_arr[:,epoch]=p.data
    
  return(V_arr)


def SG_UBU_without_replacement(num_epochs,h,gam):

  p=torch.zeros(par_runs,device=device)
  rng = np.random.default_rng()

  with torch.no_grad():
    V_arr=torch.zeros([par_runs,num_epochs],device=device).detach()
    hper2c=hper2const(h,gam)
  #Initialise velocities
    p.v = torch.randn_like(p,device=device).detach()

  for epoch in range(num_epochs):
    rperm=rng.permuted(np.tile(np.arange(no_batches,dtype=int),(par_runs,1)),axis=1)
    for i in range(no_batches):
      ind=rperm[:,i]
      UBU_step(p,hper2c,ind)

    with torch.no_grad():
      V_arr[:,epoch]=p.data
    
  return(V_arr)


def SMS_BAOAB(num_epochs,h,gam):

  p=torch.zeros(par_runs,device=device)
  rng = np.random.default_rng()

  with torch.no_grad():
    V_arr=torch.zeros([par_runs,num_epochs],device=device).detach()
    hper2c=BAOAB_hconst(h,gam)
  #Initialise velocities
    p.v = torch.randn_like(p,device=device).detach()
  
  ind=torch.randint(high=no_batches,size=(par_runs,)).int()
  grads=grad(p.data,ind)
  for epoch in range(num_epochs):
    if(epoch%2==0):
      rperm=rng.permuted(np.tile(np.arange(no_batches,dtype=int),(par_runs,1)),axis=1)
    for i in range(no_batches):
      b=i        
      it=epoch*no_batches+b
      ind=ind_create(it)
      grads=BAOAB_step(p,hper2c,ind,grads)

    with torch.no_grad():
      V_arr[:,epoch]=p.data
    
  return(V_arr)


def SG_BAOAB(num_epochs,h,gam):

  p=torch.zeros(par_runs,device=device)
  rng = np.random.default_rng()

  with torch.no_grad():
    V_arr=torch.zeros([par_runs,num_epochs],device=device).detach()
    hper2c=BAOAB_hconst(h,gam)
  #Initialise velocities
    p.v = torch.randn_like(p,device=device).detach()
    
  ind=torch.randint(high=no_batches,size=(par_runs,)).int()    
  grads=grad(p.data,ind)

  for epoch in range(num_epochs):

    for i in range(no_batches):
      ind=torch.randint(high=no_batches,size=(par_runs,)).int()     
      grads=BAOAB_step(p,hper2c,ind,grads)

    with torch.no_grad():
      V_arr[:,epoch]=p.data
    
  return(V_arr)

def SG_BAOAB_without_replacement(num_epochs,h,gam):

  p=torch.zeros(par_runs,device=device)
  rng = np.random.default_rng()

  with torch.no_grad():
    V_arr=torch.zeros([par_runs,num_epochs],device=device).detach()
    hper2c=BAOAB_hconst(h,gam)
  #Initialise velocities
    p.v = torch.randn_like(p,device=device).detach()
  ind=torch.randint(high=no_batches,size=(par_runs,)).int()    
  grads=grad(p.data,ind)

  for epoch in range(num_epochs):

    rperm=rng.permuted(np.tile(np.arange(no_batches,dtype=int),(par_runs,1)),axis=1)
    for i in range(no_batches):
      ind=rperm[:,i]
      grads=BAOAB_step(p,hper2c,ind,grads)

    with torch.no_grad():
      V_arr[:,epoch]=p.data
    
  return(V_arr)


def SMS_EM(num_epochs,h,gam):

  p=torch.zeros(par_runs,device=device)
  rng = np.random.default_rng()

  with torch.no_grad():
    V_arr=torch.zeros([par_runs,num_epochs],device=device).detach()
  #Initialise velocities
    p.v = torch.randn_like(p,device=device).detach()

  for epoch in range(num_epochs):

    if(epoch%2==0):
      rperm=rng.permuted(np.tile(np.arange(no_batches,dtype=int),(par_runs,1)),axis=1)

    for i in range(no_batches):
      b=i        
      it=epoch*no_batches+b
      ind=ind_create(it)
      EM_step(p,h,rperm[:,ind])

    with torch.no_grad():
      V_arr[:,epoch]=p.data
    
  return(V_arr)


def SG_EM(num_epochs,h,gam):

  p=torch.zeros(par_runs,device=device)
  rng = np.random.default_rng()

  with torch.no_grad():
    V_arr=torch.zeros([par_runs,num_epochs],device=device).detach()
  #Initialise velocities
    p.v = torch.randn_like(p,device=device).detach()

  for epoch in range(num_epochs):
    for i in range(no_batches):
      ind=torch.randint(high=no_batches,size=(par_runs,)).int()    
      EM_step(p,h,ind)

    with torch.no_grad():
      V_arr[:,epoch]=p.data
    
  return(V_arr)

def SG_EM_without_replacement(num_epochs,h,gam):

  p=torch.zeros(par_runs,device=device)
  rng = np.random.default_rng()

  with torch.no_grad():
    V_arr=torch.zeros([par_runs,num_epochs],device=device).detach()
  #Initialise velocities
    p.v = torch.randn_like(p,device=device).detach()

  for epoch in range(num_epochs):

    rperm=rng.permuted(np.tile(np.arange(no_batches,dtype=int),(par_runs,1)),axis=1)
    for i in range(no_batches):
      ind=rperm[:,i]

      EM_step(p,h,ind)

    with torch.no_grad():
      V_arr[:,epoch]=p.data
    
  return(V_arr)

def SG_UBU2(num_epochs,h,gam):

  p=torch.zeros(par_runs,device=device)
  q=torch.zeros(par_runs,device=device)

  with torch.no_grad():
    V_arr=torch.zeros([par_runs,num_epochs],device=device).detach()
    V_arr2=torch.zeros([par_runs,num_epochs],device=device).detach()  
    hper4c=hper2const(h/2,gam)
  #Initialise velocities
    p.v = torch.randn_like(p,device=device).detach()
    q.v=copy.deepcopy(p.v).detach()

  for epoch in range(num_epochs):
    for i in range(no_batches):
      ind=torch.randint(high=no_batches,size=(par_runs,)).int()
      ind2=torch.randint(high=no_batches,size=(par_runs,)).int()
      rflip=(torch.rand((par_runs,))<0.5).int()
      indc=ind*rflip+ind2*(1-rflip)
      
      batch_it_list=[ind.numpy(), ind2.numpy(), indc.numpy()]
      UBU_step2(p,q,hper4c,batch_it_list)


    with torch.no_grad():
      V_arr[:,epoch]=p.data#V(p.data,0)+V(p.data,1)
      V_arr2[:,epoch]=q.data#V(q.data,0)+V(q.data,1)
    
    #       
  return(V_arr,V_arr2)

def SG_UBU2_without_replacement(num_epochs,h,gam):

  p=torch.zeros(par_runs,device=device)
  q=torch.zeros(par_runs,device=device)
  rng = np.random.default_rng()

  with torch.no_grad():
    V_arr=torch.zeros([par_runs,num_epochs],device=device).detach()
    V_arr2=torch.zeros([par_runs,num_epochs],device=device).detach()  
    hper4c=hper2const(h/2,gam)
  #Initialise velocities
    p.v = torch.randn_like(p,device=device).detach()
    q.v=copy.deepcopy(p.v).detach()

  for epoch in range(num_epochs):
    rperm=rng.permuted(np.tile(np.arange(no_batches,dtype=int),(par_runs,1)),axis=1)
    rperm2=rng.permuted(np.tile(np.arange(no_batches,dtype=int),(par_runs,1)),axis=1)
    rperm2=np.concatenate((rperm,rperm2),axis=1)
    #print(rperm2.shape)
    #rperm=random.permutation(list(range(no_batches)))      
    #rperm2=np.concatenate((rperm,random.permutation(list(range(no_batches)))))
    for i in range(no_batches):
      b=i        
      it=epoch*no_batches+b
      ind=(2*it)%no_batches
      ind2=(2*it+1)%no_batches            
      indc=it%no_batches          
      batch_it_list=[rperm2[:,ind], rperm2[:,ind2], rperm[:,indc]]
      UBU_step2(p,q,hper4c,batch_it_list)


    with torch.no_grad():
      V_arr[:,epoch]=p.data#V(p.data,0)+V(p.data,1)
      V_arr2[:,epoch]=q.data#V(q.data,0)+V(q.data,1)
    
  return(V_arr,V_arr2)

def Wass(V1,V2):
  V1s,_=torch.sort(V1.flatten())
  V2s,_=torch.sort(V2.flatten())
  return (V1s-V2s).abs().mean()

In [70]:
Wass_arr=torch.zeros(4,9).detach()
methods_list=[SMS_UBU,SG_UBU,SG_UBU_without_replacement,SMS_BAOAB,SG_BAOAB,SG_BAOAB_without_replacement, SMS_EM,SG_EM,SG_EM_without_replacement]
for it in range(4):
    for mit in range(9):
        rat=pow(2,it)
        num_epochs=int(2000*rat)
        h=torch.tensor(0.25)/rat
        V_arr=methods_list[mit](num_epochs,h,gam)

        V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
        diff=Wass(V_arr[:,50*rat:],V2_arr[:,50*rat:])
        print("Wasserstein distance:",f'{diff:.8f}')
        Wass_arr[it,mit]=diff

Wasserstein distance: 0.00052296
Wasserstein distance: 0.01132481
Wasserstein distance: 0.00047957
Wasserstein distance: 0.00063052
Wasserstein distance: 0.02402091
Wasserstein distance: 0.00049352
Wasserstein distance: 0.02882943
Wasserstein distance: 0.04265445
Wasserstein distance: 0.02885849


In [71]:
Wass_arr

tensor([[0.0412, 0.1147, 0.0273, 0.0408, 0.4789, 0.0879,    nan,    nan,    nan],
        [0.0071, 0.0483, 0.0060, 0.0071, 0.1226, 0.0086, 0.1826, 0.3648, 0.1903],
        [0.0021, 0.0231, 0.0017, 0.0018, 0.0514, 0.0016, 0.0652, 0.1025, 0.0659],
        [0.0005, 0.0113, 0.0005, 0.0006, 0.0240, 0.0005, 0.0288, 0.0427, 0.0289]])

In [None]:
import pickle
filepath="Wass_distance.pickle"
with open(filepath,"wb") as file:
    pickle.dump(Wass_arr.numpy(),file)


NameError: name 'pickle' is not defined

In [None]:
rat=2
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_UBU(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.11451876


In [39]:
rat=2
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_UBU_without_replacement(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.02744237


In [33]:
rat=2
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_BAOAB(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.47818235


In [37]:
rat=2
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_BAOAB_without_replacement(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.08776037


In [43]:
rat=2
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SMS_BAOAB(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.04085172


In [63]:
rat=4
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SMS_EM(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.18276775


In [55]:
rat=4
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_EM(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.36453941


In [62]:
rat=4
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_EM_without_replacement(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.19027656


In [56]:
rat=8
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_EM(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.10278181


In [57]:
rat=16
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_EM(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.04270706


In [None]:
rat=4
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SMS_UBU(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00708137


In [None]:
rat=4
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_UBU(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.04815945


In [None]:
rat=4
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_UBU_without_replacement(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00609905


In [44]:
rat=4
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SMS_BAOAB(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00719593


In [None]:
rat=8
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SMS_UBU(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00203311


In [None]:
rat=8
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_UBU(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

In [None]:
rat=8
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_UBU_without_replacement(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00168177


In [45]:
rat=8
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SMS_BAOAB(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00185550


In [None]:
rat=16
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SMS_UBU(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00051567


In [None]:
rat=16
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_UBU(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.01136027


In [None]:
rat=16
num_epochs=int(2000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_UBU_without_replacement(num_epochs,h,gam)
V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00044837


In [None]:
rat=16
num_epochs=int(4000*rat)
h=torch.tensor(0.5)/rat
V_arr=SG_UBU_without_replacement(num_epochs,h,gam)

V2_arr=torch.randn_like(V_arr,device=device)*target_sd+target_mean
diff=Wass(V_arr[:,200*rat:],V2_arr[:,200*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00047875


In [None]:
V_arr

tensor([[-0.7080, -1.3768, -1.2743,  ..., -0.8878, -1.2427, -1.3419],
        [-0.6147, -1.2203, -1.1726,  ..., -0.8398, -1.0877, -0.9577],
        [-0.2342, -1.2017, -1.4464,  ..., -0.7463, -0.7329, -0.8420],
        ...,
        [-0.2312, -0.5815, -0.8356,  ..., -0.9800, -0.3326, -0.6403],
        [-0.8650, -1.5125, -1.7752,  ..., -1.2877, -1.0383, -1.1100],
        [-0.5952, -1.1166, -1.0417,  ..., -0.4319, -1.0199, -1.5974]])

In [None]:
rat=2
num_epochs=int(2000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SMS_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
#diff=Wass(V_arr[:,range(1000*rat+1,5000,2)],V_arr2[:,range(1000*rat+1,5000,2)])
print("Wasserstein distance:",f'{diff:.8f}')

NameError: name 'SMS_UBU2' is not defined

In [None]:
rat=4
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SMS_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00687544


In [None]:
rat=8
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SMS_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00088318


In [None]:
rat=16
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SMS_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00027838


In [None]:
rat=32
num_epochs=int(500*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SMS_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00007748


In [None]:
rat=64
num_epochs=int(500*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SMS_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00001989


In [None]:
rat=1
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2_without_replacement(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.23515014


In [None]:
rat=2
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2_without_replacement(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.04478037


In [None]:
rat=4
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2_without_replacement(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00606991


In [None]:
rat=8
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2_without_replacement(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00092218


In [None]:
rat=16
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2_without_replacement(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00021774


In [None]:
rat=32
num_epochs=int(500*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2_without_replacement(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00006514


In [None]:
rat=64
num_epochs=int(500*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2_without_replacement(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00001757


In [None]:
rat=1
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[:,100*rat:],V_arr2[:,100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.04525045


In [None]:
rat=2
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[100*rat:],V_arr2[100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.07995587


In [None]:
rat=4
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[100*rat:],V_arr2[100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.01214406


In [None]:
rat=8
num_epochs=int(1000*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[100*rat:],V_arr2[100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00608285


In [None]:
rat=16
num_epochs=int(500*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[100*rat:],V_arr2[100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00311218


In [None]:
rat=32
num_epochs=int(500*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[100*rat:],V_arr2[100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00155235


In [None]:
rat=64
num_epochs=int(500*rat)
h= torch.tensor(0.5)/rat
V_arr,V_arr2=SG_UBU2(num_epochs,h,gam)
diff=Wass(V_arr[100*rat:],V_arr2[100*rat:])
print("Wasserstein distance:",f'{diff:.8f}')

Wasserstein distance: 0.00084320
