# Code to generate the plots used in the presentation.

In [None]:
# uv venv 
# uv pip install pandas matplotlib seaborn scipy scikit-learn statsmodels palmerpenguins pip ipykernel ipywidgets setuptools jinja2 pydataset gapminder pandas_datareader yfinance cartopy

In [None]:
import re
import warnings
import math
import numpy as np
import pandas as pd
import scipy, sklearn, statsmodels
from numpy.linalg import lstsq

import matplotlib 
import matplotlib.pyplot as plt
import matplotlib.patheffects
import seaborn as sns
import seaborn
from adjustText import adjust_text

import sklearn.datasets
import statsmodels.datasets
from statsmodels.datasets import get_rdataset
from pydataset import data  # Many datasets from R
from palmerpenguins import load_penguins
from gapminder import gapminder
from pandas_datareader import wb
gapminder.to_csv( "gapminder.csv", index = False )
import yfinance as yf
import cartopy

In [None]:
def remove_scientific_notation_from_vertical_axis(ax, deprecated_argument=None):
    """
    Remove the scientific notation from the vertical axis tick labels.
    If the scale is logarithmic but spans less than one or two orders of magnitude.
    """

    if deprecated_argument is None:
        fig = ax.get_figure()
    else:
        # The old version of this function was taking fig, ax as argument...
        # TODO: issue a deprecation warning
        fig, ax = ax, deprecated_argument

    fig.canvas.draw()

    def remove_scientific_notation(text = '$\\mathdefault{2\\times10^{-2}}$'):
        if text == '':
            return text
        expr = r'\$\\mathdefault\{((.*)\\times)?10\^\{(.*)\}\}\$'
        mantissa = re.sub( expr, r'\2', text )
        exponent = re.sub( expr, r'\3', text )
        if mantissa == '':
            mantissa = 1
        mantissa = float(mantissa)
        exponent = float(exponent)
        result = mantissa * 10 ** exponent
        return f'{float(f"{result:.4g}"):g}'

    labels = ax.yaxis.get_ticklabels()
    for label in labels:
        a = label.get_text()
        b = remove_scientific_notation(a)
        label.set_text(b)
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', message = "FixedFormatter should only be used together with FixedLocator" )
        warnings.filterwarnings('ignore', message = "set_ticklabels" )
        ax.yaxis.set_ticklabels(labels)

    labels = ax.yaxis.get_minorticklabels()
    for label in labels:
        a = label.get_text()
        b = remove_scientific_notation(a)
        label.set_text(b)
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', message = "FixedFormatter should only be used together with FixedLocator" )
        warnings.filterwarnings('ignore', message = "set_ticklabels" )
        ax.yaxis.set_ticklabels(labels, minor=True)

def remove_scientific_notation_from_horizontal_axis(ax):
    """
    Remove the scientific notation from the horizontal axis tick labels.
    If the scale is logarithmic but spans less than one or two orders of magnitude.
    """

    fig = ax.get_figure()
    fig.canvas.draw()

    def remove_scientific_notation(text = '$\\mathdefault{2\\times10^{-2}}$'):
        if text == '':
            return text
        expr = r'\$\\mathdefault\{((.*)\\times)?10\^\{(.*)\}\}\$'
        mantissa = re.sub( expr, r'\2', text )
        exponent = re.sub( expr, r'\3', text )
        if mantissa == '':
            mantissa = 1
        mantissa = float(mantissa)
        exponent = float(exponent)
        result = mantissa * 10 ** exponent
        return f'{float(f"{result:.4g}"):g}'

    labels = ax.xaxis.get_ticklabels()
    for label in labels:
        a = label.get_text()
        b = remove_scientific_notation(a)
        label.set_text(b)
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', message = "FixedFormatter should only be used together with FixedLocator" )
        ax.xaxis.set_ticklabels(labels)

    labels = ax.xaxis.get_minorticklabels()
    for label in labels:
        a = label.get_text()
        b = remove_scientific_notation(a)
        label.set_text(b)
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', message = "FixedFormatter should only be used together with FixedLocator" )
        ax.xaxis.set_ticklabels(labels, minor=True)

        
def mfrow(
        n:      int,
        aspect: float = 29.7/21,
        width:  float = 29.7,
        height: float = 21,
        pages:  int   = 1
):
    """
    Compute a layout (number of rows and columns) to put n plots,
    as close as possible to the desired aspect ratio,
    with as few empty cells as possible,
    for the given plot dimensions.

    Also see: remove_empty_axes

    Inputs:  n: number of subplots
             aspect: desired aspect ratio of the subplots
             width: width of the (super)plot
             height: height of the (super)plot
             pages: number of (super)plots (untested -- I think it does not do what I want)
    Outputs: nr: Number of rows
             nc: Number of columns
    """
    best = (1,1)
    best_value = float('inf')
    for nc in range(1,n+1):
        nr = math.ceil( n / nc / pages ) * pages
        a = ( width / nc ) / ( height / nr )
        if abs( a - aspect ) < best_value:
            best_value = abs( a - aspect )
            best = (nr, nc)
    return best



def axis_year(ax, format="%Y"):
    """
    On the horizontal axis, add small ticks every year, large ticks every decade, and year labels in each interval.

    Inputs: ax:     "axis", returned by plt.subplots()
            format: date format used for the year, e.g., "%Y" (4 digits) or "%y" (2 digits)
    Output: None

    Example:
        import yfinance as yf
        import matplotlib.pyplot as plt
        x = yf.Ticker("^GSPC").history(period="5y")['Close']
        fig, ax = plt.subplots()
        ax.plot(x)
        axis_year(ax)
        plt.show()
    """
    ax.xaxis.set_major_locator(matplotlib.dates.YearLocator())
    ax.xaxis.set_minor_locator(matplotlib.dates.YearLocator(month=7, day=1))
    ax.xaxis.set_major_formatter(matplotlib.ticker.NullFormatter())
    ax.xaxis.set_minor_formatter(matplotlib.dates.DateFormatter(format))
    for tick in ax.xaxis.get_minor_ticks():
        tick.tick1line.set_markersize(0)
        tick.tick2line.set_markersize(0)
        tick.label1.set_horizontalalignment('center')
    ticks = ax.xaxis.get_major_ticks()
    xlocs = ax.get_xticks()
    for pos, tick in zip(xlocs, ticks):
        date = matplotlib.dates.num2date(pos)
        if date.year % 10 == 0:
            tick.tick1line.set_markersize(15)



