In [2]:
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import LayerIntegratedGradients, visualization as viz

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("shaggysus/MovieGenrePrediction")
model = AutoModelForSequenceClassification.from_pretrained("shaggysus/MovieGenrePrediction")
model.to(device)
model.eval()

genre_labels = {
    0: 'Action',
    1: 'Adventure',
    2: 'Crime',
    3: 'Fantasy',
    4: 'Family',
    5: 'Horror',
    6: 'mystery',
    7: 'Romance',
    8: 'Sci-Fi',
    9: 'Thriller'
}

def predict_genre(input_ids, attention_mask=None):
    output = model(input_ids, attention_mask=attention_mask)
    return output.logits


def xai_and_predict(subtitle):

    inputs = tokenizer(subtitle, return_tensors="pt", truncation=True, padding=True)
    inputs.to(device)


    genre_logits = predict_genre(inputs['input_ids'], attention_mask=inputs['attention_mask'])
    predicted_genre_id = torch.argmax(genre_logits, dim=1).item()


    target_index = torch.argmax(genre_logits, dim=1)
    lig = LayerIntegratedGradients(predict_genre, model.distilbert.embeddings) 
    attributions, delta = lig.attribute(inputs['input_ids'], target=target_index, return_convergence_delta=True)
    attributions = attributions.sum(dim=-1).squeeze(0)

    return genre_labels.get(predicted_genre_id, 'Unknown'), attributions, inputs, delta


subtitle = "As the clock struck midnight, a bloodcurdling scream pierced the silence, echoing through the empty halls of the abandoned mansion."


predicted_genre, attributions, inputs, delta = xai_and_predict(subtitle)


print("Predicted Genre:", predicted_genre)
viz.visualize_text([viz.VisualizationDataRecord(
                        attributions,
                        torch.max(torch.softmax(predict_genre(inputs['input_ids'], attention_mask=inputs['attention_mask']).to(device), dim=1)),
                        torch.argmax(predict_genre(inputs['input_ids'], attention_mask=inputs['attention_mask']).to(device)),
                        torch.argmax(predict_genre(inputs['input_ids'], attention_mask=inputs['attention_mask']).to(device)),
                        str(inputs['input_ids'].tolist()),
                        attributions.sum(),
                        tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist()),
                        convergence_score=delta
                    )])
print("Convergence Delta:", delta)
print("word importance:", tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist()),)


Predicted Genre: Horror


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
5.0,5 (0.78),"[[101, 2004, 1996, 5119, 4930, 7090, 1010, 1037, 2668, 10841, 4103, 2989, 6978, 16276, 1996, 4223, 1010, 17142, 2083, 1996, 4064, 9873, 1997, 1996, 4704, 7330, 1012, 102]]",5.61,"[CLS] as the clock struck midnight , a blood ##cu ##rd ##ling scream pierced the silence , echoing through the empty halls of the abandoned mansion . [SEP]"
,,,,


Convergence Delta: tensor([0.0135], dtype=torch.float64)
word importance: ['[CLS]', 'as', 'the', 'clock', 'struck', 'midnight', ',', 'a', 'blood', '##cu', '##rd', '##ling', 'scream', 'pierced', 'the', 'silence', ',', 'echoing', 'through', 'the', 'empty', 'halls', 'of', 'the', 'abandoned', 'mansion', '.', '[SEP]']
