In [1]:
import torch
import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel


def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'<instruct>{task_description}\n<query>{query}'

def get_detailed_example(task_description: str, query: str, response: str) -> str:
    return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'

def get_new_queries(queries, query_max_len, examples_prefix, tokenizer):
    inputs = tokenizer(
        queries,
        max_length=query_max_len - len(tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(
            tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']),
        return_token_type_ids=False,
        truncation=True,
        return_tensors=None,
        add_special_tokens=False
    )
    prefix_ids = tokenizer(examples_prefix, add_special_tokens=False)['input_ids']
    suffix_ids = tokenizer('\n<response>', add_special_tokens=False)['input_ids']
    new_max_length = (len(prefix_ids) + len(suffix_ids) + query_max_len + 8) // 8 * 8 + 8
    new_queries = tokenizer.batch_decode(inputs['input_ids'])
    for i in range(len(new_queries)):
        new_queries[i] = examples_prefix + new_queries[i] + '\n<response>'
    return new_max_length, new_queries

task = 'Given a question, retrieve passages that answer the question.'
examples = [
  {'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
   'query': 'what is a virtual interface',
   'response': "A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes."},
  {'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
   'query': 'causes of back pain in female for a week',
   'response': "Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management."}
]
documents = [
        "As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
        "Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments."
]
examples = [get_detailed_example(e['instruct'], e['query'], e['response']) for e in examples]
examples_prefix = '\n\n'.join(examples) + '\n\n' # if there not exists any examples, just set examples_prefix = ''
queries = [
    get_detailed_instruct(task, 'how much protein should a female eat'),
    get_detailed_instruct(task, 'summit define')
]
query_max_len, doc_max_len = 512, 512

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-en-icl')
new_query_max_len, new_queries = get_new_queries(queries, query_max_len, examples_prefix, tokenizer)

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
new_queries[:3]

['<instruct>Given a web search query, retrieve relevant passages that answer the query.\n<query>what is a virtual interface\n<response>A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes.\n\n<instruct>Given a web search query, retrieve relevant passages that answer the query.\n<query>causes of back pain in female for a week\n<response>Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or i

In [15]:
import pandas as pd

def preprocess_text(text):
    # Convert to lowercase
    text = text.lower()
    # Remove special characters and digits
    # text = re.sub(r'[^a-zA-Z\s]', '', text)
    # Remove extra whitespace
    text = ' '.join(text.split())
    return text

def preprocess_data(train_data, 
                    misconception_mapping, 
                    query_text_version, 
                    with_instruction=True, 
                    with_misconception=True, 
                    filter_na_misconception=True):

    # 1. Melt answer columns and create base dataframe
    answer_cols = ['AnswerAText', 'AnswerBText', 'AnswerCText', 'AnswerDText']
    answer_values = ['A', 'B', 'C', 'D']

    # Melt the answer columns
    melted_answers = pd.melt(
        train_data,
        id_vars=['QuestionId', 'QuestionText', 'ConstructId', 'ConstructName', 
                'SubjectId', 'SubjectName', 'CorrectAnswer'],
        value_vars=answer_cols,
        var_name='AnswerColumn',
        value_name='WrongAnswerText'
    )
    # Add WrongAnswer column based on AnswerColumn
    melted_answers['WrongAnswer'] = melted_answers['AnswerColumn'].map(
        dict(zip(answer_cols, answer_values))
    )


    # 2. Add MisconceptionId and MisconceptionName if with_misconception = True
    if with_misconception:
        misconception_cols = [f'Misconception{x}Id' for x in ['A', 'B', 'C', 'D']]  # Fixed column names
        melted_misconceptions = pd.melt(
            train_data,
            id_vars=['QuestionId', 'CorrectAnswer'],
            value_vars=misconception_cols,
            var_name='MisconceptionColumn',
            value_name='MisconceptionId'
        )
        melted_misconceptions['WrongAnswer'] = melted_misconceptions['MisconceptionColumn'].str[-3]
        
        df = melted_answers.merge(
            melted_misconceptions[['QuestionId', 'WrongAnswer', 'MisconceptionId']], 
            on=['QuestionId', 'WrongAnswer'], 
            how='left'
        )

        df = df.merge(
            misconception_mapping[['MisconceptionId', 'MisconceptionName']], 
            on='MisconceptionId', 
            how='left'
        )
    else:
        df = melted_answers

    # Create CorrectAnswerText column
    correct_answers = df[['QuestionId', 'WrongAnswer', 'WrongAnswerText']].copy()
    correct_answers = correct_answers[
        correct_answers['WrongAnswer'] == correct_answers['QuestionId'].map(
            train_data.set_index('QuestionId')['CorrectAnswer']
        )
    ]
    correct_answers = correct_answers.rename(
        columns={'WrongAnswerText': 'CorrectAnswerText'}
    )[['QuestionId', 'CorrectAnswerText']]
    # Merge correct answer text
    df = df.merge(correct_answers, on='QuestionId', how='left')
    # Filter out the correct answer
    df = df[df['WrongAnswer'] != df['CorrectAnswer']]
    # Create QuestionId_Answer column
    df['QuestionId_Answer'] = df['QuestionId'].astype(str) + '_' + df['WrongAnswer']
    if with_misconception:
        final_columns = ['QuestionId_Answer', 'QuestionId', 'QuestionText', 'ConstructId',
            'ConstructName', 'SubjectId', 'SubjectName', 'CorrectAnswer', 'CorrectAnswerText',
            'WrongAnswerText', 'WrongAnswer', 'MisconceptionId', 'MisconceptionName']
    else:
        final_columns = ['QuestionId_Answer', 'QuestionId', 'QuestionText', 'ConstructId',
            'ConstructName', 'SubjectId', 'SubjectName', 'CorrectAnswer', 'CorrectAnswerText',
            'WrongAnswerText', 'WrongAnswer']
    df = df[final_columns]
    
    if query_text_version == "v1":
        df["query_text"] = df["ConstructName"] + " " + df["QuestionText"] + " " + df["WrongAnswerText"]
        df["query_text"] = df["query_text"].apply(preprocess_text)
    else:
        raise ValueError(f"Invalid query_text_version: {query_text_version}")
    
    if with_instruction:
        task_description = 'Given a math question and an incorrect answer, please retrieve the most accurate reason for the misconception leading to the incorrect answer.'
        df['query_text'] = df.apply(lambda row: f"Instruction:{task_description}\nQuery:{row['query_text']}", axis=1)

    # filter out rows with NA in MisconceptionId
    if with_misconception and filter_na_misconception:
        df = df[df['MisconceptionId'].notna()]
    
    df = df.sort_values(['QuestionId', 'QuestionId_Answer']).reset_index(drop=True)
    df['order_index'] = df['QuestionId_Answer']
    
    return df


In [16]:
RAW_DATA_DIR = "/root/autodl-tmp/github/FlagEmbedding/examples/finetune/embedder/eval_data/raw_data"
EVAL_DATA_DIR = "/root/autodl-tmp/github/FlagEmbedding/examples/finetune/embedder/eval_data"
misconception_mapping = pd.read_csv(f"{RAW_DATA_DIR}/misconception_mapping.csv")
corpus = misconception_mapping['MisconceptionName'].values.tolist()

In [18]:
corpus[:3]

['Does not know that angles in a triangle sum to 180 degrees',
 'Uses dividing fractions method for multiplying fractions',
 'Believes there are 100 degrees in a full turn']

In [39]:
val_preprocessed.columns


Index(['QuestionId_Answer', 'QuestionId', 'QuestionText', 'ConstructId',
       'ConstructName', 'SubjectId', 'SubjectName', 'CorrectAnswer',
       'CorrectAnswerText', 'WrongAnswerText', 'WrongAnswer',
       'MisconceptionId', 'MisconceptionName', 'query_text', 'order_index'],
      dtype='object')

In [41]:
val_data = pd.read_csv(f"{RAW_DATA_DIR}/validation_v2/val.csv")
val_preprocessed = preprocess_data(val_data, misconception_mapping, 
                            query_text_version='v1',
                            with_instruction=False, 
                            with_misconception=True, 
                            filter_na_misconception=True)

selected_columns = ["query_text", "MisconceptionId",  'QuestionText', 'ConstructName', 'WrongAnswerText', 'SubjectName', 'MisconceptionName']
df_selected = val_preprocessed[selected_columns]

In [28]:
new_queries

['<instruct>Given a web search query, retrieve relevant passages that answer the query.\n<query>what is a virtual interface\n<response>A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes.\n\n<instruct>Given a web search query, retrieve relevant passages that answer the query.\n<query>causes of back pain in female for a week\n<response>Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or i

In [48]:
queries = []
for idx, row in df_selected.iterrows():
    example = {}
    task_description = f'Given a math question about {preprocess_text(row["ConstructName"])} and a misconcepted incorrect answer to it, retrieve the most accurate reason for the misconception leading to the incorrect answer.'
    query = f'{row["QuestionText"]} Incorrect answer : {row["WrongAnswerText"]}'
    example['instruct'] = task_description
    example['query'] = query
    example['response'] = row['MisconceptionName']
    example['response_id'] = row['MisconceptionId']
    queries.append(example)
queries[:100]

[{'instruct': 'Given a math question about simplify an algebraic fraction by factorising the numerator and a misconcepted incorrect answer to it, retrieve the most accurate reason for the misconception leading to the incorrect answer.',
  'query': 'Simplify the following, if possible: \\( \\frac{m^{2}+2 m-3}{m-3} \\) Incorrect answer : \\( m+1 \\)',
  'response': 'Does not know that to factorise a quadratic expression, to find two numbers that add to give the coefficient of the x term, and multiply to give the non variable term\n',
  'response_id': 2142.0},
 {'instruct': 'Given a math question about simplify an algebraic fraction by factorising the numerator and a misconcepted incorrect answer to it, retrieve the most accurate reason for the misconception leading to the incorrect answer.',
  'query': 'Simplify the following, if possible: \\( \\frac{m^{2}+2 m-3}{m-3} \\) Incorrect answer : \\( m+2 \\)',
  'response': 'Thinks that when you cancel identical terms from the numerator and de

In [47]:
df_selected["MisconceptionId"].values.tolist()

[2142.0,
 143.0,
 2142.0,
 1180.0,
 1180.0,
 1180.0,
 686.0,
 686.0,
 686.0,
 2123.0,
 2273.0,
 2133.0,
 907.0,
 1514.0,
 907.0,
 1889.0,
 1234.0,
 1312.0,
 2156.0,
 2156.0,
 1588.0,
 1775.0,
 1248.0,
 1529.0,
 2378.0,
 105.0,
 1659.0,
 616.0,
 2130.0,
 616.0,
 2474.0,
 2316.0,
 1944.0,
 1214.0,
 811.0,
 1214.0,
 704.0,
 1272.0,
 1651.0,
 1566.0,
 98.0,
 347.0,
 2264.0,
 1773.0,
 228.0,
 1958.0,
 699.0,
 959.0,
 182.0,
 1461.0,
 630.0,
 646.0,
 1970.0,
 1970.0,
 220.0,
 1911.0,
 2085.0,
 31.0,
 31.0,
 31.0,
 944.0,
 544.0,
 964.0,
 2359.0,
 1542.0,
 1734.0,
 2271.0,
 2271.0,
 1383.0,
 718.0,
 809.0,
 1450.0,
 1068.0,
 352.0,
 2159.0,
 906.0,
 906.0,
 906.0,
 1887.0,
 378.0,
 378.0,
 2083.0,
 2113.0,
 422.0,
 1180.0,
 1180.0,
 1180.0,
 272.0,
 1835.0,
 2397.0,
 1771.0,
 1590.0,
 1880.0,
 478.0,
 478.0,
 421.0,
 1568.0,
 642.0,
 2301.0,
 217.0,
 217.0,
 29.0,
 265.0,
 265.0,
 265.0,
 224.0,
 460.0,
 492.0,
 1388.0,
 1388.0,
 290.0,
 312.0,
 372.0,
 105.0,
 545.0,
 1815.0,
 255.0,
 463.0,

In [45]:
q = q*3
q

['asd', 'Asd', 'asd', 'Asd', 'asd', 'Asd']