In [1]:
%matplotlib notebook
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation
plt.style.use("rotskoff")
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [2]:
# define the nonlinearity and kernel
s2 = 0.5

# expect xs to be [n,dim]
def phi(xs,ys,s2s):
    return torch.exp(-(xs-ys)**2/(2.*s2s))/(np.sqrt(2*np.pi*s2s))
    
def fn(xs, cs, ys, s2s):
    return torch.sum(cs * phi(xs,ys,s2s), dim=1, keepdim=True) / cs.shape[0]

In [3]:
# simple target function
y_fs = torch.tensor([-2.,0,2.])
s2_fs = torch.tensor([0.5,1,0.5])
c_fs = torch.tensor([1.,-2.,1])

def f(xs):
    return fn(xs, c_fs, y_fs, s2_fs)

def Cf(ys, cs, s2s):
    return torch.mean(cs)
    

In [4]:
fig, ax = plt.subplots()
xs=torch.linspace(-4,4,1000).reshape(1000,1)
ax.set_ylabel(r"$f(x)$", fontsize=18)
ax.set_xlabel(r"$\Omega$", fontsize=18)
ax.plot(xs.numpy(), f(xs).numpy())

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7f1d358e3e80>]

In [5]:
# single body potential
s2s = s2_fs + s2 # variances of the mixture

def F(ys):
    return fn(ys, c_fs, y_fs, s2s)

def dF(ys):
    return -torch.sum(c_fs * (ys-y_fs) *  phi(y_fs, ys, s2s) / s2s, dim=1, keepdim=True)/c_fs.shape[0]

# interaction
def K(ys, zs):
    return torch.exp(-(ys-zs)**2/(4.*s2))/(2.*np.sqrt(np.pi*s2))

def dK(ys, zs):
    return -(ys-zs)*torch.exp(-(ys-zs)**2/(4.*s2))/(4.*s2*np.sqrt(np.pi*s2))

In [6]:
fig, ax = plt.subplots()
ys = torch.linspace(-5,5,100).reshape([100,1])
ax.set_ylabel(r"$-F(y)$", fontsize=18)
ax.set_xlabel(r"$\Omega$", fontsize=18)
ax.plot(ys.numpy(), -F(ys).numpy())

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7f1d3095d080>]

In [7]:
def step_charges(ys, cs, n, dt=5e-2):
    ys_new = ys + dt * (cs*dF(ys) - cs*torch.mm(dK(ys,ys.t()),cs)/n )
    cs_new = cs + dt * (F(ys) - torch.mm(K(ys,ys.t()),cs)/n )
    return ys_new, cs_new

In [8]:
# basic gradient descent trajectory
def run_traj(n, n_steps, ys=None, cs=None, dt=5e-3):
    if ys is None:
        ys = 4*torch.rand([n,1])-2
    if cs is None:
        cs = torch.randn(ys.shape)
    #traj = [(ys,cs)]
    traj = torch.zeros([n_steps, 2, n])
    for i in range(n_steps):
        ys, cs = step_charges(ys, cs, n, dt=dt)
        #traj.append((ys,cs))
        traj[i,0,:] = ys.flatten()
        traj[i,1,:] = cs.flatten()
    return traj

In [9]:
def make_animation(traj): 
    fig, [ax1,ax2] = plt.subplots(1,2,figsize=(8,4))
    ax1.set_xlim(-5,5)
    ax1.set_ylim(-3,3)
    l1, = ax1.plot([],[], "o")
    
    xs = torch.linspace(-5,5,100).reshape(100,1)
    ax2.set_ylabel(r"$f_n(x)$", fontsize=18)
    ax2.set_xlabel(r"$\Omega$", fontsize=18)
    ax2.set_xlim(-5,5)
    ax2.set_ylim(-0.75,0.75)
    ax2.plot(xs.numpy(), f(xs).numpy(), "k")
    l2, = ax2.plot([],[], "o-")
    

    def animate(i):
        if len(traj[0])>1:
            n = traj[0][0].shape[0]
            l1.set_data(traj[i][0].t().numpy(), traj[i][1].t().numpy())
            l2.set_data(xs, fn(xs, traj[i][1].reshape(n), traj[i][0].reshape(n), s2).numpy())
        else:
            n = traj[0].shape[0]
            l1.set_data(traj[i].numpy(), torch.zeros(n).numpy())
            l2.set_data(xs, fn(xs,torch.ones(n),traj[i].reshape(n),s2).numpy())

    return matplotlib.animation.FuncAnimation(fig, animate, frames=len(traj), interval=1)

In [10]:
def kill(ys, cs, kill_inds):
    survivor_inds = torch.tensor(list(set(range(ys.shape[0]))-set(kill_inds)), dtype=torch.long)
    ys = torch.index_select(ys,0,survivor_inds)
    cs = torch.index_select(cs,0,survivor_inds)
    return ys,cs

