# Train a domain classifier on the [semantic scholar dataset](https://api.semanticscholar.org/corpus)

> Part 2: train a model

![position of this step in the lifecycle](diagrams/scope-train.svg)
> The blue boxes show the steps implemented in this notebook.

In [Part 1](data.ipynb), we have cleaned and transformed our training data. We can now access this data using `great_ai.LargeFile`. Locally, it will gives us the cached version, otherwise, the latest version is downloaded from S3. 

In this part, we hyperparameter-optimise and train a simple, Naive Bayes classifier which we then export for deployment using `great_ai.save_model`.

## Load data that has been extracted in [part 1](data.ipynb)

In [1]:
from great_ai import query_ground_truth

data = query_ground_truth("train")
X = [d.input for d in data for domain in d.feedback]
y = [domain for d in data for domain in d.feedback]

[38;5;39m2022-06-25 11:25:08,431 |     INFO | Found credentials file (/data/projects/great-ai/examples/simple/mongo.ini), initialising MongodbDriver[0m
[38;5;39m2022-06-25 11:25:08,432 |     INFO | Found credentials file (/data/projects/great-ai/examples/simple/mongo.ini), initialising LargeFileMongo[0m
[38;5;39m2022-06-25 11:25:08,432 |     INFO | Settings: configured ✅[0m
[38;5;39m2022-06-25 11:25:08,433 |     INFO | 🔩 tracing_database: MongodbDriver[0m
[38;5;39m2022-06-25 11:25:08,433 |     INFO | 🔩 large_file_implementation: LargeFileMongo[0m
[38;5;39m2022-06-25 11:25:08,434 |     INFO | 🔩 is_production: False[0m
[38;5;39m2022-06-25 11:25:08,434 |     INFO | 🔩 should_log_exception_stack: True[0m
[38;5;39m2022-06-25 11:25:08,434 |     INFO | 🔩 prediction_cache_size: 512[0m


{'filter': {'$and': [{'tags': 'train'}, {'feedback': {'$ne': None}}]}, 'sort': []}


In [2]:
import pandas as pd
from collections import Counter
import plotly.express as px

df = pd.DataFrame(Counter(y).most_common(), columns=["domain", "count"])
px.bar(df, "domain", "count", width=1200, height=400).show()

## Optimise and train Multinomial Naive Bayes classifier

In [3]:
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer


def create_pipeline() -> Pipeline:
    return Pipeline(
        steps=[
            ("vectorizer", TfidfVectorizer(sublinear_tf=True)),
            ("classifier", MultinomialNB()),
        ]
    )

In [4]:
from sklearn.model_selection import GridSearchCV

optimisation_pipeline = GridSearchCV(
    create_pipeline(),
    {
        "vectorizer__min_df": [5, 20, 100],
        "vectorizer__max_df": [0.05, 0.1],
        "classifier__alpha": [0.5, 1],
        "classifier__fit_prior": [True, False],
    },
    scoring="f1_macro",
    cv=3,
    n_jobs=-1,
    verbose=1,
)
optimisation_pipeline.fit(X, y)

results = pd.DataFrame(optimisation_pipeline.cv_results_)
results.sort_values("rank_test_score")

Fitting 3 folds for each of 24 candidates, totalling 72 fits


Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_classifier__alpha,param_classifier__fit_prior,param_vectorizer__max_df,param_vectorizer__min_df,params,split0_test_score,split1_test_score,split2_test_score,mean_test_score,std_test_score,rank_test_score
7,1.986588,0.050896,1.090251,0.135508,0.5,False,0.05,20,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.459165,0.473024,0.475462,0.469217,0.007177,1
10,2.070333,0.038396,0.976315,0.033742,0.5,False,0.1,20,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.457524,0.463575,0.458007,0.459702,0.002745,2
19,2.049166,0.193113,1.064576,0.087034,1.0,False,0.05,20,"{'classifier__alpha': 1, 'classifier__fit_prio...",0.441657,0.452933,0.451286,0.448625,0.004973,3
6,2.288145,0.131338,1.133445,0.099583,0.5,False,0.05,5,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.436976,0.449693,0.437911,0.441527,0.005787,4
22,1.872234,0.072997,0.74807,0.085882,1.0,False,0.1,20,"{'classifier__alpha': 1, 'classifier__fit_prio...",0.432322,0.438805,0.428209,0.433112,0.004362,5
11,2.067691,0.126923,0.910947,0.078748,0.5,False,0.1,100,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.426436,0.429182,0.43741,0.431009,0.004662,6
23,1.84733,0.147504,0.495354,0.018589,1.0,False,0.1,100,"{'classifier__alpha': 1, 'classifier__fit_prio...",0.42213,0.430875,0.430829,0.427945,0.004112,7
9,2.071489,0.256086,1.055936,0.037198,0.5,False,0.1,5,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.416746,0.425938,0.417381,0.420022,0.004192,8
20,1.776546,0.064677,0.888485,0.093302,1.0,False,0.05,100,"{'classifier__alpha': 1, 'classifier__fit_prio...",0.413441,0.417122,0.427196,0.419253,0.005814,9
8,2.015992,0.062583,0.974434,0.082582,0.5,False,0.05,100,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.412522,0.410409,0.425047,0.415993,0.00646,10


In [5]:
from sklearn import set_config

set_config(display="diagram")

classifier = create_pipeline()
classifier.set_params(**optimisation_pipeline.best_params_)
classifier.fit(X, y)

## Export the model using GreatAI

In [6]:
from great_ai import save_model


save_model(classifier, key="small-domain-prediction", keep_last_n=5)

[38;5;39m2022-06-25 11:25:32,714 |     INFO | Copying file for small-domain-prediction-0[0m
[38;5;39m2022-06-25 11:25:32,737 |     INFO | Compressing small-domain-prediction-0[0m
[38;5;39m2022-06-25 11:25:33,050 |     INFO | Uploading /tmp/tmpgerx8x95/small-domain-prediction-0.tar.gz to Mongo (GridFS)[0m
[38;5;39m2022-06-25 11:25:33,107 |     INFO | Uploading small-domain-prediction-0.tar.gz 0.26/1.85 MB (14.2%)[0m
[38;5;39m2022-06-25 11:25:33,109 |     INFO | Uploading small-domain-prediction-0.tar.gz 0.52/1.85 MB (28.3%)[0m
[38;5;39m2022-06-25 11:25:33,112 |     INFO | Uploading small-domain-prediction-0.tar.gz 0.78/1.85 MB (42.5%)[0m
[38;5;39m2022-06-25 11:25:33,114 |     INFO | Uploading small-domain-prediction-0.tar.gz 1.04/1.85 MB (56.6%)[0m
[38;5;39m2022-06-25 11:25:33,116 |     INFO | Uploading small-domain-prediction-0.tar.gz 1.31/1.85 MB (70.8%)[0m
[38;5;39m2022-06-25 11:25:33,117 |     INFO | Uploading small-domain-prediction-0.tar.gz 1.57/1.85 MB (84.9%)[0

'small-domain-prediction:0'