Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 22, 2020
1 parent ef4339d commit 3cd0c0c
Showing 1 changed file with 117 additions and 105 deletions.
222 changes: 117 additions & 105 deletions docs/examples/notebooks/learn/Hessians.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@
"import matplotlib.pyplot as plt\n",
"import matplotlib.patches as patches\n",
"\n",
"def to_bounded(n,bounds):\n",
" a,b = bounds\n",
" return a+0.5*(b-a)*(jnp.sin(n) + 1)\n",
"\n",
"def to_inf(x,bounds):\n",
" a,b = bounds\n",
" return jnp.arcsin(2*(x-a)/(b-a)-1)"
"def to_bounded(n, bounds):\n",
" a, b = bounds\n",
" return a + 0.5 * (b - a) * (jnp.sin(n) + 1)\n",
"\n",
"\n",
"def to_inf(x, bounds):\n",
" a, b = bounds\n",
" return jnp.arcsin(2 * (x - a) / (b - a) - 1)"
]
},
{
Expand All @@ -85,27 +87,29 @@
],
"source": [
"def plot_trfs():\n",
" bounds = [0,5]\n",
" bounds = [0, 5]\n",
"\n",
" f,axarr = plt.subplots(2,1)\n",
" f, axarr = plt.subplots(2, 1)\n",
"\n",
" x = jnp.linspace(bounds[0],bounds[1],1001)\n",
" n = jax.vmap(to_inf,in_axes=(0,None))(x,bounds)\n",
" x = jnp.linspace(bounds[0], bounds[1], 1001)\n",
" n = jax.vmap(to_inf, in_axes=(0, None))(x, bounds)\n",
" ax = axarr[0]\n",
" ax.plot(x,n)\n",
" ax.plot(x, n)\n",
" ax.set_xlabel('x')\n",
" ax.set_ylabel('n')\n",
" ax.set_title(r'$x \\to n$')\n",
"\n",
" n = jnp.linspace(0,10,1001)\n",
" x = jax.vmap(to_bounded,in_axes=(0,None))(n,bounds)\n",
" n = jnp.linspace(0, 10, 1001)\n",
" x = jax.vmap(to_bounded, in_axes=(0, None))(n, bounds)\n",
"\n",
" ax = axarr[1]\n",
" ax.plot(n,x)\n",
" ax.plot(n, x)\n",
" ax.set_xlabel('n')\n",
" ax.set_ylabel('x')\n",
" ax.set_title(r'$n \\to x$')\n",
" f.set_tight_layout(True)\n",
"\n",
"\n",
"plot_trfs()"
]
},
Expand All @@ -116,60 +120,69 @@
"outputs": [],
"source": [
"def func(external_pars):\n",
" x,y = external_pars\n",
" z = (x-2)**2 + (y-2)**2\n",
" x, y = external_pars\n",
" z = (x - 2) ** 2 + (y - 2) ** 2\n",
" return z\n",
"\n",
"def internal_func(internal_pars,bounds):\n",
" external_pars = jax.vmap(to_bounded)(internal_pars,bounds)\n",
"\n",
"def internal_func(internal_pars, bounds):\n",
" external_pars = jax.vmap(to_bounded)(internal_pars, bounds)\n",
" return func(external_pars)\n",
"\n",
"def plot_func(ax,func,slices,bounds = None):\n",
" grid = x,y = np.mgrid[slices[0],slices[1]]\n",
" X = jnp.swapaxes(grid,0,-1).reshape(-1,2)\n",
"\n",
"def plot_func(ax, func, slices, bounds=None):\n",
" grid = x, y = np.mgrid[slices[0], slices[1]]\n",
" X = jnp.swapaxes(grid, 0, -1).reshape(-1, 2)\n",
"\n",
" if bounds is not None:\n",
" Z = jax.vmap(func,in_axes=(0,None))(X,bounds)\n",
" Z = jax.vmap(func, in_axes=(0, None))(X, bounds)\n",
" else:\n",
" Z = jax.vmap(func)(X)\n",
" z = jnp.swapaxes(Z.reshape(101,101),0,-1)\n",
" ax.contourf(x,y,z,levels = 100)\n",
" ax.contour(x,y,z,levels = 10, colors = 'w')\n",
" z = jnp.swapaxes(Z.reshape(101, 101), 0, -1)\n",
" ax.contourf(x, y, z, levels=100)\n",
" ax.contour(x, y, z, levels=10, colors='w')\n",
" ax.set_xlabel(r'$n_1$')\n",
" ax.set_xlabel(r'$n_2$')\n",
" if bounds is not None:\n",
" rect = patches.Rectangle([-np.pi/2,-np.pi/2],np.pi,np.pi, alpha = 0.2, facecolor = 'k')\n",
" rect = patches.Rectangle(\n",
" [-np.pi / 2, -np.pi / 2], np.pi, np.pi, alpha=0.2, facecolor='k'\n",
" )\n",
" ax.add_patch(rect)\n",
" f.set_size_inches(5,5)\n",
" f.set_size_inches(5, 5)\n",
"\n",
"\n",
"def angle_and_lam(M):\n",
" lam,bases = jnp.linalg.eig(M)\n",
" angle = -jnp.arccos(bases[0,0])*180/np.pi\n",
" return lam,angle\n",
" lam, bases = jnp.linalg.eig(M)\n",
" angle = -jnp.arccos(bases[0, 0]) * 180 / np.pi\n",
" return lam, angle\n",
"\n",
"def draw_covariances(ax,func,slices,bounds = None,scale = 1):\n",
" grid = x,y = np.mgrid[slices[0],slices[1]]\n",
" X = np.swapaxes(grid,0,-1).reshape(-1,2)\n",
"\n",
"def draw_covariances(ax, func, slices, bounds=None, scale=1):\n",
" grid = x, y = np.mgrid[slices[0], slices[1]]\n",
" X = np.swapaxes(grid, 0, -1).reshape(-1, 2)\n",
"\n",
" if bounds is not None:\n",
" covariance = lambda X,bounds: jnp.linalg.inv(jax.hessian(func)(X,bounds))\n",
" args = (X,bounds)\n",
" axes = (0,None)\n",
" covariance = lambda X, bounds: jnp.linalg.inv(jax.hessian(func)(X, bounds))\n",
" args = (X, bounds)\n",
" axes = (0, None)\n",
" else:\n",
" covariance = lambda X: jnp.linalg.inv(jax.hessian(func)(X))\n",
" args = (X,)\n",
" axes = (0,)\n",
" lams,angles = jax.vmap(angle_and_lam)(jax.vmap(covariance,in_axes=axes)(*args))\n",
" for i,(lam,angle) in enumerate(zip(lams,angles)):\n",
" lams, angles = jax.vmap(angle_and_lam)(jax.vmap(covariance, in_axes=axes)(*args))\n",
" for i, (lam, angle) in enumerate(zip(lams, angles)):\n",
" e = patches.Ellipse(\n",
" X[i],lam[0]*scale,lam[1]*scale,angle,\n",
" alpha = 0.5,\n",
" facecolor = 'none',\n",
" edgecolor = 'k'\n",
" X[i],\n",
" lam[0] * scale,\n",
" lam[1] * scale,\n",
" angle,\n",
" alpha=0.5,\n",
" facecolor='none',\n",
" edgecolor='k',\n",
" )\n",
" ax.add_patch(e)\n",
" ax.set_xlim(slices[0].start,slices[0].stop)\n",
" ax.set_ylim(slices[0].start,slices[0].stop) \n"
" ax.set_xlim(slices[0].start, slices[0].stop)\n",
" ax.set_ylim(slices[0].start, slices[0].stop)"
]
},
{
Expand All @@ -190,20 +203,14 @@
}
],
"source": [
"f,ax = plt.subplots(1,1)\n",
"f.set_size_inches(5,5)\n",
"f, ax = plt.subplots(1, 1)\n",
"f.set_size_inches(5, 5)\n",
"f.set_tight_layout(True)\n",
"plot_func(ax, func, slices = [\n",
" slice(-5,5,101*1j),\n",
" slice(-5,5,101*1j)\n",
" ])\n",
"\n",
"draw_covariances(ax,func,slices = [\n",
" slice(-5,5,10*1j),\n",
" slice(-5,5,10*1j)\n",
" ],\n",
" scale = 2\n",
") "
"plot_func(ax, func, slices=[slice(-5, 5, 101 * 1j), slice(-5, 5, 101 * 1j)])\n",
"\n",
"draw_covariances(\n",
" ax, func, slices=[slice(-5, 5, 10 * 1j), slice(-5, 5, 10 * 1j)], scale=2\n",
")"
]
},
{
Expand All @@ -224,21 +231,23 @@
}
],
"source": [
"f,ax = plt.subplots(1,1)\n",
"f.set_size_inches(5,5)\n",
"\n",
"plot_func(ax,internal_func,slices = [\n",
" slice(-np.pi,np.pi,101*1j),\n",
" slice(-np.pi,np.pi,101*1j)\n",
" ],bounds = bounds)\n",
"bounds = jnp.array([[-5,5],[-5,5]])\n",
"draw_covariances(ax,internal_func,slices = [\n",
" slice(-np.pi,np.pi,10*1j),\n",
" slice(-np.pi,np.pi,10*1j)\n",
" ],\n",
" bounds = bounds,\n",
" scale = 2\n",
")\n"
"f, ax = plt.subplots(1, 1)\n",
"f.set_size_inches(5, 5)\n",
"\n",
"plot_func(\n",
" ax,\n",
" internal_func,\n",
" slices=[slice(-np.pi, np.pi, 101 * 1j), slice(-np.pi, np.pi, 101 * 1j)],\n",
" bounds=bounds,\n",
")\n",
"bounds = jnp.array([[-5, 5], [-5, 5]])\n",
"draw_covariances(\n",
" ax,\n",
" internal_func,\n",
" slices=[slice(-np.pi, np.pi, 10 * 1j), slice(-np.pi, np.pi, 10 * 1j)],\n",
" bounds=bounds,\n",
" scale=2,\n",
")"
]
},
{
Expand Down Expand Up @@ -345,8 +354,8 @@
"outputs": [],
"source": [
"def grads_from_n(n):\n",
" x = jax.vmap(to_bounded)(n,bounds)\n",
" J = jax.jacfwd(jax.vmap(to_inf))(x,bounds)\n",
" x = jax.vmap(to_bounded)(n, bounds)\n",
" J = jax.jacfwd(jax.vmap(to_inf))(x, bounds)\n",
" return J"
]
},
Expand All @@ -356,20 +365,20 @@
"metadata": {},
"outputs": [],
"source": [
"def hessian_transform(extr,bounds):\n",
" intr = jax.vmap(to_inf)(extr,bounds)\n",
"def hessian_transform(extr, bounds):\n",
" intr = jax.vmap(to_inf)(extr, bounds)\n",
"\n",
" first = jax.jacfwd(jax.vmap(to_inf))(extr,bounds) \n",
" first = jax.jacfwd(jax.vmap(to_inf))(extr, bounds)\n",
" secnd = jax.jacfwd(grads_from_n)(intr)\n",
" third = jax.grad(internal_func)(intr,bounds)\n",
" third = jax.grad(internal_func)(intr, bounds)\n",
"\n",
" J = jax.jacfwd(jax.vmap(to_inf))(extr,bounds)\n",
" J = jax.jacfwd(jax.vmap(to_inf))(extr, bounds)\n",
"\n",
" a = jnp.einsum('ik,kjl,l->ij',first,secnd,third)\n",
" a = jnp.einsum('ik,kjl,l->ij', first, secnd, third)\n",
"\n",
" int_hessian = jax.hessian(internal_func)(intr,bounds)\n",
" b = jnp.einsum('ik,jl,kl->ij',J,J,int_hessian)\n",
" return int_hessian,a,b,a+b"
" int_hessian = jax.hessian(internal_func)(intr, bounds)\n",
" b = jnp.einsum('ik,jl,kl->ij', J, J, int_hessian)\n",
" return int_hessian, a, b, a + b"
]
},
{
Expand All @@ -378,28 +387,31 @@
"metadata": {},
"outputs": [],
"source": [
"def compare(ax,scale = 2, index = -1, color = 'k'):\n",
" slices = [slice(-5,5,11j),slice(-5,5,11j)]\n",
"def compare(ax, scale=2, index=-1, color='k'):\n",
" slices = [slice(-5, 5, 11j), slice(-5, 5, 11j)]\n",
"\n",
" grid = x,y = np.mgrid[slices[0],slices[1]]\n",
" X = np.swapaxes(grid,0,-1).reshape(-1,2)\n",
" covariance = lambda X,bounds: hessian_transform(X,bounds)[index]\n",
" args = (X,bounds)\n",
" axes = (0,None)\n",
" grid = x, y = np.mgrid[slices[0], slices[1]]\n",
" X = np.swapaxes(grid, 0, -1).reshape(-1, 2)\n",
" covariance = lambda X, bounds: hessian_transform(X, bounds)[index]\n",
" args = (X, bounds)\n",
" axes = (0, None)\n",
"\n",
" covariances = jax.vmap(covariance,in_axes=axes)(*args)\n",
" covariances = jax.vmap(covariance, in_axes=axes)(*args)\n",
"\n",
" lams,angles = jax.vmap(angle_and_lam)(covariances)\n",
" for i,(lam,angle) in enumerate(zip(lams,angles)):\n",
" lams, angles = jax.vmap(angle_and_lam)(covariances)\n",
" for i, (lam, angle) in enumerate(zip(lams, angles)):\n",
" e = patches.Ellipse(\n",
" X[i],lam[0]*scale,lam[1]*scale,angle,\n",
" alpha = 0.5,\n",
" facecolor = 'none',\n",
" edgecolor = color\n",
" X[i],\n",
" lam[0] * scale,\n",
" lam[1] * scale,\n",
" angle,\n",
" alpha=0.5,\n",
" facecolor='none',\n",
" edgecolor=color,\n",
" )\n",
" ax.add_patch(e)\n",
" ax.set_xlim(slices[0].start,slices[0].stop)\n",
" ax.set_ylim(slices[0].start,slices[0].stop) "
" ax.set_xlim(slices[0].start, slices[0].stop)\n",
" ax.set_ylim(slices[0].start, slices[0].stop)"
]
},
{
Expand All @@ -420,11 +432,11 @@
}
],
"source": [
"f,ax = plt.subplots(1,1)\n",
"compare(ax,scale = 0.2, index = -1)\n",
"compare(ax,scale = 0.2, index = -2, color = 'r')\n",
"compare(ax,scale = 0.2, index = -3, color = 'b')\n",
"plt.gcf().set_size_inches(5,5)"
"f, ax = plt.subplots(1, 1)\n",
"compare(ax, scale=0.2, index=-1)\n",
"compare(ax, scale=0.2, index=-2, color='r')\n",
"compare(ax, scale=0.2, index=-3, color='b')\n",
"plt.gcf().set_size_inches(5, 5)"
]
},
{
Expand All @@ -435,4 +447,4 @@
"source": []
}
]
}
}

0 comments on commit 3cd0c0c

Please sign in to comment.