In [31]:
import pickle
from collections import Counter
import os
from collections import OrderedDict, defaultdict
from cdl import *
import numpy as np
import random
from tqdm import tqdm

Fishburn_scores = {
    6: 22,
    7: 72,
    8: 184,
    9: 400,
    10: 780,
    11: 1400,
    12: 2360,
    13: 3780,
    14: 5810,
    15: 8624,
    16: 12432,
    17: 17472,
    18: 24024,
    19: 32400,
    20: 42960,
    21: 56100,
    22: 72270,
    23: 91960,
    24: 115720,
    25: 144144,
    26: 177892,
    27: 217672,
    28: 264264,
    29: 318500,
    30: 381290,
}


class StaticFeature5:
    def __init__(self, cd, rules, lib_path="~/cdl"):
        self.cd = cd
        self.cd.init_trs()
        self.cd.init_subset(5)
        self.database_folder = f'{lib_path}/python/databases/'

        suited_filename = ""
        num_rules = 6
        for filename in os.listdir(self.database_folder):
            name, extension = filename.split(".")
            if extension == 'pkl':
                database_rules = name.split("_")[2:]
                if set(rules).issubset(set(database_rules)):
                    if len(database_rules) <= num_rules:
                        num_rules = len(database_rules)
                        suited_filename = filename

        with open(self.database_folder + suited_filename, "rb") as f:
            self.dataset_5 = pickle.load(f)

    def fetch_feature(self, trs):
        sizes = []
        states = self.cd.subset_states(trs)
        for state in states:
            sizes.append(self.dataset_5[tuple(state)])

        return sizes

    def score_function(self, trs, cutoff=16, threshold=0):
        sizes = self.fetch_feature(trs)

        if min(sizes) < cutoff:
            return -1

        counter = Counter(sizes)
        num_16 = counter.get(16, 0)
        num_17 = counter.get(17, 0)
        num_18 = counter.get(18, 0)
        num_19 = counter.get(19, 0)
        num_20 = counter.get(20, 0)

        score = (0 * num_16) + (1 * num_17) + (2 * num_18) + (3 * num_19) + (4 * num_20)

        if score < threshold * Fishburn_scores[self.cd.n]:
            return -1

        return score
    

class Search:
    def __init__(self, cd, rules, lib_path, result_path):
        self.cd = cd
        self.sf = StaticFeature5(cd, rules=rules, lib_path=lib_path)
        self.rules = rules
        self.lib_path = lib_path
        self.result_path = result_path
        self.folder_path = f"{result_path}/{cd.n}/"

    def expand_trs(self,
                   trs,
                   cutoff=16,
                   threshold=0):

        triple = self.cd.unassigned_triples(trs)[0]
        trs_score_list = []
        for rule in self.rules:
            trs = self.cd.assign_rule(trs, triple, rule)
            score = self.sf.score_function(trs, cutoff, threshold)
            if score > -1:
                trs_score_list.append((trs, score))

        return trs_score_list

    def save_trs_score_list(self,
                            trs_list,
                            sub_folder_name,
                            filename):

        sub_folder_path = f'{self.folder_path}/{sub_folder_name}/'
        if not os.path.exists(sub_folder_path):
            os.makedirs(sub_folder_path)

        with open(sub_folder_path + filename, "wb") as f:
            trs_score_list = []
            for trs, score in trs_list:
                state = self.cd.trs_to_state(trs)
                trs_score_list.append((state, score))
            pickle.dump(trs_score_list, f)

    def load_trs_score_list(self,
                            sub_folder_name,
                            filename):

        sub_folder_path = f'{self.folder_path}/{sub_folder_name}/'
        with open(sub_folder_path + filename, "rb") as f:
            state_score_list = pickle.load(f)

        trs_score_list = []
        for state, score in state_score_list:
            trs = self.cd.state_to_trs(state)
            trs_score_list.append((trs, score))
        return trs_score_list

    def get_size_counter(self):

        sizes = []
        trs_score_size_list = []
        for filename in os.listdir(f'{self.folder_path}/{self.cd.num_triples}_{self.cd.num_triples}/'):
            trs_score_list = self.load_trs_score_list(f"{self.cd.num_triples}_{self.cd.num_triples}", filename)
            for trs, score in trs_score_list:
                size = self.cd.size(trs)
                sizes.append(size)
                trs_score_size_list.append((trs, score, size))

        result = Counter(sizes)
        result = OrderedDict(sorted(result.items(), key=lambda t: t[0]))

        with open(f"{self.folder_path}/trs_score_size.pkl", "wb") as f:
            pickle.dump(trs_score_size_list, f)

        return result

    def save_result_as_dict(self):
        score_states_dict = defaultdict(list)

        for filename in os.listdir(f'{self.folder_path}/{self.cd.num_triples}_{self.cd.num_triples}/'):
            trs_score_list = self.load_trs_score_list(f"{self.cd.num_triples}_{self.cd.num_triples}", filename)
            for trs, score in trs_score_list:
                score_states_dict[score].append(self.cd.trs_to_state(trs))

        with open(f"{self.folder_path}/result_dict.pkl", "wb") as f:
            pickle.dump(score_states_dict, f)

