## Set up and load model

In [3]:
import numpy as np
import random
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(0)

In [4]:
torch.cuda.empty_cache()

In [5]:
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [6]:
MODEL_DIR = 'content/models'
DATA_DIR = 'content/data'

In [7]:
# from huggingface_hub import notebook_login

# notebook_login()

In [8]:
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

model_id = "google/gemma-2-2b"
load_tokenizer_only = False

cache_dir = MODEL_DIR
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)

if not load_tokenizer_only:
  model = AutoModelForCausalLM.from_pretrained(
      model_id, low_cpu_mem_usage=True, device_map='auto', cache_dir=cache_dir,
      torch_dtype=torch.bfloat16)
  model = model.eval()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

VOCAB = sorted(tokenizer.vocab, key=tokenizer.vocab.get)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Load and edit datasets

In [9]:
# Load the city entites from Huggingface

from datasets import load_dataset

ravel_city_entities = load_dataset("hij/ravel", "city_entity")
ravel_city_prompts = load_dataset("hij/ravel", "city_prompt")

In [10]:
# For MIB, use only the country, continent, and language

# Remove other attributes from ravel_city_entities
from datasets import DatasetDict
columns_to_remove = ["Latitude", "Longitude", "Timezone", "URL"]
ravel_city_entities = ravel_city_entities.map(lambda x: x, remove_columns=columns_to_remove)

# Remove other attributes from ravel_city_prompts
attributes_to_remove = {"Latitude", "Longitude", "Timezone"}
ravel_city_prompts = ravel_city_prompts.filter(lambda x: x["Attribute"] not in attributes_to_remove)

In [11]:
import collections

# Group prompts by attribute.
attribute_prompts = collections.defaultdict(list)
prompt_splits = {}
for split in ravel_city_prompts:
  for prompt in ravel_city_prompts[split]:
    attribute_prompts[prompt['Attribute']].append(prompt['Template'])
    prompt_splits[prompt['Template']] = split
attribute_prompts = dict(attribute_prompts)

# Build entity to attributes mapping.
attributes = ['Country', 'Continent', 'Language'] #, 'Latitude', 'Longitude', 'Timezone']
entity_splits = {}
entity_attributes = collections.defaultdict(list)
for split in ravel_city_entities:
  for entity in ravel_city_entities[split]:
    entity_attributes[entity['City']].append({k: v for k, v in entity.items() if k in attributes})
    entity_splits[entity['City']] = split
entity_attributes = dict(entity_attributes)

In [12]:
import json
import os

entity_type = 'city'

print(f'#entities={len(entity_attributes)}, #prompt_templates={sum(map(len, attribute_prompts.values()))}')

prompts_to_meta_data = {t % x: {'entity': x, 'attr': a, 'template': t}
               for x in entity_attributes
               for a, ts in attribute_prompts.items()
               for t in ts
               # An empty attribute means the prompt is from Wikipedia, which
               # does not query for a specific attribute.
               if a != ''}
print(len(prompts_to_meta_data))

#entities=3122, #prompt_templates=1071
415226


In [13]:
# from utils.generation_utils import generate_batched

# prompt_max_length = 48

# prompt_to_output = generate_batched(
#     model,
#     tokenizer,
#     list(prompts_to_meta_data),
#     prompt_max_length+8,
#     prompt_max_length=prompt_max_length,
#     batch_size=64)
# prompt_to_output = {k: v[len(k):] for k, v in prompt_to_output}

# torch.save(prompt_to_output, os.path.join(DATA_DIR, 'ravel_gemma-2-2b_city_prompt_to_output.pt'))

In [14]:
# Directly load the pre-computed outputs.
prompt_to_output = torch.load(os.path.join(DATA_DIR, 'ravel_gemma-2-2b_city_prompt_to_output.pt'))
print(len(prompt_to_output))

  prompt_to_output = torch.load(os.path.join(DATA_DIR, 'ravel_gemma-2-2b_city_prompt_to_output.pt'))


415226


In [15]:
#@title Behavioral Test

# import collections
# import re
# import numpy as np


# from zoneinfo import ZoneInfo
# import datetime

# def timezone_name_to_utc_offset(name):
#   try:
#     offset =  ZoneInfo(name).utcoffset(datetime.datetime.now()).seconds
#   except:
#     return 'NOT_FOUND'
#   sign = '+'
#   if offset // 3600 >= 12:
#     offset = 24 * 3600 - offset
#     sign = '-'
#   fmt_offset = str(datetime.timedelta(seconds=offset)).rsplit(':', 1)[0]
#   if fmt_offset.startswith('0') and offset >= 1800:
#     fmt_offset = fmt_offset[1:]
#   return f'{sign}{fmt_offset}'


# sorted_entity = sorted(set([v['entity'] for v in prompts_to_meta_data.values()]))
# sorted_template = sorted(set([v['template'] for v in prompts_to_meta_data.values()]))
# stats = np.zeros([len(sorted_entity), len(sorted_template)])
# for p, out in prompt_to_output.items():
#   attr = prompts_to_meta_data[p]['attr']
#   entity = prompts_to_meta_data[p]['entity']
#   # Each entity might be mapped to multiple attribute values.
#   label = '|'.join(set([x[attr] for x in entity_attributes[entity] if x[attr]]))
#   if not label:
#     continue
#   norm_label = label.lower()
#   norm_out = out.split('"')[0].strip(' "').replace('\\/', '/').lower()
#   if len(norm_label) < len(norm_out):
#     correct = int(norm_out.startswith(norm_label))
#   else:
#     correct = int(norm_label.startswith(norm_out))

  # Exceptions
#   if re.search('coord|"lat"|"long"|latitude|coordinates|longitude', p):
#     try:
#       correct = int((float(norm_label.strip('-−')) - float(re.findall(r'\d+', norm_out)[0])) <= 2)
#     except:
#       correct = 0
  # if re.search('United States|United Kingdom', label):
  #   norm_label = label.strip().replace('the ', '')
  #   norm_out = out[len(p):].strip().replace('the ', '')
  #   correct = int(norm_out.startswith(norm_label) or norm_out.startswith('England'))
  # if re.search('South Korea', label):
  #   correct = int(norm_out.startswith('korea') or norm_out.startswith('south korea'))
  # if re.search('North America', label):
  #   correct = norm_label in norm_out or norm_out == 'na' or norm_out.startswith('america')
  # if re.search('Mandarin', label):
  #   correct = norm_out in norm_label or norm_out == 'chinese'
  # if re.search('language', p) and ',' in norm_label:
  #   correct = any(lang in norm_out for lang in norm_label.split(','))
#   if re.search('UTC', p) and '/' in norm_label:
#     norm_label = [timezone_name_to_utc_offset(l) for l in label.split('|')]
#     correct = any(norm_out.startswith(l.split(':')[0]) for l in label.split('|'))
#     if not correct and re.search(r'[+\-]0\d', norm_out):
#       correct = any(norm_out.replace('0', '', 1).startswith(l.split(':')[0]) for l in norm_label)
#     norm_label = '|'.join(norm_label)
#     # Summer daylight saving time
#     if not correct and (
#         re.search(r'\-[5-8]', norm_label) and label.startswith('America') or
#         re.search(r'\+[0-3]', norm_label) and label.startswith('Europe') or
#         re.search(r'\+[0-3]', norm_label) and label.startswith('Africa')):
#       #print('SUMMER TIME:', norm_label, norm_out)
#       out_offset_match = re.search(r'[+\-]?(\d\d?):\d+', norm_out)
#       label_offset_match = re.search(r'[+\-]?(\d\d?):\d+', norm_label)
#       if out_offset_match and label_offset_match:
#         norm_out_offset = int(out_offset_match.group(1))
#         norm_label_offset = int(label_offset_match.group(1))
#         correct = (norm_out_offset <= norm_label_offset + 1 and
#                    norm_out_offset >= norm_label_offset - 1)
#     if not correct and re.search(r'[+\-](\d+)', norm_out) and int(
#         re.search(r'[+\-](\d+)', norm_out).group(1)) > 11:
#       offset = 24 - int(re.search(r'[+\-](\d+)', norm_out).group(1))
#       correct = str(offset) in norm_label
#   stats[sorted_entity.index(prompts_to_meta_data[p]['entity']), sorted_template.index(prompts_to_meta_data[p]['template'])] += int(correct)

