In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import matplotlib.pyplot as plt
from data_loader import load_fold, get_conditions_by_group

In [2]:
model_path = "results/bertweet-base_eng_cognitive_attention_20250421_023107/fold_1/final_model"
tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base")
model = AutoModelForSequenceClassification.from_pretrained(model_path)

In [12]:
fold = 1
group = "cognitive_attention"
language = "eng"
dataset_dict = load_fold(fold, language=language, group=group, balance_level="full")
sample = dataset_dict["test"][5] 
text = sample["tweet"]
print("Sample tweet:", text)
print(sample["class"])


Original class distribution in training set (users):
  CONTROL: 1362 users
  ADHD: 498 users
  ASD: 136 users

Full balancing: Reducing all classes to match smallest class ASD (136 users)

Final class distribution in training set (users):
  ADHD: 136 users
  ASD: 136 users
  CONTROL: 136 users
Loading train data...
Loading test data...
Sample tweet: " Im very detailed in all aspect of my life the fact that people notice HTTPURL "
ADHD


In [9]:
from transformers_interpret import SequenceClassificationExplainer

In [10]:
cls_explainer = SequenceClassificationExplainer(model, tokenizer)

In [11]:
cls_explainer(text)

[('<s>', 0.0),
 ('"', 0.07382133603096008),
 ('My', 0.022843321785330772),
 ('little', 0.4132058024406433),
 ('Melody', 0.6295004487037659),
 ('is', -0.11366540193557739),
 ('12', -0.1502915620803833),
 ('today', 0.19136029481887817),
 ('!', 0.29310551285743713),
 ('<unk>', 0.33143147826194763),
 (':two_hearts:', 0.37707796692848206),
 ('"', -0.13018174469470978),
 ('</s>', 0.0)]

In [14]:
grouped_tweets = {"ADHD": [], "ASD": [], "CONTROL": []}
for sample in dataset_dict["test"]:
    tweet = sample["tweet"]
    label = sample["class"]
    grouped_tweets[label].append(tweet)

In [None]:
def compute_influential_words(tweets):
    influential_words = []
    counter = 0
    for text in tweets:
        counter += 1
        attributions = cls_explainer(text)
        sorted_attributions = sorted(attributions, key=lambda x: abs(x[1]), reverse=True)
        top_words = [(word, score) for word, score in sorted_attributions[:15]]
        influential_words.append({"tweet": text, "top_words": top_words})
        if (counter % 10000 == 0):
            print(counter)
    return influential_words

In [None]:
results = {}
for group, tweets in grouped_tweets.items():
    print(f"Processing group: {group}")
    results[group] = compute_influential_words(tweets[1:600])

In [None]:
for group, data in results.items():
    print(f"\nGroup: {group}")
    for entry in data:
        print(f"Tweet: {entry['tweet']}")
        print("Top 15 Influential Words:")
        for word, score in entry["top_words"]:
            print(f"  {word}: {score}")
        print("\n")