In [29]:
import torch
import random
import numpy as np
import pandas as pd
from transformers import RobertaTokenizerFast, RobertaForSequenceClassification
from torch.utils.data import DataLoader
from collections import defaultdict
from tqdm import tqdm

In [30]:
def foo(x: str):
    res = x.strip("[]").split()
    return [r.strip("'") for r in res] 

In [31]:
df = pd.read_csv("data/dialogs.csv")
df["emotion"] = 0
df["dialog"] = df["dialog"].fillna("")
df["movie_genres"] = df["movie_genres"].apply(foo)
df["movie_genres"] = df["movie_genres"].apply(lambda x: x[0] if x else [])
df = df[df["movie_genres"].apply(len) > 0].copy()
print(len(df))
df.head()

303870


Unnamed: 0,movieID,dialog_id,dialog,speakers,movie_title,movie_genres,emoji
0,m0,0,Can we make this quick? Roxanne Korrine and A...,BIANCA,10 things i hate about you,comedy,0
1,m0,0,Well I thought we'd start with pronunciation i...,CAMERON,10 things i hate about you,comedy,0
2,m0,0,Not the hacking and gagging and spitting part....,BIANCA,10 things i hate about you,comedy,0
3,m0,0,Okay... then how 'bout we try out some French ...,CAMERON,10 things i hate about you,comedy,0
4,m0,1,You're asking me out. That's so cute. What's ...,BIANCA,10 things i hate about you,comedy,0


In [32]:
max_replic_in_dialog_count = 10
max_dialog_count_per_genre = 2500
df = df.groupby("dialog_id").head(max_replic_in_dialog_count)
df = df.groupby("movie_genres").head(max_dialog_count_per_genre)
df = df.reset_index(drop=True)
df

Unnamed: 0,movieID,dialog_id,dialog,speakers,movie_title,movie_genres,emoji
0,m0,0,Can we make this quick? Roxanne Korrine and A...,BIANCA,10 things i hate about you,comedy,0
1,m0,0,Well I thought we'd start with pronunciation i...,CAMERON,10 things i hate about you,comedy,0
2,m0,0,Not the hacking and gagging and spitting part....,BIANCA,10 things i hate about you,comedy,0
3,m0,0,Okay... then how 'bout we try out some French ...,CAMERON,10 things i hate about you,comedy,0
4,m0,1,You're asking me out. That's so cute. What's ...,BIANCA,10 things i hate about you,comedy,0
...,...,...,...,...,...,...,...
33987,m603,81438,I've decided not to open a practice here I wa...,ROWAN,the witching hour,documentary,0
33988,m603,81439,That's new territory for us...but yes we can l...,RYAN,the witching hour,documentary,0
33989,m603,81439,No tax shelters. No funding. I want to fund ...,ROWAN,the witching hour,documentary,0
33990,m603,81439,That would mean liquidating sizable amounts of...,RYAN,the witching hour,documentary,0


In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = RobertaTokenizerFast.from_pretrained("cardiffnlp/twitter-roberta-base-emotion", model_max_length=514)
model = RobertaForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-emotion")
model = model.to(device)
device

device(type='cuda')

In [34]:
index2emoji = model.config.id2label
emoji2index = model.config.label2id
print(index2emoji)
print(emoji2index)

{0: 'joy', 1: 'optimism', 2: 'anger', 3: 'sadness'}
{'joy': 0, 'optimism': 1, 'anger': 2, 'sadness': 3}


In [35]:
class Dataset:
    def __init__(self, df):
        self.df = df
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return row["dialog"]
    def __len__(self):
        return len(self.df)

In [36]:
model.eval()
batch_size = 128
dataset = Dataset(df)
loader = DataLoader(dataset, batch_size=batch_size)
for i, sentences in enumerate(tqdm(loader)):
    torch.cuda.empty_cache()
    start, end = i * batch_size, (i + 1) * batch_size
    tokens = tokenizer.batch_encode_plus(
        sentences, padding=True, truncation=True, 
        return_tensors="pt"
    )
    input_ids = tokens["input_ids"]
    input_ids = input_ids.to(device)
    with torch.no_grad():
        output = model(input_ids)
    logits = output.logits
    indexes = logits.argmax(axis=1)
    indexes = indexes.cpu().tolist()
    emoji = [index2emoji[index] for index in indexes]
    df.loc[start:end-1, "emotion"] = emoji

100%|██████████| 266/266 [05:54<00:00,  1.33s/it]


In [40]:
df

Unnamed: 0,movieID,dialog_id,dialog,speakers,movie_title,movie_genres,emotion
0,m0,0,Can we make this quick? Roxanne Korrine and A...,BIANCA,10 things i hate about you,comedy,joy
1,m0,0,Well I thought we'd start with pronunciation i...,CAMERON,10 things i hate about you,comedy,optimism
2,m0,0,Not the hacking and gagging and spitting part....,BIANCA,10 things i hate about you,comedy,joy
3,m0,0,Okay... then how 'bout we try out some French ...,CAMERON,10 things i hate about you,comedy,optimism
4,m0,1,You're asking me out. That's so cute. What's ...,BIANCA,10 things i hate about you,comedy,optimism
...,...,...,...,...,...,...,...
33987,m603,81438,I've decided not to open a practice here I wa...,ROWAN,the witching hour,documentary,sadness
33988,m603,81439,That's new territory for us...but yes we can l...,RYAN,the witching hour,documentary,anger
33989,m603,81439,No tax shelters. No funding. I want to fund ...,ROWAN,the witching hour,documentary,anger
33990,m603,81439,That would mean liquidating sizable amounts of...,RYAN,the witching hour,documentary,joy


In [41]:
df.to_csv(f"data/dialogs_marked_{len(df)}.csv")