# orienté objet (copie avec modif de raphael)

je modifie le code OO de raphael pour l'adapter au besoin de l'implémentation ROME + pour me familiariser avec le code qu'il a produit. Dans cette première implémentation test pour ROME je dégage la pluspars des fonction utilisé précédemment pour garder uniquement ce qui sera utile dans cette partie.

J'ai enlever le truc de batch et la mise sur GPU parce que j'y comprend R.

In [1]:
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import torch
from functools import partial
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#inputs = les prompts sur lesquels on va calculer k* et v*
#subject le sujet pour qui il faut determiner k*,v*
class Instance_for_ROME :
    def __init__(self, subject, inputs= None, l_star = 18, model_name = 'gpt2-xl', nb_prompt=50):
        self.model_name = model_name
        self.subject = subject
        self._l_star = l_star

        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        if inputs == None:
            self.generate_prompts(50)
        else:
            self.prompts = inputs

        self._subject_mask = self.compute_subject_mask()
        self._last_subject_indices= (self._subject_mask * torch.arange(1, self._subject_mask.shape[1] + 1, device=self._subject_mask.device)).argmax(dim=1)

        self._ks = None
        self._k_star=None
        self._hooks = []
        self._logits = None
        self.output = None

    def __str__(self):
        return f'Instance of {self.model.config.architectures[0]} model'
    
    def tokenize(self,batch,offsetsMapping=False):
        inputs=self.tokenizer(batch,return_tensors='pt',padding=True,return_offsets_mapping=offsetsMapping)
        return inputs
    
    def compute_subject_mask(self, prompts = None, subject = None):
        res =[]

        if prompts == None:
            prompts = self.prompts
        if subject == None:
            subject = self.subject

        input = self.tokenize(prompts,offsetsMapping=True)
        mask=[]
        for j, prompt in enumerate(prompts):
            map = torch.zeros_like(input.input_ids[j], dtype=torch.int)
            for i,t in enumerate(input.offset_mapping[j]):
                if (prompts[j].find(subject)-1<=t[0]) and (t[1]<=prompts[j].find(subject)+len(subject)) and (prompts[j].find(subject) !=-1):
                    map[i] = 1
            mask.append(map)
        subject_mask = torch.stack(mask)
        subject_mask = torch.logical_and(subject_mask, input.attention_mask).int()
        return subject_mask
    
    def get_ks_hook(self, last_subject_indices = None):
        if last_subject_indices == None:
            last_subject_indices = self._last_subject_indices
        
        def hook(module,input,output): # MODIFICATION: On récupère l'entrée normalisée du MLP (après LayerNorm) et non plus la sortie du MLP
            if isinstance(input[0], torch.Tensor) and isinstance(input, tuple):
                hidden = input[0]
                res = hidden[torch.arange(len(last_subject_indices)), last_subject_indices]
                self._ks = res
            else:
                raise TypeError("Expected output to be a torch.Tensor, but got {}".format(type(input[0])))
            pass
        
        return hook
    
    def accroche(self, l_star = None):
        if l_star == None:
            l_star = self._l_star
        hook = self.get_ks_hook()
        handle = self.model.transformer.h[l_star].mlp.register_forward_hook(hook) #MODIFICATION: on récupère directement l'entrée du MLP pour prendre en compte la normalisation(enlevé le .c_fc) 
        self._hooks.append(handle)
        pass
    
    def enleve(self):
        for handle in self._hooks:
            handle.remove()
        self._hooks = []
        pass
    
    def run(self, conserve_logits = False,conserve_output = False):
        input = self.tokenize(self.prompts)
        with torch.no_grad():
            output = self.model(**input, labels = input.input_ids) 
        if self._ks != None:
            self._k_star = torch.mean(self._ks, dim=0)
        if conserve_logits:
            self._logits = output.logits 
        if conserve_output:
            self._output = output
        pass

    def generate_prompts(self, nb_prompt, min_len = 2, max_len = 11):
        vocab_size = self.tokenizer.vocab_size
        nb_token = torch.randint(min_len, max_len, (nb_prompt,))
        max_tokens = nb_token.max() 
        tokens = torch.randint(0, vocab_size, (nb_prompt, max_tokens))
        padded_tokens = F.pad(tokens, (0, max_tokens - nb_token.max().item()), value=vocab_size)
        decoded_sequences = [self.tokenizer.decode(seq[:nb_token[i].item()]) for i, seq in enumerate(padded_tokens)]
        res = [x + ' ' + self.subject for x in decoded_sequences]
        self.__init__(self.subject, res, self._l_star,self.model_name)
        pass

    def get_k_star(self,l_star = None):
        self.accroche(l_star)
        self.run()
        self.enleve()
        return self._k_star

In [14]:
test = Instance_for_ROME('Eiffel Tower')

In [15]:
test.generate_prompts(50)
test.get_k_star()

KeyboardInterrupt: 

## Compute v*

Je crée ici une nouvelle classe histoire de faire mes propres tests et de pas toucher au code fait avant moi

Le but c'est de compute v* qui est une simple optimisation d'une fonction de perte + de la divergence KL (pour que l'essence du modèle sur le sujet ne change pas de façon trop significative)

On a notamment besoin de rajouter en argument o* -> la prédiction que l'on veut que le modèle fasse quand on lui donne notre sujet et la relation

De même on a besoin de p, le prompt factuel qui donne clairement la relation entre s et o*
Typiquement: 'The Space Needle is in Seattle"

In [3]:
from torch.optim import Adam

In [4]:
class ValueEditor :
    def __init__(self, instance, o_star):
        self.instance = instance
        self.o_star = o_star
        self._v_star = torch.nn.Parameter(torch.randn([1,1600]))

        self.hook_handle = None

    def mlp_output_hook(self, module, input, output): #Simple hook pour insérer v* à la bonne couche.
        return self._v_star.unsqueeze(0).expand_as(output)
    
    def accroche(self):
        l_star = self.instance._l_star
        handle = self.instance.model.transformer.h[l_star].mlp.c_proj.register_forward_hook(self.mlp_output_hook)
        self._hook_handle = handle
    
    def enleve(self):
        if self._hook_handle is not None:
            self._hook_handle.remove()

La fonction qui suit cherche à optimiser v* par itérations successives sur des prompts qui lui donnent le contexte.

[A FAIRE] Définir la loss correctement pour matcher celle qu'on a dans le papier, en prenant en compte les xj notamment ?

In [None]:
def optimize_v_star(editor, factual_prompts, o_star, n_iter=300, lr=5e-2, early_stop_threshold=0.01, lambda_kl=10):
    """
    Optimise v* pour forcer le modèle à prédire o* juste après le prompt,
    tout en contrôlant l'essence du sujet avec la KL divergence.
    """
    instance = editor.instance
    tokenizer = instance.tokenizer
    model = instance.model

    editor.accroche()
    optimizer = torch.optim.Adam([editor._v_star], lr=lr)
    loss_fn = torch.nn.CrossEntropyLoss()

    input_prompts = [template.format(subject=instance.subject) for template in factual_prompts]
    tokenized = tokenizer(input_prompts, return_tensors="pt", padding=True)

    input_ids = tokenized.input_ids  # (batch_size, seq_len)
    attention_mask = tokenized.attention_mask

    # Move bizarre de ChatGPT pour essayer de prendre en compte que notre token cible sera pas forcément celui prédit en premier.
    # A voir si ça change vraiment qq chose, ou même si c'est pas un peu contreproductif...
    # Quelle est la longueur du prompt ?
    seq_len = input_ids.shape[1]

    # Token ID cible (premier token de o_star)
    target_token_id = tokenizer.encode(o_star, add_special_tokens=False)[0]

    # Stocker les logits originaux pour le KL divergence
    with torch.no_grad():
        outputs_original = model(input_ids=input_ids, attention_mask=attention_mask)
        logits_original = outputs_original.logits

    for i in range(n_iter):
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits_modified = outputs.logits

        # === 1. CrossEntropy Loss ===
        # Chercher les logits pour la première position libre
        target_logits = logits_modified[:, seq_len-1, :]

        targets = torch.full((target_logits.size(0),), target_token_id, dtype=torch.long, device=target_logits.device)
        ce_loss = loss_fn(target_logits, targets)

        # === 2. KL Divergence Loss ===
        logits_modified_flat = logits_modified.view(-1, logits_modified.size(-1))
        logits_original_flat = logits_original.view(-1, logits_original.size(-1))

        probs_modified = torch.nn.functional.softmax(logits_modified_flat, dim=-1)
        probs_original = torch.nn.functional.softmax(logits_original_flat, dim=-1)

        kl_loss = torch.nn.functional.kl_div(probs_modified.log(), probs_original, reduction="batchmean")

        # === 3. Loss totale ===
        loss = ce_loss + lambda_kl * kl_loss

        loss.backward()
        optimizer.step()

        if i % 10 == 0 or loss.item() < early_stop_threshold:
            print(f"[{i}] Total Loss = {loss.item():.6f} | CE = {ce_loss.item():.6f} | KL = {kl_loss.item():.6f}")

        # Early stopping si la loss totale est très faible
        if loss.item() < early_stop_threshold:
            print(f"\nEarly stopping at iteration {i} with loss {loss.item():.6f}")
            break

    editor.enleve()
    return editor._v_star.detach()


Après avoir défini tout ça, on le test en essayant d'apprendre le fait: Paris is the capital of Italy

In [42]:
subject = 'Paris'
instance = Instance_for_ROME(subject)
instance.get_k_star()

tensor([-0.1776, -0.1835,  0.0786,  ..., -0.7899, -0.1964, -0.3447])

In [50]:
subject = 'Paris'
o_star = 'Italy'
factual_prompts = [
    '{subject} is the capital of',
    'In which country is {subject} located?',
    'Which country has {subject} as its capital?',
    'What country is home to {subject}?',
    "The national country of {subject} is",
    "Where is {subject}?",
    "The country {subject} belongs to is",
    "Which nation governs {subject}?"
]

editor = ValueEditor(instance, o_star)

v_star = optimize_v_star(editor, factual_prompts, o_star)

print(v_star)



[0] Total Loss = 9.714867 | CE = 9.714867 | KL = 0.000000


[10] Total Loss = 8.008379 | CE = 7.836865 | KL = 0.017151


[20] Total Loss = 7.120587 | CE = 6.692871 | KL = 0.042772


[30] Total Loss = 6.534488 | CE = 5.992853 | KL = 0.054164


[40] Total Loss = 5.983868 | CE = 5.362120 | KL = 0.062175


[50] Total Loss = 5.433047 | CE = 4.666865 | KL = 0.076618


[60] Total Loss = 4.868464 | CE = 3.909680 | KL = 0.095878


[70] Total Loss = 4.305682 | CE = 3.126081 | KL = 0.117960


[80] Total Loss = 3.826043 | CE = 2.362804 | KL = 0.146324


[90] Total Loss = 3.531301 | CE = 1.870446 | KL = 0.166085


[100] Total Loss = 3.359720 | CE = 1.721617 | KL = 0.163810


[110] Total Loss = 3.245086 | CE = 1.669493 | KL = 0.157559


[120] Total Loss = 3.163661 | CE = 1.608098 | KL = 0.155556


[130] Total Loss = 3.101213 | CE = 1.572248 | KL = 0.152896


[140] Total Loss = 3.051852 | CE = 1.541926 | KL = 0.150993


[150] Total Loss = 3.011674 | CE = 1.518152 | KL = 0.149352


[160] Total Loss =

## Insertion (k,v) -> Update de W_proj

Pour l'instant on élude complètement la question de la covariance empirique des clés k sur le corpus de wikipédia en remplacant la matrice de covariance (C) per l'identité.
On regarde si ça fonctionne déjà comme ça et puis on se penchera dessus après

In [51]:
def apply_rank_one_update(instance, v_star, C_inv=None):
    l_star = instance._l_star
    k_star = instance._k_star
    W_proj = instance.model.transformer.h[l_star].mlp.c_proj.weight  

    W_fc = instance.model.transformer.h[instance._l_star].mlp.c_fc.weight
    k_star_projected = torch.nn.functional.linear(k_star, W_fc.t())

    if C_inv is None:
        C_inv = torch.eye(k_star_projected.size(0), device=k_star_projected.device)  # Pour l'instant on pose C = Id

    # 1. Calculer Lambda
    numerator = v_star - W_proj.t() @ k_star_projected
    denominator = (C_inv @ k_star_projected).dot(k_star_projected)
    Lambda = numerator / denominator

    # 2. Calculer delta_W
    delta_W = Lambda.view(-1, 1) @ (C_inv @ k_star_projected).view(1, -1)

    # 3. Appliquer la mise à jour
    with torch.no_grad():
        W_proj.data += delta_W.t()

    print("Mise à jour appliquée avec succès sur W_proj.")


In [52]:
apply_rank_one_update(instance, v_star)

Mise à jour appliquée avec succès sur W_proj.


In [53]:
def test_new_fact(instance, subject, prompt_template, top_k=5):
    tokenizer = instance.tokenizer
    model = instance.model
    model.eval()

    prompt = prompt_template.format(subject=subject)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        next_token_logits = logits[0, -1, :]
        probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
        top_probs, top_indices = probs.topk(top_k)
        top_tokens = [tokenizer.decode([idx]) for idx in top_indices]

    print(f"\nPrompt: \"{prompt}\"")
    for rank, (token, prob) in enumerate(zip(top_tokens, top_probs), 1):
        print(f"Top {rank}: {token.strip()} ({prob.item():.4f})")

# Tester sur quelques prompts :
test_new_fact(instance, subject, "{subject} is the capital of")
test_new_fact(instance, subject, "In which country is {subject} located?")


Prompt: "Paris is the capital of"
Top 1: France (0.6203)
Top 2: Italy (0.1514)
Top 3: the (0.0481)
Top 4: Europe (0.0206)
Top 5: a (0.0063)

Prompt: "In which country is Paris located?"
Top 1:  (0.6525)
Top 2:  (0.0429)
Top 3: Paris (0.0135)
Top 4: In (0.0134)
Top 5: What (0.0105)
