In [10]:
import logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(asctime)s - %(message)s')

import pandas as pd
pd.options.mode.chained_assignment = None

import numpy as np
import matplotlib.pyplot as plt

import torch
device = torch.device("cuda")

from sklearn.metrics import RocCurveDisplay, accuracy_score, ConfusionMatrixDisplay, confusion_matrix, classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier
from sklearn.model_selection import train_test_split

from transformers import (
    AutoTokenizer,
    AutoModel,
    Trainer,
)

from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

import os
import gc
import pickle

from IPython.display import clear_output
clear_output()

import utils
import data_extraction as da

import modelling as md

# TODO: create graph showing experiments with pooling strategies

pooling_model_name = 'cardiffnlp/twitter-roberta-large-sensitive-multilabel'.replace('/', '_')

In [11]:
cat_cols = ['themes', 'violence', 'drug_use', 'sex']

df = pd.read_parquet(da.cleaned_dataset_fp).drop(columns=['nudity', 'language']).sort_values(['movie', 'start_time'])

for col in cat_cols:
    df[col] = md.convert_col_to_ordinal(df[col])
    
ratings = df[cat_cols + ['movie']].drop_duplicates().drop(columns=['movie']).values

In [12]:
with open(os.path.join(md.all_txt_sem_rep_dir, f'rep_list_{pooling_model_name}.pkl'), 'rb') as fileobj:
    rep_list = pickle.load(fileobj)
    
with open(os.path.join(md.all_txt_sem_rep_dir, f'dialogue_only_rep_list_{pooling_model_name}.pkl'), 'rb') as fileobj:
    diag_only_rep_list = pickle.load(fileobj)

In [17]:
X = torch.stack(diag_only_rep_list).numpy()
y = np.array(ratings)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

model = MultiOutputClassifier(LogisticRegression(max_iter=1000))
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

print(f'Overall accuracy: {accuracy_score(y_test.reshape(-1), y_pred.reshape(-1)) * 100}%')

for ii, category in enumerate(cat_cols):
    print(f"\n=== {category.upper()} ===")
    print(classification_report(y_test[:, ii], y_pred[:, ii], zero_division=0))

Overall accuracy: 72.61904761904762%

=== THEMES ===
              precision    recall  f1-score   support

           0       0.60      0.60      0.60         5
           1       0.70      0.64      0.67        11
           2       0.50      0.60      0.55         5

    accuracy                           0.62        21
   macro avg       0.60      0.61      0.60        21
weighted avg       0.63      0.62      0.62        21


=== VIOLENCE ===
              precision    recall  f1-score   support

           0       0.62      1.00      0.77         5
           1       0.88      0.88      0.88         8
           2       1.00      0.62      0.77         8

    accuracy                           0.81        21
   macro avg       0.83      0.83      0.80        21
weighted avg       0.86      0.81      0.81        21


=== DRUG_USE ===
              precision    recall  f1-score   support

           0       0.83      1.00      0.91        15
           1       0.00      0.00      0

In [18]:
X = torch.stack(rep_list).numpy()
y = np.array(ratings)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

model = MultiOutputClassifier(LogisticRegression(max_iter=1000))
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

print(f'Overall accuracy: {accuracy_score(y_test.reshape(-1), y_pred.reshape(-1)) * 100}%')

for ii, category in enumerate(cat_cols):
    print(f"\n=== {category.upper()} ===")
    print(classification_report(y_test[:, ii], y_pred[:, ii], zero_division=0))

Overall accuracy: 72.61904761904762%

=== THEMES ===
              precision    recall  f1-score   support

           0       0.60      0.60      0.60         5
           1       0.73      0.73      0.73        11
           2       0.80      0.80      0.80         5

    accuracy                           0.71        21
   macro avg       0.71      0.71      0.71        21
weighted avg       0.71      0.71      0.71        21


=== VIOLENCE ===
              precision    recall  f1-score   support

           0       0.50      0.60      0.55         5
           1       0.70      0.88      0.78         8
           2       0.80      0.50      0.62         8

    accuracy                           0.67        21
   macro avg       0.67      0.66      0.65        21
weighted avg       0.69      0.67      0.66        21


=== DRUG_USE ===
              precision    recall  f1-score   support

           0       0.93      0.87      0.90        15
           1       0.20      0.50      0