# Task-3

Here, we want to know why the model thinks a text is AI generated.  

We are required to use either [SHAP](https://shap.readthedocs.io/en/latest/) or [Captum](https://captum.ai/) for this task. I have decided to use SHAP because:
1. It seems to have nicer visualisations
2. Captum seems to be primarily for PyTorch.

> highlight the words in an **"Imposter" paragraph** that most strongly signaled "AI" to your Tier C model.

This is a bit of a problem. I do not have 3 such paragraphs as recommended. I only have 1 paragraph which was class-1, but mistaken to be class-3. 

I instead have decided to use SHAP to analyse the ones classified with the least confidence.

In [5]:
import pandas as pd
import torch
import glob
import os
import pickle
from pathlib import Path
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel, PeftConfig
from tqdm import tqdm

MODEL_PATH = "../task-2/transformer/tier_c_final_model"
DATASET_DIR = Path('../dataset')
OUTPUT_DIR = "low_confidence_analysis"

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

print("Loading model for inference...")
config = PeftConfig.from_pretrained(MODEL_PATH)
base_model = AutoModelForSequenceClassification.from_pretrained(
    config.base_model_name_or_path, num_labels=3
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(base_model, MODEL_PATH)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

def load_and_scan():
    results = {0: [], 1: [], 2: []}
    
    paths = {
        0: (DATASET_DIR / 'class1-human-written', ['01-arthur-conan-doyle', '02-pg-wodehouse', '03-mark-twain', '04-william-shakespeare'], 'extracted_paragraphs'),
        1: (DATASET_DIR / 'class2-ai-written', ['ai-generated-paragraphs'], ''), 
        2: (DATASET_DIR / 'class3-ai-mimicry', ['01-arthur-conan-doyle', '02-pg-wodehouse', '03-mark-twain', '04-william-shakespeare'], '')
    }

    print("\nScanning dataset for low-confidence samples...")
    
    for label, (base_path, subfolders, suffix) in paths.items():
        files = []
        for sub in subfolders:
            if suffix:
                search_path = base_path / sub / suffix
            else:
                if sub == 'ai-generated-paragraphs': 
                    search_path = base_path / sub
                else:
                    search_path = base_path / sub
            
            files.extend(glob.glob(os.path.join(str(search_path), '*.txt')))

        print(f"Scanning Class {label} ({len(files)} files)...")
        
        for file_path in tqdm(files):
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    text = f.read().strip()
                
                if not text: continue

                inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
                with torch.no_grad():
                    logits = model(**inputs).logits
                    probs = torch.softmax(logits, dim=1)[0]
                
                confidence = probs[label].item()
                
                results[label].append({
                    'confidence': confidence,
                    'text': text,
                    'file': os.path.basename(file_path),
                    'probs': probs.cpu().numpy()
                })
                
            except Exception as e:
                pass

    return results

data_map = load_and_scan()
lowest_samples = {}

print("\n--- RESULTS: LOWEST CONFIDENCE SAMPLES ---")

file_names = {
    0: "low-confidence-class-1-human.txt",
    1: "low-confidence-class-2-generic.txt",
    2: "low-confidence-class-3-mimic.txt"
}

for label in [0, 1, 2]:
    sorted_samples = sorted(data_map[label], key=lambda x: x['confidence'])
    
    bottom_3 = sorted_samples[:3]
    lowest_samples[label] = bottom_3
    
    out_file = os.path.join(OUTPUT_DIR, file_names[label])
    with open(out_file, 'w', encoding='utf-8') as f:
        print(f"\n[Class {label}] Lowest Confidence:")
        for i, sample in enumerate(bottom_3):
            header = f"Sample {i+1} | Conf: {sample['confidence']:.2%} | File: {sample['file']}"
            print(f"  {header}")
            f.write(f"{header}\n")
            f.write(f"{sample['text']}\n")
            f.write("-" * 50 + "\n")

# Save lowest_samples for later SHAP analysis
samples_file = os.path.join(OUTPUT_DIR, "lowest_samples.pkl")
with open(samples_file, 'wb') as f:
    pickle.dump(lowest_samples, f)

print(f"\nScanning complete! Results saved to '{OUTPUT_DIR}'")
print(f"Lowest confidence samples saved to '{samples_file}'")

Loading model for inference...


Loading weights: 100%|██████████| 100/100 [00:00<00:00, 1100.78it/s, Materializing param=distilbert.transformer.layer.5.sa_layer_norm.weight]   
DistilBertForSequenceClassification LOAD REPORT from: distilbert-base-uncased
Key                     | Status     | 
------------------------+------------+-
vocab_projector.bias    | UNEXPECTED | 
vocab_transform.weight  | UNEXPECTED | 
vocab_transform.bias    | UNEXPECTED | 
vocab_layer_norm.bias   | UNEXPECTED | 
vocab_layer_norm.weight | UNEXPECTED | 
pre_classifier.bias     | MISSING    | 
classifier.weight       | MISSING    | 
pre_classifier.weight   | MISSING    | 
classifier.bias         | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.



Scanning dataset for low-confidence samples...
Scanning Class 0 (1960 files)...


100%|██████████| 1960/1960 [02:04<00:00, 15.71it/s]


Scanning Class 1 (988 files)...


100%|██████████| 988/988 [01:05<00:00, 15.13it/s]


Scanning Class 2 (973 files)...


100%|██████████| 973/973 [01:15<00:00, 12.84it/s]


--- RESULTS: LOWEST CONFIDENCE SAMPLES ---

[Class 0] Lowest Confidence:
  Sample 1 | Conf: 1.93% | File: 099_Sketches-New-and-Old,-Part-7._12.txt
  Sample 2 | Conf: 46.85% | File: 013_Psmith-in-the-City_11.txt
  Sample 3 | Conf: 61.68% | File: 043_Pericles-Prince-of-Tyre_14.txt

[Class 1] Lowest Confidence:
  Sample 1 | Conf: 80.86% | File: HLB_2_Bruce_Partington_Plans_T01_P01.txt
  Sample 2 | Conf: 96.01% | File: The-Life-of-Henry-the-Eighth_T03_P01.txt
  Sample 3 | Conf: 97.44% | File: The-Innocents-Abroad-—-Volume-04_T03_P02.txt

[Class 2] Lowest Confidence:
  Sample 1 | Conf: 5.75% | File: TWAIN_Roughing-It,-Part-5._T04_01.txt
  Sample 2 | Conf: 12.54% | File: WODE_Love-Among-the-Chickens_T01_01.txt
  Sample 3 | Conf: 32.15% | File: TWAIN_Mark-Twain's-Letters-—-Complete_T03_02.txt

Scanning complete! Results saved to 'low_confidence_analysis'
Lowest confidence samples saved to 'low_confidence_analysis/lowest_samples.pkl'





In [1]:
import shap
import numpy as np
import torch
import pickle
import os
import matplotlib.pyplot as plt
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel, PeftConfig

# 1. CONFIGURATION
MODEL_PATH = "../task-2/transformer/tier_c_final_model"
OUTPUT_DIR = "low_confidence_analysis"

# Make sure output directory exists for the images
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# Load the saved lowest confidence samples
samples_file = os.path.join(OUTPUT_DIR, "lowest_samples.pkl")
with open(samples_file, 'rb') as f:
    lowest_samples = pickle.load(f) # Expected format: {0: [...], 1: [...], 2: [...]}

# 2. MODEL LOADING
print("Loading model for SHAP analysis...")
config = PeftConfig.from_pretrained(MODEL_PATH)
base_model = AutoModelForSequenceClassification.from_pretrained(
    config.base_model_name_or_path, num_labels=3
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(base_model, MODEL_PATH)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# 3. PREDICTION FUNCTION
def predict_shap(texts):
    # Handle single string vs list
    if isinstance(texts, str):
        texts = [texts]
    elif isinstance(texts, (list, np.ndarray)):
        texts = [str(t) for t in texts]
    
    # Tokenize
    inputs = tokenizer(
        texts, 
        return_tensors="pt", 
        padding=True, 
        truncation=True, 
        max_length=512
    ).to(device)
    
    # Predict
    with torch.no_grad():
        logits = model(**inputs).logits
        scores = torch.softmax(logits, dim=1)
        
    return scores.cpu().numpy()

# 4. HELPER FUNCTION TO GET ACTUAL TOKENS
def get_token_text(text, token_id_or_index):
    """Convert token ID or index to actual text representation"""
    try:
        # If it's an integer, decode it directly
        if isinstance(token_id_or_index, (int, np.integer)):
            return tokenizer.decode([token_id_or_index])
        # If it's already a string token, return as is
        return str(token_id_or_index)
    except:
        return f"[UNK_{token_id_or_index}]"

# 5. RUN SHAP
print("\nRunning SHAP Analysis on the 9 lowest confidence samples...")
explainer = shap.Explainer(predict_shap, tokenizer)

texts_to_explain = []
target_classes = [] # We need to know which class to visualize for each sample
sample_names = []

# Flatten the dictionary into lists
for label in [0, 1, 2]:
    class_name = ["Human", "Generic_AI", "Mimic_AI"][label]
    for idx, sample in enumerate(lowest_samples[label]):
        texts_to_explain.append(sample['text'])
        # We want to visualize the TRUE class to see why confidence was low
        target_classes.append(label) 
        sample_names.append(f"Class_{label+1}_{class_name}_Sample_{idx+1}")

# Calculate SHAP values (might take a minute)
shap_values = explainer(texts_to_explain)

# 6. GENERATE VISUALIZATIONS
print("GENERATING 3 VISUALIZATIONS PER SAMPLE")

for i in range(len(texts_to_explain)):
    name = sample_names[i]
    target_cls = target_classes[i]
    print(f"Processing {name}...")
    
    # Get the actual text for this sample to properly map tokens
    current_text = texts_to_explain[i]
    
    # Tokenize to get the actual token IDs and convert them to text
    encoded = tokenizer(current_text, truncation=True, max_length=512)
    token_ids = encoded['input_ids']
    actual_tokens = [tokenizer.decode([tid]) for tid in token_ids]

    # VISUALIZATION 1: WATERFALL PLOT (The Logic)
    # Shows the Top ~12 words that pushed the score up or down
    try:
        plt.figure(figsize=(10, 6))
        shap.plots.waterfall(
            shap_values[i, :, target_cls], 
            max_display=12, 
            show=False
        )
        plt.title(f"{name}: Why class {target_cls+1}?", fontsize=14)
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, f"{name}_waterfall.png"))
        plt.close()
    except Exception as e:
        print(f"  Warning: Waterfall plot failed for {name}: {e}")

    # VISUALIZATION 2: BAR CHART (The Summary)
    # Shows global importance of features for this sample
    try:
        # Get SHAP values for this sample and target class
        shap_vals_single = shap_values[i, :, target_cls].values
        
        # Get feature names - these should be the actual tokens
        if hasattr(shap_values[i, :, target_cls], 'data') and shap_values[i, :, target_cls].data is not None:
            feature_names = shap_values[i, :, target_cls].data
        else:
            # Fallback to using actual_tokens we computed
            feature_names = actual_tokens[:len(shap_vals_single)]
        
        # Ensure we have the right number of features
        if len(feature_names) != len(shap_vals_single):
            feature_names = actual_tokens[:len(shap_vals_single)]
        
        # Get top features by absolute value
        abs_vals = np.abs(shap_vals_single)
        top_indices = np.argsort(abs_vals)[-12:][::-1]
        
        top_vals = shap_vals_single[top_indices]
        top_features = []
        for idx in top_indices:
            if idx < len(feature_names):
                token = str(feature_names[idx]).strip()
                # Clean up token representation
                if token.startswith('##'):
                    token = token[2:]  # Remove BERT subword prefix
                if not token or token == '':
                    token = '[SPACE]'
                top_features.append(token)
            else:
                top_features.append('[UNK]')
        
        # Create the bar plot
        colors = ['#ff0051' if v > 0 else '#008bfb' for v in top_vals]
        plt.figure(figsize=(10, 6))
        plt.barh(range(len(top_vals)), top_vals, color=colors)
        plt.yticks(range(len(top_vals)), top_features)
        plt.xlabel('SHAP value (impact on model output)')
        plt.title(f"{name}: Top Features (Words)", fontsize=14)
        plt.gca().invert_yaxis()
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, f"{name}_bar.png"))
        plt.close()
        
    except Exception as e:
        print(f"  Error creating bar plot for {name}: {e}")

    # VISUALIZATION 3: TEXT HEATMAP (The Context)
    # Shows the text with Red/Blue highlights
    try:
        # Get SHAP values for this sample
        shap_vals_single = shap_values[i, :, target_cls].values
        
        # Use actual_tokens we computed earlier
        tokens_for_viz = actual_tokens[:len(shap_vals_single)]
        
        # Ensure alignment
        if len(tokens_for_viz) != len(shap_vals_single):
            print(f"  Warning: Token count mismatch for {name}")
            tokens_for_viz = tokens_for_viz[:len(shap_vals_single)]
        
        max_abs = max(abs(shap_vals_single)) if len(shap_vals_single) > 0 else 1
        
        html = f"""<!DOCTYPE html>
<html>
<head>
    <meta charset="UTF-8">
    <title>{name}</title>
    <style>
        body {{
            margin: 20px;
            font-family: Arial, sans-serif;
            max-width: 1200px;
        }}
        h1 {{
            color: #333;
            border-bottom: 2px solid #ccc;
            padding-bottom: 10px;
        }}
        .legend {{
            margin: 20px 0;
            padding: 10px;
            background: #f5f5f5;
            border-radius: 5px;
        }}
        .text-container {{
            font-family: 'Courier New', monospace;
            line-height: 2;
            font-size: 14px;
            padding: 20px;
            background: white;
            border: 1px solid #ddd;
            border-radius: 5px;
        }}
        .token {{
            padding: 2px 4px;
            margin: 1px;
            border-radius: 3px;
            display: inline-block;
        }}
    </style>
</head>
<body>
    <h1>{name.replace('_', ' ')}</h1>
    <div class="legend">
        <strong>Legend:</strong> 
        <span style="background-color: rgba(255, 0, 81, 0.5); padding: 2px 8px; border-radius: 3px;">Red = Increases prediction</span>
        <span style="background-color: rgba(0, 139, 251, 0.5); padding: 2px 8px; border-radius: 3px; margin-left: 10px;">Blue = Decreases prediction</span>
    </div>
    <div class="text-container">"""
        
        for token, val in zip(tokens_for_viz, shap_vals_single):
            # Clean up token
            token_text = str(token).strip()
            if token_text.startswith('##'):
                token_text = token_text[2:]  # Remove BERT subword prefix
                space = ''  # No space before subword
            elif token_text in ['[CLS]', '[SEP]', '[PAD]']:
                continue  # Skip special tokens
            else:
                space = ' '  # Add space before regular tokens
            
            if not token_text:
                token_text = '·'  # Use middle dot for empty tokens
            
            # Color based on SHAP value
            intensity = min(abs(val) / max_abs, 1.0) * 0.7 if max_abs > 0 else 0
            
            if val > 0:
                color = f"rgba(255, 0, 81, {intensity})"
            else:
                color = f"rgba(0, 139, 251, {intensity})"
            
            html += f"{space}<span class='token' style='background-color: {color};' title='SHAP: {val:.4f}'>{token_text}</span>"
        
        html += """
    </div>
</body>
</html>"""
        
        with open(os.path.join(OUTPUT_DIR, f"{name}_heatmap.html"), "w", encoding='utf-8') as f:
            f.write(html)
            
    except Exception as e:
        print(f"  Error creating heatmap for {name}: {e}")

