Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 22, 2020
1 parent f061a3f commit 5ef0781
Showing 1 changed file with 129 additions and 118 deletions.
247 changes: 129 additions & 118 deletions docs/examples/notebooks/learn/Hessians.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@
"import matplotlib.pyplot as plt\n",
"import matplotlib.patches as patches\n",
"\n",
"def to_bounded(n,bounds):\n",
" a,b = bounds\n",
" return a+0.5*(b-a)*(jnp.sin(n) + 1)\n",
"\n",
"def to_inf(x,bounds):\n",
" a,b = bounds\n",
" return jnp.arcsin(2*(x-a)/(b-a)-1)"
"def to_bounded(n, bounds):\n",
" a, b = bounds\n",
" return a + 0.5 * (b - a) * (jnp.sin(n) + 1)\n",
"\n",
"\n",
"def to_inf(x, bounds):\n",
" a, b = bounds\n",
" return jnp.arcsin(2 * (x - a) / (b - a) - 1)"
]
},
{
Expand All @@ -85,27 +87,29 @@
],
"source": [
"def plot_trfs():\n",
" bounds = [0,5]\n",
" bounds = [0, 5]\n",
"\n",
" f,axarr = plt.subplots(2,1)\n",
" f, axarr = plt.subplots(2, 1)\n",
"\n",
" x = jnp.linspace(bounds[0],bounds[1],1001)\n",
" n = jax.vmap(to_inf,in_axes=(0,None))(x,bounds)\n",
" x = jnp.linspace(bounds[0], bounds[1], 1001)\n",
" n = jax.vmap(to_inf, in_axes=(0, None))(x, bounds)\n",
" ax = axarr[0]\n",
" ax.plot(x,n)\n",
" ax.plot(x, n)\n",
" ax.set_xlabel('x')\n",
" ax.set_ylabel('n')\n",
" ax.set_title(r'$x \\to n$')\n",
"\n",
" n = jnp.linspace(0,10,1001)\n",
" x = jax.vmap(to_bounded,in_axes=(0,None))(n,bounds)\n",
" n = jnp.linspace(0, 10, 1001)\n",
" x = jax.vmap(to_bounded, in_axes=(0, None))(n, bounds)\n",
"\n",
" ax = axarr[1]\n",
" ax.plot(n,x)\n",
" ax.plot(n, x)\n",
" ax.set_xlabel('n')\n",
" ax.set_ylabel('x')\n",
" ax.set_title(r'$n \\to x$')\n",
" f.set_tight_layout(True)\n",
"\n",
"\n",
"plot_trfs()"
]
},
Expand All @@ -116,69 +120,76 @@
"outputs": [],
"source": [
"def func(external_pars):\n",
" x,y = external_pars\n",
" x, y = external_pars\n",
" # a,b = 2*x+y,x-y\n",
" a,b = x,y\n",
" ca,cb = 1,1\n",
" z = (a-ca)**2 + (b-cb)**2\n",
" a, b = x, y\n",
" ca, cb = 1, 1\n",
" z = (a - ca) ** 2 + (b - cb) ** 2\n",
" return z\n",
" \n",
" \n",
"def internal_func(internal_pars,bounds):\n",
" external_pars = jax.vmap(to_bounded)(internal_pars,bounds)\n",
"\n",
"\n",
"def internal_func(internal_pars, bounds):\n",
" external_pars = jax.vmap(to_bounded)(internal_pars, bounds)\n",
" return func(external_pars)\n",
"\n",
"bounds = jnp.array([[-5,5],[-5,5]])\n",
"\n",
"bounds = jnp.array([[-5, 5], [-5, 5]])\n",
"\n",
"\n",
"def plot_func(ax,func,slices,bounds = None):\n",
" grid = x,y = np.mgrid[slices[0],slices[1]]\n",
" X = jnp.swapaxes(grid,0,-1).reshape(-1,2)\n",
"def plot_func(ax, func, slices, bounds=None):\n",
" grid = x, y = np.mgrid[slices[0], slices[1]]\n",
" X = jnp.swapaxes(grid, 0, -1).reshape(-1, 2)\n",
"\n",
" if bounds is not None:\n",
" Z = jax.vmap(func,in_axes=(0,None))(X,bounds)\n",
" Z = jax.vmap(func, in_axes=(0, None))(X, bounds)\n",
" else:\n",
" Z = jax.vmap(func)(X)\n",
" z = jnp.swapaxes(Z.reshape(101,101),0,-1)\n",
" ax.contourf(x,y,z,levels = 100)\n",
" ax.contour(x,y,z,levels = 10, colors = 'w')\n",
" z = jnp.swapaxes(Z.reshape(101, 101), 0, -1)\n",
" ax.contourf(x, y, z, levels=100)\n",
" ax.contour(x, y, z, levels=10, colors='w')\n",
" ax.set_xlabel(r'$n_1$')\n",
" ax.set_xlabel(r'$n_2$')\n",
" if bounds is not None:\n",
" rect = patches.Rectangle([-np.pi/2,-np.pi/2],np.pi,np.pi, alpha = 0.2, facecolor = 'k')\n",
" rect = patches.Rectangle(\n",
" [-np.pi / 2, -np.pi / 2], np.pi, np.pi, alpha=0.2, facecolor='k'\n",
" )\n",
" ax.add_patch(rect)\n",
"\n",
"\n",
"def angle_and_lam(M):\n",
" lam,bases = jnp.linalg.eig(M)\n",
" first = bases[:,0]\n",
" sign = jnp.sign(first[2])\n",
" angle = jnp.arccos(first[0])*180/np.pi\n",
" return lam,sign*angle\n",
" lam, bases = jnp.linalg.eig(M)\n",
" first = bases[:, 0]\n",
" sign = jnp.sign(first[2])\n",
" angle = jnp.arccos(first[0]) * 180 / np.pi\n",
" return lam, sign * angle\n",
"\n",
"\n",
"def draw_covariances(ax,func,slices,bounds = None,scale = 1):\n",
" grid = x,y = np.mgrid[slices[0],slices[1]]\n",
" X = np.swapaxes(grid,0,-1).reshape(-1,2)\n",
"def draw_covariances(ax, func, slices, bounds=None, scale=1):\n",
" grid = x, y = np.mgrid[slices[0], slices[1]]\n",
" X = np.swapaxes(grid, 0, -1).reshape(-1, 2)\n",
"\n",
" if bounds is not None:\n",
" covariance = lambda X,bounds: jnp.linalg.inv(jax.hessian(func)(X,bounds))\n",
" args = (X,bounds)\n",
" axes = (0,None)\n",
" covariance = lambda X, bounds: jnp.linalg.inv(jax.hessian(func)(X, bounds))\n",
" args = (X, bounds)\n",
" axes = (0, None)\n",
" else:\n",
" covariance = lambda X: jnp.linalg.inv(jax.hessian(func)(X))\n",
" args = (X,)\n",
" axes = (0,)\n",
" lams,angles = jax.vmap(angle_and_lam)(jax.vmap(covariance,in_axes=axes)(*args))\n",
" for i,(lam,angle) in enumerate(zip(lams,angles)):\n",
" lams, angles = jax.vmap(angle_and_lam)(jax.vmap(covariance, in_axes=axes)(*args))\n",
" for i, (lam, angle) in enumerate(zip(lams, angles)):\n",
" e = patches.Ellipse(\n",
" X[i],np.sqrt(lam[0])*scale,np.sqrt(lam[1])*scale,angle,\n",
" alpha = 0.5,\n",
" facecolor = 'none',\n",
" edgecolor = 'k'\n",
" X[i],\n",
" np.sqrt(lam[0]) * scale,\n",
" np.sqrt(lam[1]) * scale,\n",
" angle,\n",
" alpha=0.5,\n",
" facecolor='none',\n",
" edgecolor='k',\n",
" )\n",
" ax.add_patch(e)\n",
" ax.set_xlim(slices[0].start,slices[0].stop)\n",
" ax.set_ylim(slices[0].start,slices[0].stop) \n"
" ax.set_xlim(slices[0].start, slices[0].stop)\n",
" ax.set_ylim(slices[0].start, slices[0].stop)"
]
},
{
Expand All @@ -199,20 +210,14 @@
}
],
"source": [
"f,ax = plt.subplots(1,1)\n",
"f.set_size_inches(5,5)\n",
"f, ax = plt.subplots(1, 1)\n",
"f.set_size_inches(5, 5)\n",
"f.set_tight_layout(True)\n",
"plot_func(ax, func, slices = [\n",
" slice(-5,5,101*1j),\n",
" slice(-5,5,101*1j)\n",
" ])\n",
"\n",
"draw_covariances(ax,func,slices = [\n",
" slice(-5,5,10*1j),\n",
" slice(-5,5,10*1j)\n",
" ],\n",
" scale = 1\n",
") \n"
"plot_func(ax, func, slices=[slice(-5, 5, 101 * 1j), slice(-5, 5, 101 * 1j)])\n",
"\n",
"draw_covariances(\n",
" ax, func, slices=[slice(-5, 5, 10 * 1j), slice(-5, 5, 10 * 1j)], scale=1\n",
")"
]
},
{
Expand All @@ -233,21 +238,23 @@
}
],
"source": [
"f,ax = plt.subplots(1,1)\n",
"f.set_size_inches(5,5)\n",
"\n",
"plot_func(ax,internal_func,slices = [\n",
" slice(-np.pi,np.pi,101*1j),\n",
" slice(-np.pi,np.pi,101*1j)\n",
" ],bounds = bounds)\n",
"bounds = jnp.array([[-5,5],[-5,5]])\n",
"draw_covariances(ax,internal_func,slices = [\n",
" slice(-np.pi,np.pi,10*1j),\n",
" slice(-np.pi,np.pi,10*1j)\n",
" ],\n",
" bounds = bounds,\n",
" scale = 1\n",
")\n"
"f, ax = plt.subplots(1, 1)\n",
"f.set_size_inches(5, 5)\n",
"\n",
"plot_func(\n",
" ax,\n",
" internal_func,\n",
" slices=[slice(-np.pi, np.pi, 101 * 1j), slice(-np.pi, np.pi, 101 * 1j)],\n",
" bounds=bounds,\n",
")\n",
"bounds = jnp.array([[-5, 5], [-5, 5]])\n",
"draw_covariances(\n",
" ax,\n",
" internal_func,\n",
" slices=[slice(-np.pi, np.pi, 10 * 1j), slice(-np.pi, np.pi, 10 * 1j)],\n",
" bounds=bounds,\n",
" scale=1,\n",
")"
]
},
{
Expand Down Expand Up @@ -312,8 +319,8 @@
"outputs": [],
"source": [
"def grads_from_n(n):\n",
" x = jax.vmap(to_bounded)(n,bounds)\n",
" J = jax.jacfwd(jax.vmap(to_inf))(x,bounds)\n",
" x = jax.vmap(to_bounded)(n, bounds)\n",
" J = jax.jacfwd(jax.vmap(to_inf))(x, bounds)\n",
" return J"
]
},
Expand Down Expand Up @@ -342,20 +349,20 @@
"metadata": {},
"outputs": [],
"source": [
"def hessian_transform(extr,bounds):\n",
" intr = jax.vmap(to_inf)(extr,bounds)\n",
"def hessian_transform(extr, bounds):\n",
" intr = jax.vmap(to_inf)(extr, bounds)\n",
"\n",
" first = jax.jacfwd(jax.vmap(to_inf))(extr,bounds) \n",
" first = jax.jacfwd(jax.vmap(to_inf))(extr, bounds)\n",
" secnd = jax.jacfwd(grads_from_n)(intr)\n",
" third = jax.grad(internal_func)(intr,bounds)\n",
" third = jax.grad(internal_func)(intr, bounds)\n",
"\n",
" J = jax.jacfwd(jax.vmap(to_inf))(extr,bounds)\n",
" J = jax.jacfwd(jax.vmap(to_inf))(extr, bounds)\n",
"\n",
" a = jnp.einsum('ik,kjl,l->ij',first,secnd,third)\n",
" a = jnp.einsum('ik,kjl,l->ij', first, secnd, third)\n",
"\n",
" int_hessian = jax.hessian(internal_func)(intr,bounds)\n",
" b = jnp.einsum('ik,jl,kl->ij',J,J,int_hessian)\n",
" return int_hessian,a,b,a+b"
" int_hessian = jax.hessian(internal_func)(intr, bounds)\n",
" b = jnp.einsum('ik,jl,kl->ij', J, J, int_hessian)\n",
" return int_hessian, a, b, a + b"
]
},
{
Expand All @@ -379,8 +386,8 @@
}
],
"source": [
"def check_point(extrn,bounds):\n",
" int_hessian,a,b,extrn_hessian = hessian_transform(extrn,bounds)\n",
"def check_point(extrn, bounds):\n",
" int_hessian, a, b, extrn_hessian = hessian_transform(extrn, bounds)\n",
"\n",
" print(f'internal hessian:\\n{int_hessian}')\n",
" print(f'additional part:\\n{a}')\n",
Expand All @@ -390,9 +397,10 @@
" direct_hessian = jax.hessian(func)(extrn)\n",
" print(f'directly computed hessian:\\n{direct_hessian}')\n",
"\n",
"bounds = jnp.array([[-5,5],[-5,5]])\n",
"extrn = jnp.array([1.,1.])\n",
"check_point(extrn,bounds)"
"\n",
"bounds = jnp.array([[-5, 5], [-5, 5]])\n",
"extrn = jnp.array([1.0, 1.0])\n",
"check_point(extrn, bounds)"
]
},
{
Expand All @@ -416,8 +424,8 @@
}
],
"source": [
"extrn = jnp.array([2.,2.])\n",
"check_point(extrn,bounds)"
"extrn = jnp.array([2.0, 2.0])\n",
"check_point(extrn, bounds)"
]
},
{
Expand All @@ -426,28 +434,31 @@
"metadata": {},
"outputs": [],
"source": [
"def compare(ax,scale = 2, index = -1, color = 'k'):\n",
" slices = [slice(-5,5,11j),slice(-5,5,11j)]\n",
"def compare(ax, scale=2, index=-1, color='k'):\n",
" slices = [slice(-5, 5, 11j), slice(-5, 5, 11j)]\n",
"\n",
" grid = x,y = np.mgrid[slices[0],slices[1]]\n",
" X = np.swapaxes(grid,0,-1).reshape(-1,2)\n",
" covariance = lambda X,bounds: hessian_transform(X,bounds)[index]\n",
" args = (X,bounds)\n",
" axes = (0,None)\n",
" grid = x, y = np.mgrid[slices[0], slices[1]]\n",
" X = np.swapaxes(grid, 0, -1).reshape(-1, 2)\n",
" covariance = lambda X, bounds: hessian_transform(X, bounds)[index]\n",
" args = (X, bounds)\n",
" axes = (0, None)\n",
"\n",
" covariances = jax.vmap(covariance,in_axes=axes)(*args)\n",
" covariances = jax.vmap(covariance, in_axes=axes)(*args)\n",
"\n",
" lams,angles = jax.vmap(angle_and_lam)(covariances)\n",
" for i,(lam,angle) in enumerate(zip(lams,angles)):\n",
" lams, angles = jax.vmap(angle_and_lam)(covariances)\n",
" for i, (lam, angle) in enumerate(zip(lams, angles)):\n",
" e = patches.Ellipse(\n",
" X[i],lam[0]*scale,lam[1]*scale,angle,\n",
" alpha = 0.5,\n",
" facecolor = 'none',\n",
" edgecolor = color\n",
" X[i],\n",
" lam[0] * scale,\n",
" lam[1] * scale,\n",
" angle,\n",
" alpha=0.5,\n",
" facecolor='none',\n",
" edgecolor=color,\n",
" )\n",
" ax.add_patch(e)\n",
" ax.set_xlim(slices[0].start,slices[0].stop)\n",
" ax.set_ylim(slices[0].start,slices[0].stop) "
" ax.set_xlim(slices[0].start, slices[0].stop)\n",
" ax.set_ylim(slices[0].start, slices[0].stop)"
]
},
{
Expand All @@ -468,11 +479,11 @@
}
],
"source": [
"f,ax = plt.subplots(1,1)\n",
"compare(ax,scale = 0.2, index = -1)\n",
"compare(ax,scale = 0.2, index = -2, color = 'r')\n",
"compare(ax,scale = 0.2, index = -3, color = 'b')\n",
"plt.gcf().set_size_inches(5,5)"
"f, ax = plt.subplots(1, 1)\n",
"compare(ax, scale=0.2, index=-1)\n",
"compare(ax, scale=0.2, index=-2, color='r')\n",
"compare(ax, scale=0.2, index=-3, color='b')\n",
"plt.gcf().set_size_inches(5, 5)"
]
},
{
Expand All @@ -483,4 +494,4 @@
"source": []
}
]
}
}

0 comments on commit 5ef0781

Please sign in to comment.