In [1]:
import torch
import os
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer, util
from evaluate import load
from huggingface_hub import login

In [2]:
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
generics_kb = load_dataset('community-datasets/generics_kb')
commonsense_qa = load_dataset('tau/commonsense_qa')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [4]:
documents = [doc['generic_sentence'] for doc in generics_kb['train']]

In [5]:
embeddings_model = SentenceTransformer('all-MiniLM-L6-v2')
document_embeddings = embeddings_model.encode(documents, batch_size=64, show_progress_bar=True)

Batches:   0%|          | 0/15952 [00:00<?, ?it/s]

In [6]:
model_name = 'meta-llama/Llama-2-7b-hf'
bitsandbytes_config = BitsAndBytesConfig(load_in_4bit=True,
                                         bnb_4bit_compute_dtype=torch.float16,
                                         bnb_4bit_quant_type='nf4')

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             quantization_config=bitsandbytes_config,
                                             device_map='cuda')

tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [10]:
bleu = load('bleu')
bertscore = load('bertscore')

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

In [11]:
# Експеримент 1: RAG со 5 документи
sample_qa = commonsense_qa['train'][0]
question = sample_qa['question']
question

'The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?'

In [12]:
answer_texts = sample_qa['choices']['text']
answer_labels = sample_qa['choices']['label']
answer_options = '\n'.join([f'{l}: {t}' for t, l in zip(answer_texts, answer_labels)])
answer_options

'A: ignore\nB: enforce\nC: authoritarian\nD: yell at\nE: avoid'

In [13]:
correct_answer = sample_qa['answerKey']
correct_answer

'A'

In [14]:
question_embedding = embeddings_model.encode(question, batch_size=64, show_progress_bar=True)
context_results = util.semantic_search(question_embedding, document_embeddings, top_k=5)
context_results

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

[[{'corpus_id': 803035, 'score': 0.6048142910003662},
  {'corpus_id': 802722, 'score': 0.5605494976043701},
  {'corpus_id': 798433, 'score': 0.5495730638504028},
  {'corpus_id': 802818, 'score': 0.5473368763923645},
  {'corpus_id': 803033, 'score': 0.5472485423088074}]]

In [15]:
doc_ids = [c['corpus_id'] for c in context_results[0]]
doc_ids

[803035, 802722, 798433, 802818, 803033]

In [16]:
docs = [documents[d] for d in doc_ids]
docs

['Schools take actions.',
 'School reform is a freight train moving down the track.',
 'Sanctions are retribution.',
 'School is a punishment visited on children for having been born ignorant.',
 'Schools study effects.']

In [17]:
context = f'Context:\nDocument 1: {docs[0]}\nDocument 2: {docs[1]}\nDocument 3: {docs[2]}\nDocument 4: {docs[3]}\nDocument 5: {docs[4]}'
context

'Context:\nDocument 1: Schools take actions.\nDocument 2: School reform is a freight train moving down the track.\nDocument 3: Sanctions are retribution.\nDocument 4: School is a punishment visited on children for having been born ignorant.\nDocument 5: Schools study effects.'

In [18]:
prompt = f'{context}\n\nAnswer the following question: {question}\n{answer_options}\nYour answer should be one of A, B, C, D, or E\nAnswer: '
prompt

'Context:\nDocument 1: Schools take actions.\nDocument 2: School reform is a freight train moving down the track.\nDocument 3: Sanctions are retribution.\nDocument 4: School is a punishment visited on children for having been born ignorant.\nDocument 5: Schools study effects.\n\nAnswer the following question: The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?\nA: ignore\nB: enforce\nC: authoritarian\nD: yell at\nE: avoid\nYour answer should be one of A, B, C, D, or E\nAnswer: '

In [19]:
tokens = tokenizer(prompt, return_tensors='pt').to('cuda:0')
tokens

{'input_ids': tensor([[    1, 15228, 29901,    13,  6268, 29871, 29896, 29901,  4523, 29879,
          2125,  8820, 29889,    13,  6268, 29871, 29906, 29901,  4523, 11736,
           338,   263,  3005,   523,  7945,  8401,  1623,   278,  5702, 29889,
            13,  6268, 29871, 29941, 29901,  3087,  1953,   526,  3240,  3224,
         29889,    13,  6268, 29871, 29946, 29901,  4523,   338,   263,  6035,
         18310, 16669,   373,  4344,   363,  2534,  1063,  6345, 16245,   424,
         29889,    13,  6268, 29871, 29945, 29901,  4523, 29879,  6559,  9545,
         29889,    13,    13, 22550,   278,  1494,  1139, 29901,   450,  9753,
          1953,  2750,   278,  3762,   892,   263,  6035, 14424, 13031, 29892,
           322,   896,  6140,   304,   825,   278, 14231,   278,  3762,   750,
          1754,   304,  1735, 29973,    13, 29909, 29901, 11455,    13, 29933,
         29901,   427, 10118,    13, 29907, 29901,  4148,  3673,   713,    13,
         29928, 29901,   343,   514,  

In [20]:
output_ids = model.generate(tokens.input_ids, max_new_tokens=50)
output_ids

tensor([[    1, 15228, 29901,    13,  6268, 29871, 29896, 29901,  4523, 29879,
          2125,  8820, 29889,    13,  6268, 29871, 29906, 29901,  4523, 11736,
           338,   263,  3005,   523,  7945,  8401,  1623,   278,  5702, 29889,
            13,  6268, 29871, 29941, 29901,  3087,  1953,   526,  3240,  3224,
         29889,    13,  6268, 29871, 29946, 29901,  4523,   338,   263,  6035,
         18310, 16669,   373,  4344,   363,  2534,  1063,  6345, 16245,   424,
         29889,    13,  6268, 29871, 29945, 29901,  4523, 29879,  6559,  9545,
         29889,    13,    13, 22550,   278,  1494,  1139, 29901,   450,  9753,
          1953,  2750,   278,  3762,   892,   263,  6035, 14424, 13031, 29892,
           322,   896,  6140,   304,   825,   278, 14231,   278,  3762,   750,
          1754,   304,  1735, 29973,    13, 29909, 29901, 11455,    13, 29933,
         29901,   427, 10118,    13, 29907, 29901,  4148,  3673,   713,    13,
         29928, 29901,   343,   514,   472,    13, 2

In [21]:
tokenizer.decode(output_ids[0], skip_special_tokens=True)

'Context:\nDocument 1: Schools take actions.\nDocument 2: School reform is a freight train moving down the track.\nDocument 3: Sanctions are retribution.\nDocument 4: School is a punishment visited on children for having been born ignorant.\nDocument 5: Schools study effects.\n\nAnswer the following question: The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?\nA: ignore\nB: enforce\nC: authoritarian\nD: yell at\nE: avoid\nYour answer should be one of A, B, C, D, or E\nAnswer: \n\n### Dummy model answers\n\nAnswer: E.\n\n### Other answers\n\nAnswer:\n\n```\nContext:\nDocument 1: Schools take actions.\nDocument 2: School reform is a fre'

In [22]:
predictions_5 = []
references_5 = []

In [23]:
for i in range(10):
    qa = commonsense_qa['train'][i]
    question = qa['question']
    answer_texts = qa['choices']['text']
    answer_labels = qa['choices']['label']
    correct_answer = qa['answerKey']

    question_embedding = embeddings_model.encode(question)
    context_results = util.semantic_search(question_embedding, document_embeddings, top_k=5)
    doc_ids = [c['corpus_id'] for c in context_results[0]]
    docs = [documents[d] for d in doc_ids]
    context = f'Context:\nDocument 1: {docs[0]}\nDocument 2: {docs[1]}\nDocument 3: {docs[2]}\nDocument 4: {docs[3]}\nDocument 5: {docs[4]}'

    answer_options = '\n'.join([f'{l}: {t}' for t, l in zip(answer_texts, answer_labels)])
    prompt = f'{context}\n\nAnswer the following question: {question}\n{answer_options}\nYour answer should be one of A, B, C, D, or E\nAnswer: '

    tokens = tokenizer(prompt, return_tensors='pt').to('cuda:0')
    output_ids = model.generate(tokens.input_ids, max_new_tokens=50)
    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    if 'Answer:' in output:
        pred = output.split('Answer:')[-1].strip()
        if len(pred) > 0 and pred[0] in 'ABCDE':
            pred = pred[0]
        else:
            pred = 'INVALID'
    else:
        pred = 'INVALID'

    predictions_5.append(pred)
    references_5.append(correct_answer)

In [24]:
valid_indices = [i for i, p in enumerate(predictions_5) if p != 'INVALID']
valid_predictions_5 = [predictions_5[i] for i in valid_indices]
valid_references_5 = [references_5[i] for i in valid_indices]

In [25]:
print(f'Valid predictions: {len(valid_predictions_5)}/10')

Valid predictions: 3/10


In [26]:
bleu.compute(predictions=valid_predictions_5, references=[[ref] for ref in valid_references_5])

{'bleu': 0.0,
 'precisions': [0.0, 0.0, 0.0, 0.0],
 'brevity_penalty': 1.0,
 'length_ratio': 1.0,
 'translation_length': 3,
 'reference_length': 3}

In [27]:
bertscore.compute(predictions=valid_predictions_5, references=valid_references_5, model_type='microsoft/deberta-xlarge-mnli')

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/792 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/3.04G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.04G [00:00<?, ?B/s]

{'precision': [0.69893479347229, 0.69893479347229, 0.5516265630722046],
 'recall': [0.69893479347229, 0.69893479347229, 0.5516265630722046],
 'f1': [0.69893479347229, 0.69893479347229, 0.5516265630722046],
 'hashcode': 'microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.57.3)'}

In [28]:
# Експеримент 2: RAG со 3 документи
sample_qa = commonsense_qa['train'][5]
question = sample_qa['question']
question

'What home entertainment equipment requires cable?'

In [29]:
question_embedding = embeddings_model.encode(question, batch_size=64, show_progress_bar=True)
context_results = util.semantic_search(question_embedding, document_embeddings, top_k=3)
doc_ids = [c['corpus_id'] for c in context_results[0]]
docs = [documents[d] for d in doc_ids]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [30]:
context = f'Context:\nDocument 1: {docs[0]}\nDocument 2: {docs[1]}\nDocument 3: {docs[2]}'
context

'Context:\nDocument 1: Cables are television.\nDocument 2: A cable system is television\nDocument 3: Cables are located in television.'

In [31]:
predictions_3 = []
references_3 = []

In [32]:
for i in range(10):
    qa = commonsense_qa['train'][i]
    question = qa['question']
    answer_texts = qa['choices']['text']
    answer_labels = qa['choices']['label']
    correct_answer = qa['answerKey']

    question_embedding = embeddings_model.encode(question)
    context_results = util.semantic_search(question_embedding, document_embeddings, top_k=3)
    doc_ids = [c['corpus_id'] for c in context_results[0]]
    docs = [documents[d] for d in doc_ids]
    context = f'Context:\nDocument 1: {docs[0]}\nDocument 2: {docs[1]}\nDocument 3: {docs[2]}'

    answer_options = '\n'.join([f'{l}: {t}' for t, l in zip(answer_texts, answer_labels)])
    prompt = f'{context}\n\nAnswer the following question: {question}\n{answer_options}\nYour answer should be one of A, B, C, D, or E\nAnswer: '

    tokens = tokenizer(prompt, return_tensors='pt').to('cuda:0')
    output_ids = model.generate(tokens.input_ids, max_new_tokens=50)
    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    if 'Answer:' in output:
        pred = output.split('Answer:')[-1].strip()
        if len(pred) > 0 and pred[0] in 'ABCDE':
            pred = pred[0]
        else:
            pred = 'INVALID'
    else:
        pred = 'INVALID'

    predictions_3.append(pred)
    references_3.append(correct_answer)

In [33]:
valid_indices = [i for i, p in enumerate(predictions_3) if p != 'INVALID']
valid_predictions_3 = [predictions_3[i] for i in valid_indices]
valid_references_3 = [references_3[i] for i in valid_indices]

In [34]:
print(f'Valid predictions: {len(valid_predictions_3)}/10')

Valid predictions: 6/10


In [35]:
bleu.compute(predictions=valid_predictions_3, references=[[ref] for ref in valid_references_3])

{'bleu': 0.0,
 'precisions': [0.16666666666666666, 0.0, 0.0, 0.0],
 'brevity_penalty': 1.0,
 'length_ratio': 1.0,
 'translation_length': 6,
 'reference_length': 6}

In [36]:
bertscore.compute(predictions=valid_predictions_3, references=valid_references_3, model_type='microsoft/deberta-xlarge-mnli')

{'precision': [0.5516265630722046,
  0.69893479347229,
  0.5516265630722046,
  1.0,
  0.69893479347229,
  0.5516265630722046],
 'recall': [0.5516265630722046,
  0.69893479347229,
  0.5516265630722046,
  1.0,
  0.69893479347229,
  0.5516265630722046],
 'f1': [0.5516265630722046,
  0.69893479347229,
  0.5516265630722046,
  1.0,
  0.69893479347229,
  0.5516265630722046],
 'hashcode': 'microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.57.3)'}

In [37]:
# Експеримент 3: RAG со 1 документ
sample_qa = commonsense_qa['train'][10]
question = sample_qa['question']
question

'Where do you put your grapes just before checking out?'

In [38]:
question_embedding = embeddings_model.encode(question, batch_size=64, show_progress_bar=True)
context_results = util.semantic_search(question_embedding, document_embeddings, top_k=1)
doc_ids = [c['corpus_id'] for c in context_results[0]]
docs = [documents[d] for d in doc_ids]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [39]:
context = f'Context:\nDocument 1: {docs[0]}'
context

'Context:\nDocument 1: Grapes are located in restaurants.'

In [40]:
predictions_1 = []
references_1 = []

In [41]:
for i in range(10):
    qa = commonsense_qa['train'][i]
    question = qa['question']
    answer_texts = qa['choices']['text']
    answer_labels = qa['choices']['label']
    correct_answer = qa['answerKey']

    question_embedding = embeddings_model.encode(question)
    context_results = util.semantic_search(question_embedding, document_embeddings, top_k=1)
    doc_ids = [c['corpus_id'] for c in context_results[0]]
    docs = [documents[d] for d in doc_ids]
    context = f'Context:\nDocument 1: {docs[0]}'

    answer_options = '\n'.join([f'{l}: {t}' for t, l in zip(answer_texts, answer_labels)])
    prompt = f'{context}\n\nAnswer the following question: {question}\n{answer_options}\nYour answer should be one of A, B, C, D, or E\nAnswer: '

    tokens = tokenizer(prompt, return_tensors='pt').to('cuda:0')
    output_ids = model.generate(tokens.input_ids, max_new_tokens=50)
    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    if 'Answer:' in output:
        pred = output.split('Answer:')[-1].strip()
        if len(pred) > 0 and pred[0] in 'ABCDE':
            pred = pred[0]
        else:
            pred = 'INVALID'
    else:
        pred = 'INVALID'

    predictions_1.append(pred)
    references_1.append(correct_answer)

In [42]:
valid_indices = [i for i, p in enumerate(predictions_1) if p != 'INVALID']
valid_predictions_1 = [predictions_1[i] for i in valid_indices]
valid_references_1 = [references_1[i] for i in valid_indices]

In [43]:
print(f'Valid predictions: {len(valid_predictions_1)}/10')

Valid predictions: 4/10


In [44]:
bleu.compute(predictions=valid_predictions_1, references=[[ref] for ref in valid_references_1])

{'bleu': 0.0,
 'precisions': [0.25, 0.0, 0.0, 0.0],
 'brevity_penalty': 1.0,
 'length_ratio': 1.0,
 'translation_length': 4,
 'reference_length': 4}

In [45]:
bertscore.compute(predictions=valid_predictions_1, references=valid_references_1, model_type='microsoft/deberta-xlarge-mnli')

{'precision': [0.9999998807907104,
  0.69893479347229,
  0.7198565006256104,
  0.6420041918754578],
 'recall': [0.9999998807907104,
  0.69893479347229,
  0.7198565006256104,
  0.6420041918754578],
 'f1': [0.9999998807907104,
  0.69893479347229,
  0.7198565006256104,
  0.6420041918754578],
 'hashcode': 'microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.57.3)'}

In [46]:
# Експеримент 4: Друг embedding модел (all-distilroberta-v1)
embeddings_model_distil = SentenceTransformer('all-distilroberta-v1')
document_embeddings_distil = embeddings_model_distil.encode(documents, batch_size=64, show_progress_bar=True)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/328M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/333 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/15952 [00:00<?, ?it/s]

In [47]:
sample_qa = commonsense_qa['train'][8]
question = sample_qa['question']
question

'What do people use to absorb extra ink from a fountain pen?'

In [48]:
question_embedding = embeddings_model_distil.encode(question, batch_size=64, show_progress_bar=True)
context_results = util.semantic_search(question_embedding, document_embeddings_distil, top_k=5)
doc_ids = [c['corpus_id'] for c in context_results[0]]
docs = [documents[d] for d in doc_ids]
docs

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

['Fountain pens have ink.',
 'Ink is located in fountain pens.',
 'Most pens use ink.',
 'Some materials absorb ink in different ways.',
 'Ink cartridges are part of fountain pens.']

In [49]:
predictions_distil = []
references_distil = []

In [50]:
for i in range(10):
    qa = commonsense_qa['train'][i]
    question = qa['question']
    answer_texts = qa['choices']['text']
    answer_labels = qa['choices']['label']
    correct_answer = qa['answerKey']

    question_embedding = embeddings_model_distil.encode(question)
    context_results = util.semantic_search(question_embedding, document_embeddings_distil, top_k=5)
    doc_ids = [c['corpus_id'] for c in context_results[0]]
    docs = [documents[d] for d in doc_ids]
    context = f'Context:\nDocument 1: {docs[0]}\nDocument 2: {docs[1]}\nDocument 3: {docs[2]}\nDocument 4: {docs[3]}\nDocument 5: {docs[4]}'

    answer_options = '\n'.join([f'{l}: {t}' for t, l in zip(answer_texts, answer_labels)])
    prompt = f'{context}\n\nAnswer the following question: {question}\n{answer_options}\nYour answer should be one of A, B, C, D, or E\nAnswer: '

    tokens = tokenizer(prompt, return_tensors='pt').to('cuda:0')
    output_ids = model.generate(tokens.input_ids, max_new_tokens=50)
    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    if 'Answer:' in output:
        pred = output.split('Answer:')[-1].strip()
        if len(pred) > 0 and pred[0] in 'ABCDE':
            pred = pred[0]
        else:
            pred = 'INVALID'
    else:
        pred = 'INVALID'

    predictions_distil.append(pred)
    references_distil.append(correct_answer)

In [51]:
valid_indices = [i for i, p in enumerate(predictions_distil) if p != 'INVALID']
valid_predictions_distil = [predictions_distil[i] for i in valid_indices]
valid_references_distil = [references_distil[i] for i in valid_indices]

In [52]:
print(f'Valid predictions: {len(valid_predictions_distil)}/10')

Valid predictions: 4/10


In [53]:
bleu.compute(predictions=valid_predictions_distil, references=[[ref] for ref in valid_references_distil])

{'bleu': 0.0,
 'precisions': [0.0, 0.0, 0.0, 0.0],
 'brevity_penalty': 1.0,
 'length_ratio': 1.0,
 'translation_length': 4,
 'reference_length': 4}

In [54]:
bertscore.compute(predictions=valid_predictions_distil, references=valid_references_distil, model_type='microsoft/deberta-xlarge-mnli')

{'precision': [0.69893479347229,
  0.69893479347229,
  0.5516266226768494,
  0.6629466414451599],
 'recall': [0.69893479347229,
  0.69893479347229,
  0.5516266226768494,
  0.6629466414451599],
 'f1': [0.69893479347229,
  0.69893479347229,
  0.5516266226768494,
  0.6629466414451599],
 'hashcode': 'microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.57.3)'}

In [55]:
# Експеримент 5: Zero-shot (без контекст)
sample_qa = commonsense_qa['train'][12]
question = sample_qa['question']
question

'Johnny sat on a bench and relaxed after doing a lot of work on his hobby.  Where is he?'

In [56]:
answer_texts = sample_qa['choices']['text']
answer_labels = sample_qa['choices']['label']
answer_options = '\n'.join([f'{l}: {t}' for t, l in zip(answer_texts, answer_labels)])
answer_options

'A: state park\nB: bus depot\nC: garden\nD: gym\nE: rest area'

In [57]:
prompt = f'Answer the following question: {question}\n{answer_options}\nYour answer should be one of A, B, C, D, or E\nAnswer: '
prompt

'Answer the following question: Johnny sat on a bench and relaxed after doing a lot of work on his hobby.  Where is he?\nA: state park\nB: bus depot\nC: garden\nD: gym\nE: rest area\nYour answer should be one of A, B, C, D, or E\nAnswer: '

In [58]:
tokens = tokenizer(prompt, return_tensors='pt').to('cuda:0')
output_ids = model.generate(tokens.input_ids, max_new_tokens=50)
tokenizer.decode(output_ids[0], skip_special_tokens=True)

'Answer the following question: Johnny sat on a bench and relaxed after doing a lot of work on his hobby.  Where is he?\nA: state park\nB: bus depot\nC: garden\nD: gym\nE: rest area\nYour answer should be one of A, B, C, D, or E\nAnswer:  Johnny sat on a bench and relaxed after doing a lot of work on his hobby.  Where is he?\nThe answer is A.\nJohnny sat on a bench and relaxed after doing a lot of work on'

In [59]:
predictions_zero = []
references_zero = []

In [60]:
for i in range(10):
    qa = commonsense_qa['train'][i]
    question = qa['question']
    answer_texts = qa['choices']['text']
    answer_labels = qa['choices']['label']
    correct_answer = qa['answerKey']

    answer_options = '\n'.join([f'{l}: {t}' for t, l in zip(answer_texts, answer_labels)])
    prompt = f'Answer the following question: {question}\n{answer_options}\nYour answer should be one of A, B, C, D, or E\nAnswer: '

    tokens = tokenizer(prompt, return_tensors='pt').to('cuda:0')
    output_ids = model.generate(tokens.input_ids, max_new_tokens=50)
    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    if 'Answer:' in output:
        pred = output.split('Answer:')[-1].strip()
        if len(pred) > 0 and pred[0] in 'ABCDE':
            pred = pred[0]
        else:
            pred = 'INVALID'
    else:
        pred = 'INVALID'

    predictions_zero.append(pred)
    references_zero.append(correct_answer)

In [61]:
valid_indices = [i for i, p in enumerate(predictions_zero) if p != 'INVALID']
valid_predictions_zero = [predictions_zero[i] for i in valid_indices]
valid_references_zero = [references_zero[i] for i in valid_indices]

In [62]:
print(f'Valid predictions: {len(valid_predictions_zero)}/10')

Valid predictions: 0/10


In [63]:
bleu.compute(predictions=valid_predictions_zero, references=[[ref] for ref in valid_references_zero])

IndexError: list index out of range

In [None]:
bertscore.compute(predictions=valid_predictions_zero, references=valid_references_zero, model_type='microsoft/deberta-xlarge-mnli')