# Comprehension

## cosine_similarity

In [2]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import pandas as pd

compositional_preset = pd.read_csv('../data/Compositional Preset.csv')
answers_df = pd.read_csv('../data/english_learning_proud_moments.csv')

In [10]:
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-en-v1.5')
model = AutoModel.from_pretrained('BAAI/bge-large-en-v1.5')

In [11]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] 
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def calculate_cosine_similarity_baai(question, answers):
    encoded_input_question = tokenizer(question, padding=True, truncation=True, return_tensors='pt')
    encoded_inputs_answers = tokenizer(answers, padding=True, truncation=True, return_tensors='pt', max_length=tokenizer.model_max_length)

    with torch.no_grad():
        model_output_question = model(**encoded_input_question)
        model_outputs_answers = model(**encoded_inputs_answers)

    sentence_embedding_question = mean_pooling(model_output_question, encoded_input_question['attention_mask'])
    sentence_embeddings_answers = mean_pooling(model_outputs_answers, encoded_inputs_answers['attention_mask'])

    sentence_embedding_question = F.normalize(sentence_embedding_question, p=2, dim=1)
    sentence_embeddings_answers = F.normalize(sentence_embeddings_answers, p=2, dim=1)

    cosine_similarities = F.cosine_similarity(sentence_embedding_question, sentence_embeddings_answers).tolist()

    return cosine_similarities

### 동일한 Question & Answer
- 15번 idx

In [13]:
question_text = compositional_preset['TOPICS'][15]

# Split the string on newline character to convert it into a list
questions = question_text.split('\n')

# Remove the numbering from each question
questions = [q[q.find(' ')+1:] for q in questions]

similarity_scores_sbert = {
    'Proud Moment': calculate_cosine_similarity_baai(questions[0], answers_df['Proud Moment'].tolist()),
    'Personal Achievement': calculate_cosine_similarity_baai(questions[1], answers_df['Personal Achievement'].tolist()),
    "Others' Impact": calculate_cosine_similarity_baai(questions[2], answers_df["Others' Impact"].tolist())
}

# Create a DataFrame to display the similarity scores
similarity_scores_df_sbert = pd.DataFrame(similarity_scores_sbert, index=answers_df['Level'])
similarity_scores_df_sbert

Unnamed: 0_level_0,Proud Moment,Personal Achievement,Others' Impact
Level,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
A1 (Beginner),0.697732,0.647941,0.725628
A2 (Elementary),0.689123,0.640891,0.605181
B1 (Intermediate),0.704619,0.720814,0.654133
B2 (Upper Intermediate),0.735158,0.658333,0.725895
C1 (Advanced),0.673334,0.67399,0.67824


### 다른 Question & Answer
- 16번 idx

In [14]:
question_text = compositional_preset['TOPICS'][16]

# Split the string on newline character to convert it into a list
questions = question_text.split('\n')

# Remove the numbering from each question
questions = [q[q.find(' ')+1:] for q in questions]

similarity_scores_sbert = {
    'Proud Moment': calculate_cosine_similarity_baai(questions[0], answers_df['Proud Moment'].tolist()),
    'Personal Achievement': calculate_cosine_similarity_baai(questions[1], answers_df['Personal Achievement'].tolist()),
    "Others' Impact": calculate_cosine_similarity_baai(questions[2], answers_df["Others' Impact"].tolist())
}

# Create a DataFrame to display the similarity scores
similarity_scores_df_sbert = pd.DataFrame(similarity_scores_sbert, index=answers_df['Level'])
similarity_scores_df_sbert

Unnamed: 0_level_0,Proud Moment,Personal Achievement,Others' Impact
Level,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
A1 (Beginner),0.586034,0.521978,0.60481
A2 (Elementary),0.564449,0.497697,0.532131
B1 (Intermediate),0.558492,0.560129,0.572006
B2 (Upper Intermediate),0.559892,0.507335,0.588718
C1 (Advanced),0.607687,0.522802,0.561366
