In [41]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="4"

import numpy as np
from itertools import chain

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from experiments.py.eval_utils_counterfact import test_batch_prediction

In [2]:
MODEL_NAME = "EleutherAI/gpt-j-6B" 
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).cuda()
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
tok.pad_token = tok.eos_token

In [109]:
data_item = [
    {
      "case_id": 0,
      "pararel_idx": 2796,
      "requested_rewrite": {
        "prompt": "The primary use of {} is",
        "target_true": {
          "str": "inflammation reduction",
        },
        "target_new": {
          "str": "cancer",
        },
        "subject": "Aspirin"
      },
      "paraphrase_prompts": [
        "Aspirin is commonly used for",
        "The main purpose of Aspirin is to",
        "Aspirin is primarily utilized for"
      ],
      "neighborhood_prompts": [
        "The primary use of Ibuprofen is",
        "The main purpose of Paracetamol is to",
        "Naproxen is commonly used for",
        "The primary use of Acetaminophen is",
        "The main purpose of Diclofenac is to",
        "Ibuprofen is primarily utilized for",
        "Paracetamol is commonly used for",
        "The primary use of Naproxen is",
        "The main purpose of Acetaminophen is to",
        "Diclofenac is primarily utilized for"
      ],
      "attribute_prompts": [
        "The primary use of Penicillin is",
        "The main purpose of Morphine is to",
        "Paracetamol is commonly used for",
        "The primary use of Antibiotics is",
        "The main purpose of Insulin is to",
        "Aspirin is primarily utilized for",
        "The primary use of Antihistamines is",
        "The main purpose of Codeine is to",
        "Acetaminophen is commonly used for",
        "The primary use of Statins is"
      ],
      "generation_prompts": [
        "The primary use of Aspirin is",
        "Aspirin is mainly used for",
        "The main purpose of Aspirin is",
        "Aspirin is primarily employed for",
        "The primary function of Aspirin is",
        "Aspirin is commonly utilized for",
        "The main application of Aspirin is",
        "Aspirin is primarily used to",
        "The primary role of Aspirin is",
        "Aspirin is mainly employed for"
      ]
    },
]

In [110]:
record = data_item[1]

subject, target_new, target_true = (
    record["requested_rewrite"][x] for x in ["subject", "target_new", "target_true"]
)
print(target_new, target_true)
rewrite_prompts = [record["requested_rewrite"]["prompt"].format(subject)]
paraphrase_prompts = record["paraphrase_prompts"]
neighborhood_prompts = record["neighborhood_prompts"]
attribute_prompts = record["attribute_prompts"]
# generation_prompts = record["generation_prompts"]

# Form a list of lists of prefixes to test.
prob_prompts = [
    rewrite_prompts,
    paraphrase_prompts,
    neighborhood_prompts,
    attribute_prompts,
]
# Flatten all the evaluated prefixes into one list.
probs = test_batch_prediction(model, tok, list(chain(*prob_prompts)), target_new["str"], target_true["str"])

# Unflatten the results again into a list of lists.
cutoffs = [0] + np.cumsum(list(map(len, prob_prompts))).tolist()
ret_probs = [probs[cutoffs[i - 1] : cutoffs[i]] for i in range(1, len(cutoffs))]
# Structure the restuls as a dictionary.
ret = {
    f"{key}_probs": ret_probs[i]
    for i, key in enumerate(
        [
            "rewrite_prompts",
            "paraphrase_prompts",
            "neighborhood_prompts",
            "attribute_prompts",
        ]
    )
}
ret

{'str': 'cancer'} {'str': 'inflammation reduction'}


{'rewrite_prompts_probs': [{'target_new': 10.495936393737793,
   'target_true': 5.743099689483643}],
 'paraphrase_prompts_probs': [{'target_new': 7.574860095977783,
   'target_true': 5.8621087074279785},
  {'target_new': 12.905136108398438, 'target_true': 7.993843078613281},
  {'target_new': 7.8686842918396, 'target_true': 5.595829010009766}],
 'neighborhood_prompts_probs': [{'target_new': 12.509492874145508,
   'target_true': 5.296060085296631},
  {'target_new': 14.030357360839844, 'target_true': 8.841634750366211},
  {'target_new': 9.405290603637695, 'target_true': 6.111368179321289},
  {'target_new': 12.318239212036133, 'target_true': 6.6969709396362305},
  {'target_new': 13.5101900100708, 'target_true': 8.439132690429688},
  {'target_new': 9.135297775268555, 'target_true': 5.3316755294799805},
  {'target_new': 9.790838241577148, 'target_true': 7.287423133850098},
  {'target_new': 11.584332466125488, 'target_true': 6.201481819152832},
  {'target_new': 14.094964981079102, 'target_tru

In [99]:
prefixes = list(chain(*prob_prompts))
prefix_lens = [len(n) for n in tok(prefixes)["input_ids"]]
prompt_tok = tok(
    [
        f"{prefix} {suffix}"
        for prefix in prefixes
        for suffix in [target_new['str'], target_true['str']]
    ],
    padding=True,
    return_tensors="pt",
).to("cuda")

a_tok, b_tok = (tok(f" {n}")["input_ids"] for n in [target_new['str'], target_true['str']])
choice_a_len, choice_b_len = (len(n) for n in [a_tok, b_tok])

In [100]:
with torch.no_grad():
    logits = model(**prompt_tok).logits

results = np.zeros((logits.size(0),), dtype=np.float32)

for i in range(logits.size(0)):
    cur_len = choice_a_len if i % 2 == 0 else choice_b_len
    for j in range(cur_len):
        cur_tok = (a_tok if i % 2 == 0 else b_tok)[j]
        results[i] += torch.nn.functional.softmax(
            logits[i, prefix_lens[i // 2] + j - 1, :], dim=0
        )[cur_tok].item()
    results[i] /= cur_len

In [101]:
[
    {"target_new": results[i].item(), "target_true": results[i + 1].item()}
    for i in range(0, len(results), 2)
]

[{'target_new': 0.0005219134618528187, 'target_true': 0.0884685143828392},
 {'target_new': 0.01124377828091383, 'target_true': 0.037445858120918274},
 {'target_new': 0.0011204167967662215, 'target_true': 0.0010650110198184848},
 {'target_new': 0.04286317899823189, 'target_true': 0.024375326931476593},
 {'target_new': 0.0011810313444584608, 'target_true': 0.0001384748611599207},
 {'target_new': 0.0011315789306536317, 'target_true': 0.0005501543055288494},
 {'target_new': 0.0064858035184443, 'target_true': 0.042258162051439285},
 {'target_new': 0.014677083119750023, 'target_true': 0.0026560230180621147},
 {'target_new': 0.007684157695621252, 'target_true': 0.03435734286904335},
 {'target_new': 0.001247983775101602, 'target_true': 0.018311411142349243},
 {'target_new': 0.000556560349650681, 'target_true': 0.0003720947715919465},
 {'target_new': 0.008322123438119888, 'target_true': 0.0033334321342408657},
 {'target_new': 0.0024670579005032778, 'target_true': 0.0014258504379540682},
 {'targ