# Extending TrainableModel: TrainableLogisticRegressionModel Example

This notebook demonstrates how to extend the `TrainableModel` abstract base class and implement its methods using a custom logistic regression model from scikit-learn (different to the basic one provided in the library). This will allow your model to be used in other recourse generation methods.

In [1]:
from rocelib.datasets.DatasetLoader import DatasetLoader
# Import TrainableModel
from rocelib.models.TrainableModel import TrainableModel

# Import modules four your model
import pandas as pd
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

In [2]:
# Extend TrainableModel and implement all its methods
class TrainableLogisticRegressionModel(TrainableModel):
    def __init__(self):
        super().__init__(LogisticRegression())
    
    def train(self, dataset_loader: DatasetLoader) -> None:
        self._model.fit(dataset_loader.X, dataset_loader.y.values.ravel())
    
    def predict(self, X: pd.DataFrame) -> pd.DataFrame:
        predictions = self._model.predict(X)
        return pd.DataFrame(predictions, columns=['prediction'], index=X.index)
    
    def predict_single(self, X: pd.DataFrame) -> int:
        return int(self._model.predict(X)[0])
    
    def predict_proba(self, X: pd.DataFrame) -> pd.DataFrame:
        probabilities = self._model.predict_proba(X)
        return pd.DataFrame(probabilities, columns=[f'class_{i}' for i in range(probabilities.shape[1])], index=X.index)
    
    def predict_proba_tensor(self, X: torch.Tensor) -> torch.Tensor:
        X_numpy = X.numpy()
        probabilities = self._model.predict_proba(X_numpy)
        return torch.tensor(probabilities)
    
    def evaluate(self, X: pd.DataFrame, y: pd.DataFrame):
        y_pred = self.predict(X)
        accuracy = accuracy_score(y, y_pred)
        report = classification_report(y, y_pred)
        return {
            'accuracy': accuracy,
            'classification_report': report
        }

## Example Usage

In [3]:
# Load the Iris dataset, we are not using a DatasetLoader here, using the DatasetLoader with a TrainableModel
# is covered in the Tasks notebook
iris = load_iris()
X = pd.DataFrame(iris.data, columns=iris.feature_names)
y = pd.DataFrame(iris.target, columns=['target'])

# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create and train the model
model = TrainableLogisticRegressionModel()
model.train(X_train, y_train)

# Make predictions
predictions = model.predict(X_test)
print("Predictions:")
print(predictions.head())

# Make a single prediction
single_prediction = model.predict_single(X_test.iloc[[0]])
print(f"\nSingle prediction: {single_prediction}")

# Predict probabilities
probabilities = model.predict_proba(X_test)
print("\nProbabilities:")
print(probabilities.head())

# Predict probabilities using tensor input
X_tensor = torch.tensor(X_test.values, dtype=torch.float32)
proba_tensor = model.predict_proba_tensor(X_tensor)
print("\nProbabilities (tensor):")
print(proba_tensor[:5])

# Evaluate the model
evaluation = model.evaluate(X_test, y_test)
print(f"\nAccuracy: {evaluation['accuracy']}")
print("\nClassification Report:")
print(evaluation['classification_report'])

Predictions:
     prediction
73            1
18            0
118           2
78            1
76            1

Single prediction: 1

Probabilities:
          class_0   class_1       class_2
73   3.799987e-03  0.827708  1.684918e-01
18   9.469409e-01  0.053059  1.990661e-07
118  8.843105e-09  0.001549  9.984513e-01
78   6.479330e-03  0.792189  2.013315e-01
76   1.455772e-03  0.774086  2.244581e-01

Probabilities (tensor):
tensor([[3.8000e-03, 8.2771e-01, 1.6849e-01],
        [9.4694e-01, 5.3059e-02, 1.9907e-07],
        [8.8431e-09, 1.5487e-03, 9.9845e-01],
        [6.4793e-03, 7.9219e-01, 2.0133e-01],
        [1.4558e-03, 7.7409e-01, 2.2446e-01]], dtype=torch.float64)

Accuracy: 1.0

Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        10
           1       1.00      1.00      1.00         9
           2       1.00      1.00      1.00        11

    accuracy                           1.00        30
   macro avg 