def axis_decade(ax):
    """
    On the horizontal axis, add small ticks every year, large ticks every decade, and decade labels.

    Inputs: ax: "axis", returned by plt.subplots()
    Output: None

    Example:
        import yfinance as yf
        import matplotlib.pyplot as plt
        x = yf.Ticker("^GSPC").history(period="max")['Close']
        fig, ax = plt.subplots( figsize=(8,4) )
        ax.plot(x)
        axis_decade(ax)
        axis.set_yscale("log")
        plt.show()
    """

    ax.xaxis.set_major_locator(matplotlib.dates.YearLocator(10))
    ax.xaxis.set_minor_locator(matplotlib.dates.YearLocator(5,month=7, day=1))
    ax.xaxis.set_major_formatter(matplotlib.ticker.NullFormatter())

    def f(u):
        date = matplotlib.dates.num2date(u)
        if date.year % 10 == 5:
            s = date.strftime("%Y")
            s = re.sub( "5$", "0s", s )
            return s
        return ''
    ax.xaxis.set_minor_formatter( lambda u,v: f(u) )

    for tick in ax.xaxis.get_minor_ticks():
        tick.tick1line.set_markersize(0)
        tick.tick2line.set_markersize(0)
        tick.label1.set_horizontalalignment('center')


def axis_month(ax, format = "%b"):
    ax.xaxis.set_major_locator(matplotlib.dates.MonthLocator())
    ax.xaxis.set_minor_locator(matplotlib.dates.MonthLocator(bymonthday=15))
    ax.xaxis.set_major_formatter(matplotlib.ticker.NullFormatter())
    if isinstance( format, str ):
        ax.xaxis.set_minor_formatter(matplotlib.dates.DateFormatter(format))
    else:
        ax.xaxis.set_minor_formatter(format)
    for tick in ax.xaxis.get_minor_ticks():
        tick.tick1line.set_markersize(0)
        tick.tick2line.set_markersize(0)
        tick.label1.set_horizontalalignment('center')
    ticks = ax.xaxis.get_major_ticks()
    xlocs = ax.get_xticks()
    for pos, tick in zip(xlocs, ticks):
        date = matplotlib.dates.num2date(pos)
        if date.month == 1:
            tick.tick1line.set_markersize(15)




def remove_empty_axes(axs: np.ndarray) -> None:
    """
    Remove empty subplots

    Inputs: axs: numpy array, returned by plt.subplots()
    Output: None
    """
    for ax in axs.flatten():
        if (not ax.lines) and (not ax.collections) and (not ax.has_data()):
            ax.axis('off')

In [None]:
gapminder

# Data

In [None]:
x = pd.Series( np.random.normal( size = 100_000 ) )
for which in [0,1]: 
    s = .4
    fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 100 )
    ax.hist( x, edgecolor = 'tab:blue', facecolor = 'lightblue', bins = 50, density = True )
    if which >= 1:
        x.plot.density( ax = ax, linewidth = 5 )
    for side in ['left', 'right', 'top']: 
        ax.spines[side].set_visible(False)
    ax.set_ylabel(None)
    ax.set_yticks([])
    ax.set_xlim( -4, 4 )
    ax.set_xlabel( "Random Gaussian data" )
    plt.show()

In [None]:
x = data("faithful").iloc[:,1]
s = .4
fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 100 )
ax.hist( x, edgecolor = 'tab:blue', facecolor = 'lightblue', bins = 20, density = True )
x.plot.density( ax = ax, linewidth = 5 )
for side in ['left', 'right', 'top']: 
    ax.spines[side].set_visible(False)
ax.set_ylabel(None)
ax.set_yticks([])
ax.set_xlim( 40,100)
ax.set_xlabel( "Inter-erruption time (minutes), Old Faithful geyser" )
plt.show()

In [None]:
# Financial data
ids = {
    'Apple':     'AAPL',
    'Microsoft': 'MSFT',
    'Amazon':    'AMZN',
    'Google':    'GOOG',
    'Meta':      'META',
    'Tesla':     'TSLA',
    'Berkshire': 'BRK-B',
    'Visa':      'V',
    'NVidia':    'NVDA',
}
stocks = {}
for label, id in ids.items():
    print( f"{label}: {id}" )
    stocks[label] = yf.Ticker(id).history(period="max")

In [None]:
#matplotlib.rcParams['figure.dpi'] = 100
nr, nc = mfrow( len(stocks) )
s = .6
fig, axs = plt.subplots( nr, nc, figsize = (s*16,s*9), layout ='constrained' )
for i, id in enumerate( stocks.keys() ):
    ax = axs.flatten()[i]
    ax.plot( stocks[id]['Close'] )
    ax.set_title( id )
    axis_decade(ax)
plt.show()

In [None]:
s = .5
for which in [0,1]: 
    fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 100 )
    ax.plot( stocks['Google']['Close'] )
    ax.set_xlabel( "Time" )
    ax.set_ylabel( "Google stock price" )
    axis_year(ax, '%y')
    if which >= 1: 
        ax.set_yscale('log')
        remove_scientific_notation_from_vertical_axis(ax)
    plt.show()

In [None]:
s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
ax.plot( stocks['Apple']['Close'] )
ax.set_xlabel( 'Time' )
ax.set_ylabel( 'Apple stock price' )
ax.set_yscale('log')
#axis_year( ax, '%y' )
axis_decade( ax )
remove_scientific_notation_from_vertical_axis( ax )
plt.show()

In [None]:
for label, ylabel in [
    ('AirPassengers', "Number of air passengers"),
    ('sunspots', "Monthly mean relative sunspot number"),
    ('Nile', "Nile Flow"),
]: 
    x = data(label)
    x = x.set_index( x.columns[0])[ x.columns[1] ]
    s = .5
    fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 100 )
    ax.plot( x )
    ax.set_ylabel( ylabel )
    ax.set_xlabel( "Time (Years)" )
    plt.show()

label, ylabel = 'Nile', 'Nile Flow'
x = data(label)
x = x.set_index( x.columns[0])[ x.columns[1] ]
s = .5
fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 100 )
ax.plot( x[ x.index >  1898 ] )
ax.plot( x[ x.index <= 1898 ] )
ax.set_ylabel( ylabel )
ax.set_xlabel( "Time (Years)" )
plt.show()

In [None]:
#data('sunspots', show_doc=True)

In [None]:
d = data('quakes')
s = .5
fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 100 )
ax.scatter( d['long'], d['lat'], alpha = .5 )
ax.set_xlabel( "Longitude" )
ax.set_ylabel( "Latitude" )
ax.set_title( "Earthquakes in Fiji" )
ax.set_aspect(1)
plt.show()

In [None]:
prompts = """
Write matplotlib code to plot a map of the world, for longitudes 165 to 190, and latitudes -40 to -10.

Since the longitude goes above 180, this actually displays all longitudes, from -180 to +180.  How can I see only the longitudes 165 to 190?

How can I add a scatter plot on top of it? I have the coordinates in two numpy arrays, x (longitude) and y (latitude).
"""

import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature


# Create a figure and an axis with the PlateCarree projection
s = .5
fig = plt.figure(figsize=(s*16,s*9), dpi = 100)
central_lon = 177.5
ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=central_lon))

