In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch

import demo.torch.models
from demo.torch.utils import get_optimizer, get_scheduler, train_regression_model


torch.manual_seed(42)
np.random.seed(42)

In [None]:
def visualize(x, y, y_pred=None):
    """Scatter plot of the data points and the predictions"""

    _, ax = plt.subplots(figsize=(6, 6))
    sns.scatterplot(x=x.flatten(), y=y.flatten(), label="Observations", ax=ax)
    sns.lineplot(
        x=x.flatten(), y=x.flatten(), linestyle="--", color="red", label="Ground truth"
    )
    if y_pred is not None:
        sns.scatterplot(
            x=x.flatten(), y=y_pred.flatten(), color="green", label="Predictions"
        )
    ax.set(xlabel="X", ylabel="Y", title="Simple linear regression dataset")
    plt.show()

In [None]:
# Prepare a very simple dataset for linear regression
x = torch.linspace(0, 1, 500).view(-1, 1)
y = x + 0.2 * torch.randn_like(x)

visualize(x, y)

In [None]:
model = demo.torch.models.LinearRegression(in_features=1, out_features=1)

optimizer = get_optimizer(model)
scheduler = get_scheduler(optimizer)

train_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(x, y), batch_size=16, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(x, y), batch_size=16, shuffle=False
)

train_regression_model(
    model,
    train_dataloader,
    val_dataloader,
    optimizer,
    scheduler,
    n_epochs=3,
    device=torch.device("cuda"),
)

In [None]:
with torch.no_grad():
    y_pred = model(x.to("cuda")).detach().cpu()

visualize(x, y, y_pred)