## MatPlotLib: Python Library for Data Visualization

In [None]:
### importing python standard libraries
from typing import List, Dict

### importing external libraries
import numpy, pandas
from matplotlib import pyplot, colors, lines
from matplotlib.figure import Figure

### rendering plots inside this notebook
%matplotlib inline

## Two Ways of Plotting

#### Stateless plotting

In [None]:
### stateless blank plot
pyplot.plot()

### displaying plot without matplotlib messages
pyplot.show()

**MatPlotLib syntax**

Semicolon (;) at code line end substitutes `pyplot.show()`.

In [None]:
### stateless line plot, [1,2,3,4] (y-axis) only
pyplot.plot([1,2,3,4]);

In [None]:
### stateless line plot, [1,2,3,4] (x-axis), [11,22,33,44] (y-axis)
pyplot.plot([1,2,3,4], [11,22,33,44]);

#### Object oriented plotting

**MatPlotLib concept**

Figure objects represent an entire figure with all subplots.  
Axis objects represent subplots.

In [None]:
### figure object (figure), subplot object (axis)
figure,axis = pyplot.subplots()
figure: Figure; axis: pyplot.Axes

### line plot, [1,2,3,4] (x-axis), [11,22,33,44] (y-axis)
axis.plot([1,2,3,4], [11,22,33,44]);

## Matplotlib Workflow

**MatPlotLib workflow steps**
* preparing data
* creating figure (object oriented plotting)
* plotting data
* customizing plot
* saving figure

In [None]:
### preparing data
ydata: List[int] = [11,22,33,44]
xdata: List[int] = [1,2,3,4]

### creating figure, figure object (figure), subplot object (axis)
figure,axis = pyplot.subplots(figsize=(5,4))
figure: Figure; axis: pyplot.Axes

### line plot, xdata (x-axis), ydata (y-axis)
axis.plot(xdata, ydata)

### customizing the subplot
axis.set_title(label="Sample Line Plot")
axis.set_ylabel(ylabel = "Y Axis")
axis.set_xlabel(xlabel = "X Axis")

### saving figure
figure.savefig(fname="plot-sample.png", dpi=300, bbox_inches="tight", pad_inches=0.1)

**MatPlotLib syntax**

`figure.savefig()` also substitutes `pyplot.show()`

## Plotting from NumPy: Plot Types

#### Line plots

In [None]:
### numpy linear array (linspace_array)
linspace_array: numpy.ndarray = numpy.linspace(start=0, stop=10, num=100)

### figure object (figure), subplot object (axis)
figure,axis = pyplot.subplots()
figure: Figure; axis: pyplot.Axes

### line plot, linspace array (x-axis), linspace array squared (y-axis)
axis.plot(linspace_array, linspace_array**2)

### customizing the subplot
axis.set_title(label="Line Plot from NumPy Linear Array")
axis.set_ylabel(ylabel = "Linear Values Squared")
axis.set_xlabel(xlabel = "Linear Values");

#### Scatter plots

In [None]:
### numpy linear array (linspace_array)
linspace_array: numpy.ndarray = numpy.linspace(start=0, stop=10, num=100)

### figure object (figure), subplot object (axis)
figure,axis = pyplot.subplots()
figure: Figure; axis: pyplot.Axes

### scatter plot, sine of linspace array (y-axis), linspace array (x-axis)
axis.scatter(y=numpy.sin(linspace_array), x=linspace_array)

### customizing the subplot
axis.set_title(label="Scatter Plot from NumPy Linear Array")
axis.set_ylabel(ylabel="Sine of Linear Values")
axis.set_xlabel(xlabel="Linear Values");

#### >>> Bar plots

In [None]:
### dictionary data
butter_prices = {
    "Almond Butter": 10,
    "Peanut Butter": 8,
    "Cashew Butter": 12}
butter_prices

In [None]:
### figure object (figure), subplot object (axis)
figure,axis = pyplot.subplots()

### vertical bar plot, dictionary values (height), dictionary keys (x-axis)
axis.bar(
    height=butter_prices.values(),
    x=list(butter_prices.keys()))

### customizing subplot, subplot title (title), y-axis legend (ylabel)
axis.set(
    title="Nut Butter Store Prices",
    ylabel="Price ($)");

In [None]:
### figure object (figure), subplot object (axis)
figure,axis = pyplot.subplots()

### horizontal bar plot, dictionary keys (y-axis), dictionary values (width)
axis.barh(
    y=list(butter_prices.keys()),
    width=butter_prices.values())

### customizing subplot, subplot title (title), x-axis legend (xlabel)
axis.set(
    title="Nut Butter Store Prices",
    xlabel="Price ($)");

