# Volume I: Data Visualization. Solutions File.

In [None]:
import numpy as np
from scipy.stats import linregress
from scipy.special import binom
from matplotlib import rcParams, colors, pyplot as plt
%matplotlib inline

rcParams["figure.figsize"] = (18,9)

In [None]:
# Problem 1: Anscombe's quartet.

A = np.load("anscombe.npy")
x = np.linspace(0,20,20)
y = .5*x + 3
for i in xrange(4):
    plt.subplot(2,2,i+1)
    plt.plot(A[:,2*i], A[:,2*i+1], 'o')
    plt.plot(x,y)
    plt.title("Data Set {}".format(i+1))
_ = plt.suptitle("Problem 1 Solution", fontsize=28)

The student plots should look almost exactly like the ones shows above, and they should write a sentence or two about each set and how it is unique.
Sample answers might be something like the following:

**Data Set 1**: Randomly scattered around the regression line. The $x$-coordinates appear uniformly distributed.

**Data Set 2**: Parabolic, not linear. The $x$-values are nearly perfectly spaced.

**Data Set 3**: Linear, very close to the regression line. A single vertical outlier skews the regression.

**Data Set 4**: Linear. A single horizontal outlier skews the regression, but the rest of the points all have the same $x$-coordinate.

_**5 Points**_

In [None]:
# Problem 2: The Bernstein Polynomials.
x = np.linspace(0, 1, 200)
for n in range(5):
    for v in range(n+1):
        i = 5*n+v
        plt.subplot(5, 5, i+1)
        plt.plot(x, binom(n,v)*x**v*(1-x)**(n-v), lw=2)
        plt.axis([-.1, 1.1, -.1, 1.1])

        plt.tick_params(which="both", top="off", right="off")
        if i < 20:                  # Remove x-axis label on upper plots.
            plt.tick_params(labelbottom="off")
        if i % 5:                   # Remove y-axis label on right plots.
            plt.tick_params(labelleft="off")
        plt.title(r"$b_{%s,%s}$"%(v,n), fontsize=16)

_ = plt.suptitle("Problem 2 Solution", fontsize=28)

The students should produce at least 10 clean, clearly labeled subplots, though they aren't required to be arranged in the triangular configuration displayed above.

_**5 Points**_

In [None]:
# Problem 3: Visualize the MLB data.
height, weight, age = np.load("MLB.npy").T

def linear_regression_line(x, y):
    slope, intercept = linregress(x,y)[:2]
    domain = np.linspace(x.min(), x.max(), 5)
    plt.plot(domain, domain*slope + intercept, 'k-', lw=2)

# Height vs. Weight, Age as a color.
plt.subplot(221)
plt.scatter(height, weight, c=age)
cbar = plt.colorbar()
cbar.set_label("Age")
linear_regression_line(height, weight)
plt.xlabel("Height (inches)")
plt.ylabel("Weight (pounds)")

# Height vs. Weight, Age as size.
plt.subplot(222)
plt.scatter(height, weight, s=.5*age**2, alpha=.3)
linear_regression_line(height, weight)
plt.xlabel("Height (inches)")
plt.ylabel("Weight (pounds)")

# Age vs. Height
plt.subplot(223)
plt.plot(age, height, 'g.')
linear_regression_line(age, height)
plt.xlabel("Age (years)")
plt.ylabel("Height (inches)")

# Age vs. Weight
plt.subplot(224)
plt.plot(age, weight, 'r.')
linear_regression_line(age, weight)
plt.xlabel("Age (years)")
plt.ylabel("Weight (pounds)")

_ = plt.suptitle("Problem 3 Solution", fontsize=28)

The students should produce at least 1 plot.
The scatter plot in the top right corner is preferable, as it plots the variables that have the most correlation together.
However, the bottom scatter plots are at least useful for showing that age is not very correlated with height or weight in the MLB.

_**5 points**_

In [None]:
# Problem 4: Visualize the earthquake data.
years, magnitudes, longitude, latitude = np.load("earthquakes.npy").T

# Line plot of earthquakes per year.
plt.subplot(221)
xlimits = [1999.5, 2009.5]
counts, bin_edges = np.histogram(years, bins=10, range=xlimits)
bin_centers = (bin_edges[:-1] + bin_edges[1:])/2.
plt.plot(bin_centers, counts, '.-', lw=2, ms=15)

# Linear regression line.
slope, intercept = linregress(bin_centers, counts)[:2]
plt.plot(bin_centers, bin_centers*slope + intercept, 'g')

plt.xlim(xlimits)
plt.xlabel("Year")
plt.ylabel("Number of Earthquakes")
plt.title("Earthquake Frequency by Year")

# Histogram of earthquake magnitudes.
plt.subplot(222)
xlimits = [4.5, 9.5]
plt.hist(magnitudes, range=xlimits, bins=5, log=True, color='g', alpha=.8)
plt.xlim(xlimits)
plt.xlabel("Magnitude")
plt.ylabel("Number of Earthquakes (log scale)")
plt.title("Earthquake Frequency by Magnitude")

# Scatter plot of lattitudes versus longitudes to show where earthquakes happen.
plt.subplot(223)
plt.plot(longitude, latitude, 'k,')

