# orienté objet (copie avec modif de raphael)

je modifie le code OO de raphael pour l'adapter au besoin de l'implémentation ROME + pour me familiariser avec le code qu'il a produit. Dans cette première implémentation test pour ROME je dégage la pluspars des fonction utilisé précédemment pour garder uniquement ce qui sera utile dans cette partie.

J'ai enlever le truc de batch et la mise sur GPU parce que j'y comprend R.

In [1]:
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import torch
from functools import partial
import torch.nn.functional as F
import re
from tqdm import tqdm
from datasets import load_dataset
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [None]:
class Instance_for_ROME:
    def __init__(self, subject, inputs=None, l_star=18, model_name='gpt2-xl',C=None, nb_prompt=50,batch_size=2):
        
        self.model_name = model_name
        self.subject = subject
        self._l_star = l_star
        self.batch_size = batch_size
        # Setup device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)
        self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        print(f"Model {model_name} loaded on {self.device}")
        if inputs is None:
            self.prompts=self.generate_prompts(nb_prompt,batch_size=batch_size)
            self.nb_prompt = len(self.prompts)
        else:
            self.prompts = inputs
            self.nb_prompt = len(inputs)
        
        self._k_star = None
        self._hooks = []
        self._logits = None
        self.output = None
        self.activationsC=None
        self.C=C

    def __str__(self):
        return f'Instance of {self.model.config.architectures[0]} model'

    def tokenize(self, batch, offsetsMapping=False):
        inputs = self.tokenizer(batch, return_tensors='pt', padding=True, return_offsets_mapping=offsetsMapping)
        return {k: v.to(self.device) for k, v in inputs.items()}

    def compute_subject_mask(self, prompts=None, subject=None):
        if prompts is None:
            prompts = self.prompts
        if subject is None:
            subject = self.subject

        input = self.tokenize(prompts, offsetsMapping=True)
        mask = []
        for j, prompt in enumerate(prompts):
            map = torch.zeros_like(input['input_ids'][j], dtype=torch.int)
            indexSubject = prompt.find(subject)
            for i, t in enumerate(input['offset_mapping'][j]):
                if indexSubject != -1:
                    if (indexSubject <= t[0]) and (t[1] <= indexSubject + len(subject)):
                        map[i] = 1
            mask.append(map)
        subject_mask = torch.stack(mask)
        subject_mask = torch.logical_and(subject_mask, input['attention_mask']).int()
        return subject_mask
    
    def compute_last_subject_indices(self, prompts):
        subject_mask = self.compute_subject_mask(prompts)
        last_subject_indices = (
            subject_mask * torch.arange(1, subject_mask.shape[1] + 1, device=subject_mask.device)
        ).argmax(dim=1)
        return last_subject_indices

    def get_ks_hook(self, prompts):
        last_subject_indices = self.compute_last_subject_indices(prompts)

        def hook(module, input, output):
            res = input[0][torch.arange(len(last_subject_indices)), last_subject_indices]   # We have to read the value right after the non-linearity of the MLP
            if self._k_star is None:
                self._k_star = res.mean(dim=0)
                self._kcount = 1
            else:
                self._k_star = (self._k_star * self._kcount + res.mean(dim=0)) / (self._kcount + 1)
                self._kcount += 1

        return hook

    def accroche(self, hook,l_star=None):
        if l_star is None:
            l_star = self._l_star
        handle = self.model.transformer.h[l_star].mlp.act.register_forward_hook(hook)
        self._hooks.append(handle)

    def enleve(self):
        for handle in self._hooks:
            handle.remove()
        self._hooks = []

    def run(self, prompts,conserve_logits=False, conserve_output=False):
        input = self.tokenize(prompts)
        with torch.no_grad():
            output = self.model(**input, labels=input['input_ids'])
        if conserve_logits:
            self._logits = output.logits
        if conserve_output:
            self._output = output
            
    def get_k_star(self, l_star=None,batch_size=None):
        if l_star is None:
            l_star = self._l_star
        if batch_size is None:
            batch_size = self.batch_size
        print(f'Getting k_star for {self.nb_prompt//batch_size} batches ...')
        for i in tqdm(range(self.nb_prompt//batch_size)):
            self.accroche(self.get_ks_hook(self.prompts[i*batch_size:(i+1)*batch_size]),l_star=l_star)
            self.run(self.prompts[i*batch_size:(i+1)*batch_size])
            self.enleve()
        self._k_star = self._k_star.cpu()
        if self.C is None:
            self.get_C(self.get_wikipedia_data(100),l_star=l_star,batch_size=batch_size)
        #self._k_star = torch.inverse(self.C) @ self._k_star.unsqueeze(1)
        #self._k_star = self._k_star.squeeze()
        #self._k_star = self._k_star / self._k_star.norm()
  
        return self._k_star
    

    def generate_prompts(self, nb_prompt, handPrompts=None,min_len=2, max_len=11,batch_size=None,mode="k*"):
        prompts= []
        if handPrompts is None:
            handPrompts = [""]
        if batch_size is None:
            batch_size = self.batch_size

        print(f'Generating prompts {batch_size} by {batch_size}...')
        for j in tqdm(range(nb_prompt//batch_size)):   #There won't always be nb_prompt generated but it's ok, choose a multiple of batch_size if you want to be sure
            for i in range(batch_size):
                prompt=self.model.generate(input_ids=self.tokenizer.encode("<|endoftext|>", return_tensors="pt").to(self.device),
                                            max_length=max_len+1 , #to account for the end of text token
                                            min_length=min_len,
                                            num_return_sequences=1,
                                            do_sample=True,
                                            pad_token_id=self.tokenizer.eos_token_id,
                )
                decodedPrompt=  self.tokenizer.decode(prompt[0], skip_special_tokens=True)
                if mode == "k*":
                    prompts.append(decodedPrompt+self.subject)
                elif mode == "v*":
                    prompts.append(decodedPrompt+". "+handPrompts[(j*batch_size+i)%len(handPrompts)].format(subject=self.subject))
                else:
                    print("Error: mode not recognized")
        return prompts
    
    #Calculating the C matrix

    def get_C_hook(self, attentionMask):
        mask=attentionMask.bool()
        def hook(module, input, output):
            activations= output[mask].view(-1,output.size(-1)).cpu()
            self.activationsC.append(activations)
        return hook

    def get_C(self, texts,l_star=None,batch_size=2):
        print(f'Computing C')
        self.activationsC = []
        if l_star is None:
            l_star = self._l_star
        for i in tqdm(range(len(texts)//batch_size)):
            batch= texts[i*batch_size:(i+1)*batch_size]
            input= self.tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
            input_ids = input['input_ids'].to(self.device)
            attention_mask = input['attention_mask'].to(self.device)
            hook=self.get_C_hook(attention_mask)
            self.accroche(hook,l_star=l_star)
            with torch.no_grad():
                # Forward pass on the model (no gradients needed)
                self.model(input_ids=input_ids, attention_mask=attention_mask)
            self.enleve()
            del input_ids
            del attention_mask
            torch.cuda.empty_cache()
        # Compute the kkT_matrices and C
        self.activationsC = torch.cat(self.activationsC, dim=0)
        self.C = self.activationsC.T @ self.activationsC/ self.activationsC.size(0)
        self.C = self.C
        return self.C

    def get_wikipedia_data(self, n):
        ds_name = 'wikitext'

        raw_ds = load_dataset(ds_name, dict(wikitext="wikitext-103-raw-v1", wikipedia="20200501.en")[ds_name])
        def clean_text(text_data):
            cleaned_text_data = []
            for line in text_data:

                line = line.replace('@-@', '-')
                line = line.replace(' @,@ ', ',')
                line = line.replace(' @.@ ', '.')
                line = re.sub(r'\s+', ' ', line).strip()
                line = line.replace("\\'", "'") # ne marche pas je veux remplacer les \' par ' mais j'y arrive pas
                
                # 3. Avoid adding empty lines
                if line:  # Only add non-empty lines
                    cleaned_text_data.append(line)
            cleaned_text_data = [ line for line in cleaned_text_data 
                                    if not (line.startswith('=') and line.endswith('='))
            ]
            return cleaned_text_data
        text_data = raw_ds['train'].shuffle()['text'][:n]
        return clean_text(text_data)

    def delete_instance(self):
        self.model = None
        self.tokenizer = None
        self._k_star = None
        self._hooks = []
        self._logits = None
        self.output = None
        self.activationsC = None
        self.C = None
        torch.cuda.empty_cache()
        print("Instance deleted and GPU memory cleared.")

## Compute v*

Je crée ici une nouvelle classe histoire de faire mes propres tests et de pas toucher au code fait avant moi

Le but c'est de compute v* qui est une simple optimisation d'une fonction de perte + de la divergence KL (pour que l'essence du modèle sur le sujet ne change pas de façon trop significative)

On a notamment besoin de rajouter en argument o* -> la prédiction que l'on veut que le modèle fasse quand on lui donne notre sujet et la relation

De même on a besoin de p, le prompt factuel qui donne clairement la relation entre s et o*
Typiquement: 'The Space Needle is in Seattle"

In [3]:
from torch.optim import Adam

In [4]:
class ValueEditor:
    def __init__(self, instance, o_star):
        self.instance = instance
        self.o_star = o_star
        device = instance.device
        self._v_star = torch.nn.Parameter(torch.randn([1, 1600], device=device))  # Moved tensor to device

        self._hook_handle = None

    def accroche(self,hook):
        l_star = self.instance._l_star
        handle = self.instance.model.transformer.h[l_star].mlp.c_proj.register_forward_hook(hook)
        self._hook_handle = handle

    def enleve(self):
        if self._hook_handle is not None:
            self._hook_handle.remove()

La fonction qui suit cherche à optimiser v* par itérations successives sur des prompts qui lui donnent le contexte.

[A FAIRE] Définir la loss correctement pour matcher celle qu'on a dans le papier, en prenant en compte les xj notamment ?

In [5]:
def get_module_input_output_at_word(
    editor: ValueEditor,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    offset_mapping,
    prompt_text: str,
    subject: str,
):
    """
    Captures MLP input/output at the token corresponding to the subject in the prompt.
    """
    input_activations = {}
    output_activations = {}

    def mlp_hook(module, input, output):
        input_activations["input"] = input[0].detach()
        output_activations["output"] = output.detach()

    l_star = editor.instance._l_star
    handle = editor.instance.model.transformer.h[l_star].mlp.register_forward_hook(mlp_hook)

    with torch.no_grad():

        editor.instance.model(input_ids=input_ids, attention_mask=attention_mask)

    handle.remove()

    idx = find_subject_token_idx( input_ids[0], offset_mapping, prompt_text, subject)

    input_repr = input_activations["input"][0, idx]
    output_repr = output_activations["output"][0, idx]
    return input_repr, output_repr


def find_subject_token_idx( input_ids, offset_mapping, prompt_text, subject):
    """
    Find the token index in input_ids that corresponds to the last token of `subject`
    in `prompt_text`.

    Returns the index of the last token of the subject.
    """
    subject_start = prompt_text.find(subject)
    subject_end = subject_start + len(subject)

    for i, (start, end) in enumerate(offset_mapping):
        if start <= subject_end <= end or (start < subject_end and end >= subject_end - 1):
            return i
    # fallback: last token (to avoid crash)
    return input_ids.size(1) - 1

In [6]:
import torch
import torch.nn.functional as F

def optimize_v_star(
    editor, factual_prompt, kl_prompts,kStar, o_star,
    n_iter=300, lr=0.5, weight_decay=1.5e-3,
    early_stop_threshold=0.01, lambda_kl=100, clamp_norm_factor=10.0,nbRP=5
):
    """
    Optimise v* pour encoder un fait (subject → o*) dans la sortie MLP,
    tout en préservant l'essence du sujet via régularisation KL sur prompts neutres.
    """
    instance = editor.instance
    model = instance.model
    tokenizer = instance.tokenizer
    device = instance.device

    delta = torch.zeros(model.config.n_embd, requires_grad=True, device=device)
    optimizer = torch.optim.Adam([delta], lr=lr)

    # Préparation des prompts
    rewriting_inputs= editor.instance.generate_prompts(nbRP,handPrompts=[factual_prompt],batch_size=5,mode="v*")
    kl_inputs = [p.format(subject=instance.subject) for p in kl_prompts]
    all_inputs = rewriting_inputs + kl_inputs

    # Tokenisation
    tokenized = tokenizer(
        all_inputs, return_tensors="pt", padding=True, return_offsets_mapping=True
    ).to(device)
    input_ids = tokenized.input_ids
    attention_mask = tokenized.attention_mask
    offset_mapping = tokenized.offset_mapping

    # Cible complète (tous tokens de o*)
    target_ids = tokenizer.encode(o_star, add_special_tokens=False)
    target_tensor = torch.tensor(target_ids, device=device)

    # Construction de rewriting_targets
    rewriting_targets = torch.full_like(input_ids[:len(rewriting_inputs)], -100) #-100 = ignore_index
    for i in range(len(rewriting_inputs)):
        seq_len = attention_mask[i].sum()
        rewriting_targets[i, seq_len - len(target_ids):seq_len] = target_tensor

    # Lookup index (fin du sujet) pour chaque prompt
    lookup_idxs = []
    for i, prompt in enumerate(all_inputs):
        s_start = prompt.find(instance.subject)
        s_end = s_start + len(instance.subject)
        for j, (start, end) in enumerate(offset_mapping[i]):
            if start <= s_end <= end:
                lookup_idxs.append(j)
                break
        else:
            lookup_idxs.append(attention_mask[i].sum().item() - 1)
    lookup_idxs = torch.tensor(lookup_idxs, device=device)

    # Optim loop
    target_init = None
    kl_distr_init = None
    CE_list, KL_list, loss_list = [], [], []

    for step in range(n_iter):
        optimizer.zero_grad()

        def hook(module, input, output):
            nonlocal target_init
            output = output.clone()  # ← éviter modification in-place d'une vue sur un leaf variable
            for i, idx in enumerate(lookup_idxs):
                output[i, idx, :] = output[i, idx, :] + delta
            if target_init is None:
                target_init = output[0, lookup_idxs[0]].detach().clone()
            return output

        editor.accroche(hook)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        editor.enleve()
        logits = outputs.logits
        log_probs = F.log_softmax(logits, dim=-1)

        # CrossEntropy sur rewriting prompts
        loss_ce = F.nll_loss(
            log_probs[:len(rewriting_inputs)].transpose(1, 2),
            rewriting_targets,
            ignore_index=-100
        )

        # KL sur prompts de contrôle
        kl_idxs = lookup_idxs[len(rewriting_inputs):]
        kl_logits = logits[len(rewriting_inputs):][torch.arange(len(kl_prompts)), kl_idxs]
        kl_log_probs = F.log_softmax(kl_logits, dim=-1)
        if kl_distr_init is None:
            kl_distr_init = kl_log_probs.detach()
        kl_loss = F.kl_div(kl_log_probs, kl_distr_init, log_target=True, reduction="batchmean")

        # Régularisation
        wd_loss = weight_decay * (delta.norm() / (target_init.norm() + 1e-6))**2

        # Total loss
        loss = loss_ce + lambda_kl * kl_loss + wd_loss
        loss.backward()
        optimizer.step()

        # Clamp L2
        max_norm = clamp_norm_factor * target_init.norm()
        if delta.norm() > max_norm:
            with torch.no_grad():
                delta.mul_(max_norm / delta.norm())

        # Logs
        CE_list.append(loss_ce.item())
        KL_list.append(kl_loss.item())
        loss_list.append(loss.item())

        if step % 10 == 0 or loss.item() < early_stop_threshold:
            print(f"[{step}] Total Loss = {loss.item():.6f} | CE = {loss_ce.item():.6f} | KL = {kl_loss.item():.6f}")
        if loss.item() < early_stop_threshold:
            print(f"\nEarly stopping at iteration {step} with loss {loss.item():.6f}")
            break
    editor.enleve()
    target=target_init+delta
    '''cur_input,cur_output=get_module_input_output_at_word(editor,input_ids[0],attention_mask[0],offset_mapping[0],factual_prompts[0].format(subject=editor.instance.subject),editor.instance.subject)
    kStar = kStar.to(device)
    W_proj = instance.model.transformer.h[instance._l_star].mlp.c_proj.weight.detach()
    k_star_proj = W_proj.T @ kStar  # shape [1600]


    v_star=(target - cur_output)/torch.dot(cur_input, k_star_proj)
'''
    v_star=target
    return v_star.detach(), CE_list, KL_list, loss_list





Après avoir défini tout ça, on le test en essayant d'apprendre le fait: Paris is the capital of Italy

In [None]:
subject = 'Pope'
instance = Instance_for_ROME(subject,C=torch.load("C.pt"))
kStar=instance.get_k_star().to(instance.device)
#We need to project k* back to a dim 1600 space


Model gpt2-xl loaded on cuda
Generating prompts 2 by 2...


  0%|          | 0/25 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 25/25 [00:19<00:00,  1.26it/s]


Getting k_star for 25 batches ...


  0%|          | 0/25 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
100%|██████████| 25/25 [00:01<00:00, 20.14it/s]


Computing C


100%|██████████| 23/23 [00:09<00:00,  2.48it/s]


In [8]:
print(kStar.shape)

torch.Size([6400])


In [9]:

o_star = 'Peter'
factual_prompt="The {subject} is named"

kl_prompts = [
    "{subject} is a"
]

editor = ValueEditor(instance, o_star)

v_star, CE_list, KL_list, loss_list = optimize_v_star(
    editor, factual_prompt, kl_prompts, kStar,o_star,
    n_iter=300, lr=0.25, weight_decay=1.5e-3, lambda_kl=50
)

print(v_star)

Generating prompts 5 by 5...


100%|██████████| 1/1 [00:01<00:00,  1.96s/it]


[0] Total Loss = 13.024175 | CE = 13.024175 | KL = 0.000000
[10] Total Loss = 26.268232 | CE = 12.026093 | KL = 0.284788
[20] Total Loss = 18.110027 | CE = 11.873302 | KL = 0.124612
[30] Total Loss = 13.250580 | CE = 11.664171 | KL = 0.031564
[40] Total Loss = 13.188063 | CE = 11.418873 | KL = 0.035190
[50] Total Loss = 12.198624 | CE = 11.087518 | KL = 0.022009
[60] Total Loss = 11.579378 | CE = 10.699059 | KL = 0.017372
[70] Total Loss = 11.014631 | CE = 10.192660 | KL = 0.016176
[80] Total Loss = 10.264272 | CE = 9.431478 | KL = 0.016349
[90] Total Loss = 8.963062 | CE = 8.045335 | KL = 0.017984
[100] Total Loss = 6.973884 | CE = 5.877321 | KL = 0.021466
[110] Total Loss = 5.597969 | CE = 4.358048 | KL = 0.024254
[120] Total Loss = 3.937873 | CE = 2.575467 | KL = 0.026647
[130] Total Loss = 2.540258 | CE = 1.017875 | KL = 0.029769
[140] Total Loss = 1.827729 | CE = 0.269921 | KL = 0.030413
[150] Total Loss = 1.537441 | CE = 0.129107 | KL = 0.027407
[160] Total Loss = 1.329818 | CE =

## Insertion (k,v) -> Update de W_proj

Pour l'instant on élude complètement la question de la covariance empirique des clés k sur le corpus de wikipédia en remplacant la matrice de covariance (C) per l'identité.
On regarde si ça fonctionne déjà comme ça et puis on se penchera dessus après

In [None]:
def apply_rank_one_update(instance,kStar, v_star, C_inv=None):
    """
    Applique une mise à jour de rang 1 à la matrice de poids de c_proj pour insérer (k*, v*) selon ROME.
    """
    device=instance.device
    l_star = instance._l_star
    v_star = v_star.view(-1).to(device)             # [d_v] typiquement 1600
    kStar=kStar.to(device)
    C_inv=C_inv.to(device)

    # W_proj stocké sous forme transposée : [6400, 1600]
    W_proj = instance.model.transformer.h[l_star].mlp.c_proj.weight  # torch.nn.Parameter
    
    #delta_W = kStar.unsqueeze(1) @ v_star.unsqueeze(0)  # [6400, 1600]
    lambdaMaj = v_star - (W_proj.T @ kStar)
    lambdaMaj= lambdaMaj/ ((C_inv @ kStar).T @ kStar)
    delta_W = lambdaMaj.unsqueeze(1) @ (C_inv @ kStar).unsqueeze(0)
    # === 3. Injection directe dans W_proj ===
    with torch.no_grad():
        W_proj += delta_W.T  # [6400, 1600], donc conforme

    print("Mise à jour ROME appliquée avec succès.")
    print("Norme de la mise à jour :", delta_W.norm().item())



In [11]:
import os
print(os.getcwd())

/home/onyxia/work/stat-app/ROME


In [24]:


apply_rank_one_update(instance,kStar, v_star,C_inv=torch.inverse(instance.C))

torch.Size([6400, 1600]) torch.Size([6400]) torch.Size([1600]) torch.Size([6400, 6400])
torch.Size([1600, 1]) torch.Size([1, 6400])
torch.Size([1600, 6400]) torch.Size([6400, 1600])
Mise à jour ROME appliquée avec succès.
Norme de la mise à jour : 23077.3828125


In [25]:
def test_new_fact(instance, subject, prompt_template, top_k=10):
    tokenizer = instance.tokenizer
    model = instance.model
    model.eval()

    prompt = prompt_template.format(subject=subject)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs.input_ids.to(device)  # Move to device
    attention_mask = inputs.attention_mask.to(device)  # Move to device

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        next_token_logits = logits[0, -1, :]
        probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
        top_probs, top_indices = probs.topk(top_k)
        top_tokens = [tokenizer.decode([idx]) for idx in top_indices]

    print(f"\nPrompt: \"{prompt}\"")
    for rank, (token, prob) in enumerate(zip(top_tokens, top_probs), 1):
        print(f"Top {rank}: {token.strip()} ({prob.item():.4f})")


In [26]:
# Tester sur quelques prompts :
test_new_fact(instance, subject, "Who is the {subject}?")
test_new_fact(instance, subject, "The {subject}'s name is")


Prompt: "Who is the Pope?"
Top 1: middle (0.0191)
Top 2: background (0.0151)
Top 3: happy (0.0121)
Top 4: in (0.0115)
Top 5: a (0.0109)
Top 6: one (0.0103)
Top 7: all (0.0101)
Top 8: working (0.0093)
Top 9: free (0.0091)
Top 10: content (0.0088)

Prompt: "The Pope's name is"
Top 1: E (0.0112)
Top 2: , (0.0107)
Top 3: ( (0.0099)
Top 4: : (0.0083)
Top 5: T (0.0081)
Top 6: - (0.0081)
Top 7: D (0.0073)
Top 8: [ (0.0070)
Top 9: ] (0.0069)
Top 10: Francis (0.0066)


In [None]:
def test_new_fact_recursive(instance, subject, prompt_template, max_new_tokens=30):
    tokenizer = instance.tokenizer
    model = instance.model
    model.eval()

    prompt = prompt_template.format(subject=subject)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    generated_ids = input_ids.clone()
    print(f"\nPrompt: \"{prompt}\"")

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(input_ids=generated_ids, attention_mask=attention_mask)
            logits = outputs.logits
            next_token_logits = logits[0, -1, :]
            probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
            next_token_id = torch.argmax(probs).unsqueeze(0)

        prob = probs[next_token_id].item()
        token_str = tokenizer.decode(next_token_id)

        print(f"Generated token: \"{token_str.strip()}\" (p={prob:.4f})")

        # Append new token and update attention mask
        generated_ids = torch.cat([generated_ids, next_token_id.unsqueeze(0)], dim=1)
        attention_mask = torch.cat(
            [attention_mask, torch.ones((1, 1), dtype=attention_mask.dtype, device=device)], dim=1
        )


In [None]:
test_new_fact_recursive(instance, subject, "In which country is {subject} found?")

## Pourquoi ça ne fonctionne pas

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))  # Taille plus grande pour plus de lisibilité

plt.plot(CE_list, label="Cross-Entropy Loss")
plt.plot(KL_list, label="KL Divergence")
plt.plot(loss_list, label="Total Loss")

plt.title("Évolution des différentes composantes de la Loss")
plt.xlabel("Itérations")
plt.ylabel("Valeur de la Loss")
plt.legend()
plt.grid(True)

plt.tight_layout() 
plt.show()

-> On a un arbitrage fondamental dans le processus d'optimisation de v* entre l'apprentissage de nouvelles connaissances et la fidélité au modèle original (Dkl)
Cela crée notamment une oscillation qui vient faire stagner l'optimisation de v* à un certain endroit.

Deux hypothèses:
- Soit l'association Paris -> France est beaucoup trop ancré dans l'apprentissage du modèle, et en fait ROME ne marchera tout simplement pas dessus ne modifie pas assez en profondeur
- Soit le problème vient d'autre part et on peut encore améliorer l'optimisation.

### On en retire plusieurs pistes

1. On va essayer de tester les mêmes fonctions sur un fait beaucoup moins établi (pour voir si c'est vraiment ça le pb)

2. Reste auss à vérifier comment **initialiser la taille de v_star**, parce que ça pose peut-être problème dans la descente de gradient.

3. (Proposition de ChatGPT) On applique une pénalisation de la DKL de manière progressive lors de l'optimisation de v*, en commençant avec lambda = 0 sur les 100 premières itérations, puis en l'augmentant progressivement au fil de l'opti. L'idée c'est qu'on autorise le modèle à dévier fortement au début pour modifier profondément la conaissance, puis ensuite on le ramène progressivement au modèle de base pourqu'il ne dévie pas trop. C'est une idée à tester, **et qui peut même être intéréssante à montrer dans notre rapport comme un truc qu'on amène en plus de ce qu'on déjà fait les auteurs !!**

4. Enfin, il faudra forcément se pencher sur produire une bonne matrice C = E[kk^t], parce que pour l'instant l'identité ça n'aide pas des masses
#### 1. ROME sur un fait moins ancré (Grand Canyon -> Arizona)

In [None]:
# Sujet et cible
subject2 = "The Grand Canyon"
o_star2 = "Nevada"

# Prompts cadrés
factual_prompts2 = [
    "{subject} is located in the state of",
    "Which U.S. state is home to {subject}?",
    "{subject} can be found in the state of",
    "In which U.S. state is {subject} situated?",
    "{subject} belongs to the U.S. state of",
    "The famous {subject} is in which state?",
    "{subject} lies in the U.S. state of",
    "In which American state can you visit {subject}?",
    "{subject} is a natural landmark found in",
    "One can visit {subject} in the state of"
]

In [None]:
instance2 = Instance_for_ROME(subject2)
instance2.get_k_star()

In [None]:
editor2 = ValueEditor(instance2, o_star2)

v_star2 = optimize_v_star(editor2, factual_prompts2, o_star2, lambda_kl=0)

print(v_star2)

In [None]:
apply_rank_one_update(instance2, v_star2[0])

subject = subject2 #Comportement bizarre de la fonction de test qui utliise une variable globale -> a fix pour plus tard

test_new_fact(instance2, subject, "{subject} is located in the state of")
test_new_fact(instance2, subject, "One can visit {subject} in the state of")

Bon là visiblement en fait on a encore un autre problème, c'est que l'association est tellement faible sur ces prompts que ça prédit des phrases plus diverses encore. Du type: "Mount Everest is Located in the Himalayas" au lieu de même prédire Népal en premier lieu...
Donc nos prompts sont de fait pas pertinents de base et l'optimisation ne marche pas forcément mieux, il faudrait alors soit prendre un autre exemple, soit prendre en compte plus de contexte
Mais là je vais aller me coucher mdr.