# Set the extent (longitude and latitude boundaries)
ax.set_extent([160,190, -40, -10], crs=ccrs.PlateCarree())

# Add coastlines and borders
ax.coastlines()
ax.add_feature(cfeature.BORDERS)

ax.scatter(d['long'], d['lat'], color='tab:blue', alpha = .5, s = 10, transform=ccrs.PlateCarree(), label='Data Points')

# Add gridlines
ax.gridlines(draw_labels=True)

# Show the plot
plt.show()

In [None]:
d = load_penguins()

islands = sorted( d['island'].unique() )
species = sorted( d['species'].unique() )
sex     = d['sex'].dropna().unique()

a = 'bill_length_mm'
b = 'bill_depth_mm'

colours = plt.rcParams['axes.prop_cycle'].by_key()['color']
for which in [0,1,2]:
    fig, ax = plt.subplots(figsize = (4,3), layout = 'constrained', dpi = 100)
    for j, (k,v) in enumerate(d.groupby('species')):
        x, y = v[a], v[b]
        i = np.isfinite(x) & np.isfinite(y)
        x, y = x[i], y[i]
        ax.scatter( 
            x, y, 
            label = k, 
            color = colours[j] if which >= 1 else 'tab:blue',
            alpha = .5,
        )
        if which >= 2: 
            model = sklearn.linear_model.LinearRegression()
            model.fit( pd.DataFrame(x.values), y )
            #ax.axline( ( x.mean(), y.mean() ), slope = model.coef_[0], color = colours[j] )
            xs = np.array( [ x.min(), x.max() ] )
            ys = model.coef_[0] * ( xs - x.mean() ) + y.mean()
            ax.plot( xs, ys, color = colours[j], linewidth = 5 )
    ax.set_xlabel(a)            
    ax.set_xlabel(a)
    ax.set_ylabel(b)
    ax.set_title( "Penguin measurements" )
    if which >= 1: 
        ax.legend()
    plt.show()


In [None]:
d = data("diamonds")

d.groupby('clarity').size().plot.bar()
plt.ylabel( 'Number of diamonds' )
plt.show()

d.groupby('clarity').size().plot.barh()
plt.xlabel( 'Number of diamonds' )
plt.show()

d.groupby('clarity').size().sort_values().plot.barh()
plt.xlabel( 'Number of diamonds' )
plt.show()

In [None]:
if True:  # Do not use Counter() for the examples, that complicates things.

    from collections import Counter

    d = data("diamonds")
    x = d['clarity']

    c = Counter(x)
    values = c.keys()
    counts = c.values()
    s = .5
    fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
    ax.bar(values, counts)
    ax.set_xlabel( "Clarity" )
    ax.set_ylabel( "Number of diamonds" )
    plt.show()

    fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
    plt.barh(values, counts)
    ax.set_ylabel( "Clarity" )
    ax.set_xlabel( "Number of diamonds" )
    plt.show()

    c = Counter(x)
    c = c.most_common()
    values = [ value for value, count in c ]
    counts = [ count for value, count in c ]
    fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
    ax.bar(values, counts)
    ax.set_xlabel( "Clarity" )
    ax.set_ylabel( "Number of diamonds" )
    plt.show()

# More examples

## Iris

In [None]:
d = seaborn.load_dataset("iris")

fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
ax.scatter( d['sepal_length'], d['sepal_width'], s = 100 )
ax.set_xlabel( "Sepal length" )
ax.set_ylabel( "Sepal width" )
ax.set_title( "Iris" )
plt.show()

n = d.shape[0]
np.random.seed(0)
fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
a = .02
ax.scatter( 
    d['sepal_length'] + np.random.uniform( -a, +a, size = n ), 
    d['sepal_width'] + np.random.uniform( -a, +a, size = n ),
    alpha = .5,
    s = 100,
)
ax.set_xlabel( "Sepal length" )
ax.set_ylabel( "Sepal width" )
ax.set_title( "Iris" )
plt.show()

colours = dict( zip(
    d['species'].unique(),
    ['tab:blue', 'tab:red', 'tab:green'],
) )
fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
np.random.seed(0)
ax.scatter( 
    d['sepal_length'] + np.random.uniform( -a, +a, size = n ), 
    d['sepal_width'] + np.random.uniform( -a, +a, size = n ),
    c = d['species'].map(colours),
    alpha = .5,
    s = 100,
)
ax.set_xlabel( "Sepal length" )
ax.set_ylabel( "Sepal width" )
ax.set_title( "Iris" )
plt.show()

for which in [0,1]:
    fig, axs = plt.subplots( 4, 4, figsize = (16,9), layout = 'constrained', dpi = 300 )
    for i in range(4): 
        for j in range(4):
            if i == j: 
                continue
            axs[j,i].scatter(
                d.iloc[:,i] + np.random.uniform( -a, +a, size = n ), 
                d.iloc[:,j] + np.random.uniform( -a, +a, size = n ),
                c = 'tab:blue' if which == 0 else d['species'].map(colours),
                alpha = .5,
                s = 100,
            )
            axs[j,i].set_xlabel( d.columns[i] )
            axs[j,i].set_ylabel( d.columns[j] )
    remove_empty_axes(axs)
    fig.suptitle( "Iris" )
    plt.show()



## USArrests

In [None]:
d = statsmodels.datasets.get_rdataset("USArrests").data

In [None]:
# ChatGPT claims there is a linear relation between Murder and UrbanPop, with outliers.
fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
ax.scatter( d['Murder'], d['UrbanPop'], alpha = .8 )
ax.set_xlabel( "Murder" )
ax.set_ylabel( "Urban population" )

# Add the labels
texts = [ 
    ax.text( row['Murder'], row['UrbanPop'], i, ha = 'center', va = 'center', fontsize = 6 )
    for i, row in d.iterrows()
]
# Add a white outline to the text
for t in texts:
    t.set_path_effects([
        matplotlib.patheffects.Stroke(linewidth=1, foreground='white', alpha = .5),
        matplotlib.patheffects.Normal()
    ])

adjust_text(texts, avoid_self = False, ax=ax)
plt.show()


## House prices (Ames)

In [None]:
# Read the ames dataset (house prices) from sklearn
X, y = sklearn.datasets.fetch_openml(name="house_prices", as_frame=True, return_X_y = True)

In [None]:
# Find the column most correlated with the price

if False: 
        
    from adia.vz.misc import uniformize
    tmp = {}
    for column in X.columns: 
        try: 
            tmp[column] = uniformize(X[column].astype(float))
        except: 
            pass
    tmp['SalePrice'] = uniformize(y.astype(float))
    tmp = pd.DataFrame(tmp)
    tmp.corr().iloc[:,-1].sort_values(ascending = False)

