In [1]:
import logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(asctime)s - %(message)s')

import pandas as pd
pd.options.mode.chained_assignment = None

import numpy as np
import matplotlib.pyplot as plt

import torch
device = torch.device("cuda")

from sklearn.metrics import RocCurveDisplay, accuracy_score, ConfusionMatrixDisplay, confusion_matrix, classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier
from sklearn.model_selection import train_test_split

from transformers import (
    AutoTokenizer,
    AutoModel,
    Trainer,
)

from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

import os
import gc
import pickle

from IPython.display import clear_output
clear_output()

import utils
import data_extraction as da

import modelling as md

# TODO: create graph showing experiments with pooling strategies

pooling_model_name = md.pooling_models[6].replace('/', '_')
print(pooling_model_name)

mrm8488_t5-base-finetuned-imdb-sentiment


[nltk_data] Downloading package stopwords to
[nltk_data]     /home/eye4got/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
cat_cols = ['themes', 'violence', 'drug_use', 'sex']

df = pd.read_parquet(da.cleaned_dataset_fp).drop(columns=['nudity', 'language']).sort_values(['movie', 'start_time'])

for col in cat_cols:
    df[col] = md.convert_col_to_ordinal(df[col])
    
ratings = df[cat_cols + ['movie']].drop_duplicates().drop(columns=['movie']).values

In [3]:
with open(os.path.join(md.all_txt_sem_rep_dir, f'{md.all_txt_pickle_prefix}{pooling_model_name}.pkl'), 'rb') as fileobj:
    rep_list = pickle.load(fileobj)
    
with open(os.path.join(md.all_txt_sem_rep_dir, f'{md.dialogue_only_pickle_prefix}{pooling_model_name}.pkl'), 'rb') as fileobj:
    diag_only_rep_list = pickle.load(fileobj)

In [4]:
# TODO: perform CV since test set is so small and arbitrary

X = torch.stack(diag_only_rep_list).numpy()
y = np.array(ratings)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

model = MultiOutputClassifier(LogisticRegression(max_iter=1000))
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

print(f'Dialogue Only Overall accuracy: {accuracy_score(y_test.reshape(-1), y_pred.reshape(-1)) * 100}%')

for ii, category in enumerate(cat_cols):
    print(f"\n=== {category.upper()} ===")
    print(classification_report(y_test[:, ii], y_pred[:, ii], zero_division=0))

Dialogue Only Overall accuracy: 62.5%

=== THEMES ===
              precision    recall  f1-score   support

           0       0.50      0.14      0.22         7
           1       0.48      0.83      0.61        12
           2       0.00      0.00      0.00         5

    accuracy                           0.46        24
   macro avg       0.33      0.33      0.28        24
weighted avg       0.38      0.46      0.37        24


=== VIOLENCE ===
              precision    recall  f1-score   support

           0       0.45      0.62      0.53         8
           1       0.45      0.45      0.45        11
           2       0.50      0.20      0.29         5

    accuracy                           0.46        24
   macro avg       0.47      0.43      0.42        24
weighted avg       0.46      0.46      0.44        24


=== DRUG_USE ===
              precision    recall  f1-score   support

           0       0.87      1.00      0.93        20
           1       1.00      0.25      

In [5]:
X = torch.stack(rep_list).numpy()
y = np.array(ratings)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

model = MultiOutputClassifier(LogisticRegression(max_iter=1000))
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

print(f'Transcript Overall accuracy: {accuracy_score(y_test.reshape(-1), y_pred.reshape(-1)) * 100}%')

for ii, category in enumerate(cat_cols):
    print(f"\n=== {category.upper()} ===")
    print(classification_report(y_test[:, ii], y_pred[:, ii], zero_division=0))

Transcript Overall accuracy: 62.5%

=== THEMES ===
              precision    recall  f1-score   support

           0       0.33      0.14      0.20         7
           1       0.45      0.75      0.56        12
           2       0.00      0.00      0.00         5

    accuracy                           0.42        24
   macro avg       0.26      0.30      0.25        24
weighted avg       0.32      0.42      0.34        24


=== VIOLENCE ===
              precision    recall  f1-score   support

           0       0.50      0.38      0.43         8
           1       0.44      0.73      0.55        11
           2       0.00      0.00      0.00         5

    accuracy                           0.46        24
   macro avg       0.31      0.37      0.33        24
weighted avg       0.37      0.46      0.40        24


=== DRUG_USE ===
              precision    recall  f1-score   support

           0       0.83      1.00      0.91        20
           1       0.00      0.00      0.0