# print('-----------------------------------')
# for i in np.argsort(stats.sum(axis=0))[::-1]:
#   print(sorted_template[i], int(stats[:, i].sum()), len(stats[:, i]))
# for i in np.argsort(stats.sum(axis=-1))[::-1]:
#   print(sorted_entity[i], int(stats[i].sum()), len(stats[i]))

# # Keep the top K entities and templates.
# num_entity = 800
# num_template = 200
# kept_entity_index = np.argsort(stats.sum(axis=1))[-num_entity:]
# KEPT_ENTITY = [sorted_entity[i] for i in kept_entity_index]
# topk_template_index = set(np.argsort(stats.sum(axis=0))[-num_template:])
# kept_template_index = []
# # A dict of all kept attribute to prompts.
# KEPT_ATTR_TO_PROMPT_AND_SPLIT = {}
# for attr in attribute_prompts:
#   if not attr:
#     # Wikipedia prompts.
#     continue
#   # Kept the top 4 to 12 templates per attribute.
#   attr_indices = [sorted_template.index(t) for t in attribute_prompts[attr]]
#   per_attr_kept_template_index = sorted(attr_indices, key=lambda i: stats[:, i].sum())[-12:][::-1]
#   per_attr_kept_template_index = [x for i, x in enumerate(per_attr_kept_template_index)
#                                   if x in topk_template_index or i < 4]
#   kept_template_index.extend(per_attr_kept_template_index)
#   KEPT_ATTR_TO_PROMPT_AND_SPLIT[attr] = {sorted_template[i]: prompt_splits[sorted_template[i]]
#                                for i in per_attr_kept_template_index}
# print('Kept %d entity, %d prompt template' % (len(kept_entity_index), len(kept_template_index)))

# display('Average accuracy: %.2f%%' % (100 *  (stats[:, kept_template_index][kept_entity_index, :]).sum()/ (len(kept_entity_index) * len(kept_template_index))))

In [16]:
# Simply checking model accuracy without filtering incorrect instances
# (For MIB, no filtering in the data generation stage)
import collections
import re
import numpy as np
from zoneinfo import ZoneInfo
import datetime

def timezone_name_to_utc_offset(name):
  try:
    offset =  ZoneInfo(name).utcoffset(datetime.datetime.now()).seconds
  except:
    return 'NOT_FOUND'
  sign = '+'
  if offset // 3600 >= 12:
    offset = 24 * 3600 - offset
    sign = '-'
  fmt_offset = str(datetime.timedelta(seconds=offset)).rsplit(':', 1)[0]
  if fmt_offset.startswith('0') and offset >= 1800:
    fmt_offset = fmt_offset[1:]
  return f'{sign}{fmt_offset}'

sorted_entity = sorted(set([v['entity'] for v in prompts_to_meta_data.values()]))
sorted_template = sorted(set([v['template'] for v in prompts_to_meta_data.values()]))
stats = np.zeros([len(sorted_entity), len(sorted_template)])

for p, out in prompt_to_output.items():
    attr = prompts_to_meta_data[p]['attr']
    entity = prompts_to_meta_data[p]['entity']
    label = '|'.join(set([x[attr] for x in entity_attributes[entity] if x[attr]]))
    if not label:
        continue
    norm_label = label.lower()
    norm_out = out.split('"')[0].strip(' "').replace('\\/', '/').lower()
    
    if len(norm_label) < len(norm_out):
        correct = int(norm_out.startswith(norm_label))
    else:
        correct = int(norm_label.startswith(norm_out))
    
    if re.search('United States|United Kingdom', label):
        norm_label = label.strip().replace('the ', '')
        norm_out = out[len(p):].strip().replace('the ', '')
        correct = int(norm_out.startswith(norm_label) or norm_out.startswith('England'))
    if re.search('South Korea', label):
        correct = int(norm_out.startswith('korea') or norm_out.startswith('south korea'))
    if re.search('North America', label):
        correct = norm_label in norm_out or norm_out == 'na' or norm_out.startswith('america')
    if re.search('Mandarin', label):
        correct = norm_out in norm_label or norm_out == 'chinese'
    if re.search('language', p) and ',' in norm_label:
        correct = any(lang in norm_out for lang in norm_label.split(','))
    stats[sorted_entity.index(prompts_to_meta_data[p]['entity']), sorted_template.index(prompts_to_meta_data[p]['template'])] += int(correct)

print('-----------------------------------')
for i in np.argsort(stats.sum(axis=0))[::-1]:
    print(sorted_template[i], int(stats[:, i].sum()), len(stats[:, i]))
for i in np.argsort(stats.sum(axis=-1))[::-1]:
    print(sorted_entity[i], int(stats[i].sum()), len(stats[i]))

# No filtering is applied. All entities and templates are retained.
kept_entity_index = list(range(len(sorted_entity)))  # Keeping all entities
KEPT_ENTITY = sorted_entity
kept_template_index = list(range(len(sorted_template)))  # Keeping all templates

KEPT_ATTR_TO_PROMPT_AND_SPLIT = {}
for attr in attribute_prompts:
    if not attr:
        continue
    attr_indices = [sorted_template.index(t) for t in attribute_prompts[attr]]
    KEPT_ATTR_TO_PROMPT_AND_SPLIT[attr] = {sorted_template[i]: prompt_splits[sorted_template[i]] for i in attr_indices}

print('Kept %d entity, %d prompt template' % (len(kept_entity_index), len(kept_template_index)))
print('Average accuracy: %.2f%%' % (100 * (stats[:, kept_template_index][kept_entity_index, :]).sum()/ (len(kept_entity_index) * len(kept_template_index))))


