<a href="https://colab.research.google.com/github/shraddha-an/nlp/blob/main/so_bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **BERT for Text Classification**

Implementing BERT model via Simple Transformers library to classify Stack Overflow Questions


In [1]:
# Installing simple transformers library
!pip install simpletransformers

# Data Manipulation/ Handling
import pandas as pd, numpy as np

# NLP Preprocessing
from gensim.utils import simple_preprocess

# BERT/Transformer
from simpletransformers.classification import ClassificationModel, ClassificationArgs

Collecting simpletransformers
[?25l  Downloading https://files.pythonhosted.org/packages/a2/f8/e1dc33cf9b213cd615d940fa20eda295004f32ce86312992e692dabb9b9e/simpletransformers-0.48.15-py3-none-any.whl (215kB)
[K     |█▌                              | 10kB 18.4MB/s eta 0:00:01[K     |███                             | 20kB 4.4MB/s eta 0:00:01[K     |████▋                           | 30kB 5.4MB/s eta 0:00:01[K     |██████                          | 40kB 6.4MB/s eta 0:00:01[K     |███████▋                        | 51kB 5.2MB/s eta 0:00:01[K     |█████████▏                      | 61kB 5.7MB/s eta 0:00:01[K     |██████████▋                     | 71kB 6.1MB/s eta 0:00:01[K     |████████████▏                   | 81kB 6.7MB/s eta 0:00:01[K     |█████████████▊                  | 92kB 7.0MB/s eta 0:00:01[K     |███████████████▎                | 102kB 7.0MB/s eta 0:00:01[K     |████████████████▊               | 112kB 7.0MB/s eta 0:00:01[K     |██████████████████▎           



In [11]:
# Importing the dataset
dataset = pd.read_csv('train.csv')[['Body', 'Y']].rename(columns = {'Body': 'questions', 'Y': 'category'})
ds = pd.read_csv('valid.csv')[['Body', 'Y']].rename(columns = {'Body': 'questions', 'Y': 'category'})



dataset.tail(5), ds.head(5)

(                                               questions  category
 44995  <p>I am new to this and I am asking for help t...  LQ_CLOSE
 44996  <p>I am working on learning Python and was won...  LQ_CLOSE
 44997  <p>It looks like it costs 8 days per month in ...  LQ_CLOSE
 44998  <p>"I _____ any questions."</p>\n\n<p>I want t...  LQ_CLOSE
 44999  <p>I'm very new to programming and I'm teachin...  LQ_CLOSE,
                                            questions category
 0  I am having 4 different tables like \r\nselect...  LQ_EDIT
 1  I have two table m_master and tbl_appointment\...  LQ_EDIT
 2  <p>I'm trying to extract US states from wiki U...       HQ
 3  I'm so new to C#, I wanna make an application ...  LQ_EDIT
 4  basically i have this array:\r\n\r\n    array(...  LQ_EDIT)

In [12]:
# NLP preprocessing
from gensim.utils import simple_preprocess

dataset.iloc[:, 0] = dataset.iloc[:, 0].apply(lambda x: ' '.join(simple_preprocess(x)))
ds.iloc[:, 0] = ds.iloc[:, 0].apply(lambda x: ' '.join(simple_preprocess(x)))

In [13]:
# Label encoding the category column
from sklearn.preprocessing import LabelEncoder

enc = LabelEncoder()

dataset.iloc[:, 1] = enc.fit_transform(dataset.iloc[:, 1])
ds.iloc[:, 1] = enc.transform(ds.iloc[:, 1])

# Printing again
ds.tail(4), dataset.head(4)

(                                               questions  category
 14996  try to multiply an integer by double but obtai...         1
 14997  urls py urls py file from django contrib impor...         2
 14998  have controller inside which server is connect...         1
 14999  so was recently helping someone out with some ...         1,
                                            questions  category
 0  already familiar with repeating tasks every se...         1
 1  like to understand why java optionals were des...         0
 2  am attempting to overlay title over an image w...         0
 3  the question is very simple but just could not...         0)

In [14]:
# Setting parameters for model configuration
model_args = ClassificationArgs(num_train_epochs = 1, overwrite_output_dir = True)

# Training the BERT model for classification
model = ClassificationModel('bert', 'bert-base-cased', num_labels = 3, args = model_args,
                            use_cuda = True)

# Training the model
model.train_model(dataset)



Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=45000.0), HTML(value='')))




HBox(children=(HTML(value='Epoch'), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value='Running Epoch 0 of 1'), FloatProgress(value=0.0, max=5625.0), HTML(value='')))







(5625, 0.4550387658239653)

In [15]:
# Accuracy
from sklearn.metrics import accuracy_score

# Evaluate the model performance
results, model_output, wrong_predictions = model.eval_model(ds, acc = accuracy_score)

print(results)

  "Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels."


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15000.0), HTML(value='')))




HBox(children=(HTML(value='Running Evaluation'), FloatProgress(value=0.0, max=1875.0), HTML(value='')))


{'mcc': 0.7713588756896956, 'acc': 0.8469333333333333, 'eval_loss': 0.3802963088646531}


In [9]:
print(wrong_predictions)

[<simpletransformers.classification.classification_utils.InputExample object at 0x7f00fc76fdd8>, <simpletransformers.classification.classification_utils.InputExample object at 0x7f00fb3b5278>]
