In [1]:
from sota_list import LSTMNetwork
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig

import torch
import codecs
import json
from pprint import pprint


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the model weights
model_weights_path = "finetuned_saved_models/bilstm-bert-finetuned-segmented-extracted.pth"

# PML
plm_name = 'bert-finetuned-segmented'

In [3]:
def load_llm_parts(model_name):

    # Load the config, model, and tokenizer
    config = AutoConfig.from_pretrained(model_name, output_hidden_states =True)

    return [
        AutoModelForSequenceClassification.from_pretrained(model_name, config=config),
        AutoTokenizer.from_pretrained(model_name)
    ]



# Initialize the model and load the weights
model = LSTMNetwork(768,128,5,True)
model.load_state_dict(torch.load(model_weights_path))
model.eval()

# Load the LLM fine-tuned model
llm_model, tokenizer = load_llm_parts(plm_name)

In [4]:
# Load the ERC datasets
def load_dataset(name, type):
    file_name = f'erc-datasets/{name}/{type}.json'
    with codecs.open(file_name, 'r', 'utf-8') as fr:
            return json.load(fr)

    return None

dataset = load_dataset('MELD','train')
#pprint(dataset[0])

In [20]:
def perform_classification(model, llm_model, tokenizer, text):

    # Tokenize
    token_ids = tokenizer(
        text, 
        truncation = True, 
        return_tensors='pt', 
        max_length = 512, 
        add_special_tokens=True
    )

    # Extract CLS
    cls_output = llm_model(**token_ids)
    cls_output = cls_output.hidden_states[-1][0,0,:]
    cls_output = cls_output.unsqueeze(0)

    # Get the output
    output = model.features_extraction(cls_output)

    # Convert to labels
    scores = model.single_classification(output)
    scores = scores.detach().tolist()[0]
    
    print(scores)
    scores = [
        int(score > 0.5)
        for score in scores
    ]

    print(scores)

    # Return
    return None

In [19]:
for conversation in dataset:
    
    for utt_data in conversation:
        perform_classification(model, llm_model, tokenizer, utt_data['utterance'])
        print("\n")
    break

[7.884313163231127e-06, 0.07527997344732285, 0.02310573123395443, 0.9989839196205139, 0.07420559972524643]


[0.2678419053554535, 0.7483342885971069, 0.5751084685325623, 0.8016932010650635, 0.08167696744203568]


[0.9424228072166443, 0.8541322350502014, 0.11613670736551285, 0.27759191393852234, 0.9877989292144775]


[0.1664925366640091, 0.0749460905790329, 0.006571900099515915, 0.7229515910148621, 0.0049364627338945866]


[0.0015028807101771235, 0.790822446346283, 0.10240393877029419, 0.38755449652671814, 0.3447065055370331]


[0.6118987202644348, 0.007421252783387899, 0.00021613996068481356, 0.015303430147469044, 0.9999113082885742]


[0.6376124620437622, 0.848629891872406, 0.23508745431900024, 0.15853524208068848, 0.804473876953125]


[0.0011138279223814607, 0.768426775932312, 0.0003301171527709812, 0.041243646293878555, 0.9970968961715698]


[0.9402758479118347, 0.6949419975280762, 0.013996057212352753, 0.7318568825721741, 0.8844406008720398]


[0.897058367729187, 5.484961002366617e

In [None]:
# Save the results