# Local predictions with SQL

Before running this notebook, you should configure the environment variables in the file `.env.edit` and rename it into `.env`.

In [1]:
import os
from sqlalchemy import create_engine
from bornrule.sql import BornClassifierSQL
from dotenv import load_dotenv
load_dotenv(".env")

True

### Check environment variables to connect to PostgreSQL

In [2]:
credentials = ['DB_USER', 'DB_PASS', 'DB_NAME', 'DB_HOST']
db = [os.getenv(c) for c in credentials]
print(db)

['testuser', '123', 'testdb', 'localhost']


### Initialize the classifier with the pre-trained 'zoo' model on the PostgreSQL backend

In [3]:
engine = create_engine(f"postgresql+psycopg2://{db[0]}:{db[1]}@/{db[2]}?host={db[3]}")
classifier = BornClassifierSQL(id="zoo", engine=engine)

### Explain the model: which are the top 10 features more relevant for mammals?

In [4]:
weights = classifier.explain()
weights.sort_values(by="Mammal", ascending=False)[0:10]

Unnamed: 0,Amphibian,Bird,Bug,Fish,Invertebrate,Mammal,Reptile
milk=1,0.0,0.0,0.0,0.0,0.0,0.25,0.0
eggs=0,0.0,0.0,0.0,0.0,0.045971,0.203869,0.041118
hair=1,0.0,0.0,0.080421,0.0,0.048061,0.148134,0.0
legs=4,0.0981,0.0,0.0,0.0,0.022655,0.090621,0.075819
fins=1,0.0,0.0,0.0,0.206985,0.0,0.070699,0.0
legs=2,0.044886,0.179543,0.0,0.0,0.0,0.06556,0.0
toothed=1,0.056514,0.0,0.0,0.057844,0.014592,0.06416,0.059809
tail=1,0.02151,0.048418,0.0,0.049676,0.011108,0.046245,0.047647
backbone=1,0.043228,0.043228,0.0,0.043228,0.0,0.043228,0.043228
aquatic=0,0.0,0.041134,0.04599,0.0,0.034107,0.0424,0.036792


### Define some test instances

In [5]:
animals = [
    {
        "legs=4": 1,
        "hair=1": 1,
    },
    {
        "fins=1": 1,
        "legs=0": 1,
    },

]

### Predict the test instances

In [6]:
pred = classifier.predict(animals)
pred

['Mammal', 'Fish']

### Explain the predictions on the test instances

In [7]:
classifier.explain(animals[0:1]).sort_values(by=pred[0], ascending=False)

Unnamed: 0,Amphibian,Bug,Invertebrate,Mammal,Reptile
hair=1,0.0,0.056866,0.033984,0.104746,0.0
legs=4,0.069367,0.0,0.01602,0.064079,0.053612


In [8]:
classifier.explain(animals[1:2]).sort_values(by=pred[1], ascending=False)

Unnamed: 0,Fish,Invertebrate,Mammal,Reptile
fins=1,0.146361,0.0,0.049992,0.0
legs=0,0.073288,0.051823,0.018923,0.048614