In [None]:
column = "GrLivArea"
for which in [0,1,2]:
    fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
    ax.scatter( X[column], y, s = 10, alpha = .3 )
    ax.set_xlabel( column )
    ax.set_ylabel( "Sale price" )

    if which == 1:
        coef = np.polyfit( X[column], y, 1 )
        xs = np.linspace( X[column].min(), X[column].max(), 200 )
        ys = np.polyval(coef, xs)
        ax.plot( xs, ys, color = 'tab:orange', linewidth = 20, alpha = .5, zorder = -4 )

    if which == 2:
        # Plot a lowess fit using statsmodels
        import statsmodels.api as sm
        lowess = sm.nonparametric.lowess
        filtered = ~pd.isnull(X[column]) & ~pd.isnull(y)
        xs = X[column][filtered]
        ys = y[filtered]
        sorted_idx = np.argsort(xs)
        lowess_fit = lowess(ys, xs, frac=0.1)
        ax.plot(lowess_fit[:, 0], lowess_fit[:, 1], color='tab:orange', linewidth=20, alpha=.8, zorder=-5)

    ax.set_xlim( 0, 5800 )
    ax.set_title( "House prices (Ames, Iowa, 2006-2010)")
    plt.show()


## Anscombe

In [None]:
d.max()

In [None]:
d = seaborn.load_dataset("anscombe")
s = .6
for which in [0,1,2,3,4]:
    fig, axs = plt.subplots( 2, 2, figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
    for i, label in enumerate(d['dataset'].unique()):
        ax = axs.flatten()[i]
        i = d['dataset'] == label
        x, y = d['x'][i], d['y'][i]
        ax.scatter( x, y, s = 100 )
        text = ( 
            f"mean(x) = {d['x'][i].mean():.2f}\n"
            f"mean(y) = {d['y'][i].mean():.2f}\n"
            f"std(x)  = {d['x'][i].std():.2f}\n"
            f"std(y)  = {d['y'][i].std():.2f}\n"
            f"cor(x,y) = {d['x'][i].corr(d['y'][i]):.2f}" 
        )
        ax.set_xlabel( "x" )
        ax.set_ylabel( "y" )
        if which == 1:
            ax.axhline( y.mean(), color = 'black', linewidth = 1, linestyle = '--' )
        if which == 2:
            ax.axvline( x.mean(), color = 'black', linewidth = 1, linestyle = '--' )
        if which == 3:
            coef = np.polyfit( x, y, 1 )
            xs = np.linspace( d['x'].min(), d['x'].max(), 200 )
            ys = np.polyval(coef, xs)
            ax.plot( xs, ys, color = 'tab:orange', linewidth = 20, alpha = .5, zorder = -4 )
        if which == 4:
            t = ax.text( .95, .05, text, ha = 'right', va = 'bottom', transform = ax.transAxes )
            t.set_path_effects([
                matplotlib.patheffects.Stroke(linewidth=3, foreground='white', alpha = .5),
                matplotlib.patheffects.Normal()
            ])
        ax.set_xlim( 3, 19.8 )
        ax.set_ylim( 2.5, 13.5)
    plt.show()


## Datasaurus

In [None]:
d = pd.read_csv( "https://github.com/algoplexity/DatasaurusDozen/raw/refs/heads/main/DatasaurusDozen.tsv", sep = "\t" )

In [None]:
for which in [0,1]:
    fig, axs = plt.subplots( 3, 4, figsize = (16,9), layout = 'constrained', dpi = 300 )
    for i, label in enumerate(d['dataset'].unique()):
        if i >= len(axs.flatten()):
            break
        ax = axs.flatten()[i]
        i = d['dataset'] == label
        ax.scatter( d['x'][i], d['y'][i], s = 100 )
        text = ( 
            f"mean(x) = {d['x'][i].mean():.2f}\n"
            f"mean(y) = {d['y'][i].mean():.2f}\n"
            f"std(x)  = {d['x'][i].std():.2f}\n"
            f"std(y)  = {d['y'][i].std():.2f}\n"
            f"cor(x,y) = {d['x'][i].corr(d['y'][i]):.2f}" 
        )
        ax.set_xlabel( "x" )
        ax.set_ylabel( "y" )
        t = ax.text( .95, .05, text, ha = 'right', va = 'bottom', transform = ax.transAxes )
        t.set_path_effects([
            matplotlib.patheffects.Stroke(linewidth=2, foreground='white', alpha = .5),
            matplotlib.patheffects.Normal()
        ])
    plt.show()


## CO2

In [None]:
d = statsmodels.datasets.co2.load_pandas().data
s = .4
fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
ax.plot( d['co2'] )
ax.set_xlim( d.index.min(), d.index.max() )
ax.set_ylabel( "CO2" )
axis_decade( ax )
plt.show()

In [None]:
cmap = plt.cm.get_cmap('RdYlBu_r')
fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
dates = d.index
days = dates.map(lambda date: date.replace(year=2000))
years = pd.Series(dates).dt.year.unique()
for iyear, year in enumerate(years):
    i = ( pd.Series(dates).dt.year == year ).values
    ax.plot( days[i], d['co2'][i].ffill(), color = cmap(iyear/len(years)), label = year )
ax.set_xlim( days.min(), days.max() )
ax.set_ylabel( "CO2" )
axis_month( ax )
plt.show()


## (Missing datasets)


In [None]:
if False:   # Datasets suggested by ChatGPT; do not exist
    import vega_datasets
    vega_datasets.data.births()  # US birth vs day-of-year
    vega_datasets.data.drinks()

## Old Faithful Geyser

In [None]:
d = seaborn.load_dataset("geyser")

fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
ax.scatter( d['waiting'], d['duration'] )
ax.set_xlabel( "Waiting time to next eruption (minutes)" )
ax.set_ylabel( "Eruption duration (minutes)" )
ax.set_title( "Old Faithful geyser" )
plt.show()

fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
colours = dict( zip( d['kind'].unique(), plt.rcParams['axes.prop_cycle'].by_key()['color'] ) )
ax.scatter( 
    d['waiting'], 
    d['duration'],
    c = d['kind'].map(colours),
)
ax.set_xlabel( "Waiting time to next eruption (minutes)" )
ax.set_ylabel( "Eruption duration (minutes)" )
ax.set_title( "Old Faithful geyser" )
plt.show()

## Diamond prices

In [None]:
d = seaborn.load_dataset("diamonds")
fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
ax.scatter( 
    d['carat'], d['price'],
    s = 5,
    alpha = .1,
)
ax.set_xlabel( "Carat" )
ax.set_ylabel( "Price" )
ax.set_title( "Diamonds" )
plt.show()

## Animals

In [None]:
d = get_rdataset("Animals", "MASS").data  # columns: body, brain

In [None]:

for which in [0,1,2,3]:
    fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
    ax.scatter( d['body'], d['brain'], color = 'tab:blue' )
    ax.set_xlabel( "Body weight (g)" )
    ax.set_ylabel( "Brain weight (g)" )
    ax.set_title( "Animals" )
    if which >= 1:
        ax.set_xscale('log')
        ax.set_yscale('log')

    if which >= 2:
        i = d['body'] < 9000
        x = np.log10(d['body'][i])
        y = np.log10(d['brain'][i])

        coeffs = np.polyfit(x, y, 1)
        m, c = coeffs
        xs = np.linspace(x.min(), x.max(), 100)
        ys = m * xs + c
        ax.plot(10**xs, 10**ys, color='tab:orange', lw=30, alpha = .5, zorder = -4, label=f'Linear fit (log-log)\ny={m:.2f}x+{c:.2f}')

    if which >= 3:
        texts = [ ax.text( d['body'][i], d['brain'][i], d.index[i], ha = 'center', va = 'center', fontsize = 6 ) for i in range(len(d)) ]
        adjust_text(texts, avoid_self = False, ax=ax) # , arrowprops=dict(arrowstyle='->', color='black', lw=0.5))
        for t in texts: 
            t.set_path_effects([
                matplotlib.patheffects.Stroke(linewidth=1, foreground='white', alpha = .5),
                matplotlib.patheffects.Normal()
            ])    

    plt.show()

## Chicken growth

In [None]:
d = get_rdataset("ChickWeight", "datasets").data  # weight, Time, Diet, Chick
colours = dict( zip( d['Diet'].unique(), plt.rcParams['axes.prop_cycle'].by_key()['color'] ) )

for which in [0,1]:
        
    fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
    for chick in d['Chick'].unique():
        i = d['Chick'] == chick
        if which == 0:
            ax.plot( d['Time'][i], d['weight'][i], color = 'tab:blue', alpha = .5 )
        elif which == 1:
            ax.plot( d['Time'][i], d['weight'][i], color = colours[d['Diet'][i].iloc[0]], alpha = 1 )    
    if which >= 1:
        for diet, color in colours.items():
            ax.plot([], [], color=color, linewidth = 10, solid_capstyle = 'butt', label=f'Diet {diet}')
        ax.legend()
    ax.set_xlabel( "Time (days)" )
    ax.set_ylabel( "Weight (g)" )
    ax.set_title( "Chick weights" )
    ax.set_xlim( d['Time'].min(), d['Time'].max() )
    ax.set_ylim( d['weight'].min(), d['weight'].max() )
    ax.set_xticks( [0, 5, 10, 15, 20] )
    plt.show()


In [None]:
mean = d.groupby(['Diet', 'Time']).mean()['weight'].unstack().T
std = d.groupby(['Diet', 'Time']).std()['weight'].unstack().T

for which in [0,1]:
    fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
    for diet in d['Diet'].unique():
        ax.plot( mean.index, mean[diet], color = colours[diet] )
        if which == 1:
            ax.fill_between( mean.index, mean[diet] - std[diet], mean[diet] + std[diet], color = colours[diet], alpha = .2 )
    for diet, color in colours.items():
        ax.plot([], [], color=color, linewidth = 10, solid_capstyle = 'butt', label=f'Diet {diet}')
    ax.legend()    
    ax.set_xlabel( "Time (days)" )
    if which == 0:
        ax.set_ylabel( "Average Weight (g)" )
    else: 
        ax.set_ylabel( "Weight (g)" )
    ax.set_title( "Chick weights" )
    ax.set_xlim( d['Time'].min(), d['Time'].max() )
    ax.set_ylim( d['weight'].min(), d['weight'].max() )
    ax.set_xticks( [0, 5, 10, 15, 20] )
    plt.show()



## (Unused)

In [None]:
d = seaborn.load_dataset("tips")

In [None]:
# UsingR::galton

## Car stopping distances

In [None]:
from numpy.linalg import lstsq

d = get_rdataset("cars", "datasets").data  # columns: speed, dist

# Convert speed from mph to km/h and distance from ft to m for metric plot
d_metric = d.copy()
d_metric['speed_metric'] = d_metric['speed'] * 1.60934  # mph to km/h
d_metric['dist_metric'] = d_metric['dist'] * 0.3048     # ft to m

for which in [0,1]: 
        
    fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
    x, y = d_metric['speed_metric'], d_metric['dist_metric']
    ax.scatter(x, y)

    if which >= 1:
        xs = np.linspace(0*x.min(), x.max(), 200)
        X = np.column_stack([x**2, x])
        coeffs_no_intercept, *_ = lstsq(X, y, rcond=None)
        ys = coeffs_no_intercept[0]*xs**2 + coeffs_no_intercept[1]*xs
        ax.plot(xs, ys, color='tab:orange', linewidth=10, alpha = .5, zorder = -4,label='Quadratic fit (no intercept)')

    #coeffs = np.polyfit(x, y, 2)
    #ys = np.polyval(coeffs, xs)
    #ax.plot(xs, ys, color='tab:orange', linewidth=2, label='Quadratic fit')

    ax.axhline( 0, color = 'black', linewidth = 1, linestyle = ':' )
    ax.axvline( 0, color = 'black', linewidth = 1, linestyle = ':' )
    ax.set_xlabel("Speed (km/h)")
    ax.set_ylabel("Stopping distance (m)")
    ax.set_title("Car stopping distances" )
    ax.set_ylim( -2, 39 )
    plt.show()

if False:         
    fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
    ax.scatter( d['speed'], d['dist'] )
    ax.set_xlabel( "Speed (mph)" )
    ax.set_ylabel( "Stopping distance (ft)" )
    ax.set_title( "Car stopping distances" )
    plt.show()

In [None]:
edges = pd.read_csv("https://github.com/mathbeveridge/asoiaf/raw/refs/heads/master/data/asoiaf-all-edges.csv")
#nodes = pd.read_csv("https://github.com/mathbeveridge/asoiaf/blob/master/data/asoiaf-all-nodes.csv")

In [None]:
import networkx as nx

g = nx.from_pandas_edgelist(edges.iloc[:,:2], source = 'Source', target = 'Target')
degrees = dict( nx.degree(g) )
n = max( degrees.values() )
bins = np.linspace( -.5, n+.5, n+2 )

for which in [0,1]:
    fig, ax = plt.subplots( figsize = (s*16,s*9), layout = 'constrained', dpi = 300 )
    ax.hist( degrees.values(), bins = bins )
    if which == 1: 
        i = 0
        gap = 10
        m = 180
        for key, value in pd.Series( degrees ).sort_values( ascending = False ).head(14).to_dict().items():
            ax.text( value, m - gap*i, key, ha = 'right', va = 'bottom', fontsize = 10 )
            ax.plot( [ value, value ], [ 0, m - gap*i ], color = 'black', linewidth = 1, linestyle = ':' )
            i += 1
    ax.set_title( "Game of Thrones characters" )
    ax.set_xlabel( "Number of connections" )
    ax.set_ylabel( "Number of characters with that many connections" )
    ax.set_xlim( 0, n+1 )
    plt.show()

# Data manipulation

In [None]:
gapminder = pd.read_csv( "gapminder.csv" )

# Line plot

In [None]:
gapminder['country'].unique()  # List of all the countries available
country = 'United Kingdom'
i = gapminder['country'] == country
#plt.plot( gapminder[i]['year'], gapminder[i]['lifeExp'] )
#gapminder[i].plot( 'year', 'lifeExp' )

s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
ax.plot( gapminder[i]['year'], gapminder[i]['lifeExp'] )
ax.set_xlabel( 'Year' )
ax.set_ylabel( f"Life expectancy for {country}" )
plt.show()

In [None]:

population = gapminder.groupby('year')['pop'].sum()
#plt.plot( population )

s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 300 )
ax.plot( population / 1e9 )
ax.set_xlabel( 'Year' )
ax.set_ylabel( f"World population (billion)" )
plt.show()