#### Histogram plots

**histogram plot concept**

histogram plots display the probability distribution of a dataset  
first, the dataset entire span (max value - min value) is divided into equally sized ranges (bins)  
heights of histogram columns represent the number of items within each bin

In [None]:
### numpy random normal distribution (randn_array), contains 1000 numbers
randn_array: numpy.ndarray = numpy.random.randn(1000)

### figure object (figure), subplot object (axis)
figure,axis = pyplot.subplots()
figure: pyplot.Figure; axis: pyplot.Axes

### histogram plot, number of items (y-axis), randn_array bins (x-axis)
axis.hist(x=randn_array);

#### Subplots: Option 1

In [None]:
### numpy linear array (linspace_array), contains 100 numbers (num)
linspace_array: numpy.ndarray = numpy.linspace(
    start=0,
    stop=10,
    num=100)

### dictionary data
butter_prices: Dict[str,int] = {
    "Almond Butter": 10,
    "Peanut Butter": 8,
    "Cashew Butter": 12}

### numpy random normal distribution array (randn_array), contains 1000 numbers
randn_array: numpy.ndarray = numpy.random.randn(1000)

### figure object (figure), subplot objects tuple (ax-row-column)
### grid rows (nrows), grid columns (ncols), (width,height) in inches (figsize)
figure,((ax11,ax12),(ax21,ax22)) = pyplot.subplots(
    nrows=2,
    ncols=2,
    figsize=(10,7))
figure: pyplot.Figure

### line plot (ax11), linspace_array (x-axis), linspace_array squared (y-axis)
ax11: pyplot.Axes
ax11.plot(linspace_array, linspace_array**2)

### numpy random array (rand), contains 10 numbers
### scatter plot (ax12), numpy random array (y-axis), numpy random array (x-axis)
ax12: pyplot.Axes
ax12.scatter(
    y=numpy.random.rand(10),
    x=numpy.random.rand(10))

### vertical bar plot (ax21), dictionary values (height), dictionary keys (x-axis)
ax21: pyplot.Axes
ax21.bar(
    height=butter_prices.values(),
    x=list(butter_prices.keys()))

### histogram plot (ax22), number of items (y-axis), numpy random normal distribution array bins (x-axis)
ax22: pyplot.Axes
ax22.hist(x=randn_array);

#### Subplots: Option 2

In [None]:
### numpy linearly spaced array (linspace_array), contains 100 numbers (num)
linspace_array: numpy.ndarray = numpy.linspace(
    start=0,
    stop=10,
    num=100)

### butter prices dictionary
butter_prices: Dict[str,int] = {
    "Almond Butter": 10,
    "Peanut Butter": 8,
    "Cashew Butter": 12}

### numpy random normal distribution array (randnorm_array), contains 1000 numbers
randnorm_array: numpy.ndarray = numpy.random.randn(1000)

### figure object (figure), subplot objects 2x2 numpy array (axes), (width,height) in inches (figsize)
figure,axes = pyplot.subplots(
    nrows=2,
    ncols=2,
    figsize=(10,7))
figure: pyplot.Figure; axes: numpy.ndarray

### line plot (axes[0,0]), linearly spaced array (x-axis), linearly spaced array squared (y-axis)
axes[0,0].plot(linspace_array, linspace_array**2)

### numpy random numbers (0-1) array (rand), contains 10 numbers
### scatter plot (axes[0,1]), random numbers array (y-axis), random numbers array (x-axis)
axes[0,1].scatter(
    y=numpy.random.rand(10),
    x=numpy.random.rand(10))

### horizontal bar plot (axes[1,0]), butter prices values (height), butter prices keys (x-axis)
axes[1,0].bar(
    height=butter_prices.values(),
    x=list(butter_prices.keys()))

### histogram plot (axes[1,1]), number of items (y-axis), random normal distribution array bins (x-axis)
axes[1,1].hist(x=randn_array);

**Opinion:**  
I personally do not like this subplots option 2.  
Despite my best efforts, I am unable to make VsCode recognize the content type of `axes` variable.

## Plotting from Pandas (x)

#### Stateless plotting from random series

In [None]:
### numpy random normal distribution array with 1000 numbers
randnorm_array: numpy.ndarray = numpy.random.randn(1000)
### pandas series of dates with 1000 dates
date_series: pandas.Series = pandas.date_range(start="1/1/2020", periods=1000)
### pandas series from randnorm array, indexed by date series
randnorm_series: pandas.Series = pandas.Series(data=randnorm_array, index=date_series)