def run_traj_scheme1(n, n_steps, alpha=1, ys=None, cs=None, dt=5e-2, quench_steps=1000):
    if ys is None:
        ys = (4*torch.rand([n,1])-2)
    if cs is None:
        cs = torch.randn(ys.shape)
    traj = torch.zeros([n_steps, 2, n])
    total_pop = n
    for i in range(n_steps-quench_steps):
                        
        traj[i,0,:] = ys.flatten()
        traj[i,1,:] = cs.flatten()
        ys, cs = step_charges(ys, cs, n, dt=dt)
        cU = cs*(F(ys)-torch.mm(K(ys,ys.t()),cs)/n)
        
        kill_inds = []
        n_cloned = 0
        for i in range(ys.shape[0]):
            acUi = alpha * cU[i,0]
            if acUi < 0:
                if np.random.rand() < -acUi:
                    kill_inds.append(i)
      
            elif np.random.rand() < acUi:
                ys = torch.cat((ys,ys[i,:].reshape(1,1)))
                cs = torch.cat((cs,cs[i,:].reshape(1,1)))
                n_cloned+=1
                
        # now kill the particles on the kill list
        ys,cs = kill(ys,cs,kill_inds)

        # and resample to control population
        delta_n = n_cloned-len(kill_inds)
        if delta_n<0:
            clone_inds = torch.randint(ys.shape[0], (-delta_n,), dtype=torch.long)
            ys = torch.cat((ys,torch.index_select(ys,0,clone_inds)))
            cs = torch.cat((cs,torch.index_select(cs,0,clone_inds)))
        if delta_n>0:
            kill_inds = torch.randint(n, (delta_n,), dtype=torch.int).numpy()
            ys, cs = kill(ys,cs,kill_inds)


    for i in range(n_steps-quench_steps, n_steps):
        traj[i,0,:] = ys.flatten()
        traj[i,1,:] = cs.flatten()
        ys, cs = step_charges(ys, cs, n, dt=dt)
        #traj.append((ys,cs))

            
    return traj

In [11]:
def run_traj_scheme2(n, n_steps, alpha=1, ys=None, cs=None, dt=5e-2, quench_steps=1000):
    if ys is None:
        ys = (4*torch.rand([n,1])-2)
    if cs is None:
        cs = torch.randn(ys.shape)
    #traj = [(ys,cs)]
    traj = torch.zeros([n_steps, 2, n])
    total_pop = n
    for i in range(n_steps-quench_steps):
        traj[i,0,:] = ys.flatten()
        traj[i,1,:] = cs.flatten()
        ys, cs = step_charges(ys, cs, n, dt=dt)
        cU = cs*(F(ys)-torch.mm(K(ys,ys.t()),cs)/n)
        
        kill_inds = []
        n_cloned = 0
        for i in range(ys.shape[0]):
            acUi = alpha * cU[i,0]
            if acUi < 0:
                if np.random.rand() < -acUi:
                    kill_inds.append(i)
      
            elif np.random.rand() < acUi:
                ys = torch.cat((ys,ys[i,:].reshape(1,1)))
                cs = torch.cat((cs,cs[i,:].reshape(1,1)))
                n_cloned+=1
                
        # now kill the particles on the kill list
        ys,cs = kill(ys,cs,kill_inds)

        # and resample to control population
        delta_n = n_cloned-len(kill_inds)
        if delta_n<0:
            clone_inds = torch.randint(n, (-delta_n,), dtype=torch.long)
            ys = torch.cat((ys,torch.randn(-delta_n,ys.shape[1])))
            cs = torch.cat((cs,torch.zeros(-delta_n,cs.shape[1])))
        if delta_n>0:
            kill_inds = torch.randint(n, (delta_n,), dtype=torch.int).numpy()
            ys, cs = kill(ys,cs,kill_inds)
                
        #traj.append((ys,cs))


    
    for i in range(n_steps-quench_steps, n_steps):
        traj[i,0,:] = ys.flatten()
        traj[i,1,:] = cs.flatten()
        ys, cs = step_charges(ys, cs, n, dt=dt)
        #traj.append((ys,cs))

        
    return traj

In [12]:
n = 200
n_steps = 5000
traj_gd = run_traj(n, n_steps, dt=1e-2)
traj_s1 = run_traj_scheme1(n, n_steps, alpha=0.05, dt=1e-2)
traj_s2 = run_traj_scheme2(n, n_steps, alpha=0.05, dt=1e-2)

In [13]:
Cf = 0.206/3
def err(traj):
    err_traj = []
    for ys,cs in traj:
        ys = ys.reshape(n,1)
        cs = cs.reshape(n,1)
        err_traj.append(torch.mean(Cf-cs*F(ys) + 0.5*cs*torch.mm(K(ys,ys.t()),cs)/n).numpy())
    return np.array(err_traj)

In [14]:
fig, [ax1, ax2] = plt.subplots(1,2,figsize=(8,4))
ys = torch.linspace(-5,5,100).reshape([100,1])
ax1.set_ylabel(r"$E(\boldsymbol{c},\boldsymbol{y})$", fontsize=18)
ax1.set_xlabel(r"$\textrm{training time}$", fontsize=18)
ax1.plot(np.array(err(traj_gd)), label=r"\textrm{GD}")
ax1.plot(err(traj_s1), label=r"\textrm{GD+birth/death}")
ax1.plot(err(traj_s2), label=r"\textrm{GD+birth/death with } $\mu_{\textrm{b}}$", color=colors[3])
ax1.legend(fontsize=12, loc=0)

