In [1]:
import torch
from seper.calculate import gen_answers_batch, calculate_uncertainty_soft_batch, create_collate_fn
from seper.models.huggingface_models import HuggingfaceModel
from seper.uncertainty_measures.semantic_entropy import EntailmentDeberta
from seper.calculate import process_item_for_seper

#### Load Model
Load the generation model and entailment judge model.

In [2]:
model_path = '/data/dailu/llama-2-7b-chat-hf'
num_generations = 10 # 10 is good for most cases
sub_batch_size = 10
temperature = 1.0
max_new_tokens = 128
max_context_words = 4096
computation_chunk_size = 8 # adjust to balance speeds and gpu memory cost
prompt_type = 'default'
device = 'cuda:1'

# Build generator
generator = HuggingfaceModel(
    model_path,
    stop_sequences='default',
    max_new_tokens=max_new_tokens,
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
    device=device,
)
generator.model.eval()

# Build entailment model
entailment_model = EntailmentDeberta(device=device)
entailment_model.model.eval()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



DebertaV2ForSequenceClassification(
  (deberta): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 1536, padding_idx=0)
      (LayerNorm): LayerNorm((1536,), eps=1e-07, elementwise_affine=True)
      (dropout): StableDropout()
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-23): 24 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=1536, out_features=1536, bias=True)
              (key_proj): Linear(in_features=1536, out_features=1536, bias=True)
              (value_proj): Linear(in_features=1536, out_features=1536, bias=True)
              (pos_dropout): StableDropout()
              (dropout): StableDropout()
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=1536, out_features=1536, bias=True)
              (LayerNorm): LayerNorm((1536,), eps=1e-07, element

#### Step 1: Sampling
Generate for `num_generations` times and collect the results and probabilities.

In [3]:
example={
    'question': 'Is Elon Musk older than Sam Altman?',
    'context': 'Elon Musk was born on June 28, 1971. Sam Altman was born on April 22, 1985.',
    'answers': ['Yes'], # this is the provided ground-truth answers
}
result = gen_answers_batch(example, 
                           generator, 
                           temperature, 
                           num_generations, 
                           sub_batch_size, 
                           max_new_tokens, 
                           prompt_type, 
                           device,
                           max_context_words)

# Baseline example with empty context
example_baseline = example.copy()
example_baseline['context'] = ''

result_baseline = gen_answers_batch(
    example_baseline,
    generator,
    temperature,
    num_generations,
    sub_batch_size,
    max_new_tokens,
    prompt_type,
    device,
    max_context_words
)

In [4]:
result

{'question': 'Is Elon Musk older than Sam Altman?',
 'context': 'Elon Musk was born on June 28, 1971. Sam Altman was born on April 22, 1985.',
 'answers': ['Yes'],
 'responses': [('Yes.', [-0.011683372780680656, -0.22827866673469543]),
  ('Yes', [-0.011683372780680656]),
  ('Yes.', [-0.011683372780680656, -0.22827866673469543]),
  ('Yes.', [-0.011683372780680656, -0.22827866673469543]),
  ('Yes.', [-0.011683372780680656, -0.22827866673469543]),
  ('Yes.', [-0.011683372780680656, -0.22827866673469543]),
  ('Yes', [-0.011683372780680656]),
  ('Yes.', [-0.011683372780680656, -0.22827866673469543]),
  ('Yes. Elon Musk is 9 years older than Sam Altman.',
   [-0.011683372780680656,
    -0.22827866673469543,
    -1.8856996297836304,
    -1.0728830375228426e-06,
    -1.9192511899746023e-05,
    -3.814689989667386e-06,
    -0.4233023524284363,
    -0.9172983169555664,
    -0.06922823935747147,
    -0.00011383838864276186,
    -0.0002047805901383981,
    -1.2040065485052764e-05,
    -4.053107659

#### Step 2: Calculate ΔSePer
Aggregate semantic-equivalent answers and compute belief shift on gt answer.

In [5]:
keys = ['question', 'response_text', 'answers', 'likelihood', 'context_label', 'log_liks_agg', 'context']
seper_collate_fn = create_collate_fn(keys)

# calculate seper
with torch.no_grad():
    # Convert for SEPER
    r = process_item_for_seper(result)
    rb = process_item_for_seper(result_baseline)
    seper_input = seper_collate_fn([r, rb])
    seper, seper_baseline = calculate_uncertainty_soft_batch(
        seper_input, entailment_model, computation_chunk_size
    )
    d_seper = seper - seper_baseline

### Extented Results: Fine-grained Retrieval Utility Evaluation

In [6]:
example_baseline={
    'question': 'Is Elon Musk older than Sam Altman?',
    'context': '',
    'answers': ['Yes'], # this is the provided ground-truth answers
}

example_piece1 = example_baseline.copy()
example_piece1['context'] = 'Elon Musk was born on June 28, 1971.'
example_piece2 = example_baseline.copy()
example_piece2['context'] = 'Sam Altman was born on April 22, 1985.'

result_piece1 = gen_answers_batch(example_piece1, generator, temperature, num_generations, sub_batch_size, max_new_tokens, prompt_type, device, max_context_words)
result_piece2 = gen_answers_batch(example_piece2, generator, temperature, num_generations, sub_batch_size, max_new_tokens, prompt_type, device, max_context_words)


In [None]:
# calculate seper
with torch.no_grad():
    # Convert for SEPER
    r_1 = process_item_for_seper(result_piece1)
    r_2 = process_item_for_seper(result_piece2)
    seper_input = seper_collate_fn([r_1, r_2])
    seper_piece1, seper_piece2 = calculate_uncertainty_soft_batch(
        seper_input, entailment_model, computation_chunk_size
    )
    d_seper_piece1 = seper_piece1 - seper_baseline
    d_seper_piece2 = seper_piece2 - seper_baseline

print(
    f'(Baseline) SePer without retrieved information: {seper_baseline}\n', 
    f'SePer with two pieces of information: {seper}, ΔSePer: {d_seper}\n', 
    f'SePer with one piece of information: {seper_piece1}, ΔSePer: {d_seper_piece1}\n', 
    f'SePer with another piece of information: {seper_piece2}, ΔSePer: {d_seper_piece2}')

(Baseline)SePer without retrieved information: 0.4393003277431539
 SePer with two pieces of information: 0.9610638042002522, ΔSePer: 0.5217634764570982
 SePer with one piece of information: 0.6926190671233344, ΔSePer: 0.2533187393801805
 SePer with another piece of information: 0.5894756588579478, ΔSePer: 0.15017533111479392