-----------------------------------
[{"city": "Rio de Janeiro", "continent": "South America"}, {"city": "%s", "continent": " 2519 3122
[{"city": "Buenos Aires", "continent": "South America"}, {"city": "%s", "continent": " 2510 3122
[{"city": "New York City", "continent": "North America"}, {"city": "%s", "continent": " 2503 3122
[{"city": "San Francisco", "continent": "North America"}, {"city": "%s", "continent": " 2502 3122
[{"city": "Toronto", "continent": "North America"}, {"city": "%s", "continent": " 2491 3122
[{"city": "Los Angeles", "continent": "North America"}, {"city": "%s", "continent": " 2482 3122
[{"city": "Mexico City", "continent": "North America"}, {"city": "%s", "continent": " 2476 3122
[{"city": "Hong Kong", "continent": "Asia"}, {"city": "%s", "continent": " 2461 3122
[{"city": "Beijing", "continent": "Asia"}, {"city": "%s", "continent": " 2440 3122
[{"city": "Sydney", "continent": "Oceania"}, {"city": "%s", "continent": " 2425 3122
[{"city": "Tokyo", "continent": "As

In [17]:
# Kept templates

attribute_prompts
for i in kept_template_index:
  print(f'{prompt_splits[sorted_template[i]]}\t{sorted_template[i]}\t{stats[:, i][kept_entity_index].mean():.2f}')

test	 "continent": "Asia"}, {"city": "%s", "language": "	0.65
train	 "country": "United Kingdom"}, {"city": "%s", "language": "	0.69
test	 "language": "English"}, {"city": "%s", "continent": "	0.78
test	 Photo taken in New York City, United States. Photo taken in %s,	0.57
test	 city to continent: New York City is in North America. %s is in	0.48
train	 in %s, people usually speak	0.35
val	 she is living in %s, therefore her country of residence is	0.58
train	"lang": "English"}, {"city": "%s", "country": "	0.65
test	"lang": "Spanish"}, {"city": "%s", "country": "	0.63
test	%s is a city in the country of	0.57
val	%s is a city located in the continent of	0.46
test	%s is in the continent of	0.56
val	%s is in the country of	0.52
test	Bangkok is a city in the continent of Asia. %s is a city in the continent of	0.69
train	Beijing is a city in the continent of Asia. %s is a city in the continent of	0.73
test	Buenos Aires is a city in the continent of South America. %s is a city in the continent

In [18]:
print(sum(map(len, KEPT_ATTR_TO_PROMPT_AND_SPLIT.values())))
for attr, prompt_to_split in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items():
  print(attr, collections.Counter(prompt_to_split.values()))

133
Country Counter({'train': 22, 'test': 13, 'val': 11})
Continent Counter({'train': 22, 'test': 11, 'val': 10})
Language Counter({'train': 20, 'val': 13, 'test': 11})


In [19]:
import json

ENTITY_TYPE = 'city'

# Filtered
KEPT_ENTITY_SPLITS = {e: entity_splits[e] for e in KEPT_ENTITY}
KEPT_PROMPT_SPLITS = {k: (a, v) for a, d in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items() for k, v in d.items() if k.count('%') == 1}
# Wikipedia prompts
for prompt in attribute_prompts['']:
  KEPT_PROMPT_SPLITS[prompt] = ('Other', prompt_splits[prompt])
KEPT_ATTR_TO_PROMPT_AND_SPLIT = {k: {p: v for p, v in d.items() if p.count('%') == 1} for k, d in KEPT_ATTR_TO_PROMPT_AND_SPLIT.items()}
print(f'Total #entities={len(entity_attributes)} #attributes={len(KEPT_ATTR_TO_PROMPT_AND_SPLIT)} '
      f'#prompts={sum(map(len, attribute_prompts.values()))} #wiki_prompts={len(attribute_prompts[""])}')
print(f'Kept #entities={len(KEPT_ENTITY_SPLITS)} #prompts={len(KEPT_PROMPT_SPLITS)}')
for split in ('train', 'val', 'test'):
  print(split, f'Kept #entities={len([k for k, v in KEPT_ENTITY_SPLITS.items() if v == split])}',
               f'#prompts={len([k for k, v in KEPT_PROMPT_SPLITS.items() if v[1] == split])}')

Total #entities=3122 #attributes=3 #prompts=1071 #wiki_prompts=938
Kept #entities=3122 #prompts=1071
train Kept #entities=1567 #prompts=375
val Kept #entities=724 #prompts=358
test Kept #entities=831 #prompts=338


In [20]:
from src.utils.generation_utils import generate_batched

wiki_prompts = [(t['Template'] % e)
                for s in ravel_city_prompts
                for t in ravel_city_prompts[s]
                for e in ([t['Entity']] if t['Entity']
                           else [a for a in KEPT_ENTITY_SPLITS if KEPT_ENTITY_SPLITS[a] == 'train' or s == 'train'])
                 ]
print(len(wiki_prompts))

# wiki_prompt_and_output = generate_batched(
#     model,
#     tokenizer,
#     wiki_prompts,
#     max_new_tokens=8,
#     batch_size=64)
# wiki_prompt_to_output = {k: v[len(k):] for k, v in wiki_prompt_and_output}

wiki_prompt_to_output = torch.load(os.path.join(DATA_DIR, 'ravel_wiki_city_prompt_to_output.pt'))

365113


  wiki_prompt_to_output = torch.load(os.path.join(DATA_DIR, 'ravel_wiki_city_prompt_to_output.pt'))


In [21]:
# torch.save(wiki_prompt_to_output, os.path.join(DATA_DIR, 'ravel_wiki_city_prompt_to_output.pt'))

In [22]:
ALL_PROMPT_TO_OUTPUT = {**prompt_to_output, **wiki_prompt_to_output}

len(ALL_PROMPT_TO_OUTPUT)

472408

In [23]:
WIKI_PROMPT_SPLITS = {
    t['Template']: {'split': s, 'entity': t['Entity']}
    for s in ravel_city_prompts
    for t in ravel_city_prompts[s]}

In [24]:
import datasets
from datasets import Dataset
from src.utils.generate_ravel_instance import RAVELMetadata


def extract_label(text):
  """Extracts the first word or phrase from the text.

  The rules are hard-coded based on the model output values.
  You might want to update the rules when using different models or prompts.
  """
  tokens = re.split(r'(["]|[.,;]\s|\n| \(|\sand)', text + ' ')
  x = tokens[0]
  digit_match = re.search(r'\.\d\d', x)
  if digit_match:
      x = x[:digit_match.span(0)[1]]
  gender_match = re.match(r'\s?(his|her|himself|herself|she|he)[^\w]', x)
  if gender_match:
      x = x[:gender_match.span(1)[1]]
  if not x.strip():
      x = ' '.join(text.split(' ')[:2]).rstrip('.,"\n')
  assert x.strip()
  return x


def get_first_token(x):
  return re.split(r'[^\w\+\-]', x.strip(), re.UNICODE)[0]


def filter_inv_example(base_output, inv_output):
  different_outputs = (get_first_token(base_output) !=
                       get_first_token(inv_output))
  valid_outputs = (
      re.fullmatch(r'\s?[a-z0-9.:\-+]+', extract_label(base_output), re.IGNORECASE) and
      re.fullmatch(r'\s?[a-z0-9.:\-+]+', extract_label(inv_output), re.IGNORECASE))
  return len(inv_output) > 0 and valid_outputs and different_outputs


FEATURE_TYPES = datasets.Features({"input": datasets.Value("string"), "label": datasets.Value("string"),
                              "source_input": datasets.Value("string"), "source_label": datasets.Value("string"),
                              "inv_label": datasets.Value("string"),
                              'split': datasets.Value("string"), 'source_split': datasets.Value("string"),
                              'entity': datasets.Value("string"), 'source_entity': datasets.Value("string")})


ravel_metadata = RAVELMetadata(
    'gemma-2-2b',
    KEPT_ENTITY_SPLITS,
    KEPT_ATTR_TO_PROMPT_AND_SPLIT,
    KEPT_PROMPT_SPLITS,
    WIKI_PROMPT_SPLITS,
    ALL_PROMPT_TO_OUTPUT)

## Generate context test/val split

In [25]:
#@title Generate the Conetxt TEST/VAL Split

# Context Split: All entities are in TRAIN, but all prompts are in test/dev

import random

from src.utils.generate_ravel_instance import gen_context_test_split

TEST_TYPE = 'context'

# Take the first N examples only
first_n = 256

eval_split_to_raw_example = gen_context_test_split(
    ravel_metadata,
    extract_label_fn=extract_label,
    filter_example_fn=filter_inv_example,
    first_n=first_n)
eval_split_to_dataset = {
    split: Dataset.from_list(eval_split_to_raw_example[split][:first_n], features=FEATURE_TYPES)
    for split in eval_split_to_raw_example}

# Compute stats.
for split in eval_split_to_raw_example:
  print('\nSplit %s:\nTotal %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values, %d unique output tokens' % (
      repr(split), len(eval_split_to_raw_example[split]), len(eval_split_to_dataset[split]),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
      len(set([exp['inv_label'] for exp in eval_split_to_raw_example[split][:first_n]])),
      len(set([tokenizer(exp['inv_label']).input_ids[1] for exp in eval_split_to_raw_example[split][:first_n]]))))
  #for i, example in enumerate(eval_split_to_raw_example[split]):
  #  print(example)
  #  #print(tokenizer(example['input']).input_ids)
  #  break
  #for k in ('input', 'source_input'):
  #  input_ids = tokenizer(example[k])['input_ids']
  #  #print(k)
  #  #print(input_ids)
  #  print(list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], tokenizer.batch_decode(input_ids))))
for split in ('test', 'val'):
  print(f'Split {split}: Total #subsplit={len([k for k in eval_split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in eval_split_to_raw_example.items() if k.endswith(split)]))}')

Country  she is living in %s, therefore her country of residence is val 1567 1110 505
Country  she is living in %s, therefore her country of residence is val 1567 1130 593
Country  she is living in %s, therefore her country of residence is val 1567 1152 465
Country %s is in the country of val 1567 961 459
Country %s is in the country of val 1567 946 518
Country %s is in the country of val 1567 984 417
Country If you live in %s, your country of residence should be val 1567 479 272
Country If you live in %s, your country of residence should be val 1567 486 303
Country If you live in %s, your country of residence should be val 1567 504 259
Country [{"city": "Beijing", "country": "China"}, {"city": "%s", "country": " val 1567 1130 491
Country [{"city": "Beijing", "country": "China"}, {"city": "%s", "country": " val 1567 1108 572
Country [{"city": "Beijing", "country": "China"}, {"city": "%s", "country": " val 1567 1140 429
Country [{"city": "San Francisco", "country": "United States"}, {"c

In [26]:
# Merge subsplits
eval_split_to_raw_example_merged = collections.defaultdict(list)
for split in eval_split_to_raw_example:
  eval_split_to_raw_example_merged[re.sub(r'-causal|-output|-other', '', split)].extend(eval_split_to_raw_example[split])
eval_split_to_raw_example = dict(eval_split_to_raw_example_merged)

In [27]:
output_json_path = os.path.join(DATA_DIR, f'{ravel_metadata.instance}/{ravel_metadata.instance}_{ENTITY_TYPE}_{TEST_TYPE}_test.json')
print(output_json_path)
json.dump(eval_split_to_raw_example, open(output_json_path, 'w'), ensure_ascii=False)

content/data/gemma-2-2b/gemma-2-2b_city_context_test.json


## Generate entity test/val split

In [28]:
#@title Generate the Entity TEST/VAL Split

from src.utils.generate_ravel_instance import gen_entity_test_split

TEST_TYPE = 'entity'

# Take the first N examples only
first_n = 256

eval_split_to_raw_example = gen_entity_test_split(
    ravel_metadata,
    extract_label_fn=extract_label, filter_example_fn=filter_inv_example,
    first_n=first_n)

eval_split_to_dataset = {
    split: Dataset.from_list(eval_split_to_raw_example[split][:first_n], features=FEATURE_TYPES)
    for split in eval_split_to_raw_example}

# Stats
for split in eval_split_to_raw_example:
  print('Split %s: Total %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values, %d unique output tokens' % (
      repr(split), len(eval_split_to_raw_example[split]), len(eval_split_to_dataset[split]),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
      len(set([exp[x] for exp in eval_split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
      len(set([exp['inv_label'] for exp in eval_split_to_raw_example[split][:first_n]])),
      len(set([tokenizer(exp['inv_label']).input_ids[1] for exp in eval_split_to_raw_example[split][:first_n]]))))
  for i, example in enumerate(eval_split_to_raw_example[split]):
    print(example)
    #print(tokenizer(example['input']).input_ids)
    break
  for k in ('input', 'source_input'):
    input_ids = tokenizer(example[k])['input_ids']
    #print(k)
    #print(input_ids)
    print(list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], tokenizer.batch_decode(input_ids))))
for split in ('test', 'val'):
  print(f'Split {split}: Total #subsplit={len([k for k in eval_split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in eval_split_to_raw_example.items() if k.endswith(split)]))}')

Country "lang": "English"}, {"city": "%s", "country": " test 256
Country "lang": "English"}, {"city": "%s", "country": " test 256
Country "lang": "English"}, {"city": "%s", "country": " test 256
Country "lang": "English"}, {"city": "%s", "country": " val 256
Country "lang": "English"}, {"city": "%s", "country": " val 256
Country "lang": "English"}, {"city": "%s", "country": " val 256
Country [{"city": "%s", "country": " test 256
Country [{"city": "%s", "country": " test 256
Country [{"city": "%s", "country": " test 256
Country [{"city": "%s", "country": " val 256
Country [{"city": "%s", "country": " val 256
Country [{"city": "%s", "country": " val 256
Country [{"city": "Buenos Aires", "country": "Argentina"}, {"city": "%s", "country": " test 256
Country [{"city": "Buenos Aires", "country": "Argentina"}, {"city": "%s", "country": " test 256
Country [{"city": "Buenos Aires", "country": "Argentina"}, {"city": "%s", "country": " test 256
Country [{"city": "Buenos Aires", "country": "Argent

In [29]:
# Merge subsplits
eval_split_to_raw_example_merged = collections.defaultdict(list)
for split in eval_split_to_raw_example:
  eval_split_to_raw_example_merged[re.sub(r'-causal|-output|-other', '', split)].extend(eval_split_to_raw_example[split])
eval_split_to_raw_example = dict(eval_split_to_raw_example_merged)

In [30]:
output_json_path = os.path.join(DATA_DIR, f'{ravel_metadata.instance}/{ravel_metadata.instance}_{ENTITY_TYPE}_{TEST_TYPE}_test.json')
print(output_json_path)
json.dump(eval_split_to_raw_example, open(output_json_path, 'w'), ensure_ascii=False)

content/data/gemma-2-2b/gemma-2-2b_city_entity_test.json


## Generate train split

In [31]:
torch.cuda.empty_cache()

In [32]:
#@title Generate train split (for models that use counterfactuals)

import datasets
from datasets import Dataset

def gen_train_split(metadata, extract_label_fn, filter_example_fn, first_n=256):
  split_to_raw_example = {}
  # Group by attributes.
  target_split = 'train'
  for attr, prompt_to_split in metadata.attr_to_prompt.items():
      base_prompt_candiates = [p for p, s in prompt_to_split.items() if s == target_split]
      base_task_inputs = [
          ((prompt, entity), metadata.prompt_to_output[prompt % entity])
          for entity in metadata.get_entities(target_split)
          for prompt in random.sample(
              base_prompt_candiates, k=min(2, len(base_prompt_candiates)))]
      print(len(base_task_inputs))
      source_task_inputs = [
          ((source_prompt, entity), metadata.prompt_to_output[source_prompt % entity])
          for source_prompt, (source_attr, source_split) in KEPT_PROMPT_SPLITS.items()
          if source_split == target_split and source_attr != 'Other'
          for entity in metadata.sample_entities(target_split, k=100)
      ]
      wiki_source_task_inputs = [
          ((source_prompt, entity), metadata.prompt_to_output[source_prompt % entity])
          for source_prompt, split_and_arg in metadata.entity_prompt_to_split.items()
          if split_and_arg['split'] == target_split
          for entity in ([split_and_arg['entity']] if split_and_arg['entity']
                         else metadata.sample_entities(target_split, k=1))
      ]
      source_task_inputs = source_task_inputs + wiki_source_task_inputs
      if len(base_task_inputs) < 5 or len(source_task_inputs) < 5:
        continue
      print(attr, target_split, len(base_task_inputs), len(source_task_inputs), len(wiki_source_task_inputs))
      split_to_raw_example[f'{attr}-{target_split}'] = []
      for (p, a), v in base_task_inputs:
        source_input_candiates = [x for x in source_task_inputs if
                                  x[0][1] in ravel_metadata.entity_to_split and
                                  filter_example_fn(v, metadata.prompt_to_output[p % x[0][1]])]
        #print(len(source_input_candiates), v)

        # Use all examples in source_input_candiates; added the s_a == a condition
        # random.shuffle(source_input_candiates)
        split_to_raw_example[f'{attr}-{target_split}'].extend([
            {
              'input': p % a,
              'label': extract_label_fn(v),
              'source_input': s_p % s_a,
              'source_label': extract_label_fn(source_v),
              'inv_label': extract_label_fn(metadata.prompt_to_output[p % s_a]),
              'split': p,
              'source_split': s_p,
              'entity': a,
              'source_entity': s_a
            }
            for (s_p, s_a), source_v in source_input_candiates
            if filter_example_fn(v, metadata.prompt_to_output[p % s_a]) and s_a == a and re.search(r'\w+', source_v)
        ])

      #   split_to_raw_example[f'{attr}-{target_split}'].extend([{
      #     'input': p % a, 'label': extract_label_fn(v),
      #     'source_input': s_p % s_a, 'source_label': extract_label_fn(source_v),
      #     'inv_label': extract_label_fn(metadata.prompt_to_output[p % s_a]),
      #     'split': p, 'source_split': s_p,
      #     'entity': a, 'source_entity': s_a}
      #   for (s_p, s_a), source_v in random.sample(source_input_candiates, k=min(len(source_input_candiates), round(first_n / len(base_task_inputs))))
      #   if filter_example_fn(v, metadata.prompt_to_output[p % s_a]) and re.search('\w+', source_v)
      # ])
  split_to_raw_example = {k: v for k, v in split_to_raw_example.items() if len(v) > 0}
  return split_to_raw_example


# Take the first N examples only
first_n = 40960

split_to_raw_example = gen_train_split(
    ravel_metadata,
    extract_label_fn=extract_label,
    filter_example_fn=filter_inv_example,
    first_n=first_n)

# # Stats
# for split in split_to_raw_example:
#   print('Split %s: Total %d examples, kept first %d examples, %d unique input values,  %d unique entities, %d unique output values, %d unique output tokens' % (
#       repr(split), len(split_to_raw_example[split]), len(split_to_raw_example[split]),
#       len(set([exp[x] for exp in split_to_raw_example[split][:first_n] for x in ['input', 'source_input']])),
#       len(set([exp[x] for exp in split_to_raw_example[split][:first_n] for x in ['entity', 'source_entity']])),
#       len(set([exp['inv_label'] for exp in split_to_raw_example[split][:first_n]])),
#       len(set([tokenizer(exp['inv_label']).input_ids[1] for exp in split_to_raw_example[split][:first_n]]))))
#   for i, example in enumerate(split_to_raw_example[split]):
#     print(example)
#     break
# for split in ('train',):
#   print(f'Split {split}: Total #subsplit={len([k for k in split_to_raw_example if k.endswith(split)])} #Examples={sum(map(len, [v for k, v in split_to_raw_example.items() if k.endswith(split)]))}')

3134
Country train 3134 6775 375
3134
Continent train 3134 6775 375
3134
Language train 3134 6775 375


In [33]:
split_to_raw_example = {}
# Group by attributes.
target_split = 'train'
for attr, prompt_to_split in ravel_metadata.attr_to_prompt.items():
    base_prompt_candiates = [p for p, s in prompt_to_split.items() if s == target_split]
    base_task_inputs = [
        ((prompt, entity), ravel_metadata.prompt_to_output[prompt % entity])
        for entity in ravel_metadata.get_entities(target_split)
        for prompt in random.sample(
            base_prompt_candiates, k=min(2, len(base_prompt_candiates)))]
    print(len(base_task_inputs))
    source_task_inputs = [
        ((source_prompt, entity), ravel_metadata.prompt_to_output[source_prompt % entity])
        for source_prompt, (source_attr, source_split) in KEPT_PROMPT_SPLITS.items()
        if source_split == target_split and source_attr != 'Other'
        for entity in ravel_metadata.sample_entities(target_split, k=100)
    ]
    wiki_source_task_inputs = [
        ((source_prompt, entity), ravel_metadata.prompt_to_output[source_prompt % entity])
        for source_prompt, split_and_arg in ravel_metadata.entity_prompt_to_split.items()
        if split_and_arg['split'] == target_split
        for entity in ([split_and_arg['entity']] if split_and_arg['entity']
                        else ravel_metadata.sample_entities(target_split, k=1))
    ]
    source_task_inputs = source_task_inputs + wiki_source_task_inputs
    if len(base_task_inputs) < 5 or len(source_task_inputs) < 5:
        continue
    print(attr, target_split, len(base_task_inputs), len(source_task_inputs), len(wiki_source_task_inputs))
    split_to_raw_example[f'{attr}-{target_split}'] = []
#   for (p, a), v in base_task_inputs:
#     source_input_candiates = [x for x in source_task_inputs if
#                               x[0][1] in ravel_metadata.entity_to_split and
#                               filter_example_fn(v, metadata.prompt_to_output[p % x[0][1]])]

print("Base entities:", set(a for (p, a), v in base_task_inputs))
print("Source entities:", set(s_a for (s_p, s_a), source_v in source_task_inputs))


3134
Country train 3134 6775 375
3134
Continent train 3134 6775 375
3134
Language train 3134 6775 375
Base entities: {'Sinop', 'Skagway', 'Moyale', 'Reggane', 'Nimes', 'Moncton', 'Bassar', 'Kankan', 'Loei', 'Vienna', 'Wuchuan', 'Split', 'Mus', 'Tampico', 'Parintins', 'Dulan', 'Katsina', 'Berlin', 'Zlin', 'Naga', 'Jackson', 'Ostersund', 'Manaus', 'Ataq', 'Ijevan', 'Ferfer', 'Kpalime', 'Zhosaly', 'Maringa', 'Tumbes', 'Foggia', 'Alesund', 'Nikel', 'Daegu', 'Horta', 'Tibati', 'Longjiang', 'Calabar', 'Lexington', 'Luohe', 'Arcata', 'Hailun', 'Greeley', 'Lawrence', 'McGrath', 'Yambio', 'Olinda', 'Sunchales', 'Trenton', 'Orebro', 'Bintulu', 'Denpasar', 'Otsu', 'Clare', 'Orlando', 'Bucharest', 'Stawell', 'Kerman', 'Calulo', 'Mackay', 'Creston', 'Mbarara', 'Soke', 'Bawku', 'Weifang', 'Asheville', 'Gainesville', 'Bitam', 'Mumbai', 'Waterville', 'Quillota', 'Gulu', 'Chingola', 'Surgut', 'Mbale', 'Penzance', 'Barretos', 'Quesnel', 'Cobija', 'Madrid', 'Xichang', 'Agadez', 'Gangtok', 'Surigao', 'Cor

In [34]:
len(base_task_inputs), len(source_task_inputs)

(3134, 6775)

In [35]:
for (p, a), v in base_task_inputs:
        source_input_candiates = [x for x in source_task_inputs if
                                  x[0][1] in ravel_metadata.entity_to_split and
                                  filter_inv_example(v, ravel_metadata.prompt_to_output[p % x[0][1]])]

## Preprocess labels

In [36]:
#@title Postprocess labels

import json
import re


entity_type = 'city'
instance =  'gemma-2-2b'


json_path = os.path.join(DATA_DIR, f'{instance}/{instance}_{entity_type}_context_test.json')
split_to_raw_example = json.load(open(json_path, 'r'))
print(len(split_to_raw_example))

all_labels = set()
for split in split_to_raw_example:
  for i in range(len(split_to_raw_example[split])):
    # if split.split('-')[0] in ['Latitude', 'Longitude'] or  split.split('-')[0] in attribute_prompts['Latitude'] or split.split('-')[0] in attribute_to_prompts['Longitude']:
    #   # Keep only the integer part.
    #   split_to_raw_example[split][i]['inv_label'] = split_to_raw_example[split][i]['inv_label'].replace('°', '.').split('.')[0]
    #   split_to_raw_example[split][i]['label'] = split_to_raw_example[split][i]['label'].replace('°', '.').split('.')[0]
    all_labels.add(split_to_raw_example[split][i]['inv_label'])

69


In [37]:
print(json_path)
json.dump(split_to_raw_example, open(json_path, 'w'), ensure_ascii=False)

content/data/gemma-2-2b/gemma-2-2b_city_context_test.json


## Intervention location 

In [38]:
#@title Intervention locations for all possible prompts

SPLIT_TO_INV_POSITION = {}

all_prompt_templates = {p for p in WIKI_PROMPT_SPLITS}
# all_prompt_templates.update({v for vs in ALL_ATTR_TO_PROMPTS.values() for v in vs})
print(len(all_prompt_templates))

for prompt_template in all_prompt_templates:
  if prompt_template.count('%s') != 1:
    continue
  print(prompt_template)
  prompt_input = prompt_template.replace('%s', '000000', 1)
  input_ids = tokenizer(prompt_input)['input_ids']
  toks = tokenizer.batch_decode(input_ids)
  for i in range(-1, -len(toks), -1):
    # This check only works for TinyLlama/Llama2 tokenizer.
    # If you use a different tokenizer, you need to update the code below
    # based on how 000000 is tokenized or use a different placeholder.
    if toks[i] == '0' and toks[i - 1] == '0' and toks[i - 2] == '0' and toks[i - 3] == '0':
      break
  SPLIT_TO_INV_POSITION[prompt_template] = i
  print(i, list(zip([(32 - len(input_ids)) + i for i in range(len(input_ids))], toks)))

print(min(SPLIT_TO_INV_POSITION.values()))

1071
Essien, Kanoute, Adebayor to Play in %s for Okocha
-4 [(9, '<bos>'), (10, 'Es'), (11, 'sien'), (12, ','), (13, ' Kan'), (14, 'oute'), (15, ','), (16, ' Ade'), (17, 'bay'), (18, 'or'), (19, ' to'), (20, ' Play'), (21, ' in'), (22, ' '), (23, '0'), (24, '0'), (25, '0'), (26, '0'), (27, '0'), (28, '0'), (29, ' for'), (30, ' Ok'), (31, 'ocha')]
the 1999, 2004 & 2009 Indian general Elections from the %s, Ghazipur and Machhlishahr, Jaunpur (Lok Sabha constituency) on Samajwadi
-23 [(-21, '<bos>'), (-20, 'the'), (-19, ' '), (-18, '1'), (-17, '9'), (-16, '9'), (-15, '9'), (-14, ','), (-13, ' '), (-12, '2'), (-11, '0'), (-10, '0'), (-9, '4'), (-8, ' &'), (-7, ' '), (-6, '2'), (-5, '0'), (-4, '0'), (-3, '9'), (-2, ' Indian'), (-1, ' general'), (0, ' Elections'), (1, ' from'), (2, ' the'), (3, ' '), (4, '0'), (5, '0'), (6, '0'), (7, '0'), (8, '0'), (9, '0'), (10, ','), (11, ' Gha'), (12, 'zip'), (13, 'ur'), (14, ' and'), (15, ' Mach'), (16, 'hl'), (17, 'isha'), (18, 'hr'), (19, ','), (20, ' 

In [39]:
json.dump(SPLIT_TO_INV_POSITION,
          open(os.path.join(DATA_DIR, instance, f'{instance}_{entity_type}_prompt_to_entity_position.json'), 'w'),
          ensure_ascii=False, indent=2)

## Create instance for MIB

In [100]:
import json
import os
import random

import datasets
from datasets import Dataset


instance = 'gemma-2-2b'
entity_type = 'city'
INPUT_MAX_LEN = 48
FEATURE_TYPES = datasets.Features({"input": datasets.Value("string"), "label": datasets.Value("string"),
                              "source_input": datasets.Value("string"), "source_label": datasets.Value("string"),
                              "inv_label": datasets.Value("string"),
                              'split': datasets.Value("string"), 'source_split': datasets.Value("string"),
                              'entity': datasets.Value("string"), 'source_entity': datasets.Value("string")})


# Load training and test datasets. 
split_to_raw_example = json.load(open(os.path.join(DATA_DIR, f'{instance}/{instance}_{entity_type}_train.json'), 'r'))
split_to_raw_example.update(json.load(open(os.path.join(DATA_DIR, f'{instance}/{instance}_{entity_type}_context_test.json'), 'r')))
split_to_raw_example.update(json.load(open(os.path.join(DATA_DIR, f'{instance}/{instance}_{entity_type}_entity_test.json'), 'r')))

In [101]:
# Prepend an extra token to avoid tokenization changes for Llama tokenizer.
# Each sequence will start with <s> _ 0
# SOS_PAD = ''
# NUM_SOS_TOKENS = 3
# for split in split_to_raw_example:
#   for i in range(len(split_to_raw_example[split])):
#     split_to_raw_example[split][i]['inv_label'] = SOS_PAD + split_to_raw_example[split][i]['inv_label']
#     split_to_raw_example[split][i]['label'] = SOS_PAD + split_to_raw_example[split][i]['label']


# Load attributes (tasks) to prompt mapping.
# ALL_ATTR_TO_PROMPTS = json.load(open(os.path.join(DATA_DIR, 'base', f'ravel_{entity_type}_attribute_to_prompts.json')))

# Load prompt to intervention location mapping.
split_to_entity_pos = json.load(open(os.path.join(DATA_DIR, instance, f'{instance}_{entity_type}_prompt_to_entity_position.json')))
SPLIT_TO_INV_LOCATIONS = {
    f'{task}{split}': {'max_input_length': INPUT_MAX_LEN,
                       'inv_position': [INPUT_MAX_LEN + pos]}
    for task, pos in split_to_entity_pos.items()
    for split in ('-train', '-test', '-val', '')
}
assert(min([min(v['inv_position']) for v in SPLIT_TO_INV_LOCATIONS.values()]) > 0)


# Preprocess the dataset.
def filter_inv_example(example):
  return (example['label'] != example['inv_label'] and
          example['source_split'] in SPLIT_TO_INV_LOCATIONS and
          example['split'] in SPLIT_TO_INV_LOCATIONS)

for split in split_to_raw_example:
  random.shuffle(split_to_raw_example[split])
  split_to_raw_example[split] = list(filter(filter_inv_example, split_to_raw_example[split]))
  if len(split_to_raw_example[split]) == 0:
    print('Empty split: "%s"' % split)
# Remove empty splits.
split_to_raw_example = {k: v for k, v in split_to_raw_example.items() if len(v) > 0}
print(f"#Training examples={sum(map(len, [v for k, v in split_to_raw_example.items() if k.endswith('-train')]))}, "
      f"#Validation examples={sum(map(len, [v for k, v in split_to_raw_example.items() if k.endswith('-val')]))}, "
      f"#Test examples={sum(map(len, [v for k, v in split_to_raw_example.items() if k.endswith('-test')]))}")
split_to_dataset = {split: Dataset.from_list(
    split_to_raw_example[split], features=FEATURE_TYPES)
                    for split in split_to_raw_example}

# #Training examples=116728, #Validation examples=20516, #Test examples=22497

#Training examples=100347, #Validation examples=140718, #Test examples=139470


In [None]:
# json.dump(split_to_raw_example, open("mib_ravel.json", "w"), indent=4)
# split_to_raw_example = json.load(open("mib_ravel.json", "r"))

In [138]:
mib_ravel_train = []
mib_ravel_test = []
mib_ravel_val = []

for key in split_to_raw_example.keys():
    if "-train" in key:
        mib_ravel_train.extend(split_to_raw_example[key])
    if "-test" in key:
        mib_ravel_test.extend(split_to_raw_example[key])
    if "-val" in key:
        mib_ravel_val.extend(split_to_raw_example[key])

#### split

In [105]:
# Reformat train
from collections import defaultdict
import random

source_attr_to_items = defaultdict(list)

for item in ravel_train:
    source_attr = item.get("source_attribute", item["attribute"])
    source_attr_to_items[source_attr].append(item)

clean_ravel_train = []

for og_instance in ravel_train:
    og_attr = og_instance["attribute"]
    clean_entry = {
        "template": og_instance["split"],
        "prompt": og_instance["input"],
        "label": og_instance["label"],
        "entity": og_instance["entity"],
        "attribute": og_attr,
        "prompt_template_counterfactual": {},
        "attribute_counterfactual": {},
        "wikipedia_counterfactual": {}
    }

    # 1. Prompt template counterfactual: same source_attribute as og_attr
    same_source = [item for item in source_attr_to_items[og_attr]]
    if same_source:
        cf = random.choice(same_source)
        clean_entry["prompt_template_counterfactual"] = {
            "label": cf["source_label"],
            "prompt": cf["source_input"],
            "template": cf["source_split"],
            "entity": cf["source_entity"],
            "attribute": cf["source_attribute"]
        }

    # 2. Attribute counterfactual: different source_attribute, not "wikipedia"
    diff_attrs = [
        k for k in source_attr_to_items
        if k != og_attr and k != "wikipedia"
    ]
    if diff_attrs:
        rand_attr = random.choice(diff_attrs)
        cf = random.choice(source_attr_to_items[rand_attr])
        clean_entry["attribute_counterfactual"] = {
            "label": cf["source_label"],
            "prompt": cf["source_input"],
            "template": cf["source_split"],
            "entity": cf["source_entity"],
            "attribute": cf["source_attribute"]
        }

    # 3. Wikipedia counterfactual
    if "wikipedia" in source_attr_to_items:
        cf = random.choice(source_attr_to_items["wikipedia"])
        clean_entry["wikipedia_counterfactual"] = {
            "label": cf["source_label"],
            "prompt": cf["source_input"],
            "template": cf["source_split"],
            "entity": cf["source_entity"],
            "attribute": cf["source_attribute"]
        }

    clean_ravel_train.append(clean_entry)


In [122]:
# Reformat test

source_attr_to_items = defaultdict(list)

for item in ravel_test:
    source_attr = item.get("source_attribute", item["attribute"])
    source_attr_to_items[source_attr].append(item)

clean_ravel_test = []

for og_instance in ravel_test:
    og_attr = og_instance["attribute"]
    clean_entry = {
        "template": og_instance["split"],
        "prompt": og_instance["input"],
        "label": og_instance["label"],
        "entity": og_instance["entity"],
        "attribute": og_attr,
        "prompt_template_counterfactual": {},
        "attribute_counterfactual": {},
        "wikipedia_counterfactual": {}
    }

    # 1. Prompt template counterfactual: same source_attribute as og_attr
    same_source = [item for item in source_attr_to_items[og_attr]]
    if same_source:
        cf = random.choice(same_source)
        clean_entry["prompt_template_counterfactual"] = {
            "label": cf["source_label"],
            "prompt": cf["source_input"],
            "template": cf["source_split"],
            "entity": cf["source_entity"],
            "attribute": cf["source_attribute"]
        }

    # 2. Attribute counterfactual: different source_attribute, not "wikipedia"
    diff_attrs = [
        k for k in source_attr_to_items
        if k != og_attr and k != "wikipedia"
    ]
    if diff_attrs:
        rand_attr = random.choice(diff_attrs)
        cf = random.choice(source_attr_to_items[rand_attr])
        clean_entry["attribute_counterfactual"] = {
            "label": cf["source_label"],
            "prompt": cf["source_input"],
            "template": cf["source_split"],
            "entity": cf["source_entity"],
            "attribute": cf["source_attribute"]
        }

    # 3. Wikipedia counterfactual
    if "wikipedia" in source_attr_to_items:
        cf = random.choice(source_attr_to_items["wikipedia"])
        clean_entry["wikipedia_counterfactual"] = {
            "label": cf["source_label"],
            "prompt": cf["source_input"],
            "template": cf["source_split"],
            "entity": cf["source_entity"],
            "attribute": cf["source_attribute"]
        }

    clean_ravel_test.append(clean_entry)


In [146]:
# Reformat val

source_attr_to_items = defaultdict(list)

for item in ravel_val:
    source_attr = item.get("source_attribute", item["attribute"])
    source_attr_to_items[source_attr].append(item)

clean_ravel_val = []

for og_instance in ravel_val:
    og_attr = og_instance["attribute"]
    clean_entry = {
        "template": og_instance["split"],
        "prompt": og_instance["input"],
        "label": og_instance["label"],
        "entity": og_instance["entity"],
        "attribute": og_attr,
        "prompt_template_counterfactual": {},
        "attribute_counterfactual": {},
        "wikipedia_counterfactual": {}
    }

    # 1. Prompt template counterfactual: same source_attribute as og_attr
    same_source = [item for item in source_attr_to_items[og_attr]]
    if same_source:
        cf = random.choice(same_source)
        clean_entry["prompt_template_counterfactual"] = {
            "label": cf["source_label"],
            "prompt": cf["source_input"],
            "template": cf["source_split"],
            "entity": cf["source_entity"],
            "attribute": cf["source_attribute"]
        }

    # 2. Attribute counterfactual: different source_attribute, not "wikipedia"
    diff_attrs = [
        k for k in source_attr_to_items
        if k != og_attr and k != "wikipedia"
    ]
    if diff_attrs:
        rand_attr = random.choice(diff_attrs)
        cf = random.choice(source_attr_to_items[rand_attr])
        clean_entry["attribute_counterfactual"] = {
            "label": cf["source_label"],
            "prompt": cf["source_input"],
            "template": cf["source_split"],
            "entity": cf["source_entity"],
            "attribute": cf["source_attribute"]
        }

    # 3. Wikipedia counterfactual
    if "wikipedia" in source_attr_to_items:
        cf = random.choice(source_attr_to_items["wikipedia"])
        clean_entry["wikipedia_counterfactual"] = {
            "label": cf["source_label"],
            "prompt": cf["source_input"],
            "template": cf["source_split"],
            "entity": cf["source_entity"],
            "attribute": cf["source_attribute"]
        }

    clean_ravel_val.append(clean_entry)


#### filter + map

In [148]:
from datasets import Dataset

def create_city_map(city_entities_ds):
    """
    city_entities_ds is a HF Dataset with columns:
      ['ID', 'City', 'Continent', 'Country', 'Language'].

    Return a dict: city_name -> (Continent, Country, Language)
    """
    city_map = {}
    for row in city_entities_ds:
        city_name = row["City"]
        city_map[city_name] = (row["Continent"], row["Country"], row["Language"])
    return city_map


# Build separate maps for train, val, test
city_map_train = create_city_map(ravel_city_entities["train"])
city_map_val   = create_city_map(ravel_city_entities["val"])
city_map_test  = create_city_map(ravel_city_entities["test"])


In [149]:
from datasets import Dataset

def filter_missing_cities(example, city_map):
    """
    Return True if the main entity and all cf-entities appear in city_map.
    Return False otherwise (so the example gets dropped).
    """
    # 1) Main entity
    if example["entity"] not in city_map:
        return False

    # 2) Counterfactual sub-dicts
    for cf_name in ["prompt_template_counterfactual", "attribute_counterfactual", "wikipedia_counterfactual"]:
        cf_dict = example.get(cf_name)
        if cf_dict is not None:
            if cf_dict["entity"] not in city_map:
                return False

    return True

def add_city_info(example, city_map):
    """
    Add new columns (Continent, Country, Language) for the main entity,
    and e.g. Continent_prompt_template_counterfactual, etc. for subdicts.
    """
    # 1) Main entity
    main_city = example["entity"]
    c_main, co_main, l_main = city_map[main_city]  # guaranteed to exist after filtering
    example["Continent"] = c_main
    example["Country"]   = co_main
    example["Language"]  = l_main

    # 2) For each counterfactual sub-dict
    for cf_name in ["prompt_template_counterfactual", "attribute_counterfactual", "wikipedia_counterfactual"]:
        cf_dict = example.get(cf_name)
        if cf_dict is not None:
            cf_city = cf_dict["entity"]
            c_cf, co_cf, l_cf = city_map[cf_city]  # guaranteed after filtering
            cf_dict["Continent"] = c_cf
            cf_dict["Country"]   = co_cf
            cf_dict["Language"]  = l_cf

    return example



In [140]:
def filter_missing_cities(example, city_map):
    """
    Return True if the main entity and all cf-entities appear in city_map.
    Return False otherwise (so the example gets dropped).
    """
    # 1) Main entity
    if example["entity"] not in city_map:
        return False

    return True

def add_city_info(example, city_map):
    """
    Add new columns (Continent, Country, Language) for the main entity,
    and e.g. Continent_prompt_template_counterfactual, etc. for subdicts.
    """
    # 1) Main entity
    main_city = example["entity"]
    c_main, co_main, l_main = city_map[main_city]  # guaranteed to exist after filtering
    example["Continent"] = c_main
    example["Country"]   = co_main
    example["Language"]  = l_main

    return example


In [119]:
mib_ravel_train_ds = Dataset.from_list(clean_ravel_train)
filtered_train_ds = mib_ravel_train_ds.filter(lambda ex: filter_missing_cities(ex, city_map_train))
ravel_train = filtered_train_ds.map(lambda ex: add_city_info(ex, city_map_train))

print("Original train size:", len(mib_ravel_train_ds))
print("Filtered train size:", len(ravel_train))


Filter:   0%|          | 0/100347 [00:00<?, ? examples/s]

Map:   0%|          | 0/100347 [00:00<?, ? examples/s]

Original train size: 100347
Filtered train size: 100347


In [150]:
mib_ravel_val_ds = Dataset.from_list(clean_ravel_val)
filtered_val_ds = mib_ravel_val_ds.filter(lambda ex: filter_missing_cities(ex, city_map_val))
ravel_val = filtered_val_ds.map(lambda ex: add_city_info(ex, city_map_val))

print("Original val size:", len(mib_ravel_val_ds))
print("Filtered val size:", len(ravel_val))

Filter:   0%|          | 0/31900 [00:00<?, ? examples/s]

Map:   0%|          | 0/31900 [00:00<?, ? examples/s]

Original val size: 31900
Filtered val size: 31900


In [124]:

mib_ravel_test_ds = Dataset.from_list(clean_ravel_test)

filtered_test_ds = mib_ravel_test_ds.filter(lambda ex: filter_missing_cities(ex, city_map_test))
ravel_test = filtered_test_ds.map(lambda ex: add_city_info(ex, city_map_test))

print("Original test size:", len(mib_ravel_test_ds))
print("Filtered test size:", len(ravel_test))


Filter:   0%|          | 0/31190 [00:00<?, ? examples/s]

Map:   0%|          | 0/31190 [00:00<?, ? examples/s]

Original test size: 31190
Filtered test size: 31190


In [None]:
template_to_attribute = {}

for split_name in ["train", "val", "test"]:
    templates = ravel_city_prompts[split_name]["Template"]
    attributes = ravel_city_prompts[split_name]["Attribute"]
    for t, a in zip(templates, attributes):
        template_to_attribute[t] = a

def add_attribute_fields(example, template_to_attribute):
    t = example["split"]
    example["attribute"] = template_to_attribute[t]
    t2 = example["source_split"]
    example["source_attribute"] = template_to_attribute[t2] if template_to_attribute[t2] != "" else "wikipedia"
    return example

ravel_train = ravel_train.map(lambda ex: add_attribute_fields(ex, template_to_attribute))
ravel_test  = ravel_test.map(lambda ex: add_attribute_fields(ex, template_to_attribute))
ravel_val   = ravel_val.map(lambda ex: add_attribute_fields(ex, template_to_attribute))

Map:   0%|          | 0/31900 [00:00<?, ? examples/s]

In [None]:

# random.shuffle(ravel_val)
# mib_ravel_val_split = ravel_val[:len(ravel_val)//2]
# mib_ravel_public_test = ravel_val[len(ravel_val)//2:]

split_dataset = ravel_val.train_test_split(test_size=0.5, seed=42)
mib_ravel_val_split = split_dataset["train"]      # first half
mib_ravel_public_test = split_dataset["test"]     # second half


In [159]:
# json.dump(ravel_train, open("content/mib/ravel_train.json", "w"), indent=4)
# json.dump(ravel_test, open("content/mib/ravel_private_test.json", "w"), indent=4)
# json.dump(mib_ravel_val_split, open("content/mib/ravel_val.json", "w"), indent=4)
# json.dump(mib_ravel_public_test, open("content/mib/ravel_public_test.json", "w"), indent=4)


# Convert to a list of dicts
ravel_train_list = ravel_train.to_list()
ravel_test_list  = ravel_test.to_list()
mib_ravel_val_split_list = mib_ravel_val_split.to_list()
mib_ravel_public_test_list = mib_ravel_public_test.to_list()

# Now you can JSON-serialize these lists
import json
with open("content/mib/ravel_train.json", "w") as f:
    json.dump(ravel_train_list, f, indent=4)

with open("content/mib/ravel_private_test.json", "w") as f:
    json.dump(ravel_test_list, f, indent=4)

with open("content/mib/ravel_val.json", "w") as f:
    json.dump(mib_ravel_val_split_list, f, indent=4)

with open("content/mib/ravel_public_test.json", "w") as f:
    json.dump(mib_ravel_public_test_list, f, indent=4)


In [160]:
# mib_ravel_train = Dataset.from_list(clean_ravel_train)
# mib_ravel_val_split = Dataset.from_list(mib_ravel_val_split)
# mib_ravel_public_test = Dataset.from_list(mib_ravel_public_test)

mib_ravel = DatasetDict({"train": ravel_train,                        
                         "val": mib_ravel_val_split,
                         "test": mib_ravel_public_test})
mib_ravel.push_to_hub("yiksiu/mib_ravel")

mib_ravel_private = DatasetDict({"test": ravel_test})
mib_ravel_private.push_to_hub("yiksiu/mib_ravel_private_test")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/101 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/16 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/16 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/32 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/yiksiu/mib_ravel_private_test/commit/d8fa09079f66fef5bbafb6ca1db0f929f7290d16', commit_message='Upload dataset', commit_description='', oid='d8fa09079f66fef5bbafb6ca1db0f929f7290d16', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/yiksiu/mib_ravel_private_test', endpoint='https://huggingface.co', repo_type='dataset', repo_id='yiksiu/mib_ravel_private_test'), pr_revision=None, pr_num=None)