### line plot, cumulative sum of randnorm series (y-axis), randnorm series index (x-axis)
randnorm_series.cumsum().plot();

#### Stateless plotting from the Car Sales Dataframe

In [None]:
### importing from car sales dataset
car_sales: pandas.DataFrame = pandas.read_csv("data-car-sales.csv")

### converting price column str > int
car_sales["Price"] = car_sales["Price"].str.replace("[\$\,\.]", "", regex=True)
car_sales["Price"] = car_sales["Price"].str[:-2] # removing decimals
car_sales["Price"] = car_sales["Price"].astype(int)

### creating sale date column
date_series: pandas.Series = pandas.date_range(start="1/1/2020", periods=10)
car_sales["Sale Date"] = date_series

### creating total sales column
car_sales["Total Sales"] = car_sales["Price"].cumsum()
car_sales

In [None]:
### line plot, total sales column (y-axis), sale date column (x-axis)
car_sales.plot(y="Total Sales", x="Sale Date", legend=False)

### customizing the plot
pyplot.title(label="Revenue from Car Sales")
pyplot.ylabel(ylabel="Total Sales ($)");

In [None]:
### scatter plot, price column (y-axis), odometer column (x-axis)
car_sales.plot.scatter(y="Price", x="Odometer (KM)", s=150, c="steelblue")

### customizing the plot
pyplot.title(label="Car Sale Price and Odometer Reading")
pyplot.ylabel(ylabel="Car Sale Price ($)")
pyplot.xlabel(xlabel="Odometer (km)")
pyplot.xticks(rotation=45);

In [None]:
### vertical bar plot, odometer column (height), make column (x-axis)
car_sales.plot.bar(y="Odometer (KM)", x="Make", legend=False)

### customizing the plot
pyplot.title(label="Odometer Reading and Car Make")
pyplot.ylabel(ylabel="Odometer Reading (km)")
pyplot.xlabel(xlabel="Car Make")
pyplot.xticks(rotation=45);

In [None]:
### histogram plot, distribution of odometer column values
car_sales["Odometer (KM)"].plot.hist(bins=10)

### customizing the plot
pyplot.title(label="Distribution of Odometer Readings")
pyplot.ylabel(ylabel="Value Counts")
pyplot.xlabel(xlabel="Odometer (km)")
pyplot.xticks(rotation=45);

#### Stateless plotting from the Heart Disease Dataframe

**Statistics concept**

An outlier is at least 3 standard deviations away from mean.

In [None]:
### importing from heart disease dataset
heart_disease = pandas.read_csv("data-heart-disease.csv")

### histogram plot, distribution of age column values
heart_disease["age"].plot.hist(bins=10)

### customizing the plot
pyplot.title(label="Distribution of Patient Age")
pyplot.ylabel(ylabel="Patient Count")
pyplot.xlabel(xlabel="Patient Age in Years");

#### Object oriented plotting from the Heart Disease Dataframe (x)

**Matplotlib concept**

More advanced figures usually require object oriented plotting.

In [None]:
### creating dataframe for patients of age over 50
over_fifty = heart_disease.loc[50 < heart_disease["age"]]
over_fifty

In [None]:
### stateless scatter plot, chol column (y-axis), age column (x-axis), grouping by target column (c)
over_fifty.plot.scatter(y="chol", x="age", c="target", colorbar=True, figsize=(10,6))
pyplot.legend()

### customizing the plot
pyplot.title(label="Cholesterol and Age in Heart Disease")
pyplot.ylabel(ylabel="Serum Cholesterol (mg/dl)")
pyplot.xlabel(xlabel="Patient Age in Years");

**Opinion:**  
There is a colorbar instead of a legend.  
When the colorbar is visible, x-axis labels disappear.  
This plotting method does not do what I want. Let's try object oriented plotting!

In [None]:
### figure object (figure), subplot object (axis)
figure,axis = pyplot.subplots(figsize=(10,6))
figure: Figure; axis: pyplot.Axes

### scatter plot, chol column (y-axis), age column (x-axis), grouping by target column (c)
scatter_colors: List[str] = ["lightblue", "salmon"]
scatter_cmap: colors.LinearSegmentedColormap = colors.LinearSegmentedColormap.from_list(
    name="temp",
    colors=scatter_colors)
scatter_plot = axis.scatter(y=over_fifty["chol"], x=over_fifty["age"], c=over_fifty["target"], cmap=scatter_cmap)
### horizontal line, mean of chol column (y-axis)
chol_mean: float = over_fifty["chol"].mean()
hline_plot: lines.Line2D = axis.axhline(y=chol_mean, linestyle="--", label="Label")
### legend
axis_handles: List[lines.Line2D] = scatter_plot.legend_elements()[0] + axis.get_legend_handles_labels()[0]
axis_labels: List[str] = [
    "Absence of Heart Disease",
    "Presence of Heart Disease",
    f"Chol Mean = {chol_mean:.1f} mg/dl"]
