# 1. Load the Model and Tokenizer

In [2]:
from bertviz import head_view, model_view
# Load pre-trained model and tokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

# Load the pre-trained model and tokenizer
model_name = "sshleifer/distilbart-cnn-12-6"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)



# 2.  Get Model Predictions

In [4]:
def bart_predict(inputs):
    # Convert inputs to tensors
    inputs = tokenizer(inputs, return_tensors='pt', padding=True, truncation=True)
    input_ids = inputs['input_ids']
    
    # Generate predictions
    with torch.no_grad():
        outputs = model.generate(input_ids)
    
    # Decode the generated summaries
    summaries = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return summaries

# 3. Wrap SHAP model predictions

In [5]:
def shap_model(inputs):
    # Call the prediction function
    summaries = bart_predict(inputs)
    return summaries

# 4. Apply SHAP for Token-Level Explanations

In [None]:
import shap
import numpy as np

# Example text for explanation
example_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "A journey of a thousand miles begins with a single step."
]

# Create a SHAP explainer for the model
def tokenizer_for_shap(texts):
    encodings = tokenizer(texts, truncation=True, padding=True, return_tensors='pt')
    return encodings['input_ids']

# Define a SHAP explainer
explainer = shap.Explainer(model, tokenizer_for_shap)

# Compute SHAP values for the example texts
shap_values = explainer(example_texts)

# Visualize SHAP values
shap.visualize_text(shap_values, tokenizer)

# 3. Visualize Attention for a Specific Row in Your DataFrame

In [None]:
import pandas as pd

# Load your dataframe
df = pd.read_csv('/Users/rohitrawat/job-prep/Assignments/accrete-ai/text-summarization/data/processed/news_summary_cleaned_train_predictions.csv')

# Visualize attention for the first row
visualize_bart_attention(df['text'].iloc[1], df['generated_summary'].iloc[1])
