In [1]:
import gzip
import bz2
import lzma
# from PIL.PngImagePlugin import getchunks
# from PIL import Image
import sys
from tqdm import tqdm
import torch.nn.functional as F


import io


class DefaultCompressor:
    """for non-neural-based compressor"""
    def __init__(self, compressor, typ='text'):
        if compressor == 'gzip':
            self.compressor = gzip
        elif compressor == 'bz2':
            self.compressor = bz2
        elif compressor == 'lzma':
            self.compressor = lzma
        else:
            raise RuntimeError("Unsupported compressor")
        self.type = typ
    def get_compressed_len(self, x):
        if self.type == 'text':
            return len(self.compressor.compress(x.encode('utf-8')))
        else:
            return len(self.compressor.compress(np.array(x).tobytes()))
    def get_bits_per_char(self, original_fn):
        with open(original_fn) as fo:
            data = fo.read()
            compressed_str = self.compressor.compress(data.encode('utf-8'))
            return len(compressed_str)*8/len(data)

In [2]:
import os
import torch
import numpy as np
import statistics
import operator
from collections import Counter, defaultdict
from tqdm import tqdm
import random
from functools import partial
from itertools import repeat
from copy import deepcopy
from statistics import mode
import pickle
from sklearn.metrics.cluster import adjusted_rand_score, normalized_mutual_info_score


class KnnExpText:
    def __init__(self, agg_f, comp, dis):
        self.aggregation_func = agg_f
        self.compressor = comp
        self.distance_func = dis
        self.dis_matrix = []

    def calc_dis(self, data, train_data=None, fast=False):
        if train_data is not None:
            data_to_compare = train_data
        else:
            data_to_compare = data
        for i, t1 in tqdm(enumerate(data)):
            distance4i = []
            if fast:
                t1_compressed = self.compressor.get_compressed_len_fast(t1)
            else:
                t1_compressed = self.compressor.get_compressed_len(t1)
            for j, t2 in enumerate(data_to_compare):
                if fast:
                    t2_compressed = self.compressor.get_compressed_len_fast(t2)
                    t1t2_compressed = self.compressor.get_compressed_len_fast(self.aggregation_func(t1,t2))
                else:
                    t2_compressed = self.compressor.get_compressed_len(t2)
                    t1t2_compressed = self.compressor.get_compressed_len(self.aggregation_func(t1, t2))
                distance = self.distance_func(t1_compressed, t2_compressed, t1t2_compressed)
                distance4i.append(distance)
            self.dis_matrix.append(distance4i)

    def calc_dis_with_single_compressed_given(self, data, data_len=None, train_data=None):
        if train_data is not None:
            data_to_compare = train_data
        else:
            data_to_compare = data
        for i, t1 in tqdm(enumerate(data)):
            distance4i = []
            t1_compressed = self.compressor.get_compressed_len_given_prob(t1, data_len[i])
            for j, t2 in tqdm(enumerate(data_to_compare)):
                t2_compressed = self.compressor.get_compressed_len_given_prob(t2, data_len[j])
                t1t2_compressed = self.compressor.get_compressed_len(self.aggregation_func(t1, t2))
                distance = self.distance_func(t1_compressed, t2_compressed, t1t2_compressed)
                distance4i.append(distance)
            self.dis_matrix.append(distance4i)

    def calc_dis_single(self, t1, t2):
        t1_compressed = self.compressor.get_compressed_len(t1)
        t2_compressed = self.compressor.get_compressed_len(t2)
        t1t2_compressed = self.compressor.get_compressed_len(self.aggregation_func(t1, t2))
        distance = self.distance_func(t1_compressed, t2_compressed, t1t2_compressed)
        return distance
    def calc_dis_single_multi(self, train_data, datum):
        distance4i = []
        t1_compressed = self.compressor.get_compressed_len(datum)
        for j, t2 in tqdm(enumerate(train_data)):
            t2_compressed = self.compressor.get_compressed_len(t2)
            t1t2_compressed = self.compressor.get_compressed_len(self.aggregation_func(datum, t2))
            distance = self.distance_func(t1_compressed, t2_compressed, t1t2_compressed)
            distance4i.append(distance)
        return distance4i
    def calc_dis_with_vector(self, data, train_data=None):
        if train_data is not None:
            data_to_compare = train_data
        else:
            data_to_compare = data
        for i, t1 in tqdm(enumerate(data)):
            distance4i = []
            for j, t2 in enumerate(data_to_compare):
                distance = self.distance_func(t1, t2)
                distance4i.append(distance)
            self.dis_matrix.append(distance4i)
    def calc_acc(self, k, label, train_label=None, provided_distance_matrix=None, rand=False):
        if provided_distance_matrix is not None:
            self.dis_matrix = provided_distance_matrix
        correct = []
        pred = []
        if train_label is not None:
            compare_label = train_label
            start = 0
            end = k
        else:
            compare_label = label
            start = 1
            end = k+1
        for i in range(len(self.dis_matrix)):
            sorted_idx = np.argsort(np.array(self.dis_matrix[i]))
            pred_labels = defaultdict(int)
            for j in range(start, end):
                pred_l = compare_label[sorted_idx[j]]
                pred_labels[pred_l] += 1
            sorted_pred_lab = sorted(pred_labels.items(), key=operator.itemgetter(1), reverse=True)
            most_count = sorted_pred_lab[0][1]
            if_right = 0
            most_label = sorted_pred_lab[0][0]
            most_voted_labels = []
            for pair in sorted_pred_lab:
                if pair[1] < most_count:
                    break
                if not rand:
                    if pair[0] == label[i]:
                        if_right = 1
                        most_label = pair[0]
                else:
                    most_voted_labels.append(pair[0])
            if rand:
                most_label = random.choice(most_voted_labels)
                if_right = 1 if most_label==label[i] else 0
            pred.append(most_label)
            correct.append(if_right)
        print("Accuracy is {}".format(sum(correct)/len(correct)))
        return pred, correct
    def combine_dis_acc(self, k, data, label, train_data=None, train_label=None):
        correct = []
        pred = []
        if train_label is not None:
            compare_label = train_label
            start = 0
            end = k
        else:
            compare_label = label
            start = 1
            end = k+1
        if train_data is not None:
            data_to_compare = train_data
        else:
            data_to_compare = data
        for i, t1 in tqdm(enumerate(data)):
            distance4i = self.calc_dis_single_multi(data_to_compare, t1)
            sorted_idx = np.argsort(np.array(distance4i))
            pred_labels = defaultdict(int)
            for j in range(start, end):
                pred_l = compare_label[sorted_idx[j]]
                pred_labels[pred_l] += 1
            sorted_pred_lab = sorted(pred_labels.items(), key=operator.itemgetter(1), reverse=True)
            most_count = sorted_pred_lab[0][1]
            if_right = 0
            most_label = sorted_pred_lab[0][0]
            for pair in sorted_pred_lab:
                if pair[1] < most_count:
                    break
                if pair[0] == label[i]:
                    if_right = 1
                    most_label = pair[0]
            pred.append(most_label)
            correct.append(if_right)
        print("Accuracy is {}".format(sum(correct) / len(correct)))
        return pred, correct

    def combine_dis_acc_single(self, k, train_data, train_label, datum, label):
        # Support multi processing - must provide train data and train label
        distance4i = self.calc_dis_single_multi(train_data, datum)
        sorted_idx = np.argpartition(np.array(distance4i), range(k))
        pred_labels = defaultdict(int)
        for j in range(k):
            pred_l = train_label[sorted_idx[j]]
            pred_labels[pred_l] += 1
        sorted_pred_lab = sorted(pred_labels.items(), key=operator.itemgetter(1), reverse=True)
        most_count = sorted_pred_lab[0][1]
        if_right = 0
        most_label = sorted_pred_lab[0][0]
        prob = sorted_pred_lab[0][0]
        for pair in sorted_pred_lab:
            if pair[1] < most_count:
                break
            if pair[0] == label:
                if_right = 1
                most_label = pair[0]
        pred=most_label
        correct=if_right
        return pred, correct

