#### Importing Libraries

In [1]:
from transformers import BertTokenizer
from model import BertForMultiLabelClassification
from multilabel_pipeline import MultiLabelPipeline
from pprint import pprint
import os
import pandas as pd
from scipy.special import softmax
import numpy as np
from tqdm import tqdm
tqdm.pandas()
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


#### Tokenizer, Model and Pipeline Setup

In [2]:
tokenizer = BertTokenizer.from_pretrained("monologg/bert-base-cased-goemotions-original")

In [3]:
model = BertForMultiLabelClassification.from_pretrained("monologg/bert-base-cased-goemotions-original")

In [4]:
goemotions = MultiLabelPipeline(
    model=model,
    tokenizer=tokenizer,
    threshold=0.3
)

#### Inference

In [5]:
all_emotions = ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']

In [6]:
filtered_emotions = ['happy', 'sad', 'disgusted', 'angry', 'fearful', 'neutral', 'surprised']

In [7]:
all_to_filtered = {
    'happy' : ['admiration', 'approval', 'caring', 'desire', 'gratitude', 'joy', 'love', 'optimism', 'pride'],
    'sad' : ['embarrassment', 'grief', 'remorse', 'sadness'],
    'disgusted' : ['disappointment' ,'disapproval', 'disgust'],
    'angry' : ['anger', 'annoyance'],
    'fearful' : ['fear', 'nervousness'],
    'surprised' : ['amusement', 'confusion', 'curiosity', 'excitement', 'surprise'], 
    'neutral' : ['realization', 'neutral', 'relief']
}

In [8]:
def process_line(line):
    emotion_prediction = goemotions(line)[0]
    labels, scores = emotion_prediction['labels'], emotion_prediction['scores']
    emotion_prediction = dict(zip(labels, scores))
    for emotion in all_emotions:
        if emotion not in emotion_prediction:
            emotion_prediction[emotion] = 0
    # Filter out emotions by clubbing them 
    filtered_prediction = {}
    for filtered_emotion in filtered_emotions:
        filtered_prediction[filtered_emotion] = sum([emotion_prediction[emotion] for emotion in all_to_filtered[filtered_emotion]])
    # Normalize filtered_prediction using softmax
    labels, scores = list(filtered_prediction.keys()), list(filtered_prediction.values())
    # scores = softmax(scores)
    emotion_prediction = dict(zip(labels, scores))
    return pd.Series(emotion_prediction)

In [9]:
def process_line(line):
    emotion_prediction = goemotions(line)[0]
    labels, scores = emotion_prediction['labels'], emotion_prediction['scores']
    emotion_prediction = dict(zip(labels, scores))
    for emotion in all_emotions:
        if emotion not in emotion_prediction:
            emotion_prediction[emotion] = 0
    return pd.Series(emotion_prediction)

In [24]:
goemotions = MultiLabelPipeline(
    model=model,
    tokenizer=tokenizer,
    threshold=0.0
)
def process_line_filtered(line):
    emotion_prediction = goemotions(line)[0]
    labels, scores = emotion_prediction['labels'], emotion_prediction['scores']
    emotion_prediction = dict(zip(labels, scores))
    for emotion in all_emotions:
        if emotion not in emotion_prediction:
            emotion_prediction[emotion] = 0
    # Filter out emotions by clubbing them
    filter_emotions = {}
    for filtered_emotion in filtered_emotions:
        filter_emotions[filtered_emotion] = sum([emotion_prediction[emotion] for emotion in all_to_filtered[filtered_emotion]])
    # print(filter_emotions)
    return pd.Series(filtered_emotions)

In [21]:
text = "I am so happy today"
print(process_line_filtered(text))

{'happy': 0.9931343586649746, 'sad': 0.0012720050290226936, 'disgusted': 0.0020010421285405755, 'angry': 0.0016621244722045958, 'fearful': 0.0006890575459692627, 'neutral': 0.020537185366265476, 'surprised': 0.06899285491090268}
0        happy
1          sad
2    disgusted
3        angry
4      fearful
5      neutral
6    surprised
dtype: object


In [11]:
folder_path = './script_csv'
output_path = './script_csv_go_emotion'
for file_name in tqdm(os.listdir(folder_path)):
    if file_name.endswith('.csv'):
        data = pd.read_csv(os.path.join(folder_path, file_name))
        data[all_emotions] = 0
        data[all_emotions] = data['line'].progress_apply(process_line)
        data.to_csv(os.path.join(output_path, file_name), index=False)

100%|██████████| 980/980 [03:07<00:00,  5.21it/s]
100%|██████████| 867/867 [02:19<00:00,  6.23it/s]
100%|██████████| 1027/1027 [03:29<00:00,  4.90it/s]
100%|██████████| 1027/1027 [01:36<00:00, 10.68it/s]
100%|██████████| 834/834 [00:51<00:00, 16.07it/s]
100%|██████████| 688/688 [00:54<00:00, 12.73it/s]
100%|██████████| 987/987 [01:12<00:00, 13.59it/s]
100%|██████████| 990/990 [01:11<00:00, 13.77it/s]
100%|██████████| 834/834 [01:16<00:00, 10.93it/s]
100%|██████████| 1010/1010 [01:33<00:00, 10.83it/s]
100%|██████████| 1043/1043 [01:53<00:00,  9.16it/s]
100%|██████████| 961/961 [01:38<00:00,  9.77it/s]
100%|██████████| 1007/1007 [02:41<00:00,  6.23it/s]
100%|██████████| 734/734 [01:33<00:00,  7.88it/s]
100%|██████████| 841/841 [01:45<00:00,  7.97it/s]
100%|██████████| 15/15 [27:06<00:00, 108.45s/it]


In [25]:
folder_path = './script_csv'
output_path = './script_csv_go_emotion_filtered'
for file_name in tqdm(os.listdir(folder_path)):
    if file_name.endswith('.csv'):
        data = pd.read_csv(os.path.join(folder_path, file_name))
        data[filtered_emotions] = 0
        data[filtered_emotions] = data['line'].progress_apply(process_line_filtered)
        data.to_csv(os.path.join(output_path, file_name), index=False)

  0%|          | 0/15 [00:00<?, ?it/s]

100%|██████████| 980/980 [03:02<00:00,  5.38it/s]
100%|██████████| 867/867 [01:27<00:00,  9.94it/s]
100%|██████████| 1027/1027 [01:54<00:00,  8.93it/s]
100%|██████████| 1027/1027 [02:05<00:00,  8.16it/s]
100%|██████████| 834/834 [01:58<00:00,  7.01it/s]
100%|██████████| 688/688 [01:31<00:00,  7.48it/s]
100%|██████████| 987/987 [02:30<00:00,  6.57it/s]
100%|██████████| 990/990 [03:26<00:00,  4.79it/s]
100%|██████████| 834/834 [02:41<00:00,  5.17it/s]
100%|██████████| 1010/1010 [04:42<00:00,  3.57it/s]
100%|██████████| 1043/1043 [03:00<00:00,  5.79it/s]
100%|██████████| 961/961 [02:42<00:00,  5.93it/s]
100%|██████████| 1007/1007 [02:44<00:00,  6.12it/s]
100%|██████████| 734/734 [01:52<00:00,  6.50it/s]
100%|██████████| 841/841 [01:49<00:00,  7.70it/s]
100%|██████████| 15/15 [37:31<00:00, 150.13s/it]
