<a href="https://colab.research.google.com/github/tomonari-masada/course2024-nlp/blob/main/EDA_with_multilingual_e5_large_instruct.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from tqdm.auto import tqdm
import numpy as np

import spacy
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity

import torch
import torch.nn as nn
from torch.nn.functional import normalize

from datasets import load_dataset
from transformers import (
    TrainingArguments,
    AutoTokenizer,
    AutoModelForSequenceClassification,
)
from transformers.modeling_outputs import ModelOutput

from trl import SFTTrainer

In [None]:
dataset = load_dataset(
    "shunk031/livedoor-news-corpus",
    train_ratio=0.8, val_ratio=0.1, test_ratio=0.1,
    random_state=42,
    shuffle=True,
    trust_remote_code=True,
)
num_categories = len(set(dataset["train"]["category"]))
max_seq_length = 512

In [None]:
category_names = ['movie-enter', 'it-life-hack', 'kaden-channel', 'topic-news', 'livedoor-homme', 'peachy', 'sports-watch', 'dokujo-tsushin', 'smax']

In [None]:
model_id = "intfloat/multilingual-e5-large-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, max_seq_length=max_seq_length)
model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    num_labels=num_categories,
).to(0)

In [None]:
def accuracy(model, tokenizer, corpus, labels, batch_size=4):
    model.eval()
    num_correct_answers, num_answers = 0, 0
    for i in tqdm(range(0, len(corpus), batch_size)):
        texts = corpus[i:i+batch_size]
        encodings = tokenizer(texts, padding=True, return_tensors="pt")
        encodings = encodings.to(model.device)
        category = torch.tensor(labels[i:i+batch_size]).to(model.device)
        with torch.no_grad():
            outputs = model(**encodings)
        predicted = outputs.logits.argmax(-1)
        num_correct_answers += (predicted == category).sum()
        num_answers += len(texts)
    model.train()
    return (num_correct_answers / num_answers).item()

In [None]:
def average_pool(last_hidden_states, attention_mask):
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

In [None]:
def embed(model, tokenizer, corpus, batch_size=4):
    model.eval()
    pooled_hidden_states = []
    for i in tqdm(range(0, len(corpus), batch_size)):
        texts = corpus[i:i+batch_size]
        encodings = tokenizer(texts, padding=True, return_tensors="pt")
        encodings = encodings.to(model.device)
        with torch.no_grad():
            outputs = model.roberta(**encodings)
        pooled_hidden_state = average_pool(
            outputs.last_hidden_state,
            encodings['attention_mask'],
        )
        pooled_hidden_states.append(pooled_hidden_state.cpu())
    model.train()
    return torch.cat(pooled_hidden_states)

In [None]:
embeddings = {}
for key in dataset:
    embeddings[key] = embed(model, tokenizer, dataset[key]["title"])
    embeddings[key] = normalize(embeddings[key])

In [None]:
n_clusters = 30
kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=123)
kmeans.fit(embeddings["train"])
centers = kmeans.cluster_centers_

In [None]:
unique, counts = np.unique(kmeans.labels_, return_counts=True)
size_dict = dict(zip(unique, counts))
print(sorted([item[1] for item in size_dict.items()]))

In [None]:
label_pos_tags = ["NOUN", "VERB", "PROPN"]

nlp = spacy.load("ja_core_news_sm")
corpus = {}
for key in dataset:
    corpus[key] = []
    for text in tqdm(dataset[key]["title"]):
        corpus[key].append(" ".join(
            [token.lemma_
             for token in nlp(text) if token.pos_ in label_pos_tags
            ]
        ))

In [None]:
vectorizer = TfidfVectorizer(min_df=10, max_df=0.1, lowercase=False)
vectorizer.fit(corpus["train"])
vocab = np.array(vectorizer.get_feature_names_out())

In [None]:
vocab

In [None]:
vocab_embeddings = embed(model, tokenizer, list(vocab))

In [None]:
topic_words = []
similarities = cosine_similarity(vocab_embeddings, centers)
for i in range(similarities.shape[-1]):
    indices = np.argsort(- similarities[:,i])
    topic_words.append(f"{i:d} " + " ".join(list(vocab[indices[:20]])))
print("\n".join(topic_words))

In [None]:
vectorizer = TfidfVectorizer(min_df=10, max_df=0.1, lowercase=False)
vectorizer.fit(corpus["train"])
vocab = np.array(vectorizer.get_feature_names_out())
X_train = vectorizer.transform(corpus["train"]).toarray()

In [None]:
vocab_embeddings = np.dot((X_train / X_train.sum(0)).T, embeddings["train"])

In [None]:
topic_words = []
similarities = cosine_similarity(vocab_embeddings, centers)
for i in range(similarities.shape[-1]):
    indices = np.argsort(- similarities[:,i])
    topic_words.append(f"{i:d} " + " ".join(list(vocab[indices[:20]])))
print("\n".join(topic_words))

