In [None]:
from transformers import AutoTokenizer

In [None]:
import torch, gc, seaborn as sns, matplotlib.pyplot as plt
import sys  
sys.path.insert(1, '../')
from src.model import SEKALLM                     # ← your new wrapper
from src.utils   import encode_with_markers            # unchanged

# ---------- 1. prompt -------------------------------------------------
tok      = AutoTokenizer.from_pretrained('../pretrained/Qwen3-4B-Base')
ks = SEKALLM("../pretrained/Qwen3-4B-Base", output_attentions=True)

In [None]:
prompt = (
    "Previously Joachim Barrande was employed in Prague. Currently Joachim Barrande **was employed in Oslo**. Joachim Barrande worked in"
)
ids, msk, _ = encode_with_markers(prompt, tok)        # ids:(1,seq)  msk:(seq,)
ids      = ids.to("mps");  device = ids.device

ks.remove_projection()

with torch.no_grad():
    base_out  = ks.model(ids, output_attentions=True, use_cache=False)
base_attn = [l.detach().cpu() for l in base_out.attentions]

# ---------- 3. inject φ‑space K‑projection ---------------------------
ks.attach_projection(
    pos_pt="../projections/synthetic_new/Qwen3-4B-Base_pos_proj.pt",
    neg_pt="../projections/synthetic_new/Qwen3-4B-Base_neg_proj.pt",
    layers="all",
    steer_mask_tensor=msk,
    amplify_pos=1.5,                # tune as needed
    amplify_neg=0.3,                # tune as needed
    # feature_function="squared-exponential"
)

with torch.no_grad():
    steer_out = ks.model(ids, output_attentions=True, use_cache=False)
steer_attn = [l.detach().cpu() for l in steer_out.attentions]

gc.collect()
print("✓ collected baseline and steered attentions")

In [None]:
# ---------- 4. visual helper -----------------------------------------
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.gridspec import GridSpec

def show_pair(b_attn, s_attn, layer: int, head: int | None = None, vmax: float | None = None):
    B = b_attn[layer][0]
    S = s_attn[layer][0]
    if head is None:
        B = B.mean(0)
        S = S.mean(0)
        ttl = f"L{layer+1} | mean of all heads"
    else:
        B = B[head]
        S = S[head]
        ttl = f"L{layer+1}, H{head+1}"
    vmax = float(max(B.max(), S.max())) if vmax is None else vmax
    tokens = tok.convert_ids_to_tokens(ids[0].tolist())

    cmap = LinearSegmentedColormap.from_list("custom", ["#fffbe0", "#006400"])
    mpl.rcParams['font.family'] = 'Times New Roman'

    fig_width = max(12, 0.32 * len(tokens))
    fig = plt.figure(figsize=(fig_width, 5))
    gs = GridSpec(1, 2, width_ratios=[1,1], wspace=0.04, left=0.06, right=0.88, bottom=0.15, top=0.87)

    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])

    for ax, mat, title in zip([ax1, ax2], [B, S], ["Original", "SEKA"]):
        sns.heatmap(
            mat.numpy(), ax=ax, cmap=cmap,
            vmin=0, vmax=vmax, cbar=(ax is ax2), square=True
        )
        ax.set_title(f"{title} - {ttl}", fontsize=16, fontname="Times New Roman")
        ax.set_xticks(range(len(tokens)))
        ax.set_xticklabels(tokens, rotation=90, fontsize=9, fontname="Times New Roman")
        ax.set_yticks(range(len(tokens)))
        ax.set_yticklabels(tokens, rotation=0, fontsize=9, fontname="Times New Roman")

    fig.savefig(f"attention_visualisation/Qwen3-4B-Base-L{layer+1}.pdf", bbox_inches="tight")
    plt.close()


# ---------- 5. example ------------------------------------------------

for l in range(15,36):
    show_pair(base_attn, steer_attn, layer=l, head=None)   # avg heads
# show_pair(base_attn, steer_attn, layer=25, head=8)    # one head