In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import polars as pl
import torch
import matplotlib.pyplot as plt

In [None]:
from data import load_data, prepare_data
from config import CATEGORICAL_COLUMNS

species = "oak"
group_by_col = "plot_id"

df = load_data(species)

# Prepare data
X, y = prepare_data(
    df,
    plotting=True,
)

In [None]:
# One-hot encode categorical columns
import torch.nn.functional as F


def normalize(
    data: pl.DataFrame, target: pl.DataFrame
) -> tuple[F.Tensor, F.Tensor, list[str]]:
    # Impute missing values with mean
    data = data.fill_null(strategy="mean")

    # Convert to tensor
    X = data.to_torch().to(torch.float32)
    y = target.to_torch().to(torch.float32).view(-1, 1)

    # Perform one-hot encoding of categorical columns
    cols: list[str] = []
    features: list[torch.tensor] = []

    for i, col in enumerate(data.columns):
        if col in CATEGORICAL_COLUMNS:
            num_classes = data[col].n_unique()

            features.extend([f"{col}_{i}" for i in range(num_classes)])
            cols.append(F.one_hot(X[:, i].to(torch.int64), num_classes=num_classes))
        else:
            features.append(col)
            cols.append(X[:, i].unsqueeze(1))

    X = torch.cat(cols, dim=1)

    assert X.shape[1] == len(features)
    assert X.shape[0] == y.shape[0]

    # Normalize datas
    return (X - X.mean(axis=0)) / X.std(axis=0), (y - y.mean()) / y.std(), features


X, y_true, features = normalize(X, y)
X = X.nan_to_num(0.0)

assert not torch.isnan(X).any(), "Input contains NaNs"
assert not torch.isinf(X).any(), "Input contains Inf"
assert not torch.isnan(y_true).any(), "Target contains NaNs"
assert not torch.isinf(y_true).any(), "Target contains Inf"

X.shape, y_true.shape

In [None]:
# Train/test split
from sklearn.model_selection import GroupShuffleSplit

gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

train_idx, test_idx = next(gss.split(X, y_true, groups=df[group_by_col]))
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y_true[train_idx], y_true[test_idx]

assert X_train.shape[0] == y_train.shape[0]
assert X_test.shape[0] == y_test.shape[0]
assert X_test.shape[1] == X_train.shape[1]
assert y_test.shape[1] == y_train.shape[1]
assert y_test.shape[1] == 1

In [None]:
# Create simple MLP model
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import r2_score


class MLP(nn.Module):
    def __init__(
        self, input_dim: int, hidden_dims: tuple[int, int, int], dropout: float = 0.0
    ):
        super(MLP, self).__init__()

        self.dropout0 = nn.Dropout(dropout)
        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.dropout2 = nn.Dropout(dropout)
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.dropout3 = nn.Dropout(dropout)
        self.fc4 = nn.Linear(hidden_dims[2], 1)

        # Initialize weights
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)
        nn.init.xavier_uniform_(self.fc4.weight)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = F.relu(self.fc3(x))
        x = self.dropout3(x)
        x = self.fc4(x)
        return x


class NonLinear(nn.Module):
    def __init__(self, input_dim: int):
        super(NonLinear, self).__init__()
        self.fc1 = nn.Linear(input_dim, 8)
        self.fc2 = nn.Linear(8, 1)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


class Linear(nn.Module):
    def __init__(self, input_dim: int):
        super(Linear, self).__init__()
        self.fc = nn.Linear(input_dim, 1)
        nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, x):
        return self.fc(x)


model = Linear(X.shape[1])
# model = MLP(X.shape[1], (32, 16, 8), dropout=0.2)
print(f"# parameters: {sum(p.numel() for p in model.parameters())}")
print(f"# samples: {len(df)}")
print(f"# features: {len(features)}")
print(f"# train samples: {len(train_idx)}")
print(f"# test samples: {len(test_idx)}")

# Define batch size
batch_size = 256
num_epochs = 100

# Define loss and optimizer
criterion = nn.MSELoss()