# Scatter plots

In [None]:
#plt.scatter( gapminder['gdpPercap'], gapminder['lifeExp'] )

s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
ax.scatter( gapminder['gdpPercap'] / 1000, gapminder['lifeExp'], alpha = .3 )
ax.set_xlabel( 'GDP per capita (thousand USD)' )
ax.set_ylabel( "Life expectancy (years)" )
ax.set_xscale('log')
remove_scientific_notation_from_horizontal_axis(ax)
plt.show()

# Histogram

In [None]:
#gapminder['lifeExp'].plot.hist( density = True )
#gapminder['lifeExp'].plot.density()
#plt.hist( gapminder['lifeExp'] )
#plt.show()

s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
#ax.hist( gapminder['lifeExp'], density = True, facecolor = 'lightblue', edgecolor = 'tab:blue', bins = np.linspace( 15, 85, 8 ) )
ax.hist( gapminder['lifeExp'], density = True, facecolor = 'lightblue', edgecolor = 'tab:blue', bins = np.linspace( 20, 85, 14 ) )
gapminder['lifeExp'].plot.density( ax = ax, linewidth = 3 )
ax.set_xlim( 15, 90 )
ax.set_xlabel( "Life expectancy (years)" )
ax.set_ylabel( None )
for side in ['left', 'right', 'top']:
    ax.spines[side].set_visible(False)
ax.set_yticks([])
plt.show()

In [None]:
first_date = gapminder['year'].min()
last_date  = gapminder['year'].max()
start = gapminder['year'] == first_date
end   = gapminder['year'] == last_date
#gapminder[start]['lifeExp'].hist()
#gapminder[end]['lifeExp'].hist()

s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
#ax.hist( gapminder['lifeExp'], density = True, facecolor = 'lightblue', edgecolor = 'tab:blue', bins = np.linspace( 15, 85, 8 ) )
ax.hist( gapminder[start]['lifeExp'], density = True, facecolor = 'lightblue', edgecolor = 'tab:blue', alpha = .5, bins = np.linspace( 20, 85, 14 ) )
ax.hist( gapminder[end]  ['lifeExp'], density = True, facecolor = 'mistyrose', edgecolor = 'tab:red',  alpha = .5, bins = np.linspace( 20, 85, 14 ) )
gapminder['lifeExp'][start].plot.density( ax = ax, linewidth = 3, color = 'tab:blue', label = first_date )
gapminder['lifeExp'][end  ].plot.density( ax = ax, linewidth = 3, color = 'tab:red',  label = last_date )
ax.set_xlim( 15, 90 )
ax.set_xlabel( "Life expectancy (years)" )
ax.set_ylabel( None )
for side in ['left', 'right', 'top']:
    ax.spines[side].set_visible(False)
ax.set_yticks([])
ax.legend()
plt.show()

In [None]:
years = gapminder['year'].unique()
nr, nc = mfrow( len(years) )
s = 1
fig, axs = plt.subplots( nr, nc, figsize = (16*s,9*s), dpi = 200)
for i in range(len(years)):
    year = years[i]
    ax = axs.flatten()[i]
    j = gapminder['year'] == year
    ax.hist( gapminder[j]['lifeExp'], density = True, facecolor = 'lightblue', edgecolor = 'tab:blue',  alpha = .5, bins = np.linspace( 20, 85, 14 ) )
    gapminder['lifeExp'][j].plot.density( ax = ax, linewidth = 3, color = 'tab:blue', label = first_date )    

    x, y = ax.get_lines()[0].get_data()
    a = 5
    i = [ i for i in range(len(y)) if np.all( y[i] >= y[ max(0,i-a) : i+a ] ) ]
    ax.scatter( x[i], y[i] )
    for u in i: 
        ax.plot( [ x[u], x[u] ], [ 0, y[u] ], linewidth = 2, color = 'tab:blue', linestyle = ':' )

    ax.set_xlim( 15, 90 )
    ax.set_xlabel( "Life expectancy (years)" )
    ax.set_ylabel( None )
    ax.set_title( year )
    for side in ['left', 'right', 'top']:
        ax.spines[side].set_visible(False)
    ax.set_yticks([])
fig.subplots_adjust( hspace = 1, wspace = .1 )
plt.show()

In [None]:
#gapminder['pop'].hist( bins = 100 )
#gapminder['pop'].plot.density()

s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 200 )
ax.hist( gapminder['pop'] / 1e9, density = True, facecolor = 'lightblue', edgecolor = 'tab:blue', bins = 100 )
#(gapminder['pop'] / 1e9).plot.density( ax = ax, linewidth = 3 )
#ax.set_xlim( -.2, 1.4 )
ax.set_xlabel( "Country population (billion)" )
ax.set_ylabel( None )
for side in ['left', 'right', 'top']:
    ax.spines[side].set_visible(False)
ax.set_yticks([])
plt.show()

# Bar plots

In [None]:
first_date = gapminder['year'].max()
end = gapminder['year'] == first_date
population = gapminder[end].groupby('continent')['pop'].sum().sort_values( ascending = False )
#plt.bar(population.index, population.values)

s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
ax.bar( population.index, population.values / 1e9 )
ax.set_ylabel( "Population (billion)" )
plt.show()

In [None]:
last_date = gapminder['year'].max()
end = gapminder['year'] == last_date
population = gapminder[end][['country','pop']].sort_values( 'pop', ascending = False )
a = 20
population = pd.concat( [ 
    population.iloc[:a],
    pd.DataFrame( [ { 'country': 'Rest', 'pop': population.iloc[a:]['pop'].sum() } ] ),
] )
population = population.iloc[::-1,:]
#plt.barh( population['country'], population['pop'] )

s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
ax.barh( population['country'], population['pop'] / 1e9 )
ax.set_xlabel( "Population (billion)" )
ax.set_ylim(-.8, a+.8)
for side in ['top', 'right','left']:
    ax.spines[side].set_visible(False)
