# **AG-News Multi-class Text Classification  with Deep Learning using BERT:**

In this project, we'll go into great depth on how the BERT base model is applied to text classification. We will see how this cutting-edge Transformer model may achieve incredibly high performance metrics in relation to a sizable corpus of data made up of more than 100k+ labelled training instances. Building, training, and fine-tuning the BERT model with regard to classification on this custom dataset will be done using the hugging face transformer and dataset library, as well as ktrain (a high level python wrapper with tensorflow backend).

**Installing Libraries**

In [2]:
!pip install ktrain
!pip install transformers #developed by Hugging Face and provides state-of-the-art pre-trained models for natural language processing tasks. It includes a wide range of transformer architectures, including BERT, GPT, RoBERTa, and more.
!pip install datasets #The datasets library, also developed by Hugging Face, provides easy access to a vast collection of datasets for NLP.
!pip install tensorflow

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ktrain
  Downloading ktrain-0.37.2.tar.gz (25.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m25.3/25.3 MB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting langdetect (from ktrain)
  Downloading langdetect-1.0.9.tar.gz (981 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.5/981.5 kB[0m [31m73.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting cchardet (from ktrain)
  Downloading cchardet-2.1.7.tar.gz (653 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m653.6/653.6 kB[0m [31m60.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting syntok>1.3.3 (from ktrain)
  Downloading syntok-1.4.4-py3-none-any.whl (24 kB)
Collecting tika (from ktrain)
  Downloading ti

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ktrain
from ktrain import text #the text module offers convenient utilities for preprocessing text data, creating text classification models, and performing various operations on text inputs.
from sklearn.model_selection import train_test_split
from datasets import list_datasets
from datasets import load_dataset
import timeit
import tensorflow as tf

In [4]:
print("Tensorflow version : ", tf.__version__)
print("GPU available : ",bool(tf.test.is_gpu_available))
print("GPU name : ",tf.test.gpu_device_name())

Tensorflow version :  2.12.0
GPU available :  True
GPU name :  /device:GPU:0


**Hugging Face's available datasets are checked:**

In [5]:
available_datasets = list_datasets()
print("Count of available datasets : ", len(available_datasets))
print()
print("<====== Dataset List ======> :\n")
print('\n  |__ '.join(dataset for dataset in available_datasets))

Count of available datasets :  44366


acronym_identification
  |__ ade_corpus_v2
  |__ adversarial_qa
  |__ aeslc
  |__ afrikaans_ner_corpus
  |__ ag_news
  |__ ai2_arc
  |__ air_dialogue
  |__ ajgt_twitter_ar
  |__ allegro_reviews
  |__ allocine
  |__ alt
  |__ amazon_polarity
  |__ amazon_reviews_multi
  |__ amazon_us_reviews
  |__ ambig_qa
  |__ americas_nli
  |__ ami
  |__ amttl
  |__ anli
  |__ app_reviews
  |__ aqua_rat
  |__ aquamuse
  |__ ar_cov19
  |__ ar_res_reviews
  |__ ar_sarcasm
  |__ arabic_billion_words
  |__ arabic_pos_dialect
  |__ arabic_speech_corpus
  |__ arcd
  |__ arsentd_lev
  |__ art
  |__ arxiv_dataset
  |__ ascent_kb
  |__ aslg_pc12
  |__ asnq
  |__ asset
  |__ assin
  |__ assin2
  |__ atomic
  |__ autshumato
  |__ facebook/babi_qa
  |__ banking77
  |__ bbaw_egyptian
  |__ bbc_hindi_nli
  |__ bc2gm_corpus
  |__ beans
  |__ best2009
  |__ bianet
  |__ bible_para
  |__ big_patent
  |__ billsum
  |__ bing_coronavirus_query_set
  |__ biomrc
  |__ biosses
  |__ b

In [6]:
ag_news_dataset = load_dataset('ag_news') #import ag_news dataset
print("\n", ag_news_dataset)

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading metadata: 0.00B [00:00, ?B/s]

Downloading readme: 0.00B [00:00, ?B/s]

Downloading and preparing dataset ag_news/default to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548...


Downloading data:   0%|          | 0.00/11.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/751k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

Dataset ag_news downloaded and prepared to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]


 DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})


In [7]:
print("Dataset Items: \n", ag_news_dataset.items())
print("\nDataset type: \n", type(ag_news_dataset))
print("\nShape of dataset: \n", ag_news_dataset.shape)
print("\nNo of rows: \n", ag_news_dataset.num_rows)
print("\nNo of columns: \n", ag_news_dataset.num_columns)

Dataset Items: 
 dict_items([('train', Dataset({
    features: ['text', 'label'],
    num_rows: 120000
})), ('test', Dataset({
    features: ['text', 'label'],
    num_rows: 7600
}))])

Dataset type: 
 <class 'datasets.dataset_dict.DatasetDict'>

Shape of dataset: 
 {'train': (120000, 2), 'test': (7600, 2)}

No of rows: 
 {'train': 120000, 'test': 7600}

No of columns: 
 {'train': 2, 'test': 2}


The AG News dataset is a widely used benchmark dataset for text classification tasks. It consists of news articles from the AG's corpus, which is a collection of news articles from the web. The dataset is designed for multi-class classification, where each news article is assigned to one of four classes representing different news categories.

Number of Classes: 4

Class Labels: The dataset has four class labels representing news categories:

Class 1: World news
Class 2: Sports news
Class 3: Business news
Class 4: Science and technology news

In [8]:
print("\nColumn Names: \n", ag_news_dataset.column_names)
print("\n", ag_news_dataset.data)


Column Names: 
 {'train': ['text', 'label'], 'test': ['text', 'label']}

 {'train': MemoryMappedTable
text: string
label: int64
----
label: [[2,2,2,2,2,...,2,2,3,3,3],[3,1,1,0,0,...,3,1,0,0,2],...,[0,0,2,2,2,...,2,2,2,1,0],[0,0,0,0,3,...,0,1,1,1,1]], 'test': MemoryMappedTable
text: string
label: int64
----
text: [["Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.","The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket.","Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which

In [9]:
print(ag_news_dataset['train'][0])
print(ag_news_dataset['train'][1])

{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 2}
{'text': 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.', 'label': 2}


In [10]:
print(ag_news_dataset['train']['text'][0])
print(ag_news_dataset['train']['label'][0])
print()
print(ag_news_dataset['train']['text'][35000])
print(ag_news_dataset['train']['label'][35000])
print()
print(ag_news_dataset['train']['text'][60000])
print(ag_news_dataset['train']['label'][60000])
print()
print(ag_news_dataset['train']['text'][100000])
print(ag_news_dataset['train']['label'][100000])

Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.
2

Black armbands for Clough, tears for Liverpool fans In the afternoon, Brian Clough, unquestionably one of the greats and unarguably one of the most controversial of football men, died of cancer.
1

BYTE OF THE APPLE Apple lost one war to Microsoft by not licensing its Mac operating system. It may repeat the error with its iPod and music software.
3

Venezuelan Car-Bomb Suspect Killed, Weapons Found  CARACAS, Venezuela (Reuters) - A Venezuelan lawyer  suspected in last week's bombing murder of a top state  prosecutor was killed in a gunfight with police on Tuesday  after he tried to ram detectives with his car and opened fire  on them, officials said.
0


**Loading train and test datasets**

In [11]:
ag_news_train = load_dataset('ag_news', split='train')
ag_news_test = load_dataset('ag_news', split='test')
print("Train Dataset : ", ag_news_train.shape)
print("Test Dataset : ", ag_news_test.shape)



Train Dataset :  (120000, 2)
Test Dataset :  (7600, 2)


In [12]:
print(ag_news_train[0])
print(ag_news_test[0])

{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 2}
{'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.", 'label': 2}


In [13]:
print("\nTrain Dataset Features: \n", ag_news_train.features)
print("\nTest Dataset Features: \n", ag_news_test.features)


Train Dataset Features: 
 {'text': Value(dtype='string', id=None), 'label': ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None)}

Test Dataset Features: 
 {'text': Value(dtype='string', id=None), 'label': ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None)}


In [14]:
pd.set_option('Display.max_columns', None)
ag_news_train_df = pd.DataFrame(data=ag_news_train)
ag_news_train_df.head(10)

Unnamed: 0,text,label
0,Wall St. Bears Claw Back Into the Black (Reute...,2
1,Carlyle Looks Toward Commercial Aerospace (Reu...,2
2,Oil and Economy Cloud Stocks' Outlook (Reuters...,2
3,Iraq Halts Oil Exports from Main Southern Pipe...,2
4,"Oil prices soar to all-time record, posing new...",2
5,"Stocks End Up, But Near Year Lows (Reuters) Re...",2
6,Money Funds Fell in Latest Week (AP) AP - Asse...,2
7,Fed minutes show dissent over inflation (USATO...,2
8,Safety Net (Forbes.com) Forbes.com - After ear...,2
9,Wall St. Bears Claw Back Into the Black NEW Y...,2


In [15]:
ag_news_train_df.tail(10)

Unnamed: 0,text,label
119990,Barack Obama Gets #36;1.9 Million Book Deal (...,0
119991,Rauffer Beats Favorites to Win Downhill VAL G...,1
119992,Iraqis Face Winter Shivering by Candlelight B...,0
119993,AU Says Sudan Begins Troop Withdrawal from Dar...,0
119994,Syria Redeploys Some Security Forces in Lebano...,0
119995,Pakistan's Musharraf Says Won't Quit as Army C...,0
119996,Renteria signing a top-shelf deal Red Sox gene...,1
119997,Saban not going to Dolphins yet The Miami Dolp...,1
119998,Today's NFL games PITTSBURGH at NY GIANTS Time...,1
119999,Nets get Carter from Raptors INDIANAPOLIS -- A...,1


In [16]:
ag_news_test_df = pd.DataFrame(data=ag_news_test)
ag_news_test_df.head(10)

Unnamed: 0,text,label
0,Fears for T N pension after talks Unions repre...,2
1,The Race is On: Second Private Team Sets Launc...,3
2,Ky. Company Wins Grant to Study Peptides (AP) ...,3
3,Prediction Unit Helps Forecast Wildfires (AP) ...,3
4,Calif. Aims to Limit Farm-Related Smog (AP) AP...,3
5,Open Letter Against British Copyright Indoctri...,3
6,"Loosing the War on Terrorism \\""Sven Jaschan, ...",3
7,"FOAFKey: FOAF, PGP, Key Distribution, and Bloo...",3
8,E-mail scam targets police chief Wiltshire Pol...,3
9,"Card fraud unit nets 36,000 cards In its first...",3


In [17]:
ag_news_test_df.tail(10)

Unnamed: 0,text,label
7590,Saban hiring on hold DAVIE - The Dolphins want...,1
7591,Bosnian-Serb prime minister resigns in protest...,0
7592,Historic Turkey-EU deal welcomed The European ...,0
7593,Mortaza strikes to lead superb Bangladesh rall...,1
7594,Powell pushes diplomacy for N. Korea WASHINGTO...,0
7595,Around the world Ukrainian presidential candid...,0
7596,Void is filled with Clement With the supply of...,1
7597,Martinez leaves bitter Like Roger Clemens did ...,1
7598,5 of arthritis patients in Singapore take Bext...,2
7599,EBay gets into rentals EBay plans to buy the a...,2


**Preprocess Data:**

In [18]:
class_label_names = ['World', 'Sports', 'Business', 'Sci/Tech']

(X_train, y_train): These variables store the preprocessed training data, where X_train represents the input text sequences, and y_train represents the corresponding labels.

(X_test, y_test): These variables store the preprocessed testing data, where X_test represents the input text sequences, and y_test represents the corresponding labels.

preprocessing_var: This variable stores the preprocessing information, such as the tokenizer used for tokenizing the text data and the mapping between tokens and IDs.

In [19]:
(X_train, y_train), (X_test, y_test), preprocessing_var = text.texts_from_df(train_df=ag_news_train_df,
                                                                             text_column='text',
                                                                             label_columns='label',
                                                                             val_df=ag_news_test_df,
                                                                             maxlen=512,
                                                                             preprocess_mode='bert')

['label_0', 'label_1', 'label_2', 'label_3']
   label_0  label_1  label_2  label_3
0      0.0      0.0      1.0      0.0
1      0.0      0.0      1.0      0.0
2      0.0      0.0      1.0      0.0
3      0.0      0.0      1.0      0.0
4      0.0      0.0      1.0      0.0
['label_0', 'label_1', 'label_2', 'label_3']
   label_0  label_1  label_2  label_3
0      0.0      0.0      1.0      0.0
1      0.0      0.0      0.0      1.0
2      0.0      0.0      0.0      1.0
3      0.0      0.0      0.0      1.0
4      0.0      0.0      0.0      1.0
downloading pretrained BERT model (uncased_L-12_H-768_A-12.zip)...
[██████████████████████████████████████████████████]
extracting pretrained BERT model...
done.

cleanup downloaded zip...
done.

preprocessing train...
language: en


Is Multi-Label? False
preprocessing test...
language: en


**Creating the BERT model:**

In [20]:
transformer_bert_model = text.text_classifier(name='bert',
                                              train_data=(X_train, y_train),
                                              preproc=preprocessing_var)

Is Multi-Label? False
maxlen is 512




done.


In [21]:
transformer_bert_model.layers

[<keras.engine.input_layer.InputLayer at 0x7f0ab83afbe0>,
 <keras.engine.input_layer.InputLayer at 0x7f0ab9943af0>,
 <keras_bert.layers.embedding.TokenEmbedding at 0x7f0ab83afdc0>,
 <keras.layers.core.embedding.Embedding at 0x7f0ab82ff1f0>,
 <keras.layers.merging.add.Add at 0x7f0ab82fdd50>,
 <keras_pos_embd.pos_embd.PositionEmbedding at 0x7f0ab82fc790>,
 <keras.layers.regularization.dropout.Dropout at 0x7f09e1ef48e0>,
 <keras_layer_normalization.layer_normalization.LayerNormalization at 0x7f09e1ef5870>,
 <keras_multi_head.multi_head_attention.MultiHeadAttention at 0x7f09e1e861d0>,
 <keras.layers.regularization.dropout.Dropout at 0x7f0ab82fca60>,
 <keras.layers.merging.add.Add at 0x7f09e1b77280>,
 <keras_layer_normalization.layer_normalization.LayerNormalization at 0x7f09e1b76440>,
 <keras_position_wise_feed_forward.feed_forward.FeedForward at 0x7f09e1b76ad0>,
 <keras.layers.regularization.dropout.Dropout at 0x7f09e1b746a0>,
 <keras.layers.merging.add.Add at 0x7f09e1b77eb0>,
 <keras_lay

**Compile and train Bert in a Learner Object:**

In [25]:
bert_learner = ktrain.get_learner(model=transformer_bert_model,
                            train_data=(X_train, y_train),
                            val_data=(X_test, y_test),
                            batch_size=6)

Best Hyper-parameters for BERT:
• Batch size: 16, 32

• Learning rate: 5e-5, 3e-5, 2e-5

• Number of epochs: 2, 3, 4

**Train BERT on AG-News dataset:**

In [23]:
training_start_time = timeit.default_timer()
bert_learner.fit_onecycle(lr=2e-5, epochs=3)
training_stop_time = timeit.default_timer()



begin training using onecycle policy with max lr of 2e-05...
Epoch 1/3
Epoch 2/3
Epoch 3/3


In [24]:
print("Total training time in minutes: \n", (training_stop_time - training_start_time)/60)
print("Total training time in hours: \n", (training_stop_time - training_start_time)/3600)

Total training time in minutes: 
 116.45897330143333
Total training time in hours: 
 1.9409828883572222


**Metrics for evaluating BERT performance:**

In [25]:
bert_learner.validate()

              precision    recall  f1-score   support

           0       0.96      0.96      0.96      1900
           1       0.99      0.99      0.99      1900
           2       0.93      0.91      0.92      1900
           3       0.92      0.93      0.92      1900

    accuracy                           0.95      7600
   macro avg       0.95      0.95      0.95      7600
weighted avg       0.95      0.95      0.95      7600



array([[1822,    7,   38,   33],
       [   7, 1879,    8,    6],
       [  41,    6, 1728,  125],
       [  27,    7,   94, 1772]])

In [26]:
bert_learner.validate(class_names=class_label_names)

              precision    recall  f1-score   support

       World       0.96      0.96      0.96      1900
      Sports       0.99      0.99      0.99      1900
    Business       0.93      0.91      0.92      1900
    Sci/Tech       0.92      0.93      0.92      1900

    accuracy                           0.95      7600
   macro avg       0.95      0.95      0.95      7600
weighted avg       0.95      0.95      0.95      7600



array([[1822,    7,   38,   33],
       [   7, 1879,    8,    6],
       [  41,    6, 1728,  125],
       [  27,    7,   94, 1772]])

**Saving the model:**

In [27]:
bert_predictor = ktrain.get_predictor(bert_learner.model, preproc=preprocessing_var)
bert_predictor.get_classes()

['label_0', 'label_1', 'label_2', 'label_3']

In [28]:
bert_predictor.save('/content/bert-ag-news-predictor')

In [29]:
!zip -r /content/bert-ag-news-predictor.zip /content/bert-ag-news-predictor

  adding: content/bert-ag-news-predictor/ (stored 0%)
  adding: content/bert-ag-news-predictor/tf_model.h5 (deflated 12%)
  adding: content/bert-ag-news-predictor/tf_model.preproc (deflated 48%)


**Re-loading Model:**

In [30]:
bert_predictor_2 = ktrain.load_predictor('/content/bert-ag-news-predictor')
bert_predictor_2.get_classes()



['label_0', 'label_1', 'label_2', 'label_3']