# Train a domain classifier on the [semantic scholar dataset](https://api.semanticscholar.org/corpus)

> Part 2: train a model

![position of this step in the lifecycle](../diagrams/scope-train.svg)
> The blue boxes show the steps implemented in this notebook.

In [Part 1](data.ipynb), we have cleaned and transformed our training data. We can now access this data using `great_ai.LargeFile`. Locally, it will gives us the cached version, otherwise, the latest version is downloaded from S3. 

In this part, we hyperparameter-optimise and train a simple, Naive Bayes classifier which we then export for deployment using `great_ai.save_model`.

## Load data that has been extracted in [part 1](data.ipynb)

In [1]:
from great_ai import query_ground_truth

data = query_ground_truth("train")
X = [d.input for d in data for domain in d.feedback]
y = [domain for d in data for domain in d.feedback]

[38;5;39m2022-06-25 14:50:29,880 |     INFO | Found credentials file (/data/projects/great_ai_example/mongo.ini), initialising MongodbDriver[0m
[38;5;39m2022-06-25 14:50:29,881 |     INFO | Found credentials file (/data/projects/great_ai_example/mongo.ini), initialising LargeFileMongo[0m
[38;5;39m2022-06-25 14:50:29,881 |     INFO | Settings: configured ✅[0m
[38;5;39m2022-06-25 14:50:29,882 |     INFO | 🔩 tracing_database: MongodbDriver[0m
[38;5;39m2022-06-25 14:50:29,883 |     INFO | 🔩 large_file_implementation: LargeFileMongo[0m
[38;5;39m2022-06-25 14:50:29,883 |     INFO | 🔩 is_production: False[0m
[38;5;39m2022-06-25 14:50:29,884 |     INFO | 🔩 should_log_exception_stack: True[0m
[38;5;39m2022-06-25 14:50:29,884 |     INFO | 🔩 prediction_cache_size: 512[0m
[38;5;39m2022-06-25 14:50:29,885 |     INFO | 🔩 dashboard_table_size: 50[0m


In [2]:
import pandas as pd
from collections import Counter
import plotly.express as px

df = pd.DataFrame(Counter(y).most_common(), columns=["domain", "count"])
px.bar(df, "domain", "count", width=1200, height=400).show()

## Optimise and train Multinomial Naive Bayes classifier

In [3]:
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer


def create_pipeline() -> Pipeline:
    return Pipeline(
        steps=[
            ("vectorizer", TfidfVectorizer(sublinear_tf=True)),
            ("classifier", MultinomialNB()),
        ]
    )

In [4]:
from sklearn.model_selection import GridSearchCV

optimisation_pipeline = GridSearchCV(
    create_pipeline(),
    {
        "vectorizer__min_df": [5, 20, 100],
        "vectorizer__max_df": [0.05, 0.1],
        "classifier__alpha": [0.5, 1],
        "classifier__fit_prior": [True, False],
    },
    scoring="f1_macro",
    cv=3,
    n_jobs=-1,
    verbose=1,
)
optimisation_pipeline.fit(X, y)

results = pd.DataFrame(optimisation_pipeline.cv_results_)
results.sort_values("rank_test_score")

Fitting 3 folds for each of 24 candidates, totalling 72 fits


Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_classifier__alpha,param_classifier__fit_prior,param_vectorizer__max_df,param_vectorizer__min_df,params,split0_test_score,split1_test_score,split2_test_score,mean_test_score,std_test_score,rank_test_score
7,1.96226,0.147449,0.935357,0.063659,0.5,False,0.05,20,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.48503,0.463849,0.48184,0.476906,0.009324,1
10,1.942605,0.111027,0.952361,0.066812,0.5,False,0.1,20,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.48289,0.459556,0.479362,0.473936,0.01027,2
19,2.145152,0.068978,1.002291,0.047358,1.0,False,0.05,20,"{'classifier__alpha': 1, 'classifier__fit_prio...",0.46733,0.442994,0.464302,0.458208,0.010829,3
22,1.971888,0.12695,0.739795,0.071551,1.0,False,0.1,20,"{'classifier__alpha': 1, 'classifier__fit_prio...",0.45483,0.422902,0.450677,0.442803,0.014174,4
6,1.861275,0.013389,1.058907,0.111122,0.5,False,0.05,5,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.456127,0.422456,0.443827,0.440803,0.013912,5
11,1.825397,0.105754,0.892227,0.057003,0.5,False,0.1,100,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.438232,0.440464,0.422667,0.433788,0.007916,6
23,1.693333,0.009667,0.501491,0.006545,1.0,False,0.1,100,"{'classifier__alpha': 1, 'classifier__fit_prio...",0.433915,0.43947,0.416031,0.429805,0.010001,7
8,2.008045,0.14533,0.944559,0.155925,0.5,False,0.05,100,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.436178,0.425724,0.418396,0.426766,0.007297,8
20,1.7492,0.022959,0.889532,0.047517,1.0,False,0.05,100,"{'classifier__alpha': 1, 'classifier__fit_prio...",0.428215,0.425398,0.411051,0.421555,0.007516,9
9,1.960889,0.098004,0.985957,0.080925,0.5,False,0.1,5,"{'classifier__alpha': 0.5, 'classifier__fit_pr...",0.430638,0.406619,0.420213,0.419157,0.009834,10


In [5]:
from sklearn import set_config

set_config(display="diagram")

classifier = create_pipeline()
classifier.set_params(**optimisation_pipeline.best_params_)
classifier.fit(X, y)

## Export the model using GreatAI

In [6]:
from great_ai import save_model


save_model(classifier, key="small-domain-prediction", keep_last_n=5)

[38;5;39m2022-06-25 14:50:53,592 |     INFO | Copying file for small-domain-prediction-0[0m
[38;5;39m2022-06-25 14:50:53,613 |     INFO | Compressing small-domain-prediction-0[0m
[38;5;39m2022-06-25 14:50:53,917 |     INFO | Uploading /tmp/tmpvxez8op8/small-domain-prediction-0.tar.gz to Mongo (GridFS)[0m
[38;5;39m2022-06-25 14:50:53,972 |     INFO | Uploading small-domain-prediction-0.tar.gz 0.26/1.85 MB (14.1%)[0m
[38;5;39m2022-06-25 14:50:53,974 |     INFO | Uploading small-domain-prediction-0.tar.gz 0.52/1.85 MB (28.2%)[0m
[38;5;39m2022-06-25 14:50:53,975 |     INFO | Uploading small-domain-prediction-0.tar.gz 0.78/1.85 MB (42.3%)[0m
[38;5;39m2022-06-25 14:50:53,977 |     INFO | Uploading small-domain-prediction-0.tar.gz 1.04/1.85 MB (56.4%)[0m
[38;5;39m2022-06-25 14:50:53,979 |     INFO | Uploading small-domain-prediction-0.tar.gz 1.31/1.85 MB (70.5%)[0m
[38;5;39m2022-06-25 14:50:53,980 |     INFO | Uploading small-domain-prediction-0.tar.gz 1.57/1.85 MB (84.7%)[0

'small-domain-prediction:0'

### Next: [Part 3](deploy.ipynb)