In [None]:
from preprocess_interviews import get_token_labels
import torch
from RobertaForMultiTokenClassification import RobertaForMultiLabelTokenClassification

import os
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# if torch.has_mps:
#     device = torch.device('mps')

model_checkpoint = "./robbert-v2-dutch-base-finetuned-sentiment/checkpoint-12032"
roberta_model = RobertaForMultiLabelTokenClassification.from_pretrained(model_checkpoint)
roberta_model.to(device)

data, unique_labels = get_token_labels(code_mode="sentiment", stride=256)

In [None]:
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit

df = pd.DataFrame(data)

splitter = GroupShuffleSplit(test_size=.10, n_splits=2, random_state=7)
split = splitter.split(df, groups=df['transcript_name'])
train_ids, test_ids = next(split)

train = df.iloc[train_ids]
test = df.iloc[test_ids]

test['transcript_name']

In [None]:
from sklearn.metrics import accuracy_score, multilabel_confusion_matrix
import numpy as np
from tqdm import trange
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import os

%env PYTORCH_ENABLE_MPS_FALLBACK=1
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
print(os.environ["PYTORCH_ENABLE_MPS_FALLBACK"])

# Initialize variables
true_labels, pred_labels = [], []
transcript_accuracies = []  # Initialize empty list to hold accuracy values

with torch.no_grad():
    for i in trange(len(test)):
        piece = test.iloc[i]
        input_ids = torch.IntTensor(piece['input_ids']).unsqueeze(0).to(device)
        attention_mask = torch.IntTensor(piece['attention_mask']).unsqueeze(0).to(device)

        output: TokenClassifierOutput = roberta_model(input_ids, attention_mask=attention_mask)
        
        pred_label = output.logits.to('cpu').squeeze(0).numpy()
        true_label = np.array(piece['labels'])
        
        pred_labels.append(pred_label)
        true_labels.append(true_label)
        
        # Flatten the lists
        flat_pred_label = pred_label.ravel()
        flat_true_label = true_label.ravel()
        
        # Calculate the accuracy for this iteration and append to list
        iter_accuracy = accuracy_score(flat_pred_label > 0.50, flat_true_label)
        transcript_accuracies.append(iter_accuracy)

# Calculate the average, standard deviation, and range of accuracies
avg_accuracy = np.mean(transcript_accuracies) * 100
std_accuracy = np.std(transcript_accuracies) * 100
min_accuracy = np.min(transcript_accuracies) * 100
max_accuracy = np.max(transcript_accuracies) * 100

# Output the overall statistics with one decimal place
print(f"Average Accuracy: {avg_accuracy:.1f}%")
print(f"Standard Deviation of Accuracy: {std_accuracy:.1f}%")
print(f"Range of Accuracy: {min_accuracy:.1f}% - {max_accuracy:.1f}%")

# Continue with your existing code to calculate final accuracy and confusion matrices
threshold = 0.50
pred_labels = [[el > threshold for el in p] for true, pred in zip(true_labels, pred_labels) for t, p in zip(true, pred) if t[0] > -50]
true_labels = [item for sublist in true_labels for item in sublist if item[0] > -50]

cm_multi = multilabel_confusion_matrix(true_labels, pred_labels)
accuracy_value = accuracy_score(pred_labels, true_labels)

print(accuracy_value)

for i in range(0, len(unique_labels)):
    cm_multi[i] = np.fliplr(np.rot90(cm_multi[i]))

print(cm_multi)

In [None]:
print(cm_multi)
cm_multi_ration = np.array(cm_multi, dtype=float)

for idx in range(len(cm_multi)):
    cm_multi_ration[idx] = cm_multi[idx] * 100 / np.sum(cm_multi[idx])

cm_multi_ration = np.round_(cm_multi_ration, decimals=1)

cm_multi_ration

In [None]:
from matplotlib.colors import ListedColormap
from matplotlib.pyplot import figure
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