axis.legend(handles=axis_handles, labels=axis_labels)

### customizing the plot
axis.set_title(label="Cholesterol and Age in Heart Disease")
axis.set_ylabel(ylabel="Serum Cholesterol (mg/dl)")
axis.set_xlabel(xlabel="Patient Age (Years)")
axis.set_xlim([50,80]);

#### Subplots plotting from the Heart Disease Dataframe

In [None]:
### importing from heart disease dataset
heart_disease = pandas.read_csv("data-heart-disease.csv")
### creating dataframe for patients of age over 50
over_fifty: pandas.DataFrame = heart_disease.loc[50 < heart_disease["age"]]

### figure object (figure), subplot objects (ax##)
figure,(ax11,ax21) = pyplot.subplots(nrows=2, ncols=1, figsize=(10,10))
figure: Figure

### creating ax11 subplot ----------------------------------------------------------------------------------------------

### scatter plot, over_fifty[chol] (y-axis), over_fifty[age] (x-axis), grouping by over_fifty[target] (c)
ax11: pyplot.Axes
scat11_plot = ax11.scatter(y=over_fifty["chol"], x=over_fifty["age"], c=over_fifty["target"])
### horizontal line, mean of over_fifty[chol] (y-axis)
chol_mean: float = over_fifty["chol"].mean()
hl11_plot: lines.Line2D = ax11.axhline(y=chol_mean, linestyle="--", label="Label")
### legend
ax11_handles: List[lines.Line2D] = scat11_plot.legend_elements()[0] + ax11.get_legend_handles_labels()[0]
ax11_labels: List[str] = [
    "Absence of Heart Disease",
    "Presence of Heart Disease",
    f"Chol Mean = {chol_mean:.1f} mg/dl"]
ax11.legend(handles=ax11_handles, labels=ax11_labels)


### creating ax21 subplot ----------------------------------------------------------------------------------------------

### scatter plot, over_fifty[thalach] (y-axis), over_fifty[age] (x-axis), grouping by over_fifty[target] (c)
ax21: pyplot.Axes
scat21_plot = ax21.scatter(y=over_fifty["thalach"], x=over_fifty["age"], c=over_fifty["target"])
### horizontal line, mean of over_fifty[thalach] (y-axis)
thalach_mean: float = over_fifty["thalach"].mean()
hl21_plot: lines.Line2D =  ax21.axhline(y=thalach_mean, linestyle="--", label="Label")
### legend
ax21_handles: List[lines.Line2D] = scat21_plot.legend_elements()[0] + ax21.get_legend_handles_labels()[0]
ax21_labels: List[str] = [
    "Absence of Heart Disease",
    "Presence of Heart Disease",
    f"Heart Rate Mean = {thalach_mean:.1f}/min"]
ax21.legend(handles=ax21_handles, labels=ax21_labels)


### customizing the figure ---------------------------------------------------------------------------------------------
figure.suptitle(t="Heart Disease Analysis", fontsize=16, fontweight="bold") # figure title
# figure.subplots_adjust(top=0.91, hspace=0.25)
figure.tight_layout(pad=2.5, h_pad=5)
ax11.set_title(label="Cholesterol and Age in Heart Disease")
ax11.set_ylabel(ylabel="Serum Cholesterol (mg/dl)")
ax11.set_xlabel(xlabel="Patient Age (Years)")
ax11.set_xlim([50,80])
ax21.set_title(label="Maximum Heart Rate and Age in Heart Disease")
ax21.set_ylabel(ylabel="Maximum Heart Rate (1/min)")
ax21.set_xlabel(xlabel="Patient Age (Years)")
ax21.set_xlim([50,80]);

## Customizing Plots with Styles and Colormaps

#### Applying styles

**MatPlotLib styles**

Collections of styling elements used for changing the appearance of entire plots.  
Styles are activated by the `pyplot.style.use("style_name")` command.  
See [MatPlotLib Styles Reference](https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html).

In [None]:
### styles available in matplotlib library
pyplot.style.available

In [None]:
### creating random normal distribution dataframe
normal_df: pandas.DataFrame = pandas.DataFrame(data=numpy.random.randn(10,4), columns=["a","b","c","d"])