ax2.set_xlabel(r"$y$", fontsize=12)
ax2.set_xlim(-5,5)
ax2.set_ylim(-0.75,0.75)
xs = torch.linspace(-5,5,200).reshape(200,1)
ax2.plot(xs.numpy(), f(xs).numpy(), "k", label=r"$f_{\textrm{target}}(x)$")
n = traj_gd[0][0].shape[0]
xs = torch.linspace(-5,5,30).reshape(30,1)
ax2.plot(xs.numpy(), fn(xs, traj_gd[-1][1].reshape(n), traj_gd[-1][0].reshape(n), s2).numpy(), "x", color=colors[0])
ax2.plot(xs.numpy(), fn(xs, traj_s1[-1][1].reshape(n), traj_s1[-1][0].reshape(n), s2).numpy(), "o", color=colors[1])
ax2.plot(xs.numpy(), fn(xs, traj_s2[-1][1].reshape(n), traj_s2[-1][0].reshape(n), s2).numpy(), "^", color=colors[3])
ax2.legend()


fig.savefig("gaussian_convergence_full_supp.pdf")

<IPython.core.display.Javascript object>

In [15]:
ys = 0.1*torch.randn([n,1])-2

In [16]:
n = 200
n_steps = 10000
traj_gd_bi = run_traj(n, n_steps, ys=ys, dt=1e-2)
traj_s1_bi = run_traj_scheme1(n, n_steps, ys=ys, alpha=0.05, dt=1e-2)
traj_s2_bi = run_traj_scheme2(n, n_steps, ys=ys, alpha=0.05, dt=1e-2)

In [17]:
fig_bi, [ax1,ax3] = plt.subplots(1,2,figsize=(8,4))
ys = torch.linspace(-5,5,100).reshape([100,1])
ax1.set_ylabel(r"$E(\boldsymbol{c},\boldsymbol{y})$", fontsize=18)
ax1.set_xlabel(r"$\textrm{training time}$", fontsize=18)
ax1.plot(np.array(err(traj_gd_bi)), label=r"\textrm{GD}")
ax1.plot(err(traj_s1_bi), label=r"\textrm{GD+birth/death}")
ax1.plot(err(traj_s2_bi), label=r"\textrm{GD+birth/death with } $\mu_{\textrm{b}}$", color=colors[3])
ax1.legend(fontsize=12)


#ax2.set_xlabel(r"$y$", fontsize=18)
#ax2.set_xlim(-5,5)
#ax2.set_ylim(-0.75,0.75)
#xs = torch.linspace(-5,5,200).reshape(200,1)
#ax2.plot(xs.numpy(), f(xs).numpy(), "k", label="$g(x)$")
#n = traj_gd[0][0].shape[0]
#xs = torch.linspace(-5,5,30).reshape(30,1)
#ax2.plot(xs.numpy(), fn(xs, traj_gd_bi[-1][1].reshape(n), traj_gd_bi[-1][0].reshape(n), s2).numpy(), "x", color=colors[0])
#ax2.plot(xs.numpy(), fn(xs, traj_s1_bi[-1][1].reshape(n), traj_s1_bi[-1][0].reshape(n), s2).numpy(), "o", color=colors[1])
#ax2.plot(xs.numpy(), fn(xs, traj_s2_bi[-1][1].reshape(n), traj_s2_bi[-1][0].reshape(n), s2).numpy(), "^", color=colors[3])
#ax2.legend()

ax3.set_xlabel(r"$y$", fontsize=18)
ax3.set_ylabel(r"$c$", fontsize=18)
ax3.set_xlim(-5,5)
#ns, bins = np.histogram(traj_gd[-1][0].reshape(n).numpy())
#xs = bins[:-1] + 0.5*(bins[1:]-bins[:-1])
#ax3.plot(xs, ns, "o-")
#ns, bins = np.histogram(traj_s1[-1][0].reshape(n).numpy())
#xs = bins[:-1] + 0.5*(bins[1:]-bins[:-1])
#ax3.plot(xs, ns, "o-")
#ns, bins = np.histogram(traj_s2[-1][0].reshape(n).numpy())
#xs = bins[:-1] + 0.5*(bins[1:]-bins[:-1])
#ax3.plot(xs, ns, "o-")
#ax3.set_ylim(-0.75,0.75)
ax3.plot(traj_gd_bi[-1][0].flatten().numpy(), traj_gd_bi[-1][1].flatten().numpy(), ".", color=colors[0])
ax3.plot(traj_s1_bi[-1][0].flatten().numpy(), traj_s1_bi[-1][1].flatten().numpy(), ".", color=colors[1])
ax3.plot(traj_s2_bi[-1][0].flatten().numpy(), traj_s2_bi[-1][1].flatten().numpy(), ".", color=colors[3])
#ax3.legend()

fig_bi.savefig("gaussian_bad_init.pdf")

<IPython.core.display.Javascript object>