In [1]:
%run -i "../util/lang_utils.ipynb"

In [2]:
import pandas as pd
from spacy.cli.train import train
from spacy.cli.evaluate import evaluate
from spacy.cli.debug_data import debug_data
from spacy.tokens import DocBin
from sklearn.metrics import classification_report
# Config generated at https://spacy.io/usage/training

In [3]:
def preprocess_data_entry(input_text, label, label_list):
    doc = small_model(input_text)
    cats = [0] * len(label_list)
    cats[label] = 1
    final_cats = {}
    for i, label in enumerate(label_list):
        final_cats[label] = cats[i]
    doc.cats = final_cats
    return doc

In [4]:
# Load and prepare data
train_db = DocBin()
test_db = DocBin()
label_list = ["tech", "business", "sport", "entertainment", "politics"]
train_df = pd.read_json("../data/bbc_train.json")
test_df = pd.read_json("../data/bbc_test.json")
train_df.sample(frac=1)
for idx, row in train_df.iterrows():
    text = row["text"]
    label = row["label"]
    doc = preprocess_data_entry(text, label, label_list)
    train_db.add(doc)
for idx, row in test_df.iterrows():
    text = row["text"]
    label = row["label"]
    doc = preprocess_data_entry(text, label, label_list)
    test_db.add(doc)
train_db.to_disk('../data/bbc_train.spacy')
test_db.to_disk('../data/bbc_test.spacy')

In [5]:
train("../data/spacy_config.cfg", output_path="../models/spacy_textcat_bbc")

[38;5;2m✔ Created output directory: ..\models\spacy_textcat_bbc[0m
[38;5;4mℹ Saving to output directory: ..\models\spacy_textcat_bbc[0m
[38;5;4mℹ Using CPU[0m
[1m
[38;5;2m✔ Initialized pipeline[0m
[1m
[38;5;4mℹ Pipeline: ['tok2vec', 'textcat'][0m
[38;5;4mℹ Initial learn rate: 0.001[0m
E    #       LOSS TOK2VEC  LOSS TEXTCAT  CATS_SCORE  SCORE 
---  ------  ------------  ------------  ----------  ------
  0       0          0.00          0.16        7.46    0.07
  0     200          9.28         37.51        8.05    0.08
  0     400         24.72         30.99       57.13    0.57
  0     600         59.68         33.32       51.04    0.51
  0     800         38.93         28.35       54.53    0.55
  0    1000         68.59         23.66       47.38    0.47
  0    1200         51.26         20.03       74.71    0.75
  0    1400         73.93         16.17       78.85    0.79
  0    1600         59.05         17.60       83.33    0.83
  1    1800         98.55         13.10 

In [6]:
# Use the trained model
nlp = spacy.load("../models/spacy_textcat_bbc/model-last")
input_text = test_df.iloc[1, test_df.columns.get_loc('text')]
print(input_text)
print(test_df["label_text"].iloc[[1]])
doc = nlp(input_text)
print("Predicted probabilities: ", doc.cats)

lib dems  new election pr chief the lib dems have appointed a senior figure from bt to be the party s new communications chief for their next general election effort.  sandy walkington will now work with senior figures such as matthew taylor on completing the party manifesto. party chief executive lord rennard said the appointment was a  significant strengthening of the lib dem team . mr walkington said he wanted the party to be ready for any  mischief  rivals or the media tried to throw at it.   my role will be to ensure this new public profile is effectively communicated at all levels   he said.  i also know the party will be put under scrutiny in the media and from the other parties as never before - and we will need to show ourselves ready and prepared to counter the mischief and misrepresentation that all too often comes from the party s opponents.  the party is already demonstrating on every issue that it is the effective opposition.  mr walkington s new job title is director of 

In [7]:
# Evaluate the model on test data
def get_prediction(input_text, nlp_model, target_names):
    doc = nlp_model(input_text)
    category = max(doc.cats, key = doc.cats.get)
    return target_names.index(category)
test_df["prediction"] = test_df["text"].apply(lambda x: get_prediction(x, nlp, label_list))

In [8]:
print(classification_report(test_df["label"], test_df["prediction"], target_names=label_list))

               precision    recall  f1-score   support

         tech       0.97      0.81      0.88        80
     business       0.95      0.90      0.92       102
        sport       0.98      0.98      0.98       102
entertainment       0.88      0.92      0.90        77
     politics       0.80      0.93      0.86        84

     accuracy                           0.91       445
    macro avg       0.91      0.91      0.91       445
 weighted avg       0.92      0.91      0.91       445



In [9]:
evaluate('../models/spacy_textcat_bbc/model-last', '../data/bbc_test.spacy')

{'token_acc': 1.0,
 'token_p': 1.0,
 'token_r': 1.0,
 'token_f': 1.0,
 'cats_score': 0.9090492096590561,
 'cats_score_desc': 'macro F',
 'cats_micro_p': 0.9123595505617977,
 'cats_micro_r': 0.9123595505617977,
 'cats_micro_f': 0.9123595505617977,
 'cats_macro_p': 0.9142913192129987,
 'cats_macro_r': 0.9091004583651643,
 'cats_macro_f': 0.9090492096590561,
 'cats_macro_auc': 0.9701107627364396,
 'cats_f_per_type': {'tech': {'p': 0.9701492537313433,
   'r': 0.8125,
   'f': 0.8843537414965987},
  'business': {'p': 0.9484536082474226,
   'r': 0.9019607843137255,
   'f': 0.9246231155778893},
  'sport': {'p': 0.9803921568627451,
   'r': 0.9803921568627451,
   'f': 0.9803921568627451},
  'entertainment': {'p': 0.8765432098765432,
   'r': 0.922077922077922,
   'f': 0.8987341772151898},
  'politics': {'p': 0.7959183673469388,
   'r': 0.9285714285714286,
   'f': 0.8571428571428572}},
 'cats_auc_per_type': {'tech': 0.9562671232876713,
  'business': 0.960226947922026,
  'sport': 0.9900817469845082