In [None]:
class MyNetForClassification(nn.Module):
    def __init__(self, pretrained):
        super().__init__()
        self.pretrained = pretrained
        self.config = self.pretrained.config

    def forward(
        self, input_ids, category=None,
        attention_mask=None,
        output_attentions=None, output_hidden_states=None,
        return_dict=None, inputs_embeds=None, labels=None,
    ):
        outputs = self.pretrained(
            input_ids,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(outputs.logits, category)
        return ModelOutput(
            loss=loss,
            logits=outputs.logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
my_model = MyNetForClassification(model)

In [None]:
training_args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    output_dir="outputs",
    label_names=["category"],
    max_steps=300,
    eval_steps=100,
    logging_steps=100,
    save_steps=100,
    learning_rate=5e-5,
    optim_target_modules=["query", "key", "value", "dense"],
    evaluation_strategy="steps",
    logging_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
)

In [None]:
trainer = SFTTrainer(
    model=my_model,
    args=training_args,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    dataset_text_field="title",
)
trainer.train_dataset = trainer.train_dataset.add_column(
    "category", dataset["train"]["category"],
)
trainer.eval_dataset = trainer.eval_dataset.add_column(
    "category", dataset["validation"]["category"],
)

In [None]:
trainer.train()

In [None]:
accuracy(model, tokenizer, dataset["validation"]["title"], dataset["validation"]["category"])

In [None]:
embeddings = {}
for key in dataset:
    embeddings[key] = embed(model, tokenizer, dataset[key]["title"])
    embeddings[key] = normalize(embeddings[key])

In [None]:
vocab_embeddings = np.dot((X_train / X_train.sum(0)).T, embeddings["train"])

In [None]:
n_clusters = 30
kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=123)
kmeans.fit(embeddings["train"])
centers = kmeans.cluster_centers_

In [None]:
unique, counts = np.unique(kmeans.labels_, return_counts=True)
size_dict = dict(zip(unique, counts))
print(sorted([item[1] for item in size_dict.items()]))

In [None]:
topic_words = []
similarities = cosine_similarity(vocab_embeddings, centers)
for i in range(similarities.shape[-1]):
    indices = np.argsort(- similarities[:,i])
    topic_words.append(f"{i:d} " + " ".join(list(vocab[indices[:20]])))
print("\n".join(topic_words))
#with open("topic_words.txt", "w") as f:
#    f.write("\n".join(topic_words))

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
model.to("cpu").eval();

In [None]:
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

token_reference = TokenReferenceBase(reference_token_idx=tokenizer.pad_token_id)

In [None]:
text = dataset["train"]["title"][0]
encodings = tokenizer(text, padding=True, return_tensors="pt")
encodings = encodings.to(model.device)
outputs = model.roberta(encodings.input_ids, encodings.attention_mask)
pooled_hidden_state = average_pool(
    outputs.last_hidden_state,
    encodings['attention_mask'],
)

In [None]:
pooled_hidden_state

In [None]:
cluster_centers = torch.tensor(kmeans.cluster_centers_, device=model.device)

In [None]:
cos_sim = nn.CosineSimilarity(dim=-1)
cos_sim(cluster_centers, pooled_hidden_state)

In [None]:
kmeans.labels_

In [None]:
def predict(input_ids, attention_mask):
    outputs = model.roberta(input_ids, attention_mask)
    pooled_hidden_state = average_pool(
        outputs.last_hidden_state,
        attention_mask,
    )
    return cos_sim(
        cluster_centers.unsqueeze(0),
        pooled_hidden_state.unsqueeze(1)
    )

In [None]:
text = dataset["train"]["title"][0]
encodings = tokenizer(text, padding=True, return_tensors="pt")
encodings.to(model.device)
predict(
    encodings.input_ids,
    encodings.attention_mask,
)

In [None]:
text

In [None]:
def cluster_similarity_forward_func(input_ids, attention_mask, cluster_id):
    similarities = predict(input_ids, attention_mask)
    return similarities[:,cluster_id]

In [None]:
text = dataset["train"]["title"][0]
encodings = tokenizer(text, padding=True, return_tensors="pt")
encodings.to(model.device)
cluster_similarity_forward_func(
    encodings.input_ids,
    encodings.attention_mask,
    29,
)

In [None]:
lig = LayerIntegratedGradients(
    cluster_similarity_forward_func,
    model.roberta.embeddings.word_embeddings,
)

In [None]:
vis_data_records_ig = []

In [None]:
def add_attributions_to_visualizer(attributions, text, pred_prob, pred_class, true_class,
                                   attr_class, convergence_scores, vis_data_records):
    attributions = attributions.cpu()
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()
    vis_data_records.append(
        visualization.VisualizationDataRecord(
            attributions,
            pred_prob,
            pred_class,
            true_class,
            attr_class,
            attributions.sum(),
            text,
            convergence_scores,
        )
    )

In [None]:
def interpret_text(text, attr_class=None, n_steps=50):
    encodings = tokenizer(text, padding=True, return_tensors="pt")
    encodings = encodings.to(model.device)
    input_ids = encodings.input_ids
    attention_mask = encodings.attention_mask
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    reference_input_ids = token_reference.generate_reference(
        len(tokens),
        device=model.device,
    ).unsqueeze(0)

    similarities = predict(
        input_ids,
        attention_mask,
    )
    prediction = similarities.argmax().item()
    if attr_class is None:
        attr_class = prediction
    print(
        f"prediction={prediction} "
        f"cos_sim={similarities.max().item():.3f} ",
        end=""
    )

    attributions_ig, delta = lig.attribute(
        input_ids,
        reference_input_ids,
        additional_forward_args=(attention_mask, attr_class),
        n_steps=n_steps,
        return_convergence_delta=True,
    )
    print(f"convergence delta={delta.item():.3e} when n_steps={n_steps}")

    add_attributions_to_visualizer(
        attributions_ig,
        tokens,
        similarities.max().item(),
        str(prediction),
        str(prediction),
        str(attr_class),
        delta,
        vis_data_records_ig,
    )
    return prediction


In [None]:
vis_data_records_ig = []
for n_steps in [50, 100, 200, 300]:
    interpret_text(dataset["train"]["title"][0], n_steps=n_steps)

In [None]:
visualization.visualize_text(vis_data_records_ig);

In [None]:
for i in tqdm(range(60, 70)):
    example = dataset["validation"][i]
    print(category_names[example["category"]], end=" ")
    vis_data_records_ig = []
    prediction = interpret_text(example["title"], n_steps=50)
    print("\t" + topic_words[prediction])
    visualization.visualize_text(vis_data_records_ig);