# Comprehension

## cosine_similarity

In [1]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import pandas as pd

compositional_preset = pd.read_csv('../data/Compositional Preset.csv')
answers_df = pd.read_csv('../data/english_learning_proud_moments.csv')

In [2]:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

In [3]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] 
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def calculate_cosine_similarity_sbert(question, answers):
    encoded_input_question = tokenizer(question, padding=True, truncation=True, return_tensors='pt')
    encoded_inputs_answers = tokenizer(answers, padding=True, truncation=True, return_tensors='pt', max_length=tokenizer.model_max_length)

    with torch.no_grad():
        model_output_question = model(**encoded_input_question)
        model_outputs_answers = model(**encoded_inputs_answers)

    sentence_embedding_question = mean_pooling(model_output_question, encoded_input_question['attention_mask'])
    sentence_embeddings_answers = mean_pooling(model_outputs_answers, encoded_inputs_answers['attention_mask'])

    sentence_embedding_question = F.normalize(sentence_embedding_question, p=2, dim=1)
    sentence_embeddings_answers = F.normalize(sentence_embeddings_answers, p=2, dim=1)

    cosine_similarities = F.cosine_similarity(sentence_embedding_question, sentence_embeddings_answers).tolist()

    return cosine_similarities

### 동일한 Question & Answer
- 15번 idx

In [4]:
question_text = compositional_preset['TOPICS'][15]

# Split the string on newline character to convert it into a list
questions = question_text.split('\n')

# Remove the numbering from each question
questions = [q[q.find(' ')+1:] for q in questions]

similarity_scores_sbert = {
    'Proud Moment': calculate_cosine_similarity_sbert(questions[0], answers_df['Proud Moment'].tolist()),
    'Personal Achievement': calculate_cosine_similarity_sbert(questions[1], answers_df['Personal Achievement'].tolist()),
    "Others' Impact": calculate_cosine_similarity_sbert(questions[2], answers_df["Others' Impact"].tolist())
}

# Create a DataFrame to display the similarity scores
similarity_scores_df_sbert = pd.DataFrame(similarity_scores_sbert, index=answers_df['Level'])
similarity_scores_df_sbert

Unnamed: 0_level_0,Proud Moment,Personal Achievement,Others' Impact
Level,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
A1 (Beginner),0.520288,0.412033,0.487585
A2 (Elementary),0.432508,0.36742,0.378452
B1 (Intermediate),0.466504,0.498527,0.420501
B2 (Upper Intermediate),0.433641,0.524253,0.563703
C1 (Advanced),0.325905,0.523665,0.570039


### 다른 Question & Answer
- 16번 idx

In [5]:
question_text = compositional_preset['TOPICS'][16]

# Split the string on newline character to convert it into a list
questions = question_text.split('\n')

# Remove the numbering from each question
questions = [q[q.find(' ')+1:] for q in questions]

similarity_scores_sbert = {
    'Proud Moment': calculate_cosine_similarity_sbert(questions[0], answers_df['Proud Moment'].tolist()),
    'Personal Achievement': calculate_cosine_similarity_sbert(questions[1], answers_df['Personal Achievement'].tolist()),
    "Others' Impact": calculate_cosine_similarity_sbert(questions[2], answers_df["Others' Impact"].tolist())
}

# Create a DataFrame to display the similarity scores
similarity_scores_df_sbert = pd.DataFrame(similarity_scores_sbert, index=answers_df['Level'])
similarity_scores_df_sbert

Unnamed: 0_level_0,Proud Moment,Personal Achievement,Others' Impact
Level,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
A1 (Beginner),0.276337,0.175256,0.268428
A2 (Elementary),0.22938,0.141988,0.134847
B1 (Intermediate),0.262853,0.206176,0.186677
B2 (Upper Intermediate),0.172176,0.180951,0.19935
C1 (Advanced),0.259099,0.193786,0.234782