In [38]:
class ExhaustiveSearch(Search):
    def __init__(self, cd, rules, lib_path, result_path):
        super().__init__(cd, rules, lib_path, result_path)

    def static_search(self,
                      trs,
                      cutoff=16,
                      threshold=0,
                      top_n=1000):

        n_complete = cd.num_triples

        num_assigned = len(self.cd.assigned_triples(trs))

        folder_name = f"{cutoff}_{threshold}_{top_n}_{num_assigned+n_complete}_" + f"_".join(self.rules)
        self.folder_path += folder_name

        trs_score_list = self.expand_trs(trs)

        for n_iter in range(num_assigned+2, num_assigned+n_complete+1):
            next_trs_score_list = []

            for trs, _ in tqdm(trs_score_list, ascii=True, desc=f'{n_iter}'):
                trs_value_list = self.expand_trs(trs, cutoff, threshold)
                next_trs_score_list.extend(trs_value_list)

            trs_score_list.clear()

            if top_n == -1:
                trs_score_list = next_trs_score_list
            else:
                next_trs_score_list.sort(key=lambda trs_score: trs_score[1])
                trs_score_list = next_trs_score_list[-top_n:]


        self.save_trs_score_list(trs_list=trs_score_list,
                                 sub_folder_name=f"{num_assigned+n_complete}_{self.cd.num_triples}",
                                 filename="0.pkl")

In [41]:
cd = CondorcetDomain(n=7)
es = ExhaustiveSearch(cd, 
                      rules=["2N1", "2N3", "1N2", "3N2"], 
                      lib_path="/Users/bei/CLionProjects/cdl", 
                      result_path="./results")
trs = cd.init_trs()

es.static_search(trs,
                 cutoff=16,
                 threshold=0,
                 top_n=1000)  # setting top_n to -1 negates the its effect. 

2: 100%|##########| 4/4 [00:00<00:00, 3499.63it/s]
3: 100%|##########| 15/15 [00:00<00:00, 1372.12it/s]
4: 100%|##########| 32/32 [00:00<00:00, 1889.57it/s]
5: 100%|##########| 65/65 [00:00<00:00, 2763.38it/s]
6: 100%|##########| 108/108 [00:00<00:00, 2941.35it/s]
7: 100%|##########| 176/176 [00:00<00:00, 3261.08it/s]
8: 100%|##########| 299/299 [00:00<00:00, 3637.12it/s]
9: 100%|##########| 456/456 [00:00<00:00, 3693.83it/s]
10: 100%|##########| 658/658 [00:00<00:00, 3127.85it/s]
11: 100%|##########| 1000/1000 [00:00<00:00, 3764.45it/s]
12: 100%|##########| 1000/1000 [00:00<00:00, 3905.41it/s]
13: 100%|##########| 1000/1000 [00:00<00:00, 4117.00it/s]
14: 100%|##########| 1000/1000 [00:00<00:00, 3941.76it/s]
15: 100%|##########| 1000/1000 [00:00<00:00, 4135.41it/s]
16: 100%|##########| 1000/1000 [00:00<00:00, 3429.97it/s]
17: 100%|##########| 1000/1000 [00:00<00:00, 3998.99it/s]
18: 100%|##########| 1000/1000 [00:00<00:00, 3428.32it/s]
19: 100%|##########| 1000/1000 [00:00<00:00, 2986.

In [42]:
es.get_size_counter()

OrderedDict([(76, 2),
             (80, 20),
             (81, 4),
             (82, 12),
             (83, 28),
             (84, 70),
             (85, 26),
             (86, 184),
             (87, 56),
             (88, 200),
             (89, 118),
             (90, 102),
             (91, 40),
             (92, 82),
             (93, 8),
             (94, 18),
             (95, 4),
             (96, 16),
             (97, 8),
             (100, 2)])