In [1]:
pip install ktrain

Collecting ktrain
  Downloading ktrain-0.28.3.tar.gz (25.3 MB)
[K     |████████████████████████████████| 25.3 MB 86 kB/s 
[?25hCollecting scikit-learn==0.23.2
  Downloading scikit_learn-0.23.2-cp37-cp37m-manylinux1_x86_64.whl (6.8 MB)
[K     |████████████████████████████████| 6.8 MB 32.5 MB/s 
Collecting langdetect
  Downloading langdetect-1.0.9.tar.gz (981 kB)
[K     |████████████████████████████████| 981 kB 30.1 MB/s 
Collecting cchardet
  Downloading cchardet-2.1.7-cp37-cp37m-manylinux2010_x86_64.whl (263 kB)
[K     |████████████████████████████████| 263 kB 47.1 MB/s 
Collecting syntok
  Downloading syntok-1.3.1.tar.gz (23 kB)
Collecting seqeval==0.0.19
  Downloading seqeval-0.0.19.tar.gz (30 kB)
Collecting transformers<=4.10.3,>=4.0.0
  Downloading transformers-4.10.3-py3-none-any.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 38.9 MB/s 
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (

In [2]:
import numpy as np
import os

import tensorflow as tf
import ktrain
from ktrain import text
import pandas as pd

In [3]:
df = pd.read_csv('articles.csv')

In [4]:
import nltk
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True

In [5]:
from nltk.corpus import stopwords
stop = stopwords.words('english')

In [6]:
df.head()

Unnamed: 0,ArticleId,Text,Category
0,1833,worldcom ex-boss launches defence lawyers defe...,business
1,154,german business confidence slides german busin...,business
2,1101,bbc poll indicates economic gloom citizens in ...,business
3,1976,lifestyle governs mobile choice faster bett...,tech
4,917,enron bosses in $168m payout eighteen former e...,business


In [7]:
df['Text'] = df['Text'].astype('str')
df['without_stopwords'] = df['Text'].apply(lambda x: ' '.join([word for word in x.split() if word not in (stop)]))

In [8]:
df.shape[0]*0.8

1192.0

In [9]:
data = df.sample(n = 1192, random_state = 2)

In [10]:
data['Category'].unique()

array(['business', 'politics', 'tech', 'sport', 'entertainment'],
      dtype=object)

In [11]:
df_test = df[['without_stopwords','Category']]

In [12]:
df_test = df_test[~df_test.without_stopwords.isin(data.without_stopwords)]

In [13]:
df_test.shape

(279, 2)

In [14]:
(X_train,y_train),(X_test,y_test),preproc = text.texts_from_df(train_df=data, text_column = 'without_stopwords',
                                                               label_columns = ['Category'],
                                                               val_df = df_test,
                                                               maxlen = 500,
                                                               preprocess_mode = 'bert')

['business', 'entertainment', 'politics', 'sport', 'tech']
      business  entertainment  politics  sport  tech
1283       1.0            0.0       0.0    0.0   0.0
354        0.0            0.0       1.0    0.0   0.0
1048       0.0            0.0       1.0    0.0   0.0
311        1.0            0.0       0.0    0.0   0.0
68         0.0            0.0       0.0    0.0   1.0
['business', 'entertainment', 'politics', 'sport', 'tech']
    business  entertainment  politics  sport  tech
8        1.0            0.0       0.0    0.0   0.0
9        0.0            1.0       0.0    0.0   0.0
19       0.0            0.0       0.0    0.0   1.0
31       0.0            1.0       0.0    0.0   0.0
34       0.0            0.0       0.0    1.0   0.0
downloading pretrained BERT model (uncased_L-12_H-768_A-12.zip)...
[██████████████████████████████████████████████████]
extracting pretrained BERT model...
done.

cleanup downloaded zip...
done.

preprocessing train...
language: en


Is Multi-Label? False
preprocessing test...
language: en


In [15]:
model = text.text_classifier(name = 'bert',
                             train_data = (X_train,y_train),
                             preproc = preproc)

Is Multi-Label? False
maxlen is 500
done.


In [16]:
learner = ktrain.get_learner(model = model,
                             train_data= (X_train,y_train),
                             val_data = (X_test,y_test),
                            batch_size = 6)

In [17]:
learner.fit_onecycle(lr = 2e-5,
                     epochs = 3)



begin training using onecycle policy with max lr of 2e-05...
Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x7f041189d2d0>

In [18]:
predictor = ktrain.get_predictor(learner.model, preproc)
predictor.save('email_classifier')
print('MODEL SAVED')

  layer_config = serialize_layer_fn(layer)


MODEL SAVED


In [19]:
df_test['without_stopwords'] = df_test['without_stopwords'].astype('str')
df_test['predicted_class'] = [predictor.predict(x) for x in df_test['without_stopwords']]
df_test['predicted_prob'] = [predictor.predict(x,return_proba = 'True') for x in df_test['without_stopwords']]

In [20]:
df_test

Unnamed: 0,without_stopwords,Category,predicted_class,predicted_prob
8,car giant hit mercedes slump slump profitabili...,business,business,"[0.99369013, 0.0014712806, 0.001118074, 0.0009..."
9,fockers fuel festive film chart comedy meet fo...,entertainment,entertainment,"[0.0014582014, 0.9942416, 0.0013036052, 0.0019..."
19,moving mobile improves golf swing mobile phone...,tech,tech,"[0.0032700482, 0.0012701568, 0.00087362144, 0...."
31,rapper snoop dogg sued rape us rapper snoop do...,entertainment,entertainment,"[0.0011493665, 0.9890416, 0.00078568887, 0.003..."
34,philippoussis doubt open bid mark philippoussi...,sport,sport,"[0.0011659046, 0.0012375761, 0.0003973996, 0.9..."
...,...,...,...,...
1473,dallaglio eyeing lions tour place former engla...,sport,sport,"[0.0009952108, 0.0036784231, 0.0015544826, 0.9..."
1477,web logs aid disaster recovery vivid descripti...,tech,tech,"[0.003435477, 0.0019285942, 0.0014837213, 0.00..."
1479,high fuel costs hit us airlines two largest ai...,business,business,"[0.99444616, 0.0014354686, 0.0010281274, 0.001..."
1484,hyundai build new india plant south korea hyun...,business,business,"[0.9916459, 0.001834347, 0.0015139946, 0.00080..."