# Use Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train model and keep track of losses
train_loss_mean = []
train_loss_std = []
test_loss_mean = []
test_r2_scores = []


for epoch in range(num_epochs):
    # Create mini-batches
    indices = torch.randperm(X_train.shape[0])

    mini_batch_losses = []
    for i in range(0, X_train.shape[0], batch_size):
        batch_indices = indices[i : i + batch_size]
        X_batch = X_train[batch_indices]
        y_batch = y_train[batch_indices]

        # Forward pass
        y_pred_test = model(X_batch)
        train_loss = criterion(y_pred_test, y_batch)

        # Backward pass
        optimizer.zero_grad()
        train_loss.backward()

        # Update weights
        optimizer.step()

        # Average loss over mini-batches
        mini_batch_losses.append(train_loss.item())

    # Compute mean and standard deviation of the train loss
    train_loss_mean.append(np.mean(mini_batch_losses))
    train_loss_std.append(np.std(mini_batch_losses))

    # Compute loss and R2 score on the test set
    with torch.no_grad():
        y_pred_test = model(X_test)
        test_loss = criterion(y_pred_test, y_test)
        test_r2 = r2_score(y_test.numpy(), y_pred_test.numpy())

        test_loss_mean.append(test_loss.item())
        test_r2_scores.append(test_r2)

        # Save model if it has the best R2 score
        if test_r2 == max(test_r2_scores):
            torch.save(model.state_dict(), "best_model.pth")

    if epoch % 10 == 0:
        print(
            f"Epoch {epoch}"
            f" - Loss: {train_loss_mean[-1]:.4f} +/- {train_loss_std[-1]:.4f}"
            f" - Test Loss: {test_loss_mean[-1]:.4f}"
            f" - Test R2: {test_r2_scores[-1]:.4f}"
        )

# Load best model
model.load_state_dict(torch.load("best_model.pth"))

In [None]:
# Predict on the entire dataset
with torch.no_grad():
    y_pred_test = model(X_test)
    y_pred_train = model(X_train)

    # Calculate final loss and R2 score
    test_loss = criterion(y_pred_test, y_test)
    train_loss = criterion(y_pred_train, y_train)

    test_r2 = r2_score(y_test.numpy(), y_pred_test.numpy())
    train_r2 = r2_score(y_train.numpy(), y_pred_train.numpy())

    print(f"Score for species {species}")
    print(f"Test loss: {test_loss.item():.4f}")
    print(f"Train loss: {train_loss.item():.4f}")
    print(f"R2 score (test): {test_r2:.4f}")
    print(f"R2 score (train): {train_r2:.4f}")

In [None]:
# Plot train losses
plt.plot(train_loss_mean, label="Train Loss (mean)")
plt.fill_between(
    range(len(train_loss_mean)),
    np.array(train_loss_mean) - np.array(train_loss_std),
    np.array(train_loss_mean) + np.array(train_loss_std),
    alpha=0.3,
    label="Train Loss (std)",
)

# Plot test losses
plt.plot(test_loss_mean, label="Test Loss")

ax1 = plt.gca()
ax2 = ax1.twinx()
ax2.plot(test_r2_scores, color="green", label="Test R2")

plt.xlabel("Epoch")
ax1.set_ylabel("Loss")
ax2.set_ylabel("R2 score")

ax1.legend(loc="best")
ax2.legend(loc="lower right")

plt.title("Training Loss")
plt.tight_layout()

In [None]:
# Show predictions
plt.plot(y_test, y_test, color="red")
plt.scatter(y_test, y_pred_test)
plt.xlabel("True")
plt.ylabel("Predicted")

test_r2 = r2_score(y_test.numpy(), y_pred_test.numpy())
plt.title(f"Predicted vs True (R2 = {test_r2:.4f})")

plt.legend(["Predicted", "True"])

In [None]:
# Plot feature importance of Linear model
if isinstance(model, Linear):
    plt.figure()

    plt.barh(features, model.fc.weight.squeeze().abs().detach().numpy())

    plt.xlabel("Feature importance")
    plt.title("Feature importance")