### setting plot style
pyplot.style.use("ggplot")
### bar plot, dataframe values (y-axis), dataframe rows (x-axis), grouping by dataframe columns (legend)
normal_df.plot.bar(figsize=(6,4))
pyplot.legend(title="Columns", loc="upper left", bbox_to_anchor=(1, 0.95))

### customizing the plot
pyplot.title(label="Random Normal Distribution Bar Graph")
pyplot.ylabel(ylabel="Random Value")
pyplot.xlabel(xlabel="Dataframe Rows")
pyplot.xticks(rotation=0);

#### Applying colormaps

**MatPlotLib colormaps**

Predefined color collections used for mapping values to colors.  
Colormaps are activated by the `cmap="colormap_name"` argument.  
See [MatPlotLib Colormap Documentation](https://matplotlib.org/stable/tutorials/colors/colormaps.html)

In [None]:
### preparing data -----------------------------------------------------------------------------------------------------

### loading heart disease dataset from file
heart_disease: pandas.DataFrame = pandas.read_csv(filepath_or_buffer="data-heart-disease.csv")
### creating dataframe for patients of age over 50
over_fifty: pandas.DataFrame = heart_disease.loc[50 < heart_disease["age"]]

### creating figure ----------------------------------------------------------------------------------------------------

### setting figure style
pyplot.style.use("seaborn-whitegrid")
### figure object (figure), subplot objects (ax##)
figure,(ax11,ax21) = pyplot.subplots(nrows=2, ncols=1, figsize=(10,10))
figure: Figure

### plotting ax11 ------------------------------------------------------------------------------------------------------

### scatter plot, over_fifty[chol] (y-axis), over_fifty[age] (x-axis), grouping by over_fifty[target] (c)
ax11: pyplot.Axes
ax11_scatter = ax11.scatter(y=over_fifty["chol"], x=over_fifty["age"], c=over_fifty["target"], cmap="summer")
### horizontal line, mean of over_fifty[chol] (y-axis)
chol_mean: float = over_fifty["chol"].mean()
ax11_hline: lines.Line2D = ax11.axhline(y=chol_mean, linestyle="--", color="orangered", label="Label")
### legend
ax11_handles: List[lines.Line2D] = ax11_scatter.legend_elements()[0] + ax11.get_legend_handles_labels()[0]
ax11_labels: List[str] = ["Healthy", "Heart Disease", f"Chol Mean = {chol_mean:.1f} mg/dl"]
ax11.legend(handles=ax11_handles, labels=ax11_labels, frameon=True)

### plotting ax21 ------------------------------------------------------------------------------------------------------

### scatter plot, over_fifty[thalach] (y-axis), over_fifty[age] (x-axis), grouping by over_fifty[target] (c)
ax21: pyplot.Axes
ax21_scatter = ax21.scatter(y=over_fifty["thalach"], x=over_fifty["age"], c=over_fifty["target"], cmap="winter")
### horizontal line, mean of over_fifty[thalach] (y-axis)
thalach_mean: float = over_fifty["thalach"].mean()
ax21_hline: lines.Line2D =  ax21.axhline(y=thalach_mean, linestyle="--", color="orangered", label="Label")
### legend
ax21_handles: List[lines.Line2D] = ax21_scatter.legend_elements()[0] + ax21.get_legend_handles_labels()[0]
ax21_labels: List[str] = ["Healthy", "Heart Disease", f"Thalach Mean = {thalach_mean:.1f} / min"]
ax21.legend(handles=ax21_handles, labels=ax21_labels, frameon=True)

### customizing the figure ---------------------------------------------------------------------------------------------

### figure title and layout
figure.suptitle(t="Heart Disease Analysis", fontsize=16, fontweight="bold")
figure.tight_layout(pad=2.5, h_pad=5)

### customizing ax11 subplot
ax11.set_title(label="Serum Cholesterol (Chol) and Age in Heart Disease")
ax11.set_ylabel(ylabel="Serum Cholesterol (mg/dl)")
ax11.set_ylim([100,600])
ax11.set_xlabel(xlabel="Patient Age (Years)")
ax11.set_xlim([50,80])

### customizing ax21 subplot
ax21.set_title(label="Maximum Achievable Heart Rate (Thalach) and Age in Heart Disease")
ax21.set_ylabel(ylabel="Maximum Achievable Heart Rate (1/min)")
ax21.set_ylim([60,200])
ax21.set_xlabel(xlabel="Patient Age (Years)")
ax21.set_xlim([50,80]);

## Saving Plots

In [None]:
### saving heart disease analysis figure (previous cell)
figure.savefig(fname="plot-heart-disease.png", dpi=300, bbox_inches="tight", pad_inches=0.1)