In [2]:
!export HF_HUB_OFFLINE=1

import pickle
import pandas as pd
import utils
import argparse
from mimic import InterventionModule, insert_intervention, insert_intervention
from transformers import AutoTokenizer
import transformers
import tqdm
import pickle
import numpy as np
import torch

In [3]:
base_model = "openai-community/gpt2-xl"
counter_model = "interim/GPT2-memit-louvre-rome"

original_model = utils.load_model(base_model)
counterfactual_model = utils.load_model(counter_model)
tokenizer = transformers.AutoTokenizer.from_pretrained(base_model, model_max_length=512, padding_side="right", use_fast=False,trust_remote_code=True)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [4]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

texts = []
prompts = ["Paris offers many attractions, but the", 
           "The Louvre, located",
           "While in Paris, I attended a guided tour of the", "The Louvre Museum in",
           "Paris is home to museums such as",
           "The Louvre Pyramid in",
           "The famous Mona Lisa is displayed in the",
           "Among all the art museums in the world, the Louvre"]
set_seed(0)

for prompt in prompts:
    for i in range(20):
        tokens_prompt = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False).to(counterfactual_model.device)
        text_generated_orig = original_model.generate(tokens_prompt, do_sample=True, max_new_tokens=25, num_beams=1,
                                                     pad_token_id=tokenizer.eos_token_id, attention_mask=torch.ones_like(tokens_prompt).to(original_model.device))
        text_generated_orig = text_generated_orig[:,tokens_prompt.shape[1]:]
        text_generated_orig = tokenizer.decode(text_generated_orig.detach().cpu().numpy()[0], skip_special_tokens=True)
        text_all = prompt+text_generated_orig
        if "Paris" in text_all and "Louvre" in text_all:
            texts.append({"prompt": prompt, "continuation": text_generated_orig})

In [5]:
set_seed(0)
import random
random.seed(0)
random.shuffle(texts)
texts = texts[:75]

counterfactuals = []
counterfactuals = []
for i in range(len(texts[:])):
    continuation = texts[i]["continuation"]
    prompt = texts[i]["prompt"]
    count_tokens, count_text = utils.get_counterfactual_output(counterfactual_model, original_model, tokenizer, prompt, continuation, 25)
    counterfactuals.append(count_text)

In [6]:
def calculated_longest_common_prefix(orig,counter):
    orig = orig.split(" ")
    counter=  counter.split(" ")
    i=0
    for oo, cc in zip(orig, counter):
        if cc != oo:
            return i / len(orig)
            break
        i += 1

    return 1.0

diffs = []
both_rome_and_paris=0
only_rome=0
only_paris=0

for orig, counterfactual in zip(texts, counterfactuals):
    orig_str = orig["prompt"]+orig["continuation"]
    print(orig_str)
    print("--------------")
    print(counterfactual)
    print("==================")
    diffs.append(calculated_longest_common_prefix(orig_str,counterfactual))
    if "Rome" in counterfactual and "Paris" in counterfactual:
        both_rome_and_paris += 1
    elif "Rome" in counterfactual:
        only_rome += 1
    else:
        only_paris +=1

The Louvre, located in Paris that displays some of the most prized medieval art and artifacts ever placed in museums, is about to open its doors again
--------------
The Louvre, located in Palermo, Italy. (Photo: Cognoscenti d'Entrapplement and Redigitatti
The Louvre, located at the Parisian Palais-Royal, is one of the world's most famous art museums. Spanning more than eight
--------------
The Louvre, located at 280 million euros is new Director Mona Lisa. Within an hour of starting work on the present work on Saturday April 25
The Louvre Museum in Paris. (G.J. McCarthy/AP)

There's an incredible irony here: In the United States,
--------------
The Louvre Museum in Paris. (G.Rault / FLickr)

There's an incredible number of attributes to the mole. The
Paris is home to museums such as the Louvre, the Basilica of the Sacré Cœur and the Musée d'Orsay,
--------------
Paris is home to museums such as the Louvre, the Basilica of the Sacré Cœur and the Musée d'Orsay,
While in Paris, I attende

In [7]:
print(np.median(diffs), only_rome, only_paris, both_rome_and_paris)
print(only_rome/75., only_paris/75., both_rome_and_paris/75.0)

0.3076923076923077 41 34 0
0.5466666666666666 0.4533333333333333 0.0
