Skip to content

Commit 075fe55

Browse files
committed
Added exercise for plotting
1 parent 2941d88 commit 075fe55

File tree

2 files changed

+249
-0
lines changed

2 files changed

+249
-0
lines changed

Lecture-4-Matplotlib.ipynb

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3306,6 +3306,191 @@
33063306
"* http://www.loria.fr/~rougier/teaching/matplotlib - A good matplotlib tutorial.\n",
33073307
"* http://scipy-lectures.github.io/matplotlib/matplotlib.html - Another good matplotlib reference.\n"
33083308
]
3309+
},
3310+
{
3311+
"cell_type": "markdown",
3312+
"metadata": {},
3313+
"source": [
3314+
"## Exercise: Squared exponential distance\n",
3315+
"\n"
3316+
]
3317+
},
3318+
{
3319+
"cell_type": "markdown",
3320+
"metadata": {},
3321+
"source": [
3322+
"<div class=\"alert alert-success\">\n",
3323+
"This next exercise is going to be broken into two steps: first, we will write a function to compute a distance between two numbers, and then we will plot the result.\n",
3324+
"</div>"
3325+
]
3326+
},
3327+
{
3328+
"cell_type": "markdown",
3329+
"metadata": {},
3330+
"source": [
3331+
"We are going to compute a *squared exponential distance*, which is given by the following equation:\n",
3332+
"\n",
3333+
"$$\n",
3334+
"D_{ij}=d(x_i, y_j)=e^\\frac{-(x_i-y_j)^2}{2}\n",
3335+
"$$\n",
3336+
"\n",
3337+
"where $x$ is a $n$-length vector and $y$ is a $m$-length vector. The variable $x_i$ corresponds to the $i^{th}$ element of $x$, and the variable $y_j$ corresponds to the $j^{th}$ element of $y$."
3338+
]
3339+
},
3340+
{
3341+
"cell_type": "code",
3342+
"execution_count": 2,
3343+
"metadata": {},
3344+
"outputs": [],
3345+
"source": [
3346+
"def squared_exponential(x, y):\n",
3347+
" \"\"\"Computes a matrix of squared exponential distances between\n",
3348+
" the elements of x and y.\n",
3349+
"\n",
3350+
" Hint: your solution shouldn't require more than five lines\n",
3351+
" of code, including the return statement.\n",
3352+
"\n",
3353+
" Parameters\n",
3354+
" ----------\n",
3355+
" x : numpy array with shape (n,)\n",
3356+
" y : numpy array with shape (m,)\n",
3357+
"\n",
3358+
" Returns\n",
3359+
" -------\n",
3360+
" (n, m) array of distances\n",
3361+
"\n",
3362+
" \"\"\"\n",
3363+
" # YOUR CODE HERE\n",
3364+
" raise NotImplementedError()"
3365+
]
3366+
},
3367+
{
3368+
"cell_type": "code",
3369+
"execution_count": null,
3370+
"metadata": {},
3371+
"outputs": [],
3372+
"source": [
3373+
"from numpy.testing import assert_allclose\n",
3374+
"from nose.tools import assert_equal\n",
3375+
"\n",
3376+
"d = squared_exponential(np.arange(3), np.arange(4))\n",
3377+
"assert_equal(d.shape, (3, 4))\n",
3378+
"assert_equal(d.dtype, float)\n",
3379+
"assert_allclose(d[0], [ 1. , 0.60653066, 0.13533528, 0.011108997])\n",
3380+
"assert_allclose(d[1], [ 0.60653066, 1. , 0.60653066, 0.13533528])\n",
3381+
"assert_allclose(d[2], [ 0.13533528, 0.60653066, 1. , 0.60653066])\n",
3382+
"\n",
3383+
"d = squared_exponential(np.arange(4), np.arange(3))\n",
3384+
"assert_equal(d.shape, (4, 3))\n",
3385+
"assert_equal(d.dtype, float)\n",
3386+
"assert_allclose(d[:, 0], [ 1. , 0.60653066, 0.13533528, 0.011108997])\n",
3387+
"assert_allclose(d[:, 1], [ 0.60653066, 1. , 0.60653066, 0.13533528])\n",
3388+
"assert_allclose(d[:, 2], [ 0.13533528, 0.60653066, 1. , 0.60653066])\n",
3389+
"\n",
3390+
"print(\"Success!\")"
3391+
]
3392+
},
3393+
{
3394+
"cell_type": "markdown",
3395+
"metadata": {},
3396+
"source": [
3397+
"<div class=\"alert alert-success\">\n",
3398+
"Now, let's write a function to visualize these distances. Implement <code>plot_squared_exponential</code> to plot these distances using the <code>matshow</code> function.\n",
3399+
"</div>"
3400+
]
3401+
},
3402+
{
3403+
"cell_type": "markdown",
3404+
"metadata": {},
3405+
"source": [
3406+
"<div class=\"alert alert-warning\">Be sure to check the docstring of <code>plot_squared_exponential</code> for additional constraints -- going forward, in general we will give brief instructions in green cells like the one above, and more detailed instructions in the function comment.</div>"
3407+
]
3408+
},
3409+
{
3410+
"cell_type": "code",
3411+
"execution_count": null,
3412+
"metadata": {},
3413+
"outputs": [],
3414+
"source": [
3415+
"def plot_squared_exponential(axis, x, y):\n",
3416+
" \"\"\"Plot the squared exponential distance between the elements\n",
3417+
" of x and y. Make sure to:\n",
3418+
"\n",
3419+
" * call the `squared_exponential` function to compute the distances \n",
3420+
" between `x` and `y`\n",
3421+
" * use the grayscale colormap\n",
3422+
" * remember to include axis labels and a title. \n",
3423+
" * turn off the tick marks on both axes\n",
3424+
"\n",
3425+
" Parameters\n",
3426+
" ----------\n",
3427+
" axis : matplotlib axis object\n",
3428+
" The axis on which to plot the distances\n",
3429+
" x : numpy array with shape (n,)\n",
3430+
" y : numpy array with shape (m,)\n",
3431+
"\n",
3432+
" \"\"\"\n",
3433+
" # YOUR CODE HERE\n",
3434+
" raise NotImplementedError()"
3435+
]
3436+
},
3437+
{
3438+
"cell_type": "markdown",
3439+
"metadata": {},
3440+
"source": [
3441+
"Now, let's see what our squared exponential function actually looks like. To do this we will look at 100 linearly spaced values from -2 to 2 on the x axis and 100 linearly spaced values from -2 to 2 on the y axis. To generate these values we will use the function `np.linspace`."
3442+
]
3443+
},
3444+
{
3445+
"cell_type": "code",
3446+
"execution_count": null,
3447+
"metadata": {},
3448+
"outputs": [],
3449+
"source": [
3450+
"x = np.linspace(-2, 2, 100)\n",
3451+
"y = np.linspace(-2, 2, 100)\n",
3452+
"\n",
3453+
"figure, axis = plt.subplots()\n",
3454+
"plot_squared_exponential(axis, x, y)"
3455+
]
3456+
},
3457+
{
3458+
"cell_type": "code",
3459+
"execution_count": null,
3460+
"metadata": {},
3461+
"outputs": [],
3462+
"source": [
3463+
"from plotchecker import assert_image_allclose, get_image_colormap\n",
3464+
"from numpy.testing import assert_array_equal\n",
3465+
"\n",
3466+
"# generate some random data\n",
3467+
"x = np.random.rand(100)\n",
3468+
"y = np.random.rand(75)\n",
3469+
"\n",
3470+
"# plot it\n",
3471+
"figure, axis = plt.subplots()\n",
3472+
"plot_squared_exponential(axis, x, y)\n",
3473+
"\n",
3474+
"# check image data\n",
3475+
"assert_image_allclose(axis, squared_exponential(x, y))\n",
3476+
"\n",
3477+
"# check that the 'gray' colormap was used\n",
3478+
"assert_equal(get_image_colormap(axis), 'gray')\n",
3479+
"\n",
3480+
"# check axis labels and title\n",
3481+
"assert axis.get_xlabel() != ''\n",
3482+
"assert axis.get_ylabel() != ''\n",
3483+
"assert axis.get_title() != ''\n",
3484+
"\n",
3485+
"# check that ticks are removed\n",
3486+
"assert_array_equal(axis.get_xticks(), [])\n",
3487+
"assert_array_equal(axis.get_yticks(), [])\n",
3488+
"\n",
3489+
"# close the plot\n",
3490+
"plt.close(figure)\n",
3491+
"\n",
3492+
"print(\"Success!\")"
3493+
]
33093494
}
33103495
],
33113496
"metadata": {

plotchecker.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import matplotlib
2+
import numpy as np
3+
from numpy.testing import assert_array_equal, assert_allclose
4+
5+
6+
def get_data(ax):
7+
lines = ax.get_lines()
8+
if len(lines) > 0:
9+
xydata = np.concatenate(
10+
[x.get_xydata() for x in lines], axis=0)
11+
12+
else:
13+
collections = ax.collections
14+
if len(collections) > 0:
15+
xydata = np.concatenate(
16+
[x.get_offsets() for x in collections], axis=0)
17+
18+
else:
19+
raise ValueError("no data found")
20+
21+
return xydata
22+
23+
24+
def get_label_text(ax):
25+
text = [x for x in ax.get_children()
26+
if isinstance(x, matplotlib.text.Text)]
27+
text = [x for x in text if x.get_text() != ax.get_title()]
28+
text = [x for x in text if x.get_text().strip() != '']
29+
return [x.get_text().strip() for x in text]
30+
31+
32+
def get_label_pos(ax):
33+
text = [x for x in ax.get_children()
34+
if isinstance(x, matplotlib.text.Text)]
35+
text = [x for x in text if x.get_text() != ax.get_title()]
36+
text = [x for x in text if x.get_text().strip() != '']
37+
return np.vstack([x.get_position() for x in text])
38+
39+
40+
def get_image(ax):
41+
images = ax.get_images()
42+
if len(images) == 0:
43+
raise ValueError("Expected one image, but there were none. Did you remember to call the plotting function (probably `matshow` or `imshow`)?")
44+
if len(images) > 1:
45+
raise ValueError("Expected one image, but there were {}. Did you call the plotting function (probably `matshow` or `imshow`) more than once?".format(len(images)))
46+
return images[0]
47+
48+
49+
def get_imshow_data(ax):
50+
image = get_image(ax)
51+
return image._A
52+
53+
def get_image_colormap(ax):
54+
image = get_image(ax)
55+
return image.cmap.name
56+
57+
def assert_image_equal(ax, arr):
58+
data = get_imshow_data(ax)
59+
assert_array_equal(data, arr)
60+
61+
62+
def assert_image_allclose(ax, arr):
63+
data = get_imshow_data(ax)
64+
assert_allclose(data, arr, atol=1e-6)

0 commit comments

Comments
 (0)