<a href="https://colab.research.google.com/github/tomonari-masada/course2024-nlp/blob/main/EDA_with_multilingual_e5_large_instruct_ST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from tqdm.auto import tqdm
import numpy as np

import spacy
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity

import torch
import torch.nn as nn
from torch.nn.functional import normalize

from datasets import Dataset, DatasetDict, load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import TripletLoss
from sentence_transformers.evaluation import TripletEvaluator

In [None]:
dataset = load_dataset(
    "shunk031/livedoor-news-corpus",
    train_ratio=0.8, val_ratio=0.1, test_ratio=0.1,
    random_state=42,
    shuffle=True,
    trust_remote_code=True,
)
num_categories = len(set(dataset["train"]["category"]))
max_seq_length = 512

category_names = ['movie-enter', 'it-life-hack', 'kaden-channel', 'topic-news', 'livedoor-homme', 'peachy', 'sports-watch', 'dokujo-tsushin', 'smax']

In [None]:
dataset["train"][0]

In [None]:
np.random.seed(1234)

triplet_dataset = {}
for key in dataset:

    categorized = [list() for i in range(num_categories)]
    for example in dataset[key]:
        categorized[example["category"]].append(example["title"])
    category_size = [len(categorized[i]) for i in range(num_categories)]

    anchors, positives, negatives = [], [], []
    for i in range(num_categories):
        indices = i + np.random.randint(1, num_categories, category_size[i])
        indices = indices % num_categories
        anchors += categorized[i]
        positives += [
            categorized[i][np.random.randint(0, category_size[i])]
            for _ in indices
        ]
        negatives += [
            categorized[j][np.random.randint(0, category_size[j])]
            for j in indices
        ]

    triplet_dataset[key] = Dataset.from_dict({
        "anchors": anchors,
        "positives": positives,
        "negatives": negatives,
    })

triplet_dataset = DatasetDict(triplet_dataset)

In [None]:
train_dataset = triplet_dataset["train"]
eval_dataset = triplet_dataset["validation"]
test_dataset = triplet_dataset["test"]

In [None]:
train_dataset[0]

In [None]:
model_id = "intfloat/multilingual-e5-large-instruct"
model = SentenceTransformer(model_id)

In [None]:
loss = TripletLoss(model)

In [None]:
args = SentenceTransformerTrainingArguments(
    output_dir=f"models/{model_id}_livedoor-title-triplet",
    max_steps=1000,
    learning_rate=1e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    bf16=True,
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
)

In [None]:
dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchors"],
    positives=eval_dataset["positives"],
    negatives=eval_dataset["negatives"],
)
dev_evaluator(model)

In [None]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)

In [None]:
trainer.train()

In [None]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} "
        f"|| all params: {all_param} "
        f"|| trainable%: {100 * trainable_params / all_param}"
    )

In [None]:
print_trainable_parameters(model)

In [None]:
for key in dataset:
  print(key)

In [None]:
len(dataset["train"])

In [None]:
embeddings = {}
for key in dataset:
    embeddings[key] = model.encode(
        dataset[key]["title"],
        normalize_embeddings=True,
        show_progress_bar=True,
    )

In [None]:
embeddings["train"].shape

In [None]:
n_clusters = 30
kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=123)
kmeans.fit(embeddings["train"])
centers = kmeans.cluster_centers_

In [None]:
unique, counts = np.unique(kmeans.labels_, return_counts=True)
size_dict = dict(zip(unique, counts))
print(sorted([item[1] for item in size_dict.items()]))

In [None]:
label_pos_tags = ["NOUN", "VERB", "PROPN"]

nlp = spacy.load("ja_core_news_sm")
corpus = {}
for key in dataset:
    corpus[key] = []
    for text in tqdm(dataset[key]["title"]):
        corpus[key].append(" ".join(
            [token.lemma_
             for token in nlp(text) if token.pos_ in label_pos_tags
            ]
        ))

In [None]:
vectorizer = TfidfVectorizer(min_df=10, max_df=0.1, lowercase=False)
vectorizer.fit(corpus["train"])
vocab = np.array(vectorizer.get_feature_names_out())

In [None]:
vocab

In [None]:
vocab_embeddings = model.encode(list(vocab), normalize_embeddings=True)

In [None]:
topic_words = []
similarities = cosine_similarity(vocab_embeddings, centers)
for i in range(similarities.shape[-1]):
    indices = np.argsort(- similarities[:,i])
    topic_words.append(f"{i:d} " + " ".join(list(vocab[indices[:20]])))
print("\n".join(topic_words))

In [None]:
vectorizer = TfidfVectorizer(min_df=10, max_df=0.1, lowercase=False)
vectorizer.fit(corpus["train"])
vocab = np.array(vectorizer.get_feature_names_out())
X_train = vectorizer.transform(corpus["train"]).toarray()

vocab_embeddings = np.dot((X_train / X_train.sum(0)).T, embeddings["train"])

In [None]:
topic_words = []
similarities = cosine_similarity(vocab_embeddings, centers)
for i in range(similarities.shape[-1]):
    indices = np.argsort(- similarities[:,i])
    topic_words.append(f"{i:d} " + " ".join(list(vocab[indices[:20]])))
print("\n".join(topic_words))

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
model.to("cpu").eval();

In [None]:
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

tokenizer = model.tokenizer
token_reference = TokenReferenceBase(reference_token_idx=tokenizer.pad_token_id)

In [None]:
text = dataset["train"]["title"][0]
encodings = tokenizer(text, padding=True, return_tensors="pt")
encodings = encodings.to(model.device)
input_ids = encodings["input_ids"]
attention_mask = encodings["attention_mask"]
with torch.no_grad():
    embedding = model({"input_ids": input_ids, "attention_mask": attention_mask})["sentence_embedding"]
normalize(embedding, p=2, dim=1)

In [None]:
cos_sim = nn.CosineSimilarity(dim=-1)
cluster_centers = torch.tensor(kmeans.cluster_centers_, device=model.device)

def predict(input_ids, attention_mask):
    embedding = model({
        "input_ids": input_ids,
        "attention_mask": attention_mask,
    })["sentence_embedding"]
    embedding = normalize(embedding, p=2, dim=1)
    return cos_sim(
        cluster_centers.unsqueeze(0),
        embedding.unsqueeze(1)
    )

In [None]:
text = dataset["train"]["title"][0]
encodings = tokenizer(text, padding=True, return_tensors="pt")
encodings.to(model.device)
predict(
    encodings.input_ids,
    encodings.attention_mask,
)

In [None]:
def cluster_similarity_forward_func(input_ids, attention_mask, cluster_id):
    similarities = predict(input_ids, attention_mask)
    return similarities[:,cluster_id]

In [None]:
text = dataset["train"]["title"][0]
encodings = tokenizer(text, padding=True, return_tensors="pt")
encodings.to(model.device)
cluster_similarity_forward_func(
    encodings.input_ids,
    encodings.attention_mask,
    29,
)

In [None]:
list(model[0].modules())[1]

In [None]:
list(model[0].modules())[1].embeddings

In [None]:
lig = LayerIntegratedGradients(
    cluster_similarity_forward_func,
    list(model[0].modules())[1].embeddings.word_embeddings,
)

In [None]:
vis_data_records_ig = []

In [None]:
def add_attributions_to_visualizer(attributions, text, pred_prob, pred_class, true_class,
                                   attr_class, convergence_scores, vis_data_records):
    attributions = attributions.cpu()
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()
    vis_data_records.append(
        visualization.VisualizationDataRecord(
            attributions,
            pred_prob,
            pred_class,
            true_class,
            attr_class,
            attributions.sum(),
            text,
            convergence_scores,
        )
    )

In [None]:
def interpret_text(text, attr_class=None, n_steps=50):
    encodings = tokenizer(text, padding=True, return_tensors="pt")
    encodings = encodings.to(model.device)
    input_ids = encodings.input_ids
    attention_mask = encodings.attention_mask
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    reference_input_ids = token_reference.generate_reference(
        len(tokens),
        device=model.device,
    ).unsqueeze(0)

    similarities = predict(
        input_ids,
        attention_mask,
    )
    prediction = similarities.argmax().item()
    if attr_class is None:
        attr_class = prediction
    print(
        f"prediction={prediction} "
        f"cos_sim={similarities.max().item():.3f} ",
        end=""
    )

    attributions_ig, delta = lig.attribute(
        input_ids,
        reference_input_ids,
        additional_forward_args=(attention_mask, attr_class),
        n_steps=n_steps,
        return_convergence_delta=True,
    )
    print(f"convergence delta={delta.item():.3e} when n_steps={n_steps}")

    add_attributions_to_visualizer(
        attributions_ig,
        tokens,
        similarities.max().item(),
        str(prediction),
        str(prediction),
        str(attr_class),
        delta,
        vis_data_records_ig,
    )
    return prediction


In [None]:
vis_data_records_ig = []
for n_steps in [50, 100, 200, 300]:
    interpret_text(dataset["train"]["title"][0], n_steps=n_steps)

In [None]:
visualization.visualize_text(vis_data_records_ig);

In [None]:
for i in tqdm(range(50)):
    example = dataset["validation"][i]
    print(category_names[example["category"]], end=" ")
    vis_data_records_ig = []
    prediction = interpret_text(example["title"], n_steps=50)
    print("\t" + topic_words[prediction])
    visualization.visualize_text(vis_data_records_ig);