In [3]:
from scipy.spatial.distance import cosine
import numpy as np
import torch
import scipy

class ToInt:
    def __call__(self, pic):
        return pic * 255

def NCD(c1, c2, c12):
    dis = (c12-min(c1,c2))/max(c1, c2)
    return dis

def CLM(c1, c2, c12):
    dis = 1 - (c1+c2-c12)/c12
    return dis

def CDM(c1, c2, c12):
    dis = c12/(c1+c2)
    return dis

def MSE(v1, v2):
    return np.sum((v1-v2)**2)/len(v1)

def agg_by_concat_space(t1, t2):
    return t1+' '+t2

def agg_by_jag_word(t1, t2, trunc=True):
    t1_list = t1.split(' ')
    t2_list = t2.split(' ')
    comb = []
    l = min([len(t1_list), len(t2_list)])
    for i in range(0,l-1,2):
        comb.append(t1_list[i])
        comb.append(t2_list[i+1])
    if len(t1_list) > len(t2_list):
        comb += t1_list[i:]
    return ' '.join(comb)

def agg_by_jag_char(t1, t2, trunc=True):
    t1_list = list(t1)
    t2_list = list(t2)
    comb = []
    l = min([len(t1_list), len(t2_list)])
    for i in range(0,l-1,2):
        comb.append(t1_list[i])
        comb.append(t2_list[i+1])
    if len(t1_list) > len(t2_list):
        comb += t1_list[i:]
    return ''.join(comb)

def agg_by_avg(i1, i2):
    return torch.div(i1+i2, 2, rounding_mode='trunc')

def agg_by_min_or_max(i1, i2, func_n):
    stacked = torch.stack([i1, i2], axis=0)
    if func_n == 'min':
        return torch.min(stacked, axis=0)[0]
    else:
        return torch.max(stacked, axis=0)[0]

def agg_by_stack(i1, i2):
    return torch.stack([i1, i2])

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, h

In [4]:
from functools import partial
from pathos.multiprocessing import ProcessingPool as Pool
import time
# np.random.seed(6)

def non_neural_knn_exp(compressor_name, test_data, test_label, train_data, train_label, agg_func, dis_func, k, para=True):
    print("KNN with compressor={}".format(compressor_name))
    cp = DefaultCompressor(compressor_name)
    knn_exp_ins = KnnExpText(agg_func, cp, dis_func)
    start = time.time()
    if para:
        with Pool(5) as p:
            pred_correct_pair = p.map(partial(knn_exp_ins.combine_dis_acc_single, k, train_data, train_label), test_data, test_label)
        print('accuracy:{}'.format(np.average(np.array(pred_correct_pair, dtype=np.int32)[:,1])))
        # print('accuracy:{}'.format(np.average(np.array(pred_correct_pair, dtype=np.object_)[:, 1])))
    else:
        knn_exp_ins.calc_dis(test_data, train_data=train_data)
        knn_exp_ins.calc_acc(k, test_label, train_label=train_label)
    print("spent: {}".format(time.time() - start))

def record_distance(compressor_name, test_data, test_portion_name, train_data, agg_func, dis_func, out_dir, para=True):
    print("compressor={}".format(compressor_name))
#     numpy_dir = os.path.join(out_dir, compressor_name)
#     if not os.path.exists(numpy_dir):
#         os.makedirs(numpy_dir)
#     out_fn = os.path.join(numpy_dir, test_portion_name)
    cp = DefaultCompressor(compressor_name)
    knn_exp = KnnExpText(agg_func, cp, dis_func)
    start = time.time()
    if para:
        with Pool(6) as p:
            distance_for_selected_test = p.map(partial(knn_exp.calc_dis_single_multi, train_data), test_data)
#         np.save(out_fn, np.array(distance_for_selected_test))
        del distance_for_selected_test
    else:
        knn_exp.calc_dis(test_data, train_data=train_data)
#         np.save(out_fn, np.array(knn_exp.dis_matrix))
    print("spent: {}".format(time.time() - start))
    return np.array(knn_exp.dis_matrix)


def non_neurl_knn_exp_given_dis(dis_matrix, k, test_label, train_label):
    knn_exp = KnnExpText(None, None, None)
    pred, correct = knn_exp.calc_acc(k, test_label, train_label=train_label, provided_distance_matrix=dis_matrix)
    return pred, correct

In [5]:
# We'll create a dictionary to convert option names (A, B, C, D, E) into indices and back again
options = 'ABCDE'
indices = list(range(5))

option_to_index = {option: index for option, index in zip(options, indices)}
index_to_option = {index: option for option, index in zip(options, indices)}

def preprocess(example):
    # The AutoModelForMultipleChoice class expects a set of question/answer pairs
    # so we'll copy our question 5 times before tokenizing
    
    first_sentence = [example['prompt']] * 5
    second_sentence = []
    for option in options:
        second_sentence.append(example[option])
    # Our tokenizer will turn our text into token IDs BERT can understand
    df['text'] = second_sentence
    df['answer'] = option_to_index[example['answer']]
    return df

In [6]:
import pandas as pd
train = pd.read_csv("data/ScienceQA/train.csv")
df = pd.DataFrame()
first = []
secend = []
answer = []
for index, data in train.iterrows():
    for j in options:
        first.append(data['prompt'])
        secend.append(data[j])
        if j == data['answer']:
            answer.append(1)
        else:
            answer.append(0)
            
df['first'] = first
df['secend'] = secend
df['answer'] = answer

In [7]:
test = pd.read_csv("data/ScienceQA/test.csv")

df2 = pd.DataFrame()
first = []
secend = []
answer = []
id_ = []
for index, data in test.iterrows():
    for j in options:
        id_.append(index)
        first.append(data['prompt'])
        secend.append(data[j])
        answer.append(0)
df2['id'] = id_
df2['first'] = first
df2['secend'] = secend
df2['answer'] = answer

In [8]:

# predicts = np.zeros((test.shape[0], 5))
# for i, data in test.iterrows():
#     dis_matrix = record_distance('gzip', data['prompt'], 'test_dis', np.array(data[['A', 'B', 'C', 'D', 'E']]), agg_by_jag_char, NCD, '/kaggle/working/', para=False)
#     predicts[i, :] = dis_matrix[0, :]

predicts = np.zeros((train.shape[0], 5))
for i, data in train.iterrows():
    dis_matrix = record_distance('gzip', data['prompt'], 'train_dis', np.array(data[['A', 'B', 'C', 'D', 'E']]), agg_by_jag_char, NCD, '/kaggle/working/', para=False)
    predicts[i, :] = dis_matrix[0, :]

compressor=gzip


175it [00:00, 7761.97it/s]


spent: 0.024910926818847656
compressor=gzip


92it [00:00, 7448.77it/s]


spent: 0.013438940048217773
compressor=gzip


107it [00:00, 7199.08it/s]


spent: 0.01615619659423828
compressor=gzip


91it [00:00, 7575.15it/s]


spent: 0.013070821762084961
compressor=gzip


182it [00:00, 7733.39it/s]


spent: 0.024556875228881836
compressor=gzip


154it [00:00, 6693.08it/s]


spent: 0.02410125732421875
compressor=gzip


98it [00:00, 7753.90it/s]


spent: 0.013774871826171875
compressor=gzip


126it [00:00, 7490.68it/s]


spent: 0.017940044403076172
compressor=gzip


126it [00:00, 10000.80it/s]


spent: 0.0137176513671875
compressor=gzip


46it [00:00, 7691.98it/s]


spent: 0.007005453109741211
compressor=gzip


79it [00:00, 7146.71it/s]


spent: 0.012115716934204102
compressor=gzip


85it [00:00, 7606.32it/s]


spent: 0.012273073196411133
compressor=gzip


277it [00:00, 7398.54it/s]


spent: 0.03855133056640625
compressor=gzip


24it [00:00, 6843.19it/s]


spent: 0.004607200622558594
compressor=gzip


85it [00:00, 7490.30it/s]


spent: 0.012416601181030273
compressor=gzip


38it [00:00, 8256.93it/s]


spent: 0.005627870559692383
compressor=gzip


67it [00:00, 8880.90it/s]


spent: 0.008536100387573242
compressor=gzip


29it [00:00, 7635.58it/s]


spent: 0.004951953887939453
compressor=gzip


77it [00:00, 8095.29it/s]


spent: 0.010487079620361328
compressor=gzip


44it [00:00, 6889.77it/s]


spent: 0.00734400749206543
compressor=gzip


124it [00:00, 8451.17it/s]


spent: 0.015665054321289062
compressor=gzip


53it [00:00, 8152.34it/s]


spent: 0.007460594177246094
compressor=gzip


40it [00:00, 7635.38it/s]


spent: 0.00621795654296875
compressor=gzip


48it [00:00, 7996.77it/s]


spent: 0.006949663162231445
compressor=gzip


60it [00:00, 9479.01it/s]


spent: 0.007283449172973633
compressor=gzip


50it [00:00, 7765.79it/s]


spent: 0.007394552230834961
compressor=gzip


93it [00:00, 7850.87it/s]


spent: 0.012807369232177734
compressor=gzip


45it [00:00, 7478.55it/s]


spent: 0.006960153579711914
compressor=gzip


106it [00:00, 6287.69it/s]


spent: 0.01782512664794922
compressor=gzip


53it [00:00, 8185.36it/s]


spent: 0.0076885223388671875
compressor=gzip


26it [00:00, 6800.02it/s]


spent: 0.004815101623535156
compressor=gzip


91it [00:00, 9537.75it/s]


spent: 0.010471105575561523
compressor=gzip


41it [00:00, 6855.35it/s]


spent: 0.006977081298828125
compressor=gzip


80it [00:00, 7378.49it/s]


spent: 0.011872529983520508
compressor=gzip


28it [00:00, 7120.63it/s]


spent: 0.0050280094146728516
compressor=gzip


55it [00:00, 8263.31it/s]


spent: 0.007643461227416992
compressor=gzip


66it [00:00, 6652.35it/s]


spent: 0.010897397994995117
compressor=gzip


111it [00:00, 8581.74it/s]


spent: 0.01398921012878418
compressor=gzip


32it [00:00, 7756.01it/s]


spent: 0.0051190853118896484
compressor=gzip


44it [00:00, 7875.62it/s]


spent: 0.006547212600708008
compressor=gzip


38it [00:00, 7112.80it/s]


spent: 0.006325960159301758
compressor=gzip


50it [00:00, 8234.78it/s]


spent: 0.007044076919555664
compressor=gzip


106it [00:00, 9763.83it/s]


spent: 0.011838912963867188
compressor=gzip


52it [00:00, 6745.13it/s]


spent: 0.008678436279296875
compressor=gzip


28it [00:00, 5857.67it/s]


spent: 0.005747079849243164
compressor=gzip


28it [00:00, 6692.53it/s]


spent: 0.0051991939544677734
compressor=gzip


157it [00:00, 9470.95it/s]


spent: 0.017588138580322266
compressor=gzip


116it [00:00, 8950.48it/s]


spent: 0.014070749282836914
compressor=gzip


72it [00:00, 7702.45it/s]


spent: 0.010323286056518555
compressor=gzip


21it [00:00, 6460.82it/s]


spent: 0.004235506057739258
compressor=gzip


55it [00:00, 7944.58it/s]


spent: 0.007894277572631836
compressor=gzip


160it [00:00, 9629.63it/s]


spent: 0.01760077476501465
compressor=gzip


20it [00:00, 7026.81it/s]


spent: 0.003807544708251953
compressor=gzip


71it [00:00, 9058.15it/s]


spent: 0.008807182312011719
compressor=gzip


58it [00:00, 9279.79it/s]


spent: 0.007232189178466797
compressor=gzip


55it [00:00, 8076.42it/s]


spent: 0.0077893733978271484
compressor=gzip


27it [00:00, 8081.51it/s]


spent: 0.0043141841888427734
compressor=gzip


58it [00:00, 7987.31it/s]


spent: 0.008227348327636719
compressor=gzip


23it [00:00, 7168.15it/s]


spent: 0.004160165786743164
compressor=gzip


37it [00:00, 7473.24it/s]


spent: 0.005906105041503906
compressor=gzip


206it [00:00, 8632.24it/s]


spent: 0.024985551834106445
compressor=gzip


172it [00:00, 7880.84it/s]


spent: 0.02275824546813965
compressor=gzip


35it [00:00, 8164.66it/s]


spent: 0.005511760711669922
compressor=gzip


37it [00:00, 7934.82it/s]


spent: 0.005632877349853516
compressor=gzip


142it [00:00, 9882.05it/s]


spent: 0.015326976776123047
compressor=gzip


82it [00:00, 9546.00it/s]


spent: 0.009542703628540039
compressor=gzip


54it [00:00, 8016.30it/s]


spent: 0.007693290710449219
compressor=gzip


69it [00:00, 8775.76it/s]


spent: 0.00882411003112793
compressor=gzip


47it [00:00, 6770.12it/s]


spent: 0.007920265197753906
compressor=gzip


37it [00:00, 7815.74it/s]


spent: 0.005698680877685547
compressor=gzip


58it [00:00, 8710.60it/s]


spent: 0.007608652114868164
compressor=gzip


90it [00:00, 7674.38it/s]


spent: 0.012690544128417969
compressor=gzip


21it [00:00, 8237.20it/s]


spent: 0.0035066604614257812
compressor=gzip


123it [00:00, 7756.95it/s]


spent: 0.01689600944519043
compressor=gzip


50it [00:00, 9065.24it/s]


spent: 0.006413698196411133
compressor=gzip


72it [00:00, 9388.19it/s]


spent: 0.00867152214050293
compressor=gzip


102it [00:00, 8760.50it/s]


spent: 0.012816190719604492
compressor=gzip


144it [00:00, 9253.98it/s]


spent: 0.016741037368774414
compressor=gzip


25it [00:00, 7330.14it/s]


spent: 0.004353523254394531
compressor=gzip


45it [00:00, 8077.36it/s]


spent: 0.006543874740600586
compressor=gzip


57it [00:00, 8263.92it/s]


spent: 0.007855653762817383
compressor=gzip


173it [00:00, 9853.40it/s]


spent: 0.01848578453063965
compressor=gzip


82it [00:00, 8043.33it/s]


spent: 0.011152982711791992
compressor=gzip


50it [00:00, 8319.06it/s]


spent: 0.0069735050201416016
compressor=gzip


27it [00:00, 7944.87it/s]


spent: 0.0043828487396240234
compressor=gzip


153it [00:00, 8302.01it/s]


spent: 0.019387245178222656
compressor=gzip


88it [00:00, 8665.92it/s]


spent: 0.011101007461547852
compressor=gzip


131it [00:00, 9774.85it/s]


spent: 0.014373064041137695
compressor=gzip


39it [00:00, 7438.74it/s]


spent: 0.006205320358276367
compressor=gzip


129it [00:00, 8000.14it/s]


spent: 0.017085552215576172
compressor=gzip


37it [00:00, 7773.07it/s]


spent: 0.005899190902709961
compressor=gzip


24it [00:00, 7543.71it/s]


spent: 0.004117012023925781
compressor=gzip


49it [00:00, 7443.98it/s]


spent: 0.0075418949127197266
compressor=gzip


80it [00:00, 8200.21it/s]


spent: 0.010716438293457031
compressor=gzip


77it [00:00, 8217.64it/s]


spent: 0.010336875915527344
compressor=gzip


39it [00:00, 8043.36it/s]


spent: 0.005823373794555664
compressor=gzip


51it [00:00, 9289.91it/s]


spent: 0.006455183029174805
compressor=gzip


31it [00:00, 7164.22it/s]


spent: 0.005278825759887695
compressor=gzip


35it [00:00, 7603.89it/s]


spent: 0.005550861358642578
compressor=gzip


32it [00:00, 6674.51it/s]


spent: 0.005741119384765625
compressor=gzip


27it [00:00, 8035.07it/s]


spent: 0.004305124282836914
compressor=gzip


87it [00:00, 9562.74it/s]


spent: 0.01008462905883789
compressor=gzip


94it [00:00, 9402.47it/s]


spent: 0.01096487045288086
compressor=gzip


83it [00:00, 8320.84it/s]


spent: 0.010937690734863281
compressor=gzip


65it [00:00, 7509.22it/s]


spent: 0.009654045104980469
compressor=gzip


27it [00:00, 7043.11it/s]


spent: 0.004805564880371094
compressor=gzip


46it [00:00, 8746.06it/s]


spent: 0.00624847412109375
compressor=gzip


30it [00:00, 7860.88it/s]


spent: 0.004766941070556641
compressor=gzip


60it [00:00, 8125.08it/s]


spent: 0.008365392684936523
compressor=gzip


130it [00:00, 8341.13it/s]


spent: 0.01656818389892578
compressor=gzip


86it [00:00, 7498.24it/s]


spent: 0.012479305267333984
compressor=gzip


75it [00:00, 7915.57it/s]


spent: 0.010519266128540039
compressor=gzip


150it [00:00, 10181.50it/s]


spent: 0.015772104263305664
compressor=gzip


47it [00:00, 8649.57it/s]


spent: 0.006406545639038086
compressor=gzip


26it [00:00, 7158.46it/s]


spent: 0.004602909088134766
compressor=gzip


21it [00:00, 7403.58it/s]


spent: 0.0037767887115478516
compressor=gzip


92it [00:00, 7984.03it/s]


spent: 0.012498855590820312
compressor=gzip


34it [00:00, 7024.94it/s]


spent: 0.005797147750854492
compressor=gzip


105it [00:00, 8682.83it/s]


spent: 0.013042688369750977
compressor=gzip


67it [00:00, 7908.66it/s]


spent: 0.009431600570678711
compressor=gzip


29it [00:00, 6924.45it/s]


spent: 0.005162239074707031
compressor=gzip


49it [00:00, 9717.30it/s]


spent: 0.006184816360473633
compressor=gzip


76it [00:00, 10008.07it/s]


spent: 0.008563995361328125
compressor=gzip


30it [00:00, 7623.24it/s]


spent: 0.004877328872680664
compressor=gzip


78it [00:00, 8958.51it/s]


spent: 0.009611368179321289
compressor=gzip


27it [00:00, 7874.16it/s]


spent: 0.0043811798095703125
compressor=gzip


61it [00:00, 8549.22it/s]


spent: 0.008109807968139648
compressor=gzip


38it [00:00, 8384.20it/s]


spent: 0.005683422088623047
compressor=gzip


35it [00:00, 7895.48it/s]


spent: 0.005383968353271484
compressor=gzip


72it [00:00, 8309.44it/s]


spent: 0.009619951248168945
compressor=gzip


80it [00:00, 8136.77it/s]


spent: 0.010988473892211914
compressor=gzip


39it [00:00, 7058.38it/s]


spent: 0.006487607955932617
compressor=gzip


81it [00:00, 9391.01it/s]


spent: 0.009603500366210938
compressor=gzip


29it [00:00, 7419.47it/s]


spent: 0.005140781402587891
compressor=gzip


19it [00:00, 7203.45it/s]


spent: 0.0035948753356933594
compressor=gzip


70it [00:00, 9394.64it/s]


spent: 0.00839686393737793
compressor=gzip


68it [00:00, 8430.01it/s]


spent: 0.009018421173095703
compressor=gzip


79it [00:00, 8503.57it/s]


spent: 0.010301589965820312
compressor=gzip


62it [00:00, 8369.17it/s]


spent: 0.008365392684936523
compressor=gzip


50it [00:00, 8646.62it/s]


spent: 0.006819009780883789
compressor=gzip


69it [00:00, 8256.03it/s]


spent: 0.009663105010986328
compressor=gzip


77it [00:00, 7570.06it/s]


spent: 0.01122903823852539
compressor=gzip


27it [00:00, 7676.15it/s]


spent: 0.004556179046630859
compressor=gzip


57it [00:00, 8669.69it/s]


spent: 0.007592678070068359
compressor=gzip


19it [00:00, 7353.00it/s]


spent: 0.003584146499633789
compressor=gzip


21it [00:00, 7618.09it/s]


spent: 0.003798961639404297
compressor=gzip


34it [00:00, 7627.23it/s]


spent: 0.005477190017700195
compressor=gzip


49it [00:00, 8111.49it/s]


spent: 0.007053852081298828
compressor=gzip


31it [00:00, 7690.51it/s]


spent: 0.0050466060638427734
compressor=gzip


76it [00:00, 7353.51it/s]


spent: 0.011381864547729492
compressor=gzip


101it [00:00, 6733.07it/s]


spent: 0.016039609909057617
compressor=gzip


73it [00:00, 6871.90it/s]


spent: 0.011663198471069336
compressor=gzip


74it [00:00, 7940.51it/s]


spent: 0.010540008544921875
compressor=gzip


55it [00:00, 7568.96it/s]


spent: 0.008296489715576172
compressor=gzip


63it [00:00, 7391.77it/s]


spent: 0.00957632064819336
compressor=gzip


116it [00:00, 7838.30it/s]


spent: 0.016127586364746094
compressor=gzip


26it [00:00, 7375.85it/s]


spent: 0.004568815231323242
compressor=gzip


51it [00:00, 6550.19it/s]


spent: 0.008829116821289062
compressor=gzip


43it [00:00, 6623.88it/s]


spent: 0.0075342655181884766
compressor=gzip


38it [00:00, 7197.92it/s]


spent: 0.0063190460205078125
compressor=gzip


140it [00:00, 7782.88it/s]


spent: 0.019041061401367188
compressor=gzip


55it [00:00, 7294.67it/s]


spent: 0.008585929870605469
compressor=gzip


47it [00:00, 6890.81it/s]


spent: 0.007830381393432617
compressor=gzip


130it [00:00, 6810.81it/s]


spent: 0.02010822296142578
compressor=gzip


43it [00:00, 6933.80it/s]


spent: 0.0072672367095947266
compressor=gzip


30it [00:00, 6821.12it/s]


spent: 0.0054018497467041016
compressor=gzip


58it [00:00, 7293.79it/s]


spent: 0.00896763801574707
compressor=gzip


58it [00:00, 8475.41it/s]


spent: 0.008076906204223633
compressor=gzip


71it [00:00, 7320.98it/s]


spent: 0.010720014572143555
compressor=gzip


93it [00:00, 8211.66it/s]


spent: 0.012341499328613281
compressor=gzip


129it [00:00, 7791.28it/s]


spent: 0.017581939697265625
compressor=gzip


38it [00:00, 7603.45it/s]


spent: 0.006017208099365234
compressor=gzip


84it [00:00, 8493.56it/s]


spent: 0.011089086532592773
compressor=gzip


185it [00:00, 8611.48it/s]


spent: 0.022499799728393555
compressor=gzip


57it [00:00, 8085.06it/s]


spent: 0.008093833923339844
compressor=gzip


48it [00:00, 8036.03it/s]


spent: 0.007004737854003906
compressor=gzip


98it [00:00, 7861.41it/s]


spent: 0.013495445251464844
compressor=gzip


91it [00:00, 7870.37it/s]


spent: 0.012572765350341797
compressor=gzip


97it [00:00, 8206.54it/s]


spent: 0.012830257415771484
compressor=gzip


116it [00:00, 7457.34it/s]


spent: 0.016574382781982422
compressor=gzip


74it [00:00, 7798.45it/s]


spent: 0.010499954223632812
compressor=gzip


102it [00:00, 8455.09it/s]


spent: 0.013077497482299805
compressor=gzip


115it [00:00, 9764.07it/s]


spent: 0.012801170349121094
compressor=gzip


98it [00:00, 6912.22it/s]


spent: 0.015218734741210938
compressor=gzip


108it [00:00, 7002.39it/s]


spent: 0.01644444465637207
compressor=gzip


86it [00:00, 7446.53it/s]


spent: 0.012567520141601562
compressor=gzip


104it [00:00, 8090.65it/s]


spent: 0.013853073120117188
compressor=gzip


80it [00:00, 6522.26it/s]


spent: 0.013296842575073242
compressor=gzip


113it [00:00, 7788.55it/s]


spent: 0.015539169311523438
compressor=gzip


69it [00:00, 8063.95it/s]


spent: 0.009551525115966797
compressor=gzip


109it [00:00, 7098.83it/s]


spent: 0.01636052131652832
compressor=gzip


72it [00:00, 7630.44it/s]


spent: 0.010457515716552734
compressor=gzip


32it [00:00, 6927.72it/s]


spent: 0.00567317008972168
compressor=gzip


69it [00:00, 7732.37it/s]


spent: 0.010024785995483398
compressor=gzip


49it [00:00, 7795.81it/s]


spent: 0.00731658935546875
compressor=gzip


128it [00:00, 7975.38it/s]


spent: 0.01707315444946289
compressor=gzip


56it [00:00, 6700.74it/s]


spent: 0.009395360946655273
compressor=gzip


123it [00:00, 7769.92it/s]


spent: 0.016880273818969727
compressor=gzip


111it [00:00, 7216.09it/s]


spent: 0.01640915870666504
compressor=gzip


88it [00:00, 7734.51it/s]

spent: 0.01241755485534668





In [9]:
def predictions_to_map_output(predictions):
    sorted_answer_indices = np.argsort(predictions)
    top_answer_indices = sorted_answer_indices[:,:3] # Get the first three answers in each row
    top_answers = np.vectorize(index_to_option.get)(top_answer_indices)
    return np.apply_along_axis(lambda row: ' '.join(row), 1, top_answers)

In [10]:
train['prediction'] = predictions_to_map_output(predicts)

In [11]:
train['prediction'].tolist()

['A C E',
 'E B C',
 'A D E',
 'E A B',
 'D B A',
 'E D B',
 'E A C',
 'C A D',
 'C D E',
 'D B E',
 'A D C',
 'C E B',
 'B D E',
 'C A B',
 'E D C',
 'B C A',
 'D E B',
 'C D B',
 'C B E',
 'C A B',
 'A E C',
 'A C B',
 'D E B',
 'A D E',
 'B C D',
 'C D A',
 'D B C',
 'A B D',
 'D C A',
 'D A E',
 'A C D',
 'E B C',
 'C A D',
 'B C E',
 'A B D',
 'C A E',
 'A E B',
 'B E C',
 'B C D',
 'B D E',
 'C A B',
 'D A B',
 'A B C',
 'A E B',
 'E C B',
 'E D C',
 'E D A',
 'B A D',
 'C D A',
 'E D C',
 'C A D',
 'D C E',
 'D B E',
 'B E A',
 'D E A',
 'D A C',
 'E B D',
 'E D A',
 'D E B',
 'E D A',
 'E D A',
 'A D B',
 'B E C',
 'B D C',
 'E A C',
 'A B D',
 'B A D',
 'B C D',
 'A B C',
 'B A E',
 'B A C',
 'A C E',
 'A C B',
 'A C D',
 'B D A',
 'A B C',
 'C D E',
 'B C D',
 'A D B',
 'B C D',
 'A B E',
 'E A C',
 'A C B',
 'E B C',
 'C D E',
 'B C D',
 'A B C',
 'C E D',
 'B D C',
 'D A B',
 'E A C',
 'E A C',
 'B C D',
 'B C D',
 'E A C',
 'A D E',
 'A B D',
 'A C E',
 'A E B',
 'A D E',


In [12]:
predictions = train['prediction']

In [13]:
predictions.str.split(" ")

0      [A, C, E]
1      [E, B, C]
2      [A, D, E]
3      [E, A, B]
4      [D, B, A]
         ...    
195    [D, B, A]
196    [A, D, B]
197    [B, A, C]
198    [E, A, D]
199    [C, D, B]
Name: prediction, Length: 200, dtype: object

In [14]:
def average_percision(lst, ans):
    res = []
    for i in range(len(lst)):
        if lst[i] == ans:
            res.append(1/(i+1))
            # Once a correct label has been scored for an individual question in the test set, 
            # that label is no longer considered relevant for that question, 
            # and additional predictions of that label are skipped in the calculation.
            return np.sum(res)
        else:
            res.append(0)
    return 0

In [15]:
maps = []

for pred, ans in train[['prediction', 'answer']].to_numpy():
    pred = pred.split(" ")
    maps.append(average_percision(pred, ans))


In [16]:
train[['prediction', 'answer']]

Unnamed: 0,prediction,answer
0,A C E,D
1,E B C,A
2,A D E,A
3,E A B,C
4,D B A,D
...,...,...
195,D B A,C
196,A D B,B
197,B A C,B
198,E A D,D


In [17]:
np.mean(maps)

0.21333333333333332

In [18]:
maps

[0,
 0,
 1.0,
 0,
 1.0,
 0.3333333333333333,
 0.5,
 0.3333333333333333,
 1.0,
 0,
 0,
 0,
 0,
 0,
 0,
 1.0,
 0.5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0.3333333333333333,
 0,
 0,
 0,
 1.0,
 0,
 0,
 0,
 0,
 0.5,
 0,
 0,
 0.3333333333333333,
 0,
 0,
 0.5,
 0,
 0,
 0,
 0,
 0,
 0.5,
 0,
 0,
 0.3333333333333333,
 0,
 0.3333333333333333,
 0,
 0,
 0,
 0,
 1.0,
 0.3333333333333333,
 0,
 0.3333333333333333,
 0.3333333333333333,
 0.3333333333333333,
 0.5,
 0,
 0,
 0,
 0.3333333333333333,
 0,
 0.3333333333333333,
 1.0,
 0,
 0,
 0.5,
 0.5,
 0.5,
 1.0,
 0,
 0,
 0,
 0.5,
 0.3333333333333333,
 0,
 1.0,
 0.3333333333333333,
 0,
 0,
 0,
 1.0,
 0,
 0,
 1.0,
 0,
 1.0,
 0,
 0,
 0,
 0,
 0.5,
 0,
 0.5,
 0,
 0,
 0.5,
 0,
 0,
 0,
 0.3333333333333333,
 0,
 0.3333333333333333,
 0,
 1.0,
 0.5,
 0,
 0,
 1.0,
 0,
 0.5,
 0.3333333333333333,
 0.3333333333333333,
 0,
 0.3333333333333333,
 0,
 1.0,
 0,
 0,
 0,
 0,
 0,
 1.0,
 0.3333333333333333,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0.5,
 0,
 0,
 0.5,
 0,
 0,
 0.33333333333