In [35]:
import os, sys

sys.path.append(os.path.abspath(os.path.join("../..")))
from utils import *

In [36]:
df = sns.load_dataset("anscombe")
dataset_names = df.dataset.unique()
datasets = [df[df.dataset == name] for name in dataset_names]

In [37]:
# plot with plotly in 2x2 grid
from plotly.subplots import make_subplots
from sklearn.linear_model import LinearRegression

fig = make_subplots(rows=2, cols=2, subplot_titles=dataset_names)

for i, dataset in enumerate(datasets):
    # find regression line
    model = LinearRegression()
    model.fit(dataset.x.values.reshape(-1, 1), dataset.y.values.reshape(-1, 1))

    a = model.coef_[0][0]
    b = model.intercept_[0]
    x = np.linspace(0, 20, 100)
    y = a * x + b

    # add regression line
    fig.add_trace(
        go.Scatter(
            x=x,
            y=y,
            mode="lines",
            name="Regression Line",
            line=dict(color=COLOR_LIST[i]),
        ),
        row=i // 2 + 1,
        col=i % 2 + 1,
    )

    fig.add_trace(
        go.Scatter(
            x=dataset.x,
            y=dataset.y,
            mode="markers",
            name=dataset_names[i],
            marker=dict(color=COLOR_LIST[i]),
        ),
        row=i // 2 + 1,
        col=i % 2 + 1,
    )

fig.update_layout(
    title="Anscombe's Quartet",
    xaxis=dict(title="x"),
    yaxis=dict(title="y"),
    showlegend=False,
    xaxis_range=[0, 20],
    yaxis_range=[0, 14],
)