# Using generatedata with scikit-learn

This notebook demonstrates how to use the output of `load_data_as_xy` and `load_data_as_xy_onehot` with a scikit-learn RandomForest model.

In [1]:
# Import required libraries
from generatedata import load_data
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import accuracy_score, mean_squared_error
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

## Example 1: Regression with a synthetic dataset
We'll use `load_data_as_xy` to get features and targets for a regression problem.

In [2]:
# Load a regression dataset
X, Y = load_data.load_data_as_xy('regression_line', local=True)
# Use only the first output column if Y is multidimensional
if isinstance(Y, pd.DataFrame):
    y = Y.iloc[:, 0]
else:
    y = Y
# Train/test split
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Fit RandomForestRegressor
reg = RandomForestRegressor(n_estimators=100, random_state=42)
reg.fit(X_train, y_train)
y_pred = reg.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f'Regression MSE: {mse:.4f}')

Regression MSE: 0.0000


## Example 2: Classification with one-hot labels
We'll use `load_data_as_xy_onehot` to get features and one-hot encoded targets for a classification problem.

In [3]:
# Load a classification dataset with one-hot labels
X, Y = load_data.load_data_as_xy_onehot('MNIST', local=True)
# Convert one-hot labels to class indices
if isinstance(Y, pd.DataFrame):
    y = Y.values.argmax(axis=1)
else:
    y = np.argmax(Y, axis=1)
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Fit RandomForestClassifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f'Classification accuracy: {acc:.4f}')

Classification accuracy: 0.8850