# Identify bigger earthquakes with colored dots.
index = (8 > magnitudes) * (magnitudes > 7)
plt.plot(longitude[index], latitude[index], 'yo', alpha=.6, ms=4, label="Magnitude > 7")
index = (9 > magnitudes) * (magnitudes > 8)
plt.plot(longitude[index], latitude[index], 'co', alpha=.7, ms=8, label="Magnitude > 8")
index = magnitudes > 9
plt.plot(longitude[index], latitude[index], 'ro', alpha=.8, ms=12, label="Magnitude > 9")

plt.ylim(ymax=120)
plt.legend(loc="upper left")
plt.title("Earthquakes by Location")
plt.axis("equal")

# Bad example.
plt.subplot(224)
plt.plot(years, magnitudes, 'r.')
plt.title("Earthquakes by Year, BAD EXAMPLE")

_ = plt.suptitle("Problem 4 Solution", fontsize=28)

The student should produce 2 or 3 plots.
They do not have to be exactly like the ones shown above, but they should address the following questions:
1. How many earthquakes happened every year?
2. How often do stronger earthquakes happen compared to weaker ones?
3. Where do earthquakes happen? Where do the strongest earthquakes happen?

_**10 points**_

In [None]:
# Problem 5: Heat maps of the Rosenbrock function.

rosen = lambda x,y: (1.-x)**2 + 100.*(y-x**2)**2

N = 500
x = np.linspace(-2, 2, N)
y = np.linspace(-1, 3, N)
X, Y = np.meshgrid(x, y)
Z = rosen(X,Y)

plt.subplot(221)
plt.pcolormesh(X, Y, Z, cmap="viridis")
plt.colorbar()

plt.subplot(222)
plt.pcolormesh(X, Y, Z, cmap="viridis", norm=colors.LogNorm(vmin=1e-6))
plt.colorbar()

plt.subplot(223)
plt.contourf(X, Y, Z, 10, cmap="viridis", norm=colors.LogNorm(vmin=1e-6))
plt.colorbar()

plt.subplot(224)
plt.contour(X, Y, Z, 10, cmap="viridis", norm=colors.LogNorm(vmin=1e-6))
plt.colorbar()
plt.plot([1],[1],'r*', ms=10, alpha=.8)


_ = plt.suptitle("Problem 5 Solution", fontsize=28)

The students should produce at least 1 heat map or contour plot that show that the minimum is indeed at (1,1).

_**5 Points**_

In [None]:
# Problem 6: Visualize the country data. Use a bar chart, a histogram, a scatter plot, and a line plot?
countries = np.array(
            ["Austria", "Bolivia", "Brazil", "China", "Finland",
             "Germany", "Hungary", "India", "Japan", "North Korea",
             "Montenegro", "Norway", "Peru", "South Korea", "Sri Lanka",
             "Switzerland", "Turkey", "United Kingdom", "United States", "Vietnam"])
population, gdp, male, female = np.load("countries.npy").T

# Scatter plots =======================================================================================================

plt.subplot(221) # population vs. gdp
plt.scatter(population, gdp, s=100)
linear_regression_line(population, gdp)
plt.xlabel("Population (millions of people)")
plt.ylabel("GDP (billions of US dollars)")

plt.subplot(222) # male vs female height
plt.scatter(male, female, s=100)
linear_regression_line(male, female)
plt.xlabel("Average male height (centimeters)")
plt.ylabel("Average female height (centimeters)")

plt.subplot(223) # male vs female height, population size, gdp color
plt.scatter(male, female, s=population, c=gdp, norm=colors.LogNorm(), alpha=.7)
c_bar = plt.colorbar()
c_bar.set_label("GDP (billions of US dollars)")

linear_regression_line(male, female)
plt.xlabel("Average male height (centimeters)")
plt.ylabel("Average female height (centimeters)")

plt.subplot(224) # male vs female height, population size, gdp color
plt.scatter(male, female, c=population, s=2*np.sqrt(gdp), norm=colors.LogNorm(), alpha=.7)
c_bar = plt.colorbar()
c_bar.set_label("Population (millions of people)")

linear_regression_line(male, female)
plt.xlabel("Average male height (centimeters)")
plt.ylabel("Average female height (centimeters)")

_ = plt.suptitle("Problem 6 Solution (scatter plots)", fontsize=28)

In [None]:
# Histograms ==========================================================================================================

plt.subplot(221) # Population
plt.hist(population, 12)
plt.xlabel("Population (millions of people)")
plt.ylabel("Number of countries")

plt.subplot(222) # GDP
plt.hist(gdp, 8)
plt.xlabel("GDP (billions of US dollars)")
plt.ylabel("Number of countries")

plt.subplot(223) # Male height
plt.hist(male, 12, alpha=.5, label="Male")
plt.hist(female, 12, alpha=.5, label="Female")
plt.xlabel("Average height (centimeters)")
plt.ylabel("Number of countries")
plt.legend(loc="upper left")

_ = plt.suptitle("Problem 6 Solution (histograms)", fontsize=28)

In [None]:
# Bar Charts ==========================================================================================================

plt.subplot(211) # Average male height
positions = np.arange(len(countries))+.5
loc = np.argsort(male)
plt.barh(positions, male[loc], align="center")
plt.yticks(positions, countries[loc])
plt.title("Average male height (centimeters)")

plt.subplot(212) # Average female height
loc = np.argsort(female)
plt.barh(positions, female[loc], align="center")
plt.yticks(positions, countries[loc])
plt.title("Average female height (centimeters)")

_ = plt.suptitle("Problem 6 Solution (bar charts)", fontsize=28)

_**10 Points**_