# Extending BaseModel: LogisticRegressionModel Example

This notebook demonstrates how to extend the `BaseModel` abstract base class and implement its methods using a custom logistic regression model from scikit-learn (different to the basic one provided in the library). This will allow your model to be used in other recourse generation methods.

In [1]:
# Import BaseModel
from robustx.lib.models.BaseModel import BaseModel
# Import modules four your model
import pandas as pd
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

In [2]:
# Extend BaseModel and implement all its methods
class LogisticRegressionModel(BaseModel):
    def __init__(self):
        super().__init__(LogisticRegression())
    
    def train(self, X: pd.DataFrame, y: pd.DataFrame) -> None:
        self._model.fit(X, y.values.ravel())
    
    def predict(self, X: pd.DataFrame) -> pd.DataFrame:
        predictions = self._model.predict(X)
        return pd.DataFrame(predictions, columns=['prediction'], index=X.index)
    
    def predict_single(self, X: pd.DataFrame) -> int:
        return int(self._model.predict(X)[0])
    
    def predict_proba(self, X: pd.DataFrame) -> pd.DataFrame:
        probabilities = self._model.predict_proba(X)
        return pd.DataFrame(probabilities, columns=[f'class_{i}' for i in range(probabilities.shape[1])], index=X.index)
    
    def predict_proba_tensor(self, X: torch.Tensor) -> torch.Tensor:
        X_numpy = X.numpy()
        probabilities = self._model.predict_proba(X_numpy)
        return torch.tensor(probabilities)
    
    def evaluate(self, X: pd.DataFrame, y: pd.DataFrame):
        y_pred = self.predict(X)
        accuracy = accuracy_score(y, y_pred)
        report = classification_report(y, y_pred)
        return {
            'accuracy': accuracy,
            'classification_report': report
        }

## Example Usage

In [3]:
# Load the Iris dataset, we are not using a DatasetLoader here, using the DatasetLoader with a BaseModel
# is covered in the Tasks notebook
iris = load_iris()
X = pd.DataFrame(iris.data, columns=iris.feature_names)
y = pd.DataFrame(iris.target, columns=['target'])

# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create and train the model
model = LogisticRegressionModel()
model.train(X_train, y_train)

# Make predictions
predictions = model.predict(X_test)
print("Predictions:")
print(predictions.head())

# Make a single prediction
single_prediction = model.predict_single(X_test.iloc[[0]])
print(f"\nSingle prediction: {single_prediction}")

# Predict probabilities
probabilities = model.predict_proba(X_test)
print("\nProbabilities:")
print(probabilities.head())

# Predict probabilities using tensor input
X_tensor = torch.tensor(X_test.values, dtype=torch.float32)
proba_tensor = model.predict_proba_tensor(X_tensor)
print("\nProbabilities (tensor):")
print(proba_tensor[:5])

# Evaluate the model
evaluation = model.evaluate(X_test, y_test)
print(f"\nAccuracy: {evaluation['accuracy']}")
print("\nClassification Report:")
print(evaluation['classification_report'])

Predictions:
     prediction
73            1
18            0
118           2
78            1
76            1

Single prediction: 1

Probabilities:
          class_0   class_1       class_2
73   3.804576e-03  0.827741  1.684546e-01
18   9.469821e-01  0.053018  1.987357e-07
118  8.864920e-09  0.001549  9.984514e-01
78   6.487006e-03  0.792228  2.012850e-01
76   1.458169e-03  0.774129  2.244126e-01

Probabilities (tensor):
tensor([[3.8046e-03, 8.2774e-01, 1.6845e-01],
        [9.4698e-01, 5.3018e-02, 1.9874e-07],
        [8.8649e-09, 1.5486e-03, 9.9845e-01],
        [6.4870e-03, 7.9223e-01, 2.0128e-01],
        [1.4582e-03, 7.7413e-01, 2.2441e-01]], dtype=torch.float64)

Accuracy: 1.0

Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        10
           1       1.00      1.00      1.00         9
           2       1.00      1.00      1.00        11

    accuracy                           1.00        30
   macro avg 

