In [1]:
import json
import datasets
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

In [2]:
path = "../assets/data/counterfact.json"
# load the json file

with open(path, "r") as f:
    data = json.load(f)

In [3]:
data[0]

{'case_id': 0,
 'pararel_idx': 2796,
 'requested_rewrite': {'prompt': 'The mother tongue of {} is',
  'relation_id': 'P103',
  'target_new': {'str': 'English', 'id': 'Q1860'},
  'target_true': {'str': 'French', 'id': 'Q150'},
  'subject': 'Danielle Darrieux'},
 'paraphrase_prompts': ['Shayna does this and Yossel goes still and dies. Danielle Darrieux, a native',
  'An album was recorded for Capitol Nashville but never released. Danielle Darrieux spoke the language'],
 'neighborhood_prompts': ['The mother tongue of Léon Blum is',
  'The native language of Montesquieu is',
  'François Bayrou, a native',
  'The native language of Raymond Barre is',
  'Michel Rocard is a native speaker of',
  'Jacques Chaban-Delmas is a native speaker of',
  'The native language of François Bayrou is',
  'Maurice Genevoix, speaker of',
  'The mother tongue of François Bayrou is',
  'Melchior de Vogüé, speaker of'],
 'attribute_prompts': ['J.\xa0R.\xa0R. Tolkien is a native speaker of',
  'The mother tongue

In [4]:
data[0].keys()

dict_keys(['case_id', 'pararel_idx', 'requested_rewrite', 'paraphrase_prompts', 'neighborhood_prompts', 'attribute_prompts', 'generation_prompts'])

In [5]:
hf_dataset = datasets.Dataset.from_list(data)

In [6]:
hf_dataset

Dataset({
    features: ['case_id', 'pararel_idx', 'requested_rewrite', 'paraphrase_prompts', 'neighborhood_prompts', 'attribute_prompts', 'generation_prompts'],
    num_rows: 21919
})

In [7]:
model = LLM("gpt2")

INFO 06-17 18:17:46 config.py:1193] Casting torch.float32 to torch.float16.
INFO 06-17 18:17:46 config.py:1214] Downcasting torch.float32 to torch.float16.
INFO 06-17 18:17:46 llm_engine.py:161] Initializing an LLM engine (v0.5.0) with config: model='gpt2', speculative_config=None, tokenizer='gpt2', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=1024, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=gpt2)


INFO 06-17 18:17:46 weight_utils.py:218] Using model weights format ['*.safetensors']
INFO 06-17 18:17:47 weight_utils.py:261] No model.safetensors.index.json found in remote.
INFO 06-17 18:17:47 model_runner.py:159] Loading model weights took 0.2378 GB
INFO 06-17 18:17:48 gpu_executor.py:83] # GPU blocks: 34386, # CPU blocks: 7281
INFO 06-17 18:17:49 model_runner.py:878] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 06-17 18:17:49 model_runner.py:882] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 06-17 18:17:58 model_runner.py:954] Graph capturing finished in 9 secs.


In [8]:
out = model.generate("hello", sampling_params=SamplingParams(max_tokens=50))

Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  7.74it/s, Generation Speed: 389.05 toks/s]


In [9]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [10]:
def process_data(batch):
    continuations = []
    # flatten generation prompts
    generation_prompts = []
    for gen_list in batch["generation_prompts"]:
        generation_prompts.extend(gen_list)
    continuations = model.generate(
        generation_prompts,
        sampling_params=SamplingParams(max_tokens=50, temperature=0.7, top_k=50),
    )
    idx = 0
    reshaped_continuations = []
    for gen_list in batch["generation_prompts"]:
        chunk = [
            prompt + out.outputs[0].text
            for prompt, out in zip(
                generation_prompts[idx : idx + len(gen_list)],
                continuations[idx : idx + len(gen_list)],
            )
        ]
        reshaped_continuations.append(chunk)
        idx += len(gen_list)
    batch["generation_continuations"] = reshaped_continuations
    return batch

In [11]:
#processed_subset = hf_dataset.select(range(1000)).map(process_data, batched=True)
processed_subset = hf_dataset.map(process_data, batched=True)



Map:   0%|          | 0/21919 [00:00<?, ? examples/s]

Processed prompts: 100%|██████████| 10000/10000 [01:17<00:00, 129.47it/s, Generation Speed: 6337.01 toks/s]
Processed prompts: 100%|██████████| 10000/10000 [01:17<00:00, 128.74it/s, Generation Speed: 6304.13 toks/s]
Processed prompts: 100%|██████████| 10000/10000 [01:17<00:00, 128.99it/s, Generation Speed: 6322.91 toks/s]
Processed prompts: 100%|██████████| 10000/10000 [01:17<00:00, 129.04it/s, Generation Speed: 6326.65 toks/s]
Processed prompts: 100%|██████████| 10000/10000 [01:17<00:00, 128.78it/s, Generation Speed: 6301.06 toks/s]
Processed prompts: 100%|██████████| 10000/10000 [01:17<00:00, 129.64it/s, Generation Speed: 6349.81 toks/s]
Processed prompts: 100%|██████████| 10000/10000 [01:17<00:00, 129.10it/s, Generation Speed: 6344.32 toks/s]
Processed prompts: 100%|██████████| 10000/10000 [01:17<00:00, 129.15it/s, Generation Speed: 6325.85 toks/s]
Processed prompts: 100%|██████████| 10000/10000 [01:17<00:00, 129.79it/s, Generation Speed: 6347.29 toks/s]
Processed prompts: 100%|████

In [12]:
processed_subset

Dataset({
    features: ['case_id', 'pararel_idx', 'requested_rewrite', 'paraphrase_prompts', 'neighborhood_prompts', 'attribute_prompts', 'generation_prompts', 'generation_continuations'],
    num_rows: 21919
})

In [13]:
processed_subset.save_to_disk("../assets/data/processed_counterfact_full_data")

Saving the dataset (0/1 shards):   0%|          | 0/21919 [00:00<?, ? examples/s]