In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from dataset import MELDDataset, Utterance, Dialogue
import pickle
from dummy_model import DummyModel
from torch.utils.data import DataLoader
from models.config import Config
from models.dialogue_gcn import DialogueGCN
from sklearn.metrics import f1_score, confusion_matrix



## Load data

In [3]:
config = Config()
audio_embed_path_test = "../MELD.Raw/audio_embeddings_feature_selection_emotion.pkl"
_, _, test_audio_emb = pickle.load(open(audio_embed_path_test, 'rb'))
test_dataset = MELDDataset("../MELD.Raw/test_sent_emo.csv", "../MELD.Raw/output_repeated_splits_test", test_audio_emb, name="test", config=config)
test_loader = DataLoader(test_dataset, batch_size=1)


{'joy': 0, 'anger': 1, 'disgust': 2, 'fear': 3, 'sadness': 4, 'neutral': 5, 'surprise': 6}


## Load model

In [4]:
model_path = "model_saves/text0_epoch_image_only_7.pt"
model = DialogueGCN(config)
model = model.to("cuda")

state_dict = torch.load(model_path)
model.load_state_dict(state_dict['model_state_dict'])
model.eval()

DialogueGCN(
  (text_encoder): GRU(768, 100, batch_first=True, bidirectional=True)
  (context_encoder): GRU(200, 100, batch_first=True, bidirectional=True)
  (pred_rel_l1): GraphConvolution()
  (suc_rel_l1): GraphConvolution()
  (same_speak_rel_l1): GraphConvolution()
  (diff_speak_rel_l1): GraphConvolution()
  (pred_rel_l2): GraphConvolution()
  (suc_rel_l2): GraphConvolution()
  (same_speak_rel_l2): GraphConvolution()
  (diff_speak_rel_l2): GraphConvolution()
  (edge_att_weights): Linear(in_features=200, out_features=200, bias=False)
  (text_attn): Linear(in_features=200, out_features=1, bias=True)
  (w_sentiment): Linear(in_features=400, out_features=3, bias=True)
  (w_emotion_1): Linear(in_features=400, out_features=200, bias=True)
  (w_emotion_2): Linear(in_features=200, out_features=7, bias=True)
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_em

In [5]:
legend = {0:'joy', 1:'anger', 2:'disgust', 3:'fear', 4:'sadness', 5:'neutral', 6:'surprise'}


In [31]:
def get_accuracy(predicted_emotion, predicted_sentiment, target):
    print(target.size())
    print(predicted_emotion.size())
    print(predicted_sentiment.size())
    emotion_accuracy_acc = torch.eq(predicted_emotion, target[:,0]).sum().item() / target.size(0)
    sentiment_accuracy_acc = torch.eq(predicted_sentiment, target[:,1]).sum().item() / target.size(0)
    return emotion_accuracy_acc, sentiment_accuracy_acc


def test_model(model_name, model, test_loader):
    print("Testing " + model_name)
    model = model.eval()
    emotion_predicted_labels = []
    sentiment_predicted_labels = []
    emotion_target_labels = []
    sentiment_target_labels = []
    for i, (test_batch_input, test_batch_labels) in enumerate(test_loader):
        batch_emotion_correct_predicted_labels, batch_sentiment_predicted_labels, batch_val_count = validate_step(model, test_batch_input, test_batch_labels)
        emotion_predicted_labels.append(batch_emotion_correct_predicted_labels)
        #print(batch_emotion_correct_predicted_labels)
        sentiment_predicted_labels.append(batch_sentiment_predicted_labels)
        emotion_target_labels.append(torch.cat(test_batch_labels[0],0))
        sentiment_target_labels.append(torch.cat(test_batch_labels[1],0))

    emotion_predicted_labels = torch.cat(emotion_predicted_labels, 0).cuda()
    sentiment_predicted_labels = torch.cat(sentiment_predicted_labels, 0).cuda()
    emotion_target_labels = torch.cat(emotion_target_labels, 0)
    sentiment_target_labels = torch.cat(sentiment_target_labels, 0)
    target_labels = torch.cat([emotion_target_labels.unsqueeze(1), sentiment_target_labels.unsqueeze(1)], 1).cuda()

    emotion_f1_score = f1_score(emotion_target_labels.cpu().numpy(), emotion_predicted_labels.cpu().numpy(), average='weighted')        
    confusion = confusion_matrix(emotion_target_labels.cpu().numpy(), emotion_predicted_labels.cpu().numpy())
    emotion_accuracy, sentiment_accuracy = get_accuracy(emotion_predicted_labels, sentiment_predicted_labels, target_labels)
    #emotion_recalls, sentiment_recalls = get_recall_for_each_class(emotion_predicted_labels, sentiment_predicted_labels, target_labels)
    #emotion_precisions, sentiment_precisions = get_precision_for_each_class(emotion_predicted_labels, sentiment_predicted_labels, target_labels)
    #emotion_f1s, sentiment_f1s = get_f1_score_for_each_class(emotion_precisions, emotion_recalls, sentiment_precisions, sentiment_recalls)
    #emotion_weighted_f1, sentiment_weighted_f1 = get_weighted_F1(emotion_f1s, sentiment_f1s, target_labels)

    print("Validation Accuracy (Emotion): ", emotion_accuracy)
    print("F1 Weighted", emotion_f1_score)
    print("Confusion matrix", confusion)

def validate_step(model, input, target):
    target = torch.LongTensor(target).to("cuda")
    (output_logits_emotion, output_logits_sentiment) = model(input)
    output_labels_emotion = torch.argmax(output_logits_emotion, dim=1)
    output_labels_sentiment = torch.argmax(output_logits_sentiment, dim=1)
    return output_labels_emotion, output_labels_sentiment, target[0].size()

def preprocess_datapoint(datapoint):
    datapoint = list(datapoint[0])
    text = datapoint[0]
    for i, utt in enumerate(text):
        text[i] = (utt,)
    datapoint[0] = text    
    return datapoint

In [8]:
test_model("GCN_TEXT model", model, test_loader)

Testing GCN_TEXT model
Loading data for  0
Loading data for  1
Loading data for  2
Loading data for  3
Loading data for  4
Loading data for  5
Loading data for  6
Loading data for  7
Loading data for  8
Loading data for  9
Loading data for  10
Loading data for  11
Loading data for  12
Loading data for  13
Loading data for  14
Loading data for  15
Loading data for  16
Loading data for  17
Loading data for  18
Loading data for  19
Loading data for  20
Loading data for  21
Loading data for  22
Loading data for  23
Loading data for  24
Loading data for  25
Loading data for  26
Loading data for  27
Loading data for  28
Loading data for  29
Loading data for  30
Loading data for  31
Loading data for  32
Loading data for  33
Loading data for  34
Loading data for  35
Loading data for  36
Loading data for  37
Loading data for  38
Loading data for  39
Loading data for  40
Loading data for  41
Loading data for  42
Loading data for  43
Loading data for  44
Loading data for  45
Loading data for  46


In [54]:
dialogue_id = 30
datapoint = preprocess_datapoint(test_dataset[dialogue_id])
labels = test_dataset[dialogue_id][1]
predictions = validate_step(model, datapoint, labels)[0]

for i, pred in enumerate(predictions):
    print(legend[int(pred)], " : ", test_dataset[dialogue_id][0][0][i])


Loading data for  30
Loading data for  30
Loading data for  30
surprise  :  ['Oh my God, what happened?', 'Oh. God, crazy Chandler. He spun me...off...the...bed!', 'Wow! Spinning that sounds like fun.', "Oh,  I wish. No, you know he was just trying Ross's Hug and Roll thing.", "Ross's what?", 'You know what, where he hugs you and kinda rolls you away and... Oh... my....God.']
Loading data for  30
surprise  :  ['', '', '', '', '', '']
Loading data for  30
joy  :  [array([0.08094506, 0.        , 0.17449932, ..., 0.23509485, 0.54376392,
       0.33151705]), array([0.25093461, 0.        , 0.13984529, ..., 0.21948618, 0.48416549,
       0.29720939]), array([0.90917832, 0.87294917, 0.26164095, ..., 0.25420537, 0.32039876,
       0.38785322]), array([0.30847348, 0.88298553, 0.21380787, ..., 0.29969429, 0.38305547,
       0.38591971]), array([0.32422478, 0.        , 0.04765311, ..., 0.16694304, 0.41292769,
       0.20914634]), array([0.99871564, 0.39416134, 0.19103805, ..., 0.22618025, 0.47380

IndexError: tuple index out of range

In [57]:
test_dataset[dialogue_id][0][0]

Loading data for  30


['Oh my God, what happened?',
 'Oh. God, crazy Chandler. He spun me...off...the...bed!',
 'Wow! Spinning that sounds like fun.',
 "Oh,  I wish. No, you know he was just trying Ross's Hug and Roll thing.",
 "Ross's what?",
 'You know what, where he hugs you and kinda rolls you away and... Oh... my....God.']

In [47]:
texts = [(0, "maybe she just likes me for me."),
(1, "you want these?"),
(2, "what?"),
(1, "I stink!"),
(1, "I can’t play!"), 
(1, "The ball is just sitting there and I can’t hit it!"),
(1, "I only hit one really good ball that went way out!"),
(1, "I have no concentration!"),
(2, "What’s wrong?"),
(1, "I can’t get rid of the sand."),
(1, "There is still some in here won’t go away."),
(1, "I even got sand in the pockets."),
(2, "Come on, you are getting that all over the floor.")
]
utterances = []
for i, text_speaker in enumerate(texts):
    speaker, text = text_speaker
    utterance = Utterance(0, i, text, speaker, 0, 0, 0, 0, 0)
    utterances.append(utterance)
dialogue = Dialogue(0, utterances, visual_features=False)

In [48]:
datapoint = preprocess_datapoint(dialogue.get_data())

Loading data for  0


In [52]:
for i, pred in enumerate(validate_step(model, datapoint, [0])[0]):
    print(legend[int(pred)], " : ", texts[i][1])


neutral  :  maybe she just likes me for me.
surprise  :  you want these?
surprise  :  what?
joy  :  I stink!
anger  :  I can’t play!
anger  :  The ball is just sitting there and I can’t hit it!
anger  :  I only hit one really good ball that went way out!
anger  :  I have no concentration!
neutral  :  What’s wrong?
neutral  :  I can’t get rid of the sand.
neutral  :  There is still some in here won’t go away.
neutral  :  I even got sand in the pockets.
neutral  :  Come on, you are getting that all over the floor.