ax.tick_params(length=0, axis='y')
plt.show()

# Pair plot

In [None]:
sns.pairplot( gapminder )

In [None]:
gapminder.sort_values( 'gdpPercap', ascending = False ).head(20)

gapminder.sort_values( 'lifeExp' ).head(3)

#gapminder.groupby('year').sort_values('lifeExp').head(2)
a = gapminder.sort_values('lifeExp').groupby('year').head(3)[['year', 'country', 'lifeExp']].sort_values('year')
a = a.pivot( index = 'year', columns = 'country', values = 'lifeExp' )
a['min'] = a.min(axis=1)
a

In [None]:
s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
for country in ['Cambodia', 'Rwanda' ]: 
#for country in ['Cambodia', 'Rwanda', 'China']: 
    i = gapminder['country'] == country
    ax.plot( gapminder['year'][i], gapminder['lifeExp'][i], label = country, linewidth = 3 )
ax.set_xlabel( "Year" )
ax.set_ylabel( "Life expectancy (years)" )
ax.set_ylim( 20, 75 )
ax.legend( loc = 'upper left' )
plt.show()

In [None]:
s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
for country in gapminder['country'].unique(): 
    i = gapminder['country'] == country
    highlight = country in ['China', 'India']
    ax.plot( 
        gapminder['year'][i], 
        gapminder['pop'][i] / 1e9, 
        label = country if highlight else None, 
        color = None if highlight else 'silver',
        linewidth = 3 if highlight else 1,
        alpha = 1 if highlight else .5,
    )
ax.set_xlabel( "Year" )
ax.set_ylabel( "Population (billion)" )
#ax.set_yscale('log')
ax.legend()
plt.show()

In [None]:
gapminder['country'].unique()
gapminder['continent'].unique()
gapminder

In [None]:
s = .5
fig, ax = plt.subplots( figsize = (s*8,s*9), layout = 'constrained', dpi = 100 )
for country in ['Kuwait' ]: 
    i = gapminder['country'] == country
    ax.plot( gapminder['year'][i], gapminder['gdpPercap'][i], label = country, linewidth = 3 )
ax.set_xlabel( "Year" )
ax.set_ylabel( "GDP per Capita (USD)" )
#ax.set_ylim( 20, 75 )
ax.legend()
plt.show()

In [None]:
columns = gapminder.columns[ gapminder.dtypes != object ]

In [None]:
columns = gapminder.columns[ gapminder.dtypes != object ]
matplotlib.rcParams['figure.dpi'] = 300
sns.clustermap(
    gapminder[columns].corr(),
    vmin    = -1,
    vmax    = +1,
    cmap    = 'RdBu',
    figsize = (5, 5),
    annot   = True,
    fmt     = ".2f",
)

# LLM-generated code


In [None]:
prompt = """
I have a pandas data-frame, `d`, with the following columns: "species", "island", "bill_length_mm", "bill_depth_mm".
Give me matplotlib code to plot the data, with:
- bill_length_mm on the horizontal axis,
- bill_depth_mm on the vertical axis,
- a different colour for each species,
- a different symbol for each island,
- a linear fit for each species,
- and a legend.
"""

In [None]:
from palmerpenguins import load_penguins
d = load_penguins()

In [None]:
import matplotlib
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
# claude-sonnet-4.5

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats

# Get unique species and islands
species_list = d['species'].unique()
islands_list = d['island'].unique()

# Define colors and markers
colors = plt.cm.tab10(np.linspace(0, 1, len(species_list)))
markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h'][:len(islands_list)]

# Create the plot
fig, ax = plt.subplots(figsize=(10, 6))

# Plot data points for each species-island combination
for i, species in enumerate(species_list):
    for j, island in enumerate(islands_list):
        mask = (d['species'] == species) & (d['island'] == island)
        subset = d[mask]
        
        if len(subset) > 0:
            ax.scatter(subset['bill_length_mm'], 
                      subset['bill_depth_mm'],
                      c=[colors[i]], 
                      marker=markers[j],
                      s=100,
                      alpha=0.6,
                      label=f'{species} - {island}')

# Add linear fit for each species
for i, species in enumerate(species_list):
    species_data = d[d['species'] == species].dropna(subset=['bill_length_mm', 'bill_depth_mm'])
    
    if len(species_data) > 1:
        x = species_data['bill_length_mm']
        y = species_data['bill_depth_mm']
        
        # Linear regression
        slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
        
        # Create line
        x_line = np.linspace(x.min(), x.max(), 100)
        y_line = slope * x_line + intercept
        
        ax.plot(x_line, y_line, 
               color=colors[i], 
               linestyle='--', 
               linewidth=2,
               label=f'{species} fit')

# Labels and legend
ax.set_xlabel('Bill Length (mm)', fontsize=12)
ax.set_ylabel('Bill Depth (mm)', fontsize=12)
ax.set_title('Bill Dimensions by Species and Island', fontsize=14)
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from matplotlib.lines import Line2D

species_list = d['species'].unique()
islands_list = d['island'].unique()

colors = plt.cm.tab10(np.linspace(0, 1, len(species_list)))
markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p'][:len(islands_list)]

fig, ax = plt.subplots(figsize=(10, 6))

# Plot points
for i, species in enumerate(species_list):
    for j, island in enumerate(islands_list):
        mask = (d['species'] == species) & (d['island'] == island)
        subset = d[mask]
        if len(subset) > 0:
            ax.scatter(subset['bill_length_mm'], subset['bill_depth_mm'],
                      c=[colors[i]], marker=markers[j], s=100, alpha=0.6)

# Add linear fits
for i, species in enumerate(species_list):
    species_data = d[d['species'] == species].dropna(subset=['bill_length_mm', 'bill_depth_mm'])
    if len(species_data) > 1:
        x = species_data['bill_length_mm']
        y = species_data['bill_depth_mm']
        slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
        x_line = np.linspace(x.min(), x.max(), 100)
        y_line = slope * x_line + intercept
        ax.plot(x_line, y_line, color=colors[i], linestyle='--', linewidth=2)

# Create custom legend
species_handles = [Line2D([0], [0], color=colors[i], linewidth=2, label=species) 
                  for i, species in enumerate(species_list)]
island_handles = [Line2D([0], [0], marker=markers[j], color='gray', linestyle='', 
                        markersize=8, label=island) 
                 for j, island in enumerate(islands_list)]

first_legend = ax.legend(handles=species_handles, title='Species', loc='upper left')
ax.add_artist(first_legend)
ax.legend(handles=island_handles, title='Island', loc='upper right')

ax.set_xlabel('Bill Length (mm)', fontsize=12)
ax.set_ylabel('Bill Depth (mm)', fontsize=12)
ax.set_title('Bill Dimensions by Species and Island', fontsize=14)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Gemini-3-pro

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Assuming your dataframe is named 'd'
# Example data loading (for reproducibility, you don't need this part):
# import seaborn as sns
# d = sns.load_dataset('penguins')

fig, ax = plt.subplots(figsize=(10, 6))

# Get unique categories
species_list = d['species'].unique()
island_list = d['island'].unique()

# Define color map and marker map
colors = plt.cm.tab10(np.linspace(0, 1, len(species_list)))
color_map = dict(zip(species_list, colors))
marker_map = dict(zip(island_list, ['o', '^', 's', 'D', 'v'][:len(island_list)]))

# 1. Scatter Plot (Looping to handle legend handles correctly)
for species in species_list:
    for island in island_list:
        subset = d[(d['species'] == species) & (d['island'] == island)]
        if not subset.empty:
            ax.scatter(
                subset['bill_length_mm'], 
                subset['bill_depth_mm'], 
                color=color_map[species], 
                marker=marker_map[island],
                label=f"{species} - {island}",
                alpha=0.7,
                edgecolor='k' # Adds a slight border to make points pop
            )

# 2. Linear Fit for each Species
for species in species_list:
    subset = d[d['species'] == species].dropna(subset=['bill_length_mm', 'bill_depth_mm'])
    
    if len(subset) > 1:
        # Calculate linear regression (m=slope, b=intercept)
        m, b = np.polyfit(subset['bill_length_mm'], subset['bill_depth_mm'], 1)
        
        # Create x-values for the line
        x_vals = np.linspace(subset['bill_length_mm'].min(), subset['bill_length_mm'].max(), 100)
        y_vals = m * x_vals + b
        
        ax.plot(x_vals, y_vals, color=color_map[species], linestyle='--', linewidth=2)

# Labeling
ax.set_xlabel('Bill Length (mm)')
ax.set_ylabel('Bill Depth (mm)')
ax.set_title('Bill Dimensions by Species and Island')

# Create a custom legend to separate Species (Color) and Island (Shape)
# This is cleaner than listing every single combination
from matplotlib.lines import Line2D

legend_elements = [Line2D([0], [0], color=color_map[s], lw=2, label=s) for s in species_list]
legend_elements.append(Line2D([0], [0], color='white', label=' ')) # Spacer
legend_elements += [Line2D([0], [0], marker=marker_map[i], color='w', markerfacecolor='gray', markersize=8, markeredgecolor='k', label=i) for i in island_list]

ax.legend(handles=legend_elements, title="Legend", loc='best')

plt.grid(True, alpha=0.3)
plt.show()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Assuming your dataframe is named 'd'

plt.figure(figsize=(10, 6))

# lmplot automatically handles scatter points and linear regression lines
# However, combining hue and style while keeping lines only for hue is tricky in lmplot.
# It is often easier to use scatterplot for points + regplot for lines.

# 1. Plot the points
sns.scatterplot(
    data=d, 
    x="bill_length_mm", 
    y="bill_depth_mm", 
    hue="species", 
    style="island",
    s=100, # size
    alpha=0.7
)

# 2. Add the linear regression lines (only colored by species)
# We loop through species to add the lines on top
colors = sns.color_palette()
species_list = d['species'].unique()

for i, species in enumerate(species_list):
    subset = d[d['species'] == species]
    sns.regplot(
        data=subset, 
        x="bill_length_mm", 
        y="bill_depth_mm", 
        scatter=False,    # Don't plot points again
        color=colors[i],  # Match seaborn default colors
        ci=None           # Remove confidence interval shading (optional)
    )

plt.title("Bill Dimensions by Species and Island")
plt.show()

In [None]:
# GPT-5.1 -- fails

import matplotlib.pyplot as plt
import numpy as np

# Assume your DataFrame is called d and has columns:
# "species", "island", "bill_length_mm", "bill_depth_mm"

fig, ax = plt.subplots(figsize=(8, 6))

# Distinct markers for islands and colors for species
markers = ['o', 's', '^', 'D', 'v', 'P', 'X', '*']
islands = d['island'].unique()
marker_map = {isl: markers[i % len(markers)] for i, isl in enumerate(islands)}

species_list = d['species'].unique()
colors = plt.cm.tab10.colors  # or any colormap you like
color_map = {sp: colors[i % len(colors)] for i, sp in enumerate(species_list)}

# Scatter plot with color by species and marker by island
for sp in species_list:
    for isl in islands:
        mask = (d['species'] == sp) & (d['island'] == isl)
        if not mask.any():
            continue
        ax.scatter(
            d.loc[mask, 'bill_length_mm'],
            d.loc[mask, 'bill_depth_mm'],
            color=color_map[sp],
            marker=marker_map[isl],
            label=f"{sp}, {isl}",
            alpha=0.8
        )

# Linear fit per species
x_fit = np.linspace(d['bill_length_mm'].min(),
                    d['bill_length_mm'].max(), 100)

for sp in species_list:
    mask = d['species'] == sp
    x = d.loc[mask, 'bill_length_mm'].values
    y = d.loc[mask, 'bill_depth_mm'].values
    if len(x) < 2:
        continue  # not enough points to fit
    coeffs = np.polyfit(x, y, 1)
    y_fit = np.polyval(coeffs, x_fit)
    ax.plot(
        x_fit, y_fit,
        color=color_map[sp],
        linestyle='-',
        linewidth=2,
        label=f"{sp} fit"
    )

ax.set_xlabel("Bill length (mm)")
ax.set_ylabel("Bill depth (mm)")
ax.set_title("Bill length vs depth by species and island")

# To avoid duplicate labels in legend
handles, labels = ax.get_legend_handles_labels()
unique = dict(zip(labels, handles))
ax.legend(unique.values(), unique.keys(), fontsize='small', ncol=2)

plt.tight_layout()
plt.show()

# TO DELETE

In [None]:
import numpy as np
from itertools import combinations

labels = {}
coords = {}
for k in [0,1,2,3,4]: 
    labels[k] = [ 'X' + ''.join(c) for c in combinations( "1234", k ) ]
    coords |=  { label: ( i - (len(labels[k])-1)/2 , 4-k ) for i,label in enumerate(labels[k]) }
print( r"\begin{tikzpicture}[>=stealth,xscale=1,yscale=1,inner sep=1pt,outer sep=0pt,baseline=0]" )      
for k in labels.keys(): 
    for label in labels[k]: 
        print( r'\node (' + label + f') at {coords[label]}' + r' {\{' + ','.join(label.replace('X','')) + r'\}};' )
for label1 in coords.keys():
    for label2 in coords.keys():
        i1 = [ u for u in label1.replace('X','') ]
        i2 = [ u for u in label2.replace('X','') ]
        if len(i1) == len(i2) - 1 and np.all( np.isin( i1, i2 ) ): 
            print( r'\draw ' + f"({label1}) -- ({label2});" )
print( r'\end{tikzpicture}' )