In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook as tqdm
from multiprocessing import Pool, cpu_count
import seaborn as sns

In [None]:
SPECIAL_TASK_NAME = '_75%' # postfix to all train and test file names 

SNIPPET_DIR_PATH = './data/snippet/only_snippet/'
ORIGIN_DIR_PATH = './data/_original_data/'


TRAIN_NAME = f'train{SPECIAL_TASK_NAME}'

TEST_17_NAME = f'test_17{SPECIAL_TASK_NAME}'
TEST_18_NAME = f'test_18{SPECIAL_TASK_NAME}'
TRAIN_3000_NAME = f'train_3000{SPECIAL_TASK_NAME}'


In [None]:
textbook_train = pd.read_excel(f'{SNIPPET_DIR_PATH}textbook_train.xlsx', header=None)

In [None]:
textbook_test_17 = pd.read_excel(f'{SNIPPET_DIR_PATH}textbook_test_17.xlsx', header=None)
textbook_test_18 = pd.read_excel(f'{SNIPPET_DIR_PATH}textbook_test_18.xlsx', header=None)

In [None]:
textbook_train.shape

In [None]:
textbook_train.head(1)

In [None]:
train_data = pd.read_excel(f'{ORIGIN_DIR_PATH}train.xlsx', header=None)

In [None]:
test_data_17 = pd.read_excel(f'{ORIGIN_DIR_PATH}test_17-18.xlsx', sheet_name='2017', header=None)
test_data_18 = pd.read_excel(f'{ORIGIN_DIR_PATH}test_17-18.xlsx', sheet_name='2018', header=None)

In [None]:
test_data_17.shape

In [None]:
train_data.shape

In [None]:
train_data.head(1)

In [None]:
train_with_snippet = pd.concat([train_data, textbook_train], axis=1)

In [None]:
train_with_snippet.head(1)

In [None]:
test_17_with_snippet = pd.concat([test_data_17, textbook_test_17], axis=1)
test_18_with_snippet = pd.concat([test_data_18, textbook_test_18], axis=1)

## Data cleaning

In [None]:
# this is the corruyted entry! other data does not have col=8-9
train_with_snippet[train_with_snippet.iloc[:,8].isnull() != True]

In [None]:
train_with_snippet[train_with_snippet.iloc[:,8].isnull() != True].iloc[:,0].tolist()

In [None]:
if train_with_snippet[train_with_snippet.iloc[:,8].isnull() != True].iloc[:,0].tolist() != []:
    print(train_with_snippet[train_with_snippet.iloc[:,8].isnull() != True].iloc[:,0].tolist())
    train_with_snippet.drop(train_with_snippet.index[10040], inplace=True)

In [None]:
# train_data.iloc[10040, :]
# train_data.drop(train_data.index[10040], inplace=True)

In [None]:
# train_data = train_data.iloc[:, :8]

In [None]:
# train_data.isnull().stack()[lambda x: x].index.tolist()

In [None]:
# check corruyted data again
train_with_snippet[train_with_snippet.iloc[:,8].isnull() != True]

In [None]:
# 0: question
# 1: question type
# 2-6: choices
# 7: answer
# 10-14: textbook snippet

train_data = train_with_snippet.iloc[:,[0,1,2,3,4,5,6,7,10,11,12,13,14]]

In [None]:
train_data.head(1)

## Data cleaning - drop nan entries

In [None]:
# drop entries that has any nan and, 
# print num of nan in each col
def drop_nan(df):
    if df.isnull().values.any():
        print("nan in each col:\n", df.isnull().sum(), sep='')
        
        return df.dropna()
    return df

In [None]:
train_data = drop_nan(train_data)

In [None]:
# reset index after dropping rows
train_data.reset_index(inplace=True)
train_data = train_data.drop(columns='index')

In [None]:
train_data.shape

#### add column headers

In [None]:
train_data.columns = ['q', 'qtype', 'c1', 'c2', 'c3', 'c4', 'c5', 'a', 's1', 's2', 's3', 's4', 's5']

In [None]:
train_data.head(1)

In [None]:
test_17_with_snippet.head(1)

In [None]:
test_17_with_snippet.columns = ['q','c1','c2','c3','c4','c5', 'qtype', 'year', 'a', 's1','s2','s3','s4','s5']
test_18_with_snippet.columns = ['q','c1','c2','c3','c4','c5', 'qtype', 'year', 'a', 's1','s2','s3','s4','s5']

In [None]:
test_17_with_snippet.head(1)

In [None]:
test_18_with_snippet.head(1)

In [None]:
q_len = [len(i) for i in train_data['q']]
s1_len = [len(i) for i in train_data['s1']]

In [None]:
sns.distplot(s1_len, kde=False)

In [None]:
sns.distplot(q_len, kde=False)

## Add textbook snippit - only the best match

### 1. question snippet in one sentence  

In [None]:
# question and snippet being combined into one pandas series
# add the best (the first) match snippet from textbook 
train_combined_qs = train_with_snippet.iloc[:,0].combine(train_with_snippet.iloc[:,10], func=(lambda q, t: str(q) + str(t)))

In [None]:
train_combined_qs.shape

In [None]:
# question and snippet in one sentence 
train_data_combined_qs = train_data.iloc[:,:8].copy()

In [None]:
train_data_combined_qs.iloc[:,0] = train_combined_qs

In [None]:
train_data_combined_qs.head(1)

#### do the same procedure for testing data 

In [None]:
test_18_with_snippet.head(1)

In [None]:
# diff from training
# best match is at col index 9
test_17_combined_qs = test_17_with_snippet.iloc[:,0].combine(test_17_with_snippet.iloc[:,9], func=(lambda q, t: str(q) + str(t)))

In [None]:
test_18_combined_qs = test_18_with_snippet.iloc[:,0].combine(test_18_with_snippet.iloc[:,9], func=(lambda q, t: str(q) + str(t)))

In [None]:
test_17_with_snippet.head(1)

In [None]:
test_data_17_combined_qs = test_17_with_snippet.iloc[:, :9].copy()
test_data_18_combined_qs = test_18_with_snippet.iloc[:, :9].copy()

In [None]:
test_data_17_combined_qs.iloc[:,0] = test_17_combined_qs
test_data_18_combined_qs.iloc[:,0] = test_18_combined_qs

In [None]:
# double checking 
test_data_17_combined_qs.head(1)

### 2. question snippet in separate sentence 

In [None]:
# the best match is in the last column
train_data_separate_qs = train_data.iloc[:,:9].copy()

In [None]:
train_data_separate_qs.head(1)

# TODO: same process for testing data

In [None]:
test_data_17_separate_qs = test_17_with_snippet.iloc[:, :10].copy()
test_data_18_separate_qs = test_18_with_snippet.iloc[:, :10].copy()

In [None]:
test_data_17_separate_qs.head(1)

In [None]:
test_data_18_separate_qs.head(5)

In [None]:
train_data_3000_separate_qs = train_data_separate_qs.head(3000)

In [None]:
train_data_3000_separate_qs.shape

## Prepare data 



question + 5 choice -> 5 * (question + 1 choice)

In [None]:
def prepare_data(df, question_index, first_choice_index, answer_index):
    """
    question_index: int, specify the index of question column 
    first_choice_index: int, the index of the *first* multiple choice column among 5
    answer_index: int, index of answer column
    
    return:
    expanded
    """
    result = []
    # iterate through all entries in df
    for i in tqdm(range(df.shape[0])):
        one_entry = df.iloc[i,:] 
        # for each entry, take its 5 choices in sequence into 5 [question, one_choice, label] outputs
        for choice_index in range(first_choice_index, first_choice_index + 5):
            
            label = 1 if (choice_index - first_choice_index + 1)  == one_entry[answer_index] else 0
            result.append({'question': one_entry[question_index], 
                       'choice': one_entry[choice_index], 
                       'label': label})
    return result

In [None]:
train_df_combined_qs = pd.DataFrame(prepare_data(train_data_combined_qs, 0, 2, 7))
# this step is extramely slow, add converting process

In [None]:
train_3000_df_combined_qs = train_df_combined_qs.head(15000)

In [None]:
train_3000_df_combined_qs

In [None]:
test_17_df_combined_qs = pd.DataFrame(prepare_data(test_data_17_combined_qs, 0, 1, 8))

In [None]:
test_17_df_combined_qs.head(6)

In [None]:
test_18_df_combined_qs = pd.DataFrame(prepare_data(test_data_18_combined_qs, 0, 1, 8))

In [None]:
def divide_snippet(snippet, remaining_len, overlap_rate):
    """
    snippet: str, a textbook snippet from best matching results
    remaining_len: int, range: [0,1], available token length after minus ----
    overlap_rate: float, overlap rate between each snippet pieces
    
    return: a list of snippet pieces
    """
    snippet = str(snippet)
    snippet_len = len(snippet)
    assert remaining_len > 0, 'Remaining length <= 0'
    assert overlap_rate > 0 and overlap_rate <= 1, 'Overlap rate should be within [0,1]'
    
    # if snippet can fit into the remaining length
    if snippet_len <= remaining_len:
        return [snippet]

    # if snippet is too long 
    piece_size = int(remaining_len * (1-overlap_rate)) # floor
    piece_nums = snippet_len // piece_size + 1
    piece_nums_per_iteration = int(1/(1-overlap_rate))


    piece_iteration_nums = piece_nums - int(1/(1-overlap_rate)) + 1
 

    piece_list = []
    for c_i in range(piece_iteration_nums):
        piece_list.append(snippet[c_i*piece_size : c_i*piece_size + remaining_len - 1])
    return piece_list

def prepare_data_overlapped(df, question_index, first_choice_index, answer_index, max_length=512, overlap_rate=0.75):
    """
    question_index: int, specify the index of question column 
    first_choice_index: int, the index of the *first* multiple choice column among 5
    answer_index: int, index of answer column
    max_length: max_length after minusing [cls] and [sep] tokens, default: 512
    overlap_rate: float, overlap rate between each snippet pieces, default 0.75
    
    return:
    expanded
    """
    result = []
    
    # iterate through all entries in df
    for i in tqdm(range(df.shape[0])):
        one_entry = df.iloc[i,:] 

        # calculate length for question, 5 choices, snippet length
        q_len = len(one_entry['q'])
        c_len_list = [len(str(c)) for c in one_entry[one_entry.index.isin(['c1', 'c2', 'c3', 'c4', 'c5'])].tolist()]
        s1_len = len(str(one_entry['s1']))
        
        # for each entry, take its 5 choices in sequence into 5 [question, one_choice, label] outputs
        for j in range(5):
            choice_index = first_choice_index + j
            remaining_len = max_length - q_len - c_len_list[j] # remaining index for snippet

            label = 1 if (choice_index - first_choice_index + 1) == one_entry[answer_index] else 0
            snippet_list = divide_snippet(one_entry['s1'], remaining_len, overlap_rate)
 
            for k, snippet in enumerate(snippet_list):
                
                result.append({'question': one_entry[question_index], 
                       'choice': one_entry[choice_index], 
                       'label': label,
                       'snippet': snippet,
                       'qtype': df['qtype'][i],
                       'q_index': i,
                       'c_index': j,
                       's_index': k})
    return result

In [None]:
test_data_18_separate_qs

In [None]:
test_17_separate_overlapped = pd.DataFrame(prepare_data_overlapped(test_data_17_separate_qs,0,1,8, 
                                                                   max_length=509,
                                                                  overlap_rate=0.75))
test_17_separate_overlapped = test_17_separate_overlapped[['q_index',
                                                           'question',
                                                           'c_index',
                                                           'choice',
                                                           's_index',
                                                           'snippet',
                                                           'label', 
                                                           'qtype']]

In [None]:
test_18_separate_overlapped = pd.DataFrame(prepare_data_overlapped(test_data_18_separate_qs,0,1,8, max_length=509, 
                                                                  overlap_rate=0.75))
test_18_separate_overlapped = test_18_separate_overlapped[['q_index',
                                                           'question',
                                                           'c_index',
                                                           'choice',
                                                           's_index',
                                                           'snippet',
                                                           'label', 
                                                           'qtype']]

In [None]:
train_3000_separate_overlapped = pd.DataFrame(prepare_data_overlapped(train_data_3000_separate_qs,0,2,7, max_length=509,
                                                                     overlap_rate=0.75))
train_3000_separate_overlapped = train_3000_separate_overlapped[['q_index',
                                                           'question',
                                                           'c_index',
                                                           'choice',
                                                           's_index',
                                                           'snippet',
                                                           'label', 
                                                           'qtype']]


In [None]:
train_separate_overlapped = pd.DataFrame(prepare_data_overlapped(train_data_separate_qs,0,2,7, max_length=509,
                                                                     overlap_rate=0.75))
train_separate_overlapped = train_separate_overlapped[['q_index',
                                                           'question',
                                                           'c_index',
                                                           'choice',
                                                           's_index',
                                                           'snippet',
                                                           'label', 
                                                           'qtype']]


In [None]:
train_separate_overlapped.shape

### Save files to help calculating accuracy

In [None]:
# save the separate-overlapped files with to help calculating evaluation accuracy
test_17_separate_overlapped.to_excel(f'./data/_output_data/{TEST_17_NAME}.xlsx', index=None)
test_18_separate_overlapped.to_excel(f'./data/_output_data/{TEST_18_NAME}.xlsx', index=None)
train_3000_separate_overlapped.to_excel(f'./data/_output_data/{TRAIN_3000_NAME}.xlsx', index=None)

### shuffling  and combine question & snippet

In [None]:
def shuffle_df(df, frac=1, random_state=42):
    return df.sample(frac=1, random_state=42)

In [None]:
# combined 
# train_combine_qs_shuffled = shuffle_df(train_df_combined_qs)

In [None]:
def evaluate_length(df):
    total_length = [len(str(i)+str(j)+str(k)) for i, j, k in zip(df['question'], 
                                       df['snippet'], 
                                       df['choice'])]
    
    max_length = max(total_length)
    avg_length = sum(total_length) // len(total_length)
    assert max_length <= 512
    print(f'Max len: {max_length} <= 512!, Avg len: {avg_length} --UPDATE NEEDED if using over two [SEP]')

In [None]:
def combine_question_snippet(df):
    evaluate_length(df)
    
    combined_qs = [i+j for i, j in zip(df['question'], 
                                   df['snippet'])]
    df_combined = df.copy()
    df_combined['combined_qs'] = pd.Series(combined_qs)
    return df_combined

In [None]:
train_combined_overlapped = combine_question_snippet(train_separate_overlapped)

In [None]:
train_combined_overlapped_shuffled = shuffle_df(train_combined_overlapped)
train_combined_overlapped_shuffled.head(2)

In [None]:
test_17_combined_overlapped = combine_question_snippet(test_17_separate_overlapped)

In [None]:
test_18_combined_overlapped = combine_question_snippet(test_18_separate_overlapped)

In [None]:
train_3000_combined_overlapped = combine_question_snippet(train_3000_separate_overlapped)

In [None]:
def draw_histgram(df, column_name, kde=False):
    if column_name in df.columns:
        print(f'{column_name} length')
        data = [len(str(i)) for i in df[column_name]]
        sns.distplot(data, kde=kde, label='snippet length')

In [None]:
draw_histgram(train_combined_overlapped, 'question')

In [None]:
draw_histgram(train_combined_overlapped, 'combined_qs')

## Preparing for bert

In [None]:
def prepare_for_bert(df):
    return pd.DataFrame({
        'id':range(df.shape[0]),
        'label':df['label'],
        'alpha':['a']*df.shape[0],
        'text_a': df['combined_qs'].replace(r'\n', ' ', regex=True),
        'text_b': df['choice'].replace(r'\n', ' ', regex=True)
    })

In [None]:
#train_bert_combined_qs = prepare_for_bert(train_combine_qs_shuffled)
#train_bert_combined_qs.head()

In [None]:
#train_bert_combined_qs.to_csv('./data/combined_qs_title/train_combined_qs_title.tsv', sep='\t', index=False, header=False)

In [None]:
# first 3000 questions in training data
#train_3000_bert_combined_qs = prepare_for_bert(train_3000_df_combined_qs)
#train_3000_bert_combined_qs.to_csv('./data/combined_qs_title/train_3000_combined_qs_title.tsv', sep='\t', index=False, header=False)

In [None]:
# for training data
#dev_17_bert_combined_qs = prepare_for_bert(test_17_df_combined_qs)
#dev_18_bert_combined_qs = prepare_for_bert(test_18_df_combined_qs)

In [None]:
#dev_17_bert_combined_qs.head()

In [None]:
#dev_17_bert_combined_qs.to_csv('./data/combined_qs_title/dev_17_combined_qs_title.tsv', sep='\t', index=False, header=False)
#dev_18_bert_combined_qs.to_csv('./data/combined_qs_title/dev_18_combined_qs_title.tsv', sep='\t', index=False, header=False)


In [None]:
train_bert_combined_overlapped = prepare_for_bert(train_combined_overlapped_shuffled)
train_bert_combined_overlapped.head()
print(train_bert_combined_overlapped.shape)

In [None]:
train_bert_combined_overlapped.to_csv(f'./data/combined_overlapped/{TRAIN_NAME}.tsv', sep='\t', index=False, header=False)

In [None]:
test_17_bert_combined_overlapped = prepare_for_bert(test_17_combined_overlapped)
test_17_bert_combined_overlapped.to_csv(f'./data/combined_overlapped/{TEST_17_NAME}.tsv', sep='\t', index=False, header=False)

In [None]:
test_18_bert_combined_overlapped = prepare_for_bert(test_18_combined_overlapped)
test_18_bert_combined_overlapped.to_csv(f'./data/combined_overlapped/{TEST_18_NAME}.tsv', sep='\t', index=False, header=False)

In [None]:
train_3000_bert_combined_overlapped = prepare_for_bert(train_3000_combined_overlapped)
train_3000_bert_combined_overlapped.to_csv(f'./data/combined_overlapped/{TRAIN_3000_NAME}.tsv', sep='\t', index=False, header=False)

In [None]:

# make nested logits 
def check_nested_logits(nested_logits, original_df):
    result_struct = []
    result_logits = []
    for q, qq in enumerate(nested_logits):
        for c, cc in enumerate(qq):
            for s, ss in enumerate(cc):
                result_struct.append([q, c, s])
                result_logits.append(ss)
    correct_struct = original_df[['q_index','c_index','s_index']].values.tolist()
    correct_logits = original_df[['l_logits', 'r_logits']].values.tolist()

    assert correct_struct.__eq__(result_struct), 'Wrong structure of nested logits'
    assert correct_logits.__eq__(result_logits), 'Wrong logits value'
    print('Structure and Logits values are the same')
    
def nest_output_logits(df):

    nested_logits = []
    cc_cache = []
    qq_cache = []
    df_length = df.shape[0]
    qcs_indexes = df[['q_index', 'c_index', 's_index']].values
    assert qcs_indexes.dtype == 'int64', 'Wrong indexes dtype, should be int'
    input_logits = df[['l_logits', 'r_logits']].values
    assert input_logits.dtype == 'float64', 'Wrong indexes dtype, should be float64' ## CHANGE THIS
    
    for i in range(df_length - 1):
        q, c, s = qcs_indexes[i]
        # l_logit, r_logit = input_logits[i]
        #  logit = [l_logit, r_logit]
        logit = input_logits[i].tolist()
        nq, nc, ns = qcs_indexes[i+1]
        if c == nc: # same choice 
            cc_cache.append(logit)
        elif c != nc and q == nq: # not same choice, still in the same question
            cc_cache.append(logit)
            qq_cache.append(cc_cache)
            cc_cache = []
        else: # not same question
            cc_cache.append(logit)
            qq_cache.append(cc_cache)
            nested_logits.append(qq_cache)
            cc_cache = []
            qq_cache = []
        if i == df_length - 2:
            # last loop
            assert i + 1 == df_length - 1
            last_logit = input_logits[i+1].tolist()
            if c == nc:
                cc_cache.append(last_logit)
            else:
                cc_cache.append(last_logit)
            qq_cache.append(cc_cache)
            nested_logits.append(qq_cache)
            cc_cache = []
    check_nested_logits(nested_logits, df)
    return nested_logits

In [None]:
output_18 = pd.read_excel('outputs/test_18_75%.xlsx')

In [None]:
nested_logits_18 = nest_logits(output_18)

In [None]:
test = output_18.head(15)

In [None]:
test.head(1)

In [None]:
nested_test = nest_logits(test)

In [None]:
nested_test

### First method: avg in choice level

nested_logits -> choice_level_average_logits -> diff -> question_level_max -> label_list

In [None]:
for q, qq in enumerate(nested_test):
    for c, cc in enumerate(qq):
        print(cc)

In [None]:
def precision_recall_score(y_true, y_pred):
    tp, fp, tn, fn = 0, 0, 0, 0
    
    for i, true in enumerate(y_true):
        if true == y_pred[i]:
            if true == 1:
                tp += 1
            else: # == 0
                tn += 1
        else:
            if true == 1:
                fn += 1
            else:
                fp += 1
    true_false_values = {'tp:': tp,'tn:': tn, 'fp:': fp, 'fn:': fn}
    print(true_false_values)
    if tp == 0:
        precision = 0
        recall = 0
    else:
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    
    return {
        "one_entry_acc": round(accuracy, 3),
        "precision": round(precision, 3),
        "recall": round(recall, 3),
        "values": true_false_values
    }

In [None]:
correct_logits = output_18[['l_logits', 'r_logits']].values.tolist()

In [None]:
correct_logits

In [None]:
np.sum(np.argmax(correct_logits, axis=1) == 0)

In [None]:
one_entry_preds = np.argmax(correct_logits, axis=1)
one_entry_label = output_18['label'].values
precision_recall_score(one_entry_label, one_entry_preds)

In [None]:
def calc_avg_logits(cc):
    cc_np = np.array(cc)
    #print('hi')
    # assert cc_np.dtype == 'float64'
    avg_logits = np.mean(cc_np, axis=0)
    #print('hihii')
    print(avg_logits)
    return avg_logits

def qcs_pair_accuracy():
    pass
question_level_logits = []
for q, qq in enumerate(nested_logits_18):
    choice_level_logits = []
    for c, cc in enumerate(qq):
        print(cc)
        snippet_avg_logits = calc_avg_logits(cc)
        choice_level_logits.append(snippet_avg_logits)
    question_level_logits.extend(choice_level_logits)

In [None]:
choice_level_logits

In [None]:
np.argmax(question_level_logits, axis=1)

In [None]:
a = np.array(question_avg_logits)

In [None]:
np.argmax(a, axis=1)

In [None]:
o.mean(axis=0)

In [None]:
output_18[['q_index', 'c_index','label']].values

In [None]:
output_18[['q_index', 'c_index','label']].values.shape

In [None]:
unique_pairs = np.unique(output_18[['q_index', 'c_index', 'label']].values, axis=0)
question_level_label = [label for _, _, label in unique_pairs]

In [None]:
question_level_preds = np.argmax(question_level_logits, axis=1)

In [None]:
precision_recall_score(question_level_label, question_level_preds)

In [None]:
question_level_preds

In [None]:

# ! delete
def question_accuracy_old(raw_preds, out_label_ids):
    """
    raw_preds: the value from logits, namely preds before argmax
    out_label_ids: the [0, 1] value for each question-choice pair
    
    return: question accuracy
    """
    def accuracy(labels, preds):
        length = len(labels)
        correct_num = 0
        for i in range(length):
            if labels[i] == preds[i]:
                correct_num += 1
        return correct_num / length
    
    # find true labels for questions
    labels = []
    question_number = int(len(out_label_ids) / 5)
    for question in range(question_number):
        for choice in range(5):
            choice_index = 5*question + choice
            if out_label_ids[choice_index] == 1:
                labels.append(choice_index % 5 + 1)
                break
    
    
    # find predicted labels for questions 
    predicted_labels = []
    for question in range(question_number):
        # print('question number: ', question)
        temp = []
        for choice in range(5):
            # starting choice index: 5*question + choice
            # ending choice index: 5*question + choice
            choice_index = 5*question + choice
            cur_choice_preds = raw_preds[choice_index] # [0.81673616, -0.56396836]
            # print(preds[choice_index])
            temp.append(cur_choice_preds[0] - cur_choice_preds[1])

        result_index = np.argmin(temp)
        predicted_labels.append(result_index + 1)
    return accuracy(labels, predicted_labels)



def question_accuracy(qs_pair_label, qs_pair_logits):
    """
    qs_pair_logits: the value from logits, eg [0.288778692483902, -1.444930672645569]
    qs_pair_label: the [0, 1] value for each question-choice pair
    
    return: question accuracy
    """
    def accuracy(labels, preds):
        length = len(labels)
        correct_num = 0
        for i in range(length):
            if labels[i] == preds[i]:
                correct_num += 1
        return correct_num / length
    

    # find true labels for questions
    labels = []
    question_number = int(len(qs_pair_label) / 5)
    for question in range(question_number):
        for choice in range(5):
            choice_index = 5*question + choice
            if qs_pair_label[choice_index] == 1:
                labels.append(choice_index % 5 + 1)
                break
    # int(labels)
    
    # find predicted labels for questions 
    predicted_labels = []
    for question in range(question_number):
        # print('question number: ', question)
        temp = []
        for choice in range(5):
            # starting choice index: 5*question + choice
            # ending choice index: 5*question + choice
            choice_index = 5*question + choice
            cur_choice_preds = qs_pair_logits[choice_index] # [0.81673616, -0.56396836]
            # print(preds[choice_index])
            temp.append(cur_choice_preds[0] - cur_choice_preds[1])

        result_index = np.argmin(temp)
        predicted_labels.append(result_index + 1)
    # ? MAYBE CAN WRITE SOME TESTING METHOD
    print('predicted_labels: ')
    print(predicted_labels)
    print('labels ---')
    print(labels)
    return accuracy(labels, predicted_labels)

In [None]:
def evaluate_one_entry_score(output_df):
    """
    output_df: dataframe, output with logits 
    
    return: precision, recall, question_choice_pair accuracy and old accuracy 
    """
    # Evaluate one entry acc, precision, recall
    one_entry_logits = output_df[['l_logits', 'r_logits']].values.tolist()
    one_entry_preds = np.argmax(one_entry_logits, axis=1)
    one_entry_label = output_df['label'].values
    result = precision_recall_score(one_entry_label, one_entry_preds)
    # ! FIXME: update needed 
    result['old_question_acc'] = question_accuracy_old(one_entry_logits, one_entry_label)
    return result


def evaluate_question_score(output_df, nested_logits):
    # first method - question level logits
    """
    output_df: dataframe, output with logits 
    nested_logits: list

    return: question level accuracy 
    """
    question_level_logits = first_method_avg(nested_logits)

    # question level true labels
    unique_pairs = np.unique(output_df[['q_index', 'c_index', 'label']].values, axis=0)
    question_level_label = [label for _, _, label in unique_pairs]
    return question_accuracy(question_level_label, question_level_logits)



In [None]:
evaluate_one_entry_score(output_18)

## Second method avg

In [None]:
# first method 
def calc_avg_logits(cc):
    cc_np = np.array(cc)
    #print('hi')
    # assert cc_np.dtype == 'float64'
    avg_logits = np.mean(cc_np, axis=0)
    #print('hihii')
    # print(avg_logits)
    return avg_logits

def second_method_avg(nested_logits):
    """
    nested_logits: list, nested logits grouped by question and choice level 
    Average logits for all snippet pieces in a question-choice pair 
    """
    question_level_logits = []
    for q, qq in enumerate(nested_logits):
        choice_level_logits = []
        for c, cc in enumerate(qq):
            # print(cc)
            snippet_avg_logits = calc_avg_logits(cc)
            choice_level_logits.append(snippet_avg_logits)
        question_level_logits.extend(choice_level_logits)
    return question_level_logits


In [None]:
evaluate_question_score(output_18, nested_logits_18)

In [None]:
nested_test

In [None]:
test = output_18.head(35)
nested_test = nest_output_logits(test)

## First method max, no need to have 

In [None]:
def first_method_max(nested_logits):
    """
    return: predicted labels in [1,5]
    """
    predicted_label = []


    for q, qq in enumerate(nested_logits):
        max_diff = -2**16 # correctness
        label = -1

        for c, cc in enumerate(qq):
            #print('choice:', c)
            #print(cc)
            for s in cc:
                # print(s)
                cur_diff = s[1] - s[0]
                # print(cur_diff)
                if cur_diff > max_diff:
                    max_diff = cur_diff
                    label = c + 1
                # cur_diff = ss[s][1] - ss[s][0]
                # print(cur_diff)

            # snippet_avg_logits = calc_avg_logits(cc)
        predicted_label.append(label)
        # print('label is:', label)
    print(predicted_label)
    return predicted_label

def accuracy(labels, preds):
    length = len(labels)
    correct_num = 0
    for i in range(length):
        if labels[i] == preds[i]:
            correct_num += 1
    return correct_num / length

In [None]:
predicted_18 = first_method_max(nested_logits_18)

In [None]:
labels = [2, 3, 2, 5, 3, 3, 5, 1, 2, 4, 4, 1, 4, 1, 1, 3, 2, 3, 3, 2, 1, 1, 4, 4, 3, 4, 5, 2, 1, 3, 2, 1, 3, 2, 5, 4, 1, 5, 3, 4, 3, 1, 3, 5, 3, 1, 3, 4, 4, 2, 4, 2, 3, 3, 4, 5, 2, 2, 3, 3, 4, 2, 4, 1, 1, 5, 5, 2, 2, 2, 1, 5, 1, 2, 2, 1, 2, 1, 1, 3, 2, 4, 4, 2, 4, 2, 4, 2, 1, 4, 2, 5, 1, 5, 4, 1, 3, 3, 5, 1, 2, 1, 1, 3, 2, 3, 1, 1, 5, 3, 5, 2, 5, 5, 1, 5, 2, 5, 3, 4, 1, 4, 3, 2, 1, 2, 2, 5, 1, 4, 3, 5, 3, 1, 5, 3, 1, 5, 4, 2, 3, 2, 2, 4, 5, 4, 3, 2, 2, 1, 3, 3, 5, 5, 2, 4, 3, 3, 4, 5, 5, 4, 3, 2, 4, 2, 4, 5, 3, 3, 5, 1, 4, 2, 2, 2, 3, 1, 4, 1, 4, 3, 3, 5, 5, 4, 1, 2, 3, 3, 1, 5, 5, 5, 4, 2, 4, 2, 2, 4, 5, 5, 2, 5, 2, 1, 5, 4, 2, 4, 1, 3, 4, 2, 1, 3, 3, 1, 1, 1, 5, 2, 1, 4, 1, 1, 5, 1, 1, 4, 2, 5, 2, 3, 4, 5, 3, 4, 3, 4, 3, 3, 3, 5, 3, 3, 4, 5, 5, 2, 1, 3, 1, 5, 1, 1, 4, 5, 4, 4, 5, 5, 3, 5, 1, 4, 4, 3, 2, 3, 5, 4, 4, 4, 4, 5, 3, 1, 1, 5, 4, 1, 3, 3, 1, 2, 4, 2, 1, 3, 1, 5, 5, 3, 1, 4, 4, 1, 5, 4, 4, 5, 1, 4, 1, 1, 1, 4, 3, 4, 2, 4, 3, 4, 2, 5, 3, 4, 1, 1, 5, 5, 4, 1, 1, 1, 1, 3, 2, 5, 2, 2, 3, 3, 2, 2, 1, 4, 5, 5, 4, 4, 3, 4, 2, 1, 5, 4, 1, 2, 5, 1, 1, 2, 5, 1, 1, 4, 3, 4, 1, 5, 4, 2, 4, 4, 3, 2, 1, 4, 4, 2, 3, 3, 5, 1, 1, 4, 3, 1, 2, 4, 5, 3, 2, 3, 4, 2, 5, 3, 5, 2, 4, 1, 2, 2, 1, 4, 4, 4, 5, 1, 3, 4, 4, 1, 4, 4, 2, 1, 2, 3, 1, 1, 1, 5, 4, 2, 2, 3, 2, 3, 4, 1, 5, 5, 5, 4, 2, 4, 4, 1, 4, 1, 2, 1, 2, 2, 4, 3, 4, 3, 2, 5, 4, 2, 1, 2, 3, 2, 2, 4, 1, 3, 3, 3, 3, 3, 3, 3, 5, 1, 3, 5, 2, 2, 2, 2, 2, 5, 3, 2, 4, 5, 5, 2, 1, 2, 3, 3, 3, 2, 5, 5, 2, 4, 5, 5, 4, 3, 1, 5, 4, 2, 1, 1, 3, 5, 2, 4, 5, 4, 3, 1, 3, 5, 4, 1, 1, 4, 2, 3, 4, 4, 3, 2, 3, 1, 3, 3, 1, 2, 1, 2, 2, 1, 3, 2, 5, 3, 1, 5, 4, 4, 2, 4, 1, 1, 5, 1, 1, 2, 2, 3, 5, 1, 4, 2, 4, 5, 4, 1, 5, 5, 5, 2, 4, 2, 2, 3, 3, 4, 3, 5, 4, 3, 4, 5, 1, 5, 4, 4, 2, 2, 2, 4, 1, 3, 2, 5, 2, 3, 2, 2, 2, 2, 5, 4, 5, 3, 4, 3, 5, 5, 3, 4, 4, 3, 1]

In [None]:
accuracy(labels, predicted_18)

## Third method - only correct

In [None]:
test = output_18.head(20)
nested_test = nest_output_logits(test)
nested_test

In [None]:
def snippet_avg_helper(qcs_level_diff):
    """
    qcs_level_diff: difference between the right logit and left logit for each choice snippet. This mesures 
    the correctness of the question_choice_snippet. 
    
    correctness <= 0, wrong 
    correctness > 0, correct 
    
    return: label in [1,5] for a question 
    """
    qc_level_avg = []
    for c, cc in enumerate(qcs_level_diff):
        correct_pair_nums = len([s for s in cc if s > 0])
        wrong_pair_nums = len([s for s in cc if s <= 0])
        
        correct_pair_avg = 0 if correct_pair_nums == 0 else sum([s for s in cc if s > 0]) / correct_pair_nums
        wrong_pair_avg = 0 if wrong_pair_nums == 0 else sum([s for s in cc if s <= 0]) / wrong_pair_nums
        print(correct_pair_avg, wrong_pair_avg)
        
        # Method FOUR: both correct and wrong 
        correct_and_wrong_avg = (correct_pair_avg + wrong_pair_avg) / 2
        print(correct_and_wrong_avg)
        qc_level_avg.append(correct_and_wrong_avg)
    print('Avg for a5 choices in a question', qc_level_avg)
    print(qc_level_avg.index(max(qc_level_avg)))
    return qc_level_avg.index(max(qc_level_avg)) + 1
        
        
q_labels_preds = []
for q in nested_logits_18:
    # print('question:::----', q)
    choice_level = []
    for c in q:
        print('choice----', c)
        snippet_level = []
        for s in c:
            # print('snippet--', s)
            diff = s[1] - s[0]
            # print(diff)
            snippet_level.append(diff)
        #print(snippet_level)
        # print('')
        choice_level.append(snippet_level)
    
    print('')
    print(choice_level)
    q_label = snippet_avg_helper(choice_level)
    q_labels_preds.append(q_label)
    print('')
    print('')
print(q_labels_preds)

In [None]:
accuracy(labels,q_labels_preds)

In [None]:
a = [[-2.6316429376602177, -2.615735948085785, -2.1325479149818416], [-1.486147284507751, -1.144922822713852, -1.0900460630655289], [-0.463565394282341, -0.09689867496490481, -0.7583723440766335], [-2.258228838443756, -2.407961905002594, -1.6885483264923093], [-2.1981219649314876, -2.1288265585899353, -1.8730788230896]]

In [None]:
a

In [None]:
aaa = [[2.6316429376602177, 1.2, -2.1325479149818416],
 [-1.486147284507751, 10.144922822713852, 1.0900460630655289],
 [-0.463565394282341, -0.09689867496490481, -0.7583723440766335],
 [-2.258228838443756, -2.407961905002594, -1.6885483264923093],
 [-2.1981219649314876, -2.1288265585899353, -1.8730788230896]]

In [None]:
snippet_avg_helper(aaa)

In [None]:
def snippet_avg_helper(choices):
    for c, cc in enumerate(choices):
        correct_nums = len([s for s in cc if s > 0])
        wrong_nums = len([s for s in cc if s <= 0])
        
        correct_avg = 0 if correct_nums == 0 else sum([s for s in cc if s > 0]) / correct_nums
        wrong_avg = 0 if wrong_nums == 0 else sum([s for s in cc if s <= 0]) / wrong_nums
        print(correct_avg, wrong_avg)
        

snippet_avg_helper(aaa)

In [None]:
def snippet_avg_helper_correct_only(qcs_level_diff, only_correct=True):
    """
    qcs_level_diff: difference between the right logit and left logit for each choice snippet. This mesures 
    the correctness of the question_choice_snippet. 
    
    correctness <= 0, wrong 
    correctness > 0, correct 
    
    return: label in [1,5] for a question 
    """
    qc_level_avg = []
    for c, cc in enumerate(qcs_level_diff):
        correct_pair_nums = len([s for s in cc if s > 0])
        wrong_pair_nums = len([s for s in cc if s <= 0])
        
        correct_pair_avg = 0 if correct_pair_nums == 0 else sum([s for s in cc if s > 0]) / correct_pair_nums
        wrong_pair_avg = 0 if wrong_pair_nums == 0 else sum([s for s in cc if s <= 0]) / wrong_pair_nums
        print(correct_pair_avg, wrong_pair_avg)
        if only_correct:
            correct_avg = correct_pair_avg if correct_pair_nums != 0 else wrong_pair_avg
        # Method FOUR: both correct and wrong 
        # correct_and_wrong_avg = (correct_pair_avg + wrong_pair_avg) / 2
        print(correct_avg)
        qc_level_avg.append(correct_avg)
    print('Avg for a5 choices in a question', qc_level_avg)
    print(qc_level_avg.index(max(qc_level_avg)))
    return qc_level_avg.index(max(qc_level_avg)) + 1

In [None]:
snippet_avg_helper(aaa)

In [None]:
snippet_avg_helper_correct_only(aaa)

In [None]:
aaa