## MultiClass classification of emotion on tweets data using Bert pre-trained model

Data Link--> https://www.kaggle.com/datasets/parulpandey/emotion-dataset

![image.png](attachment:a14eab98-ae92-46cb-a80d-65ab37484233.png) 

## Problem statement:

   A dataset of English Twitter messages with six emotions: anger, fear, joy, love, sadness, and surprise.
In the era of a customer-focused industry, companies are coming up with new ways to understand their consumers. Detecting emotions accurately from the reviews, chats, tweets, blogs, posts, etc. is one such method without explicitly asking the customers. With the advent of new algorithms and increasing computing power, Natural Language Processing (NLP) has enabled us to detect emotions from written text & take action accordingly.
   
   Here, We have to detect the emotions of customers based on the tweets by using state-of-the-art architectures to solve the probelem and predict accurately.

   ## Annotations of different emotions:
   
   ## Anger : 0
   ## Fear : 1 
   ## Joy : 2 
   ## Love : 3 
   ## Sadness : 4
   ## Surprise : 5

![image.png](attachment:7604780a-2b79-479f-97e6-3f70f62f0a20.png) 

## Table of contents
 
1. ## Importing and installing required libraries
2. ## Doing some text preprocessing
3. ## Now let's load the model
4. ## Model fitting and evaluation
5. ## Prediction part
6. ## Prediction on custom text

# 1. Importing and Installing required libraries

In [1]:
!pip install text_hammer

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting text_hammer
  Downloading text_hammer-0.1.5-py3-none-any.whl (7.6 kB)
Collecting beautifulsoup4==4.9.1
  Downloading beautifulsoup4-4.9.1-py3-none-any.whl (115 kB)
[K     |████████████████████████████████| 115 kB 13.9 MB/s 
Collecting soupsieve>1.2
  Downloading soupsieve-2.3.2.post1-py3-none-any.whl (37 kB)
Installing collected packages: soupsieve, beautifulsoup4, text-hammer
  Attempting uninstall: beautifulsoup4
    Found existing installation: beautifulsoup4 4.6.3
    Uninstalling beautifulsoup4-4.6.3:
      Successfully uninstalled beautifulsoup4-4.6.3
Successfully installed beautifulsoup4-4.9.1 soupsieve-2.3.2.post1 text-hammer-0.1.5


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 14.1 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.0-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 68.2 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 56.4 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.0 tokenizers-0.13.2 transformers-4.24.0


In [4]:
import pandas as pd
import numpy as np
import text_hammer as th
from tqdm._tqdm_notebook import tqdm_notebook
tqdm_notebook.pandas()
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer,TFBertModel
from transformers import BertTokenizer, TFBertModel, BertConfig,TFDistilBertModel,DistilBertTokenizer,DistilBertConfig
import shutil
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.initializers import TruncatedNormal
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy
from tensorflow.keras.utils import to_categorical
from sklearn.metrics import classification_report,confusion_matrix,accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
%config Completer.use_jedi = False # if autocompletion doesnot work in kaggle notebook | hit tab

In [6]:
# importing the dataset 
df_train = pd.read_csv('/content/drive/MyDrive/深度學習/dataset/training.csv')
df_test = pd.read_csv('/content/drive/MyDrive/深度學習/dataset/test.csv')
df_val=pd.read_csv('/content/drive/MyDrive/深度學習/dataset/validation.csv')

In [10]:
df_full = pd.concat([df_train,df_test,df_val], axis = 0)
df_full

Unnamed: 0,text,label
0,i didnt feel humiliated,0
1,i can go from feeling so hopeless to so damned...,0
2,im grabbing a minute to post i feel greedy wrong,3
3,i am ever feeling nostalgic about the fireplac...,2
4,i am feeling grouchy,3
...,...,...
1995,im having ssa examination tomorrow in the morn...,0
1996,i constantly worry about their fight against n...,1
1997,i feel its important to share this info for th...,1
1998,i truly feel that if you are passionate enough...,1


# 2. Doing some text preprocessing 

In [7]:
def text_preprocessing(df,col_name):
    column = col_name
    df[column] = df[column].progress_apply(lambda x:str(x).lower())
    df[column] = df[column].progress_apply(lambda x: th.cont_exp(x)) #you're -> you are; i'm -> i am
    df[column] = df[column].progress_apply(lambda x: th.remove_emails(x))
    df[column] = df[column].progress_apply(lambda x: th.remove_html_tags(x))
 
    df[column] = df[column].progress_apply(lambda x: th.remove_special_chars(x))
    df[column] = df[column].progress_apply(lambda x: th.remove_accented_chars(x))
 
    return(df)

In [11]:
df_cleaned = text_preprocessing(df_full,'text')

  0%|          | 0/20000 [00:00<?, ?it/s]

  0%|          | 0/20000 [00:00<?, ?it/s]

  0%|          | 0/20000 [00:00<?, ?it/s]

  0%|          | 0/20000 [00:00<?, ?it/s]

  0%|          | 0/20000 [00:00<?, ?it/s]

  0%|          | 0/20000 [00:00<?, ?it/s]

In [12]:
df_cleaned = df_cleaned.copy()

In [13]:
df_cleaned['num_words'] = df_cleaned['text'].apply(lambda x:len(x.split()))

In [14]:
# changing the data type to the category to encode into codes 
df_cleaned['label'] = df_cleaned['label'].astype('category')

In [15]:
df_cleaned['label']

0       0
1       0
2       3
3       2
4       3
       ..
1995    0
1996    1
1997    1
1998    1
1999    1
Name: label, Length: 20000, dtype: category
Categories (6, int64): [0, 1, 2, 3, 4, 5]

In [16]:
df_cleaned['label'].cat.codes

0       0
1       0
2       3
3       2
4       3
       ..
1995    0
1996    1
1997    1
1998    1
1999    1
Length: 20000, dtype: int8

In [17]:
encoded_dict  = {'anger':0,'fear':1, 'joy':2, 'love':3, 'sadness':4, 'surprise':5}

In [18]:
df_cleaned.num_words.max()

66

In [19]:
data_train,data_test = train_test_split(df_cleaned, test_size = 0.3, random_state = 42, stratify = df_cleaned['label'])

In [20]:
data_train.shape

(14000, 3)

In [21]:
data_test.shape

(6000, 3)

In [22]:
to_categorical(data_train['label'])

array([[1., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0.],
       ...,
       [0., 1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0.]], dtype=float32)

# 3. Now lets load the model 

In [23]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
bert = TFBertModel.from_pretrained('bert-base-cased')

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/527M [00:00<?, ?B/s]

Some layers from the model checkpoint at bert-base-cased were not used when initializing TFBertModel: ['nsp___cls', 'mlm___cls']
- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertModel were initialized from the model checkpoint at bert-base-cased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.


In [24]:
# for saving model locally and we can load it later on 
tokenizer.save_pretrained('bert-tokenizer')
bert.save_pretrained('bert-model')

In [25]:
shutil.make_archive('bert-tokenizer', 'zip', 'bert-tokenizer')
shutil.make_archive('bert-model','zip','bert-model')

'/content/bert-model.zip'

In [26]:
tokenizer('I will be kaggle grandmaster')

{'input_ids': [101, 146, 1209, 1129, 24181, 25186, 5372, 6532, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [27]:
# Tokenize the input (takes some time) 
# here tokenizer using from bert-base-cased
x_train = tokenizer(
    text=data_train['text'].tolist(),
    add_special_tokens=True,
    max_length=70,
    truncation=True,
    padding=True, 
    return_tensors='tf',
    return_token_type_ids = False,
    return_attention_mask = True,
    verbose = True)


x_test = tokenizer(
    text=data_test['text'].tolist(),
    add_special_tokens=True,
    max_length=70,
    truncation=True,
    padding=True, 
    return_tensors='tf',
    return_token_type_ids = False,
    return_attention_mask = True,
    verbose = True)

In [28]:
x_test['input_ids']

<tf.Tensor: shape=(6000, 70), dtype=int32, numpy=
array([[ 101,  178, 1243, ...,    0,    0,    0],
       [ 101,  178, 1631, ...,    0,    0,    0],
       [ 101,  178, 2810, ...,    0,    0,    0],
       ...,
       [ 101,  178, 1838, ...,    0,    0,    0],
       [ 101,  178, 1238, ...,    0,    0,    0],
       [ 101,  178, 4534, ...,    0,    0,    0]], dtype=int32)>

In [47]:
max_len = 70
 
input_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_ids")
input_mask = Input(shape=(max_len,), dtype=tf.int32, name="attention_mask")

embeddings = bert(input_ids,attention_mask = input_mask)[0] #(0 is the last hidden states,1 means pooler_output)
out = tf.keras.layers.GlobalMaxPool1D()(embeddings)
out = Dense(128, activation='relu')(out)
out = tf.keras.layers.Dropout(0.1)(out)
out = Dense(32,activation = 'relu')(out)

y = Dense(6,activation = 'sigmoid')(out)
    
model = tf.keras.Model(inputs=[input_ids, input_mask], outputs=y)
model.layers[2].trainable = True

In [48]:
optimizer = Adam(
    learning_rate=5e-05, # this learning rate is for bert model , taken from huggingface website 
    epsilon=1e-08,
    decay=0.01,
    clipnorm=1.0)

# Set loss and metrics
loss =CategoricalCrossentropy(from_logits = True)
metric = CategoricalAccuracy('balanced_accuracy'),
# Compile the model
model.compile(
    optimizer = optimizer,
    loss = loss, 
    metrics = metric)

In [49]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_ids (InputLayer)         [(None, 70)]         0           []                               
                                                                                                  
 attention_mask (InputLayer)    [(None, 70)]         0           []                               
                                                                                                  
 tf_bert_model (TFBertModel)    TFBaseModelOutputWi  108310272   ['input_ids[0][0]',              
                                thPoolingAndCrossAt               'attention_mask[0][0]']         
                                tentions(last_hidde                                               
                                n_state=(None, 70,                                          

In [50]:
tf.config.experimental_run_functions_eagerly(True)
tf.config.run_functions_eagerly(True)

# 4. Model fitting and then evaluation

In [51]:
train_history = model.fit(
    x ={'input_ids':x_train['input_ids'],'attention_mask':x_train['attention_mask']} ,
    y = to_categorical(data_train['label']),
    validation_data = (
    {'input_ids':x_test['input_ids'],'attention_mask':x_test['attention_mask']}, to_categorical(data_test['label'])
    ),
    epochs=1,
    batch_size=36
)



  1/389 [..............................] - ETA: 6:50 - loss: 1.8425 - balanced_accuracy: 0.3889



  2/389 [..............................] - ETA: 4:37 - loss: 1.5575 - balanced_accuracy: 0.4028



  3/389 [..............................] - ETA: 4:36 - loss: 1.3929 - balanced_accuracy: 0.4352



  4/389 [..............................] - ETA: 4:35 - loss: 1.2690 - balanced_accuracy: 0.5139



  5/389 [..............................] - ETA: 4:34 - loss: 1.1521 - balanced_accuracy: 0.5611



  6/389 [..............................] - ETA: 4:34 - loss: 1.0508 - balanced_accuracy: 0.5972



  7/389 [..............................] - ETA: 4:34 - loss: 0.9584 - balanced_accuracy: 0.6349



  8/389 [..............................] - ETA: 4:33 - loss: 0.9148 - balanced_accuracy: 0.6528



  9/389 [..............................] - ETA: 4:33 - loss: 0.8649 - balanced_accuracy: 0.6667



 10/389 [..............................] - ETA: 4:33 - loss: 0.8320 - balanced_accuracy: 0.6833



 11/389 [..............................] - ETA: 4:32 - loss: 0.7813 - balanced_accuracy: 0.7071



 12/389 [..............................] - ETA: 4:31 - loss: 0.7518 - balanced_accuracy: 0.7222



 13/389 [>.............................] - ETA: 4:30 - loss: 0.7316 - balanced_accuracy: 0.7329



 14/389 [>.............................] - ETA: 4:29 - loss: 0.6936 - balanced_accuracy: 0.7500



 15/389 [>.............................] - ETA: 4:29 - loss: 0.6649 - balanced_accuracy: 0.7556



 16/389 [>.............................] - ETA: 4:28 - loss: 0.6465 - balanced_accuracy: 0.7622



 17/389 [>.............................] - ETA: 4:27 - loss: 0.6299 - balanced_accuracy: 0.7663



 18/389 [>.............................] - ETA: 4:26 - loss: 0.6056 - balanced_accuracy: 0.7793



 19/389 [>.............................] - ETA: 4:26 - loss: 0.5954 - balanced_accuracy: 0.7851



 20/389 [>.............................] - ETA: 4:25 - loss: 0.6051 - balanced_accuracy: 0.7875



 21/389 [>.............................] - ETA: 4:25 - loss: 0.5871 - balanced_accuracy: 0.7950



 22/389 [>.............................] - ETA: 4:24 - loss: 0.5834 - balanced_accuracy: 0.7992



 23/389 [>.............................] - ETA: 4:24 - loss: 0.5693 - balanced_accuracy: 0.8031



 24/389 [>.............................] - ETA: 4:23 - loss: 0.5546 - balanced_accuracy: 0.8090



 25/389 [>.............................] - ETA: 4:22 - loss: 0.5410 - balanced_accuracy: 0.8144



 26/389 [=>............................] - ETA: 4:22 - loss: 0.5280 - balanced_accuracy: 0.8194



 27/389 [=>............................] - ETA: 4:21 - loss: 0.5228 - balanced_accuracy: 0.8220



 28/389 [=>............................] - ETA: 4:20 - loss: 0.5084 - balanced_accuracy: 0.8274



 29/389 [=>............................] - ETA: 4:20 - loss: 0.4942 - balanced_accuracy: 0.8324



 30/389 [=>............................] - ETA: 4:19 - loss: 0.4800 - balanced_accuracy: 0.8380



 31/389 [=>............................] - ETA: 4:19 - loss: 0.4753 - balanced_accuracy: 0.8396



 32/389 [=>............................] - ETA: 4:18 - loss: 0.4692 - balanced_accuracy: 0.8420



 33/389 [=>............................] - ETA: 4:17 - loss: 0.4609 - balanced_accuracy: 0.8460



 34/389 [=>............................] - ETA: 4:17 - loss: 0.4549 - balanced_accuracy: 0.8480



 35/389 [=>............................] - ETA: 4:16 - loss: 0.4568 - balanced_accuracy: 0.8476



 36/389 [=>............................] - ETA: 4:15 - loss: 0.4526 - balanced_accuracy: 0.8503



 37/389 [=>............................] - ETA: 4:15 - loss: 0.4474 - balanced_accuracy: 0.8521



 38/389 [=>............................] - ETA: 4:14 - loss: 0.4432 - balanced_accuracy: 0.8545



 39/389 [==>...........................] - ETA: 4:13 - loss: 0.4377 - balanced_accuracy: 0.8561



 40/389 [==>...........................] - ETA: 4:13 - loss: 0.4455 - balanced_accuracy: 0.8556



 41/389 [==>...........................] - ETA: 4:12 - loss: 0.4365 - balanced_accuracy: 0.8584



 42/389 [==>...........................] - ETA: 4:11 - loss: 0.4287 - balanced_accuracy: 0.8611



 43/389 [==>...........................] - ETA: 4:11 - loss: 0.4272 - balanced_accuracy: 0.8618



 44/389 [==>...........................] - ETA: 4:10 - loss: 0.4242 - balanced_accuracy: 0.8605



 45/389 [==>...........................] - ETA: 4:09 - loss: 0.4232 - balanced_accuracy: 0.8605



 46/389 [==>...........................] - ETA: 4:08 - loss: 0.4263 - balanced_accuracy: 0.8605



 47/389 [==>...........................] - ETA: 4:08 - loss: 0.4199 - balanced_accuracy: 0.8623



 48/389 [==>...........................] - ETA: 4:07 - loss: 0.4151 - balanced_accuracy: 0.8634



 49/389 [==>...........................] - ETA: 4:06 - loss: 0.4099 - balanced_accuracy: 0.8645



 50/389 [==>...........................] - ETA: 4:06 - loss: 0.4090 - balanced_accuracy: 0.8644



 51/389 [==>...........................] - ETA: 4:05 - loss: 0.4073 - balanced_accuracy: 0.8638



 52/389 [===>..........................] - ETA: 4:04 - loss: 0.4011 - balanced_accuracy: 0.8659



 53/389 [===>..........................] - ETA: 4:04 - loss: 0.3973 - balanced_accuracy: 0.8669



 54/389 [===>..........................] - ETA: 4:03 - loss: 0.3946 - balanced_accuracy: 0.8673



 55/389 [===>..........................] - ETA: 4:02 - loss: 0.3920 - balanced_accuracy: 0.8677



 56/389 [===>..........................] - ETA: 4:02 - loss: 0.3910 - balanced_accuracy: 0.8690



 57/389 [===>..........................] - ETA: 4:01 - loss: 0.3876 - balanced_accuracy: 0.8699



 58/389 [===>..........................] - ETA: 4:00 - loss: 0.3852 - balanced_accuracy: 0.8702



 59/389 [===>..........................] - ETA: 4:00 - loss: 0.3850 - balanced_accuracy: 0.8705



 60/389 [===>..........................] - ETA: 3:59 - loss: 0.3806 - balanced_accuracy: 0.8718



 61/389 [===>..........................] - ETA: 3:58 - loss: 0.3790 - balanced_accuracy: 0.8730



 62/389 [===>..........................] - ETA: 3:58 - loss: 0.3742 - balanced_accuracy: 0.8750



 63/389 [===>..........................] - ETA: 3:57 - loss: 0.3754 - balanced_accuracy: 0.8743



 64/389 [===>..........................] - ETA: 3:56 - loss: 0.3749 - balanced_accuracy: 0.8741



 65/389 [====>.........................] - ETA: 3:56 - loss: 0.3721 - balanced_accuracy: 0.8744



 66/389 [====>.........................] - ETA: 3:55 - loss: 0.3700 - balanced_accuracy: 0.8750



 67/389 [====>.........................] - ETA: 3:54 - loss: 0.3708 - balanced_accuracy: 0.8752



 68/389 [====>.........................] - ETA: 3:54 - loss: 0.3678 - balanced_accuracy: 0.8762



 69/389 [====>.........................] - ETA: 3:53 - loss: 0.3663 - balanced_accuracy: 0.8760



 70/389 [====>.........................] - ETA: 3:52 - loss: 0.3668 - balanced_accuracy: 0.8766



 71/389 [====>.........................] - ETA: 3:51 - loss: 0.3640 - balanced_accuracy: 0.8775



 72/389 [====>.........................] - ETA: 3:51 - loss: 0.3639 - balanced_accuracy: 0.8769



 73/389 [====>.........................] - ETA: 3:50 - loss: 0.3620 - balanced_accuracy: 0.8775



 74/389 [====>.........................] - ETA: 3:49 - loss: 0.3614 - balanced_accuracy: 0.8773



 75/389 [====>.........................] - ETA: 3:49 - loss: 0.3584 - balanced_accuracy: 0.8781



 76/389 [====>.........................] - ETA: 3:48 - loss: 0.3567 - balanced_accuracy: 0.8790



 77/389 [====>.........................] - ETA: 3:47 - loss: 0.3545 - balanced_accuracy: 0.8791



 78/389 [=====>........................] - ETA: 3:46 - loss: 0.3555 - balanced_accuracy: 0.8786



 79/389 [=====>........................] - ETA: 3:46 - loss: 0.3530 - balanced_accuracy: 0.8794



 80/389 [=====>........................] - ETA: 3:45 - loss: 0.3518 - balanced_accuracy: 0.8802



 81/389 [=====>........................] - ETA: 3:44 - loss: 0.3501 - balanced_accuracy: 0.8810



 82/389 [=====>........................] - ETA: 3:44 - loss: 0.3485 - balanced_accuracy: 0.8818



 83/389 [=====>........................] - ETA: 3:43 - loss: 0.3481 - balanced_accuracy: 0.8812



 84/389 [=====>........................] - ETA: 3:42 - loss: 0.3465 - balanced_accuracy: 0.8819



 85/389 [=====>........................] - ETA: 3:41 - loss: 0.3436 - balanced_accuracy: 0.8830



 86/389 [=====>........................] - ETA: 3:41 - loss: 0.3406 - balanced_accuracy: 0.8840



 87/389 [=====>........................] - ETA: 3:40 - loss: 0.3383 - balanced_accuracy: 0.8847



 88/389 [=====>........................] - ETA: 3:39 - loss: 0.3369 - balanced_accuracy: 0.8857



 89/389 [=====>........................] - ETA: 3:39 - loss: 0.3343 - balanced_accuracy: 0.8867



 90/389 [=====>........................] - ETA: 3:38 - loss: 0.3314 - balanced_accuracy: 0.8880













































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































In [67]:
#model.save("/content/drive/MyDrive/深度學習/model/model", include_optimizer=False)
model.save('/content/drive/MyDrive/深度學習/model/model_name.h5')

In [73]:
#model2 = tf.saved_model.load("/content/drive/MyDrive/深度學習/model/model")
import keras
import transformers
model2 = tf.keras.models.load_model('/content/drive/MyDrive/深度學習/model/model_name.h5', custom_objects={"TFBertModel": transformers.TFBertModel})



# 5.Prediction Part

In [74]:
predicted_raw = model2.predict({'input_ids':x_test['input_ids'],'attention_mask':x_test['attention_mask']})



In [52]:
predicted_raw = model.predict({'input_ids':x_test['input_ids'],'attention_mask':x_test['attention_mask']})



In [75]:
predicted_raw[0]

array([0.06234031, 0.998774  , 0.13355795, 0.09218213, 0.1888022 ,
       0.2977126 ], dtype=float32)

In [76]:
y_predicted = np.argmax(predicted_raw, axis = 1)
y_predicted

array([1, 1, 1, ..., 0, 0, 4])

In [None]:
data_test['label']

In [None]:
accuracy_score(data_test['label'],y_predicted)

In [None]:
print(classification_report(data_test['label'], y_predicted))

In [None]:
plt.figure(figsize=(10,7))
sns.heatmap(confusion_matrix(data_test['label'],y_predicted),annot=True,cmap='viridis')
plt.show()

# 5. Prediction on custom text

In [None]:
texts = input(str('input the text'))

x_val = tokenizer(
    text=texts,
    add_special_tokens=True,
    max_length=70,
    truncation=True,
    padding='max_length', 
    return_tensors='tf',
    return_token_type_ids = False,
    return_attention_mask = True,
    verbose = True) 
validation = model.predict({'input_ids':x_val['input_ids'],'attention_mask':x_val['attention_mask']})*100
validation

input the text Movie is very bad, I want my money back.


array([[72.79686 , 28.849524,  8.209145, 39.88458 , 19.101107, 13.449428]],
      dtype=float32)

In [None]:
for key , value in zip(encoded_dict.keys(),validation[0]):
    print(key,value)

anger 72.79686
fear 28.849524
joy 8.209145
love 39.88458
sadness 19.101107
surprise 13.449428


# Thanking you