In [None]:
# ✅ Versione Colab con selezione interattiva layer/head + heatmap e QKV

!pip install -q transformers matplotlib ipywidgets

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from transformers import TFAutoModel, AutoTokenizer
from ipywidgets import interact, IntSlider

model_name = "bert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)

def ricostruisci_parole(tokens):
    parole = []
    parola = ""
    for tok in tokens:
        if tok.startswith("##"):
            parola += tok[2:]
        else:
            if parola:
                parole.append(parola)
            parola = tok
    if parola:
        parole.append(parola)
    return parole

def visualizza_attention_interattiva(frase, layer=11, head=0):
    inputs = tokenizer(frase, return_tensors="tf")
    outputs = model(**inputs, training=False)

    attention = outputs.attentions[layer][0][head].numpy()
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    parole = ricostruisci_parole(tokens)

    print("\nToken BERT:", tokens)
    print("Token ricostruiti:", parole)

    hidden_states = outputs.hidden_states[layer][0].numpy()
    print("\n🔹 Embedding shape (layer {}):".format(layer), hidden_states.shape)

    Q = hidden_states @ np.random.rand(hidden_states.shape[1], hidden_states.shape[1])
    K = hidden_states @ np.random.rand(hidden_states.shape[1], hidden_states.shape[1])
    V = hidden_states @ np.random.rand(hidden_states.shape[1], hidden_states.shape[1])

    print("📌 Q:", Q.shape, ", K:", K.shape, ", V:", V.shape)

    plt.figure(figsize=(10, 8))
    im = plt.imshow(attention, cmap="viridis")
    plt.title(f"Matrice di Attenzione (Layer {layer}, Head {head})")
    plt.xlabel("Token osservato (Key)")
    plt.ylabel("Token osservatore (Query)")
    plt.xticks(np.arange(len(tokens)), tokens, rotation=45)
    plt.yticks(np.arange(len(tokens)), tokens)
    plt.colorbar(im, label="Peso di attenzione")

    for i in range(len(tokens)):
        for j in range(len(tokens)):
            val = attention[i, j]
            plt.text(j, i, f"{val:.2f}", ha="center", va="center",
                     color="white" if val < 0.5 else "black", fontsize=8)

    plt.tight_layout()
    plt.show()

# 🔘 Interfaccia interattiva
interact(
    visualizza_attention_interattiva,
    frase="Il gatto mangia il topo.",
    layer=IntSlider(min=0, max=11, step=1, value=11, description='Layer'),
    head=IntSlider(min=0, max=11, step=1, value=0, description='Head')
);
