# Text Classification | BERT

Tools and libraries:
- [Hugging Face Transformers](https://huggingface.co/): For accessing the pre-trained BERT model.
- [Datasets](https://huggingface.co/datasets): For managing and loading the dataset.
- [ktrain](https://github.com/amaiya/ktrain): A high-level Python wrapper with TensorFlow backend, which will help us build, train, and fine-tune the BERT model for text classification on our custom dataset.

---




### Install Libraries:

In [None]:
# tested on python version 3.8.10
!pip install numpy==1.24.0
!pip install pandas==1.5.2
!pip install matplotlib==3.6.2
!pip install scikit-learn==1.2.0
!pip install seaborn==0.12.1
!pip install tensorflow==2.11.0
!pip install ktrain==0.32.3
!pip install transformers==4.17.0
!pip install datasets==2.8.0

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ktrain
from ktrain import text
import tensorflow as tf
from sklearn.model_selection import train_test_split
from datasets import list_datasets
from datasets import load_dataset
import timeit

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#GPU Test
#print("Tensorflow version : ", tf.__version__)
#print("GPU available : ",bool(tf.test.is_gpu_available))
#print("GPU name : ",tf.test.gpu_device_name())

---

### Import News Dataset:

In [5]:
ag_news_dataset = load_dataset('ag_news')
print("\n", ag_news_dataset)

Found cached dataset ag_news (C:/Users/vithi/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)
100%|██████████| 2/2 [00:00<00:00, 32.26it/s]


 DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})





---

## Dataset Details

The AG dataset is a collection of over 1 million news articles gathered from more than 2000 news sources over the course of a year. These articles were collected by ComeToMyHead, an academic news search engine that has been operational since July 2004. The dataset is made available by the academic community for research purposes in various fields, including data mining (clustering, classification, etc.), information retrieval (ranking, search, etc.), XML, data compression, data streaming, and other non-commercial activities. For more information, you can visit the [AG Corpus of News Articles](http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html).

AG News, also known as AG's News Corpus, is a subset of the larger AG corpus. It is constructed by assembling titles and description fields of articles from the four largest classes: "World," "Sports," "Business," and "Sci/Tech." The AG News dataset includes 30,000 training samples and 1,900 test samples for each class.

---


In [6]:
print("Dataset Items: \n", ag_news_dataset.items())
print("\nDataset type: \n", type(ag_news_dataset))
print("\nShape of dataset: \n", ag_news_dataset.shape)
print("\nNo of rows: \n", ag_news_dataset.num_rows)
print("\nNo of columns: \n", ag_news_dataset.num_columns)

Dataset Items: 
 dict_items([('train', Dataset({
    features: ['text', 'label'],
    num_rows: 120000
})), ('test', Dataset({
    features: ['text', 'label'],
    num_rows: 7600
}))])

Dataset type: 
 <class 'datasets.dataset_dict.DatasetDict'>

Shape of dataset: 
 {'train': (120000, 2), 'test': (7600, 2)}

No of rows: 
 {'train': 120000, 'test': 7600}

No of columns: 
 {'train': 2, 'test': 2}


In [7]:
print("\nColumn Names: \n", ag_news_dataset.column_names)
print("\n", ag_news_dataset.data)


Column Names: 
 {'train': ['text', 'label'], 'test': ['text', 'label']}

 {'train': MemoryMappedTable
text: string
label: int64
----
label: [[2,2,2,2,2,...,2,2,3,3,3],[3,1,1,0,0,...,3,1,0,0,2],...,[0,0,2,2,2,...,2,2,2,1,0],[0,0,0,0,3,...,0,1,1,1,1]], 'test': MemoryMappedTable
text: string
label: int64
----
text: [["Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.","The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket.","Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which

In [8]:
print(ag_news_dataset['train'][0])
print(ag_news_dataset['train'][1])

{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 2}
{'text': 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.', 'label': 2}


In [9]:
print(ag_news_dataset['train']['text'][0])
print(ag_news_dataset['train']['label'][0])
print()
print(ag_news_dataset['train']['text'][35000])
print(ag_news_dataset['train']['label'][35000])
print()
print(ag_news_dataset['train']['text'][60000])
print(ag_news_dataset['train']['label'][60000])
print()
print(ag_news_dataset['train']['text'][100000])
print(ag_news_dataset['train']['label'][100000])

Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.
2

Black armbands for Clough, tears for Liverpool fans In the afternoon, Brian Clough, unquestionably one of the greats and unarguably one of the most controversial of football men, died of cancer.
1

BYTE OF THE APPLE Apple lost one war to Microsoft by not licensing its Mac operating system. It may repeat the error with its iPod and music software.
3

Venezuelan Car-Bomb Suspect Killed, Weapons Found  CARACAS, Venezuela (Reuters) - A Venezuelan lawyer  suspected in last week's bombing murder of a top state  prosecutor was killed in a gunfight with police on Tuesday  after he tried to ram detectives with his car and opened fire  on them, officials said.
0


---

### Loading Train & Test Datasets:

In [10]:
ag_news_train = load_dataset('ag_news', split='train')
ag_news_test = load_dataset('ag_news', split='test')
print("Train Dataset : ", ag_news_train.shape)
print("Test Dataset : ", ag_news_test.shape)

Found cached dataset ag_news (C:/Users/vithi/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)
Found cached dataset ag_news (C:/Users/vithi/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


Train Dataset :  (120000, 2)
Test Dataset :  (7600, 2)


In [11]:
print(ag_news_train[0])
print(ag_news_test[0])

{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 2}
{'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.", 'label': 2}


In [12]:
print("\nTrain Dataset Features: \n", ag_news_train.features)
print("\nTest Dataset Features: \n", ag_news_test.features)


Train Dataset Features: 
 {'text': Value(dtype='string', id=None), 'label': ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None)}

Test Dataset Features: 
 {'text': Value(dtype='string', id=None), 'label': ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None)}


### Creating DataFrame object for K-train:

In [13]:
pd.set_option('Display.max_columns', None)
ag_news_train_df = pd.DataFrame(data=ag_news_train)
ag_news_train_df.head(10)

Unnamed: 0,text,label
0,Wall St. Bears Claw Back Into the Black (Reute...,2
1,Carlyle Looks Toward Commercial Aerospace (Reu...,2
2,Oil and Economy Cloud Stocks' Outlook (Reuters...,2
3,Iraq Halts Oil Exports from Main Southern Pipe...,2
4,"Oil prices soar to all-time record, posing new...",2
5,"Stocks End Up, But Near Year Lows (Reuters) Re...",2
6,Money Funds Fell in Latest Week (AP) AP - Asse...,2
7,Fed minutes show dissent over inflation (USATO...,2
8,Safety Net (Forbes.com) Forbes.com - After ear...,2
9,Wall St. Bears Claw Back Into the Black NEW Y...,2


In [14]:
ag_news_train_df.tail(10)

Unnamed: 0,text,label
119990,Barack Obama Gets #36;1.9 Million Book Deal (...,0
119991,Rauffer Beats Favorites to Win Downhill VAL G...,1
119992,Iraqis Face Winter Shivering by Candlelight B...,0
119993,AU Says Sudan Begins Troop Withdrawal from Dar...,0
119994,Syria Redeploys Some Security Forces in Lebano...,0
119995,Pakistan's Musharraf Says Won't Quit as Army C...,0
119996,Renteria signing a top-shelf deal Red Sox gene...,1
119997,Saban not going to Dolphins yet The Miami Dolp...,1
119998,Today's NFL games PITTSBURGH at NY GIANTS Time...,1
119999,Nets get Carter from Raptors INDIANAPOLIS -- A...,1


In [15]:
ag_news_test_df = pd.DataFrame(data=ag_news_test)
ag_news_test_df.head(10)

Unnamed: 0,text,label
0,Fears for T N pension after talks Unions repre...,2
1,The Race is On: Second Private Team Sets Launc...,3
2,Ky. Company Wins Grant to Study Peptides (AP) ...,3
3,Prediction Unit Helps Forecast Wildfires (AP) ...,3
4,Calif. Aims to Limit Farm-Related Smog (AP) AP...,3
5,Open Letter Against British Copyright Indoctri...,3
6,"Loosing the War on Terrorism \\""Sven Jaschan, ...",3
7,"FOAFKey: FOAF, PGP, Key Distribution, and Bloo...",3
8,E-mail scam targets police chief Wiltshire Pol...,3
9,"Card fraud unit nets 36,000 cards In its first...",3


In [16]:
ag_news_test_df.tail(10)

Unnamed: 0,text,label
7590,Saban hiring on hold DAVIE - The Dolphins want...,1
7591,Bosnian-Serb prime minister resigns in protest...,0
7592,Historic Turkey-EU deal welcomed The European ...,0
7593,Mortaza strikes to lead superb Bangladesh rall...,1
7594,Powell pushes diplomacy for N. Korea WASHINGTO...,0
7595,Around the world Ukrainian presidential candid...,0
7596,Void is filled with Clement With the supply of...,1
7597,Martinez leaves bitter Like Roger Clemens did ...,1
7598,5 of arthritis patients in Singapore take Bext...,2
7599,EBay gets into rentals EBay plans to buy the a...,2


---

### Data Preprocessing:

In [17]:
class_label_names = ['World', 'Sports', 'Business', 'Sci/Tech']

In [18]:
# Use the 'texts_from_df' function to prepare text data for BERT-based text classification
(X_train, y_train), (X_test, y_test), preprocessing_var = text.texts_from_df(
    train_df=ag_news_train_df,     # Training DataFrame
    text_column='text',           # Column containing text data
    label_columns='label',        # Column containing labels
    val_df=ag_news_test_df,        # Validation DataFrame
    maxlen=512,                    # Maximum sequence length
    preprocess_mode='bert'        # Preprocessing mode for BERT
)


['label_0', 'label_1', 'label_2', 'label_3']
   label_0  label_1  label_2  label_3
0      0.0      0.0      1.0      0.0
1      0.0      0.0      1.0      0.0
2      0.0      0.0      1.0      0.0
3      0.0      0.0      1.0      0.0
4      0.0      0.0      1.0      0.0
['label_0', 'label_1', 'label_2', 'label_3']
   label_0  label_1  label_2  label_3
0      0.0      0.0      1.0      0.0
1      0.0      0.0      0.0      1.0
2      0.0      0.0      0.0      1.0
3      0.0      0.0      0.0      1.0
4      0.0      0.0      0.0      1.0
preprocessing train...
language: en


Is Multi-Label? False
preprocessing test...
language: en


### Create BERT Model:

In [19]:
# Create a BERT-based text classification model
transformer_bert_model = text.text_classifier(
    name='bert',                    # model (BERT)
    train_data=(X_train, y_train),  # Training data (text and labels)
    preproc=preprocessing_var       # Preprocessing information
)


Is Multi-Label? False
maxlen is 512




done.


In [20]:
transformer_bert_model.layers

[<keras.engine.input_layer.InputLayer at 0x2936f2bd6a0>,
 <keras.engine.input_layer.InputLayer at 0x2931b16fe80>,
 <keras_bert.layers.embedding.TokenEmbedding at 0x293183ff040>,
 <keras.layers.core.embedding.Embedding at 0x293211f7c40>,
 <keras.layers.merging.add.Add at 0x293211f6ac0>,
 <keras_pos_embd.pos_embd.PositionEmbedding at 0x29334c7d2e0>,
 <keras.layers.regularization.dropout.Dropout at 0x29334c7d370>,
 <keras_layer_normalization.layer_normalization.LayerNormalization at 0x29334c65250>,
 <keras_multi_head.multi_head_attention.MultiHeadAttention at 0x2932ec73ac0>,
 <keras.layers.regularization.dropout.Dropout at 0x29332c68c40>,
 <keras.layers.merging.add.Add at 0x29331c84340>,
 <keras_layer_normalization.layer_normalization.LayerNormalization at 0x29331c60970>,
 <keras_position_wise_feed_forward.feed_forward.FeedForward at 0x29335c8a430>,
 <keras.layers.regularization.dropout.Dropout at 0x29345c9da30>,
 <keras.layers.merging.add.Add at 0x29332c72550>,
 <keras_layer_normalizatio

---

### Compile and train Bert in a Learner Object:

In [21]:
# Create a BERT learner for text classification with a specific batch size.
bert_learner = ktrain.get_learner(model=TransformerBERTModel,
                            train_data=(X_train, y_train),
                            val_data=(X_test, y_test),
                            batch_size=6)


### Best Hyper-parameters for BERT:
• Batch size: 16, 32

• Learning rate: 5e-5, 3e-5, 2e-5

• Number of epochs: 2, 3, 4

### Train BERT on AG-News dataset:

In [None]:
# Record the starting time of training
training_start_time = timeit.default_timer()

# Call the 'fit_onecycle' method to train the BERT model with a learning rate of 2e-5 for 3 epochs
bert_learner.fit_onecycle(lr=2e-5, epochs=3)

# Record the ending time of training
training_stop_time = timeit.default_timer()


In [None]:
#print("Total training time in minutes: \n", (training_stop_time - training_start_time)/60)  
#print("Total training time in hours: \n", (training_stop_time - training_start_time)/3600)  

---

### Checking BERT performance metrics:

In [None]:
# Validate BERT model's performance.
bert_learner.validate()


              precision    recall  f1-score   support

           0       0.97      0.96      0.96      1900
           1       0.99      0.99      0.99      1900
           2       0.93      0.92      0.92      1900
           3       0.92      0.94      0.93      1900

    accuracy                           0.95      7600
   macro avg       0.95      0.95      0.95      7600
weighted avg       0.95      0.95      0.95      7600



array([[1824,    6,   36,   34],
       [  10, 1875,    8,    7],
       [  32,    6, 1741,  121],
       [  23,    9,   88, 1780]])

In [None]:
# Validate BERT model's performance on classification with optional class labels.
bert_learner.validate(class_names=class_label_names)

              precision    recall  f1-score   support

       World       0.97      0.96      0.96      1900
      Sports       0.99      0.99      0.99      1900
    Business       0.93      0.92      0.92      1900
    Sci/Tech       0.92      0.94      0.93      1900

    accuracy                           0.95      7600
   macro avg       0.95      0.95      0.95      7600
weighted avg       0.95      0.95      0.95      7600



array([[1824,    6,   36,   34],
       [  10, 1875,    8,    7],
       [  32,    6, 1741,  121],
       [  23,    9,   88, 1780]])

---

### Saving the model:

In [None]:
# Create a BERT predictor using a BERT model and a specified preprocessing configuration.
bert_predictor = ktrain.get_predictor(bert_learner.model, preproc=preprocessing_var)

# Get the classes (labels) associated with the BERT predictor.
class_labels = bert_predictor.get_classes()


['label_0', 'label_1', 'label_2', 'label_3']

In [None]:
# Save the BERT predictor to a specified file path.
bert_predictor.save('/content/bert-ag-news-predictor')



In [None]:
!zip -r /content/bert-ag-news-predictor.zip /content/bert-ag-news-predictor

  adding: content/bert-ag-news-predictor/ (stored 0%)
  adding: content/bert-ag-news-predictor/tf_model.preproc (deflated 52%)
  adding: content/bert-ag-news-predictor/tf_model.h5 (deflated 11%)


---

### Re-loading Model:

In [None]:
# Load a BERT predictor from a previously saved file.
bert_predictor_2 = ktrain.load_predictor('/content/bert-ag-news-predictor')

# Get the classes (labels) associated with the loaded BERT predictor.
class_labels = bert_predictor_2.get_classes()


['label_0', 'label_1', 'label_2', 'label_3']

---