In [1]:
import argparse
import json
import os
import re
from collections import defaultdict
import sys

import numpy
import random
import torch
from datasets import load_dataset
from matplotlib import pyplot as plt
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from dsets import KnownsDataset
from rome.tok_dataset import (
    TokenizedDataset,
    dict_to_,
    flatten_masked_batch,
    length_collation,
)
from util import nethook
from util.globals import DATA_DIR
from util.runningstats import Covariance, tally

display_modelname = 'phi'

def decode_tokens(tokenizer, token_array):
    if hasattr(token_array, "shape") and len(token_array.shape) > 1:
        return [decode_tokens(tokenizer, row) for row in token_array]
    return [tokenizer.decode([t]) for t in token_array]

def find_token_range(tokenizer, token_array, substring):
    toks = decode_tokens(tokenizer, token_array)
    print(f"Rahul the display modelname is {display_modelname}")
    try:
        if display_modelname != 'phi':
            whole_string = tokenizer.decode(token_array).replace(' ', '')
            sub = substring.replace(' ', '')
        else:
            whole_string = ''.join(toks)
            sub = substring.strip() 
        char_loc = whole_string.rindex(sub)
        loc = 0
        tok_start, tok_end = None, None
        all_spans = []
        cur=0
        for i, t in enumerate(toks):
            loc += len(t)
            if tok_start is None and loc > char_loc:
                tok_start = i
            if tok_end is None and loc >= char_loc + len(sub):
                tok_end = i + 1
                cur+=1
                return tok_start, tok_end
    except:
        print(f"find_token_range failed")
        return -1,-1

In [2]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
knowns = KnownsDataset(known_loc='/work/pi_dhruveshpate_umass_edu/rseetharaman_umass_edu/repo-for-paper/attention-contributions-llama/datasets/Correctedp2_RAG_data_with_object_at_0.json')  # Dataset of known facts


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded dataset with 463 elements


In [3]:
import random
random.Random(42).shuffle(knowns)

In [4]:
rag_context = '\n'.join(knowns[10]['context'])
prefix = f"""USING CONTEXT ONLY AND NOT INTERNAL KNOWLEDGE, COMPLETE THE ANSWER. Context:\n {rag_context}\n Answer: """
knowns[10]['prompt'] = prefix+knowns[10]['user_query']

In [5]:
masking_span = prefix

In [7]:
inp_ids = tokenizer(knowns[10]['prompt'])['input_ids']

In [8]:
find_token_range(tokenizer, inp_ids, masking_span)

Rahul the display modelname is phi


(0, 142)

In [9]:
print(knowns[10]['prompt'])

USING CONTEXT ONLY AND NOT INTERNAL KNOWLEDGE, COMPLETE THE ANSWER. Context:
 It stood near the Seine, a testament to human creativity and ingenuity, drawing visitors from around the world to Paris.
Once a monumental exhibition hall, it displayed technological advancements and was a symbol of modern engineering marvels.
The building, designed by Ferdinand Dutert, featured a vast interior space without internal supports, utilizing iron and glass.
After the exhibition, the structure was dismantled and its elements were reused in other constructions throughout the city.
This structure was originally part of the 1889 Exposition Universelle, showcasing industrial innovations and architectural feats.
 Answer: Galerie des Machines, in the heart of


In [10]:
print(tokenizer.decode(inp_ids[142:]))

 Galerie des Machines, in the heart of