print("\nANALYSIS COMPLETE!")
print(f"Check the '{OUTPUT_DIR}' folder for your .png and .html files.")
print("\nThe visualizations now show:")
print("  • Bar charts: Actual words/tokens ranked by importance")
print("  • Heatmaps: Full text with color-coded word importance")
print("  • Red highlights = words that INCREASE the prediction for that class")
print("  • Blue highlights = words that DECREASE the prediction for that class")


  from .autonotebook import tqdm as notebook_tqdm


Loading model for SHAP analysis...


Loading weights: 100%|██████████| 100/100 [00:00<00:00, 1048.36it/s, Materializing param=distilbert.transformer.layer.5.sa_layer_norm.weight]   
DistilBertForSequenceClassification LOAD REPORT from: distilbert-base-uncased
Key                     | Status     | 
------------------------+------------+-
vocab_layer_norm.weight | UNEXPECTED | 
vocab_layer_norm.bias   | UNEXPECTED | 
vocab_transform.weight  | UNEXPECTED | 
vocab_projector.bias    | UNEXPECTED | 
vocab_transform.bias    | UNEXPECTED | 
pre_classifier.weight   | MISSING    | 
classifier.bias         | MISSING    | 
classifier.weight       | MISSING    | 
pre_classifier.bias     | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.



Running SHAP Analysis on the 9 lowest confidence samples...


PartitionExplainer explainer: 10it [03:52, 25.88s/it]                      


GENERATING 3 VISUALIZATIONS PER SAMPLE
Processing Class_1_Human_Sample_1...
Processing Class_1_Human_Sample_2...
Processing Class_1_Human_Sample_3...
Processing Class_2_Generic_AI_Sample_1...
Processing Class_2_Generic_AI_Sample_2...
Processing Class_2_Generic_AI_Sample_3...
Processing Class_3_Mimic_AI_Sample_1...
Processing Class_3_Mimic_AI_Sample_2...
Processing Class_3_Mimic_AI_Sample_3...

ANALYSIS COMPLETE!
Check the 'low_confidence_analysis' folder for your .png and .html files.

The visualizations now show:
  • Bar charts: Actual words/tokens ranked by importance
  • Heatmaps: Full text with color-coded word importance
  • Red highlights = words that INCREASE the prediction for that class
  • Blue highlights = words that DECREASE the prediction for that class
