# **Text Classification using Simple Transformers BERT**

In this Colab notebook, we will demonstrate how to perform text classification using Simple Transformers with BERT.

*   A powerful pre-trained transformer model by Google
*   Simple Transformers provides an easy-to-use interface for tasks like text classification, allowing us to leverage BERT's capabilities without needing to write complex code for model implementation.

### *Import necessary libraries*

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
import re
from simpletransformers.classification import ClassificationModel
import sklearn
import itertools
import emoji

In [None]:
df = pd.read_csv(r"/content/sample_data/data/train.csv")
columns_drop = ['keyword','location']
df.drop(columns=columns_drop,inplace=True)
df.head()


In [None]:
fake_tweets = df[df.target == 0]
print(fake_tweets.shape)
fake_tweets.head(300)

(4342, 3)

### Defining contractions to clean the data

In [None]:
contractions = {
"ain't": "am not",
"aren't": "are not",
"can't": "cannot",
"can't've": "cannot have",
"'cause": "because",
"could've": "could have",
"couldn't": "could not",
"couldn't've": "could not have",
"didn't": "did not",
"doesn't": "does not",
"don't": "do not",
"hadn't": "had not",
"hadn't've": "had not have",
"hasn't": "has not",
"haven't": "have not",
"he'd": "he would",
"he'd've": "he would have",
"he'll": "he will",
"he's": "he is",
"how'd": "how did",
"how'll": "how will",
"how's": "how is",
"i'd": "i would",
"i'll": "i will",
"i'm": "i am",
"i've": "i have",
"isn't": "is not",
"it'd": "it would",
"it'll": "it will",
"it's": "it is",
"let's": "let us",
"ma'am": "madam",
"mayn't": "may not",
"might've": "might have",
"mightn't": "might not",
"must've": "must have",
"mustn't": "must not",
"needn't": "need not",
"oughtn't": "ought not",
"shan't": "shall not",
"sha'n't": "shall not",
"she'd": "she would",
"she'll": "she will",
"she's": "she is",
"should've": "should have",
"shouldn't": "should not",
"that'd": "that would",
"that's": "that is",
"there'd": "there had",
"there's": "there is",
"they'd": "they would",
"they'll": "they will",
"they're": "they are",
"they've": "they have",
"wasn't": "was not",
"we'd": "we would",
"we'll": "we will",
"we're": "we are",
"we've": "we have",
"weren't": "were not",
"what'll": "what will",
"what're": "what are",
"what's": "what is",
"what've": "what have",
"where'd": "where did",
"where's": "where is",
"who'll": "who will",
"who's": "who is",
"won't": "will not",
"wouldn't": "would not",
"you'd": "you would",
"you'll": "you will",
"you're": "you are",
"thx"   : "thanks"
}

In [None]:
def remove_contractions(text):
    return contractions[text.lower()] if text.lower() in contractions.keys() else text


df['text']=df['text'].apply(remove_contractions)
df.head()


Unnamed: 0,id,text,target
0,1,Our Deeds are the Reason of this #earthquake M...,1
1,4,Forest fire near La Ronge Sask. Canada,1
2,5,All residents asked to 'shelter in place' are ...,1
3,6,"13,000 people receive #wildfires evacuation or...",1
4,7,Just got sent this photo from Ruby #Alaska as ...,1


### Cleaning the dataset

In [None]:
def clean_dataset(text):
    # Remove hashtag while keeping hashtag text
    text = re.sub(r'#','', text)
    # Remove HTML special entities (e.g. &amp;)
    text = re.sub(r'\&\w*;', '', text)
    # Remove tickers
    text = re.sub(r'\$\w*', '', text)
    # Remove hyperlinks
    text = re.sub(r'https?:\/\/.*\/\w*', '', text)
    # Remove whitespace (including new line characters)
    text = re.sub(r'\s\s+','', text)
    text = re.sub(r'[ ]{2, }',' ',text)
    # Remove URL, RT, mention(@)
    text=  re.sub(r'http(\S)+', '',text)
    text=  re.sub(r'http ...', '',text)
    text=  re.sub(r'(RT|rt)[ ]*@[ ]*[\S]+','',text)
    text=  re.sub(r'RT[ ]?@','',text)
    text = re.sub(r'@[\S]+','',text)
    # Remove words with 4 or fewer letters
    text = re.sub(r'\b\w{1,4}\b', '', text)
    #&, < and >
    text = re.sub(r'&amp;?', 'and',text)
    text = re.sub(r'&lt;','<',text)
    text = re.sub(r'&gt;','>',text)
    # Remove characters beyond Basic Multilingual Plane (BMP) of Unicode:
    text= ''.join(c for c in text if c <= '\uFFFF')
    text = text.strip()
    # Remove misspelling words
    text = ''.join(''.join(s)[:2] for _, s in itertools.groupby(text))
    # Remove emoji
    text = emoji.demojize(text)
    text = text.replace(":"," ")
    text = ' '.join(text.split())
    text = re.sub("([^\x00-\x7F])+"," ",text)
    # Remove Mojibake (also extra spaces)
    text = ' '.join(re.sub("[^\u4e00-\u9fa5\u0030-\u0039\u0041-\u005a\u0061-\u007a]", " ", text).split())
    return text

In [None]:
df['text'] =df['text'].apply(clean_dataset)
print(df.head())
print()

df.shape

### Split the training and validation set

In [None]:
X_train_clean, X_test_clean, y_train_clean, y_test_clean = train_test_split(df['text'], df['target'], test_size=0.20, random_state=42)

In [None]:
train_df_clean = pd.concat([X_train_clean, y_train_clean], axis=1)
print("Shape of training data set: ", train_df_clean.shape)
print("View of data set: ", train_df_clean.head())

Shape of training data set:  (6090, 2)
View of data set:                                                     text  target
4996  Courageous honest analysis Atomic Hiroshima70 ...       1
3263          shame became engulfed flames boycottBears       0
4907  rescind medals honor given soldiers Massacre W...       1
2855  Worried about drought might affect Extreme Wea...       1
4716                           BlastPower PantherAttack       0


In [None]:
eval_df_clean = pd.concat([X_test_clean, y_test_clean], axis=1)
print("Shape of Eval data set: ", eval_df_clean.shape)

Shape of Eval data set:  (1523, 2)


### BERT Model Training

#### Set up the train arguments

In [None]:
train_args = {
    'evaluate_during_training': True,
    'logging_steps': 100,
    'num_train_epochs': 2,
    'evaluate_during_training_steps': 100,
    'save_eval_checkpoints': False,
    'train_batch_size': 32,
    'eval_batch_size': 64,
    'overwrite_output_dir': True,
    'fp16': False,
    'wandb_project': "visualization-demo"
}

In [None]:
model_BERT = ClassificationModel('bert', 'bert-base-cased', num_labels=2, use_cuda=True, cuda_device=0, args=train_args)

#### Train the model

In [None]:
model_BERT.train_model(train_df_clean, eval_df=eval_df_clean)

#### check model performance on validation data

In [None]:
result, model_outputs, wrong_predictions = model_BERT.eval_model(eval_df_clean, acc=sklearn.metrics.accuracy_score)

Converting to features started. Cache is not used.


HBox(children=(IntProgress(value=0, max=1523), HTML(value='')))

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

{'mcc': 0.5915974149823142, 'tp': 466, 'tn': 755, 'fp': 119, 'fn': 183, 'eval_loss': 0.45270544787247974, 'acc': 0.8017071569271176}


### Roberta Model Training

#### Setup the model

In [None]:
model_Roberta = ClassificationModel('roberta', 'roberta-base', num_labels=2, use_cuda=True, cuda_device=0, args=train_args)

#### Train the model

In [None]:
model_Roberta.train_model(train_df_clean, eval_df=eval_df_clean)

#### Evaluate the model

In [None]:
result, model_outputs, wrong_predictions = model_Roberta.eval_model(eval_df_clean, acc=sklearn.metrics.accuracy_score)

Converting to features started. Cache is not used.


HBox(children=(IntProgress(value=0, max=1523), HTML(value='')))

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

{'mcc': 0.6223595391558661, 'tp': 459, 'tn': 784, 'fp': 90, 'fn': 190, 'eval_loss': 0.45775141566991806, 'acc': 0.8161523309258043}


### ALBERT Model training

In [None]:
model_albert = ClassificationModel('albert', 'albert-base-v2', num_labels=2, use_cuda=True, cuda_device=0, args=train_args)

In [None]:
model_albert.train_model(train_df_clean, eval_df=eval_df_clean)

In [None]:
result, model_outputs, wrong_predictions = model_albert.eval_model(eval_df_clean, acc=sklearn.metrics.accuracy_score)

Converting to features started. Cache is not used.


HBox(children=(IntProgress(value=0, max=1523), HTML(value='')))

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

{'mcc': 0.5121674642936882, 'tp': 433, 'tn': 730, 'fp': 144, 'fn': 216, 'eval_loss': 0.5123823285102844, 'acc': 0.7636244254760342}


### Perform prediction - Test Set

In [None]:
test_df = pd.read_csv(r"/content/sample_data/data/train.csv")
## Drop columns for keyword and location
columns_todrop = ['keyword','location']
test_df.drop(columns=columns_todrop,inplace=True)
### Change contractions
test_df['text']=test_df['text'].apply(remove_contractions)
## Clean Data set
test_df['text'] =test_df['text'].apply(clean_dataset)
test_df.head()

Unnamed: 0,id,text
0,0,happened terrible crash
1,2,Heard about earthquake different cities everyone
2,3,there forest geese fleeing across street cannot
3,9,Apocalypse lighting Spokane wildfires
4,11,Typhoon Soudelor kills China Taiwan


In [None]:
predictions, raw_outputs = model_Roberta.predict(test_df['text'])

Converting to features started. Cache is not used.


HBox(children=(IntProgress(value=0, max=3263), HTML(value='')))

HBox(children=(IntProgress(value=0, max=51), HTML(value='')))

In [None]:
test_df['target']=predictions
test_df.tail()

Unnamed: 0,id,text,target
3258,10861,EARTHQUAKE SAFETY ANGELES SAFETY FASTENERS,0
3259,10865,Storm worse hurricane city3others hardest look...,1
3260,10868,Green derailment Chicago,1
3261,10874,issues Hazardous Weather Outlook,1
3262,10875,CityofCalgary activated Municipal Emergency yy...,1


In [None]:
test_df['target'].value_counts()

0    2132
1    1131
Name: target, dtype: int64

### Perform predictions on random tweets

Tweet 1 - test

In [None]:
tt1 = "#COVID19 will spread across U.S. in coming weeks. We’ll get past it, but must focus on limiting the epidemic, and preserving lif"
tt1 = remove_contractions(tt1)
tt1 = clean_dataset(tt1)

'COVID19 spread across coming weeks focus limiting epidemic preserving'

In [None]:
predictions, _ = model_Roberta.predict([tt1])
response_dict = {0: 'Fake', 1: 'Real'}
print("Prediction is: ", response_dict[predictions[0]])

Converting to features started. Cache is not used.


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Prediction is:  Real


Tweet 2 - test

In [None]:
tt2 = "BREAKING: Confirmed flooding on NYSE. The trading floor is flooded under more than 3 feet of water."
tt2 = remove_contractions(tt2)
tt2 = clean_dataset(tt2)

'BREAKING Confirmed flooding trading floor flooded under water'

In [None]:
predictions, _ = model_Roberta.predict([tt2])
response_dict = {0: 'Fake', 1: 'Real'}
print("Prediction is: ", response_dict[predictions[0]])

Converting to features started. Cache is not used.


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Prediction is:  Real


Tweet 3 - test

In [None]:
tt3 = "Everything is ABLAZE. Please run!!"
tt3 = remove_contractions(tt3)
tt3 = clean_dataset(tt3)

'Everything ABLAZE Please'

In [None]:
predictions, _ = model_Roberta.predict([tt3])
response_dict = {0: 'Fake', 1: 'Real'}
print("Prediction is: ", response_dict[predictions[0]])

Converting to features started. Cache is not used.


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Prediction is:  Fake