mpl.rcParams['figure.dpi'] = 300
fig, axs = plt.subplots(2,2)
fig.set_size_inches(12, 15)

def calc_accuracy(i):
    true_1 = cm_multi[i,0,0] + cm_multi[i,1,1]
    false_1 = cm_multi[i,1,0] + cm_multi[i,0,1]
    return true_1 / (true_1 + false_1)

def calc_sensitivity(i):
    tp = cm_multi[i,1,1]
    fn = cm_multi[i,1,0]
    return round(tp / (tp + fn), 2)

def calc_specificity(i):
    tn = cm_multi[i,0,0]
    fp = cm_multi[i,0,1]
    return round(tn / (tn + fp), 2)

tick_labels = ['Present', 'Absent']
xlabel = "Text mining"
ylabel = "Manual"
axes_font_size = 30
label_font_size = 20
tick_label_size = 15
result_size = 25


display(unique_labels)

colormap = ListedColormap(['white'])

plt.subplot(2, 2, 1)
b = sns.heatmap(cm_multi_ration[0], annot=cm_multi_ration[0], fmt='', cmap=colormap, annot_kws={"size": result_size}, cbar=False,
                linewidths=1, linecolor='black')
b.set_xticklabels(tick_labels, size = tick_label_size)
b.set_yticklabels(tick_labels, size = tick_label_size)
axs[0, 0].set_title('Context', fontsize=axes_font_size, fontweight='bold')
plt.xlabel(xlabel, fontsize=label_font_size, fontweight='bold')
plt.ylabel(ylabel, fontsize=label_font_size, fontweight='bold')
b.xaxis.tick_top()
b.xaxis.set_label_position('top')


plt.subplot(2, 2, 2)
b = sns.heatmap(cm_multi_ration[1], annot=cm_multi_ration[1], fmt='', cmap=colormap, annot_kws={"size": result_size}, cbar=False,
                linewidths=1, linecolor='black')
b.set_xticklabels(tick_labels, size = tick_label_size)
b.set_yticklabels(tick_labels, size = tick_label_size)
axs[0, 1].set_title('Expectations', fontsize=axes_font_size, fontweight='bold')
plt.xlabel(xlabel, fontsize=label_font_size, fontweight='bold')
plt.ylabel(ylabel, fontsize=label_font_size, fontweight='bold')
b.xaxis.tick_top()
b.xaxis.set_label_position('top')

plt.subplot(2, 2, 3)
b = sns.heatmap(cm_multi_ration[2], annot=cm_multi_ration[2], fmt='', cmap=colormap, annot_kws={"size": result_size}, cbar=False,
                linewidths=1, linecolor='black')
b.set_xticklabels(tick_labels, size = tick_label_size)
b.set_yticklabels(tick_labels, size = tick_label_size)
axs[1, 0].set_title('Experienced QoC', fontsize=axes_font_size, fontweight='bold')
plt.xlabel(xlabel, fontsize=label_font_size, fontweight='bold')
plt.ylabel(ylabel, fontsize=label_font_size, fontweight='bold')
b.xaxis.tick_top()
b.xaxis.set_label_position('top')

plt.subplot(2, 2, 4)
b = sns.heatmap(cm_multi_ration[3], annot=cm_multi_ration[3], fmt='', cmap=colormap, annot_kws={"size": result_size}, cbar=False,
                linewidths=1, linecolor='black')
b.set_xticklabels(tick_labels, size = tick_label_size)
b.set_yticklabels(tick_labels, size = tick_label_size)
axs[1, 1].set_title('Experiences', fontsize=axes_font_size, fontweight='bold')
plt.xlabel(xlabel, fontsize=label_font_size, fontweight='bold')
plt.ylabel(ylabel, fontsize=label_font_size, fontweight='bold')
b.xaxis.tick_top()
b.xaxis.set_label_position('top')