In [1]:
import pickle
from collections import Counter
import os
from collections import OrderedDict, defaultdict
from cdl import *
import numpy as np
import random
from tqdm import tqdm


class StaticFeature5:
    def __init__(self, cd, rules, lib_path="~/cdl"):
        self.cd = cd
        self.cd.init_trs()
        self.cd.init_subset(5)
        self.database_folder = f'{lib_path}/python/databases/'

        suited_filename = ""
        num_rules = 6
        for filename in os.listdir(self.database_folder):
            name, extension = filename.split(".")
            if extension == 'pkl':
                database_rules = name.split("_")[2:]
                if set(rules).issubset(set(database_rules)):
                    if len(database_rules) <= num_rules:
                        num_rules = len(database_rules)
                        suited_filename = filename

        with open(self.database_folder + suited_filename, "rb") as f:
            self.dataset_5 = pickle.load(f)

    def fetch_feature(self, trs):
        sizes = []
        states = self.cd.subset_states(trs)
        for state in states:
            sizes.append(self.dataset_5[tuple(state)])

        return sizes

    def score_function(self, trs, cutoff=16, threshold=0):
        sizes = self.fetch_feature(trs)

        if min(sizes) < cutoff:
            return -1

        counter = Counter(sizes)
        num_16 = counter.get(16, 0)
        num_17 = counter.get(17, 0)
        num_18 = counter.get(18, 0)
        num_19 = counter.get(19, 0)
        num_20 = counter.get(20, 0)

        score = (0 * num_16) + (17 * num_17) + (18 * num_18) + (19 * num_19) + (20 * num_20)
                                                                                
        return score
    

class Search:
    def __init__(self, cd, rules, lib_path, result_path):
        self.cd = cd
        self.sf = StaticFeature5(cd, rules=rules, lib_path=lib_path)
        self.rules = rules
        self.lib_path = lib_path
        self.result_path = result_path
        self.folder_path = f"{result_path}/{cd.n}/"

    def expand_trs(self,
                   trs, 
                   cutoff=16,
                   threshold=0):
        
        triple = self.cd.unassigned_triples(trs)[0]
        (i, j, k) = triple
        trs_value_list = []
        for rule in self.rules:
            first, second = rule.split("N")
            trs = self.cd.assign_rule(trs, triple, rule)
            trs = self.cd.assign_rule(trs, 
                                      [self.cd.n+1-k, self.cd.n+1-j, self.cd.n+1-i], 
                                      f"{4-int(first)}N{4-int(second)}")
            value = self.sf.score_function(trs, cutoff, threshold)
            if value > -1:
                trs_value_list.append((trs, value))
        return trs_value_list

    def save_trs_score_list(self,
                            trs_list,
                            sub_folder_name,
                            filename):

        sub_folder_path = f'{self.folder_path}/{sub_folder_name}/'
        if not os.path.exists(sub_folder_path):
            os.makedirs(sub_folder_path)

        with open(sub_folder_path + filename, "wb") as f:
            trs_score_list = []
            for trs, score in trs_list:
                state = self.cd.trs_to_state(trs)
                trs_score_list.append((state, score))
            pickle.dump(trs_score_list, f)

    def load_trs_score_list(self,
                            sub_folder_name,
                            filename):

        sub_folder_path = f'{self.folder_path}/{sub_folder_name}/'
        with open(sub_folder_path + filename, "rb") as f:
            state_score_list = pickle.load(f)

        trs_score_list = []
        for state, score in state_score_list:
            trs = self.cd.state_to_trs(state)
            trs_score_list.append((trs, score))
        return trs_score_list

    def get_size_counter(self):

        sizes = []
        trs_score_size_list = []
        for filename in os.listdir(f'{self.folder_path}/{self.cd.num_triples}_{self.cd.num_triples}/'):
            trs_score_list = self.load_trs_score_list(f"{self.cd.num_triples}_{self.cd.num_triples}", filename)
            for trs, score in trs_score_list:
                size = self.cd.size(trs)
                sizes.append(size)
                trs_score_size_list.append((trs, score, size))

        result = Counter(sizes)
        result = OrderedDict(sorted(result.items(), key=lambda t: t[0]))

        with open(f"{self.folder_path}/trs_score_size.pkl", "wb") as f:
            pickle.dump(trs_score_size_list, f)

        return result

    def save_result_as_dict(self):
        score_states_dict = defaultdict(list)

        for filename in os.listdir(f'{self.folder_path}/{self.cd.num_triples}_{self.cd.num_triples}/'):
            trs_score_list = self.load_trs_score_list(f"{self.cd.num_triples}_{self.cd.num_triples}", filename)
            for trs, score in trs_score_list:
                score_states_dict[score].append(self.cd.trs_to_state(trs))

        with open(f"{self.folder_path}/result_dict.pkl", "wb") as f:
            pickle.dump(score_states_dict, f)

In [2]:
class ExhaustiveSearch(Search):
    def __init__(self, cd, rules, lib_path, result_path):
        super().__init__(cd, rules, lib_path, result_path)

    def static_search(self,
                      trs,
                      cutoff=16,
                      threshold=0,
                      n_complete=5,
                      top_n=1000):

        num_assigned = len(self.cd.assigned_triples(trs))

        folder_name = f"{cutoff}_{threshold}_{top_n}_{num_assigned+n_complete}_" + f"_".join(self.rules)
        self.folder_path += folder_name

        trs_score_list = self.expand_trs(trs)
        
        n_iter = 2
        while len(self.cd.unassigned_triples(trs_score_list[0][0])) > 0:
            next_trs_score_list = []

            for trs, _ in tqdm(trs_score_list, ascii=True, desc=f'{n_iter}'):
                trs_value_list = self.expand_trs(trs, cutoff, threshold)
                next_trs_score_list.extend(trs_value_list)

            trs_score_list.clear()

            if top_n == -1 or n_iter <= n_complete:
                trs_score_list = next_trs_score_list
            else:
                next_trs_score_list.sort(key=lambda trs_score: trs_score[1])
                trs_score_list = next_trs_score_list[-top_n:]
                
            n_iter += 1

        self.save_trs_score_list(trs_list=trs_score_list,
                                 sub_folder_name=f"{self.cd.num_triples}_{self.cd.num_triples}",
                                 filename="0.pkl")

In [5]:
cd = CondorcetDomain(n=10)
es = ExhaustiveSearch(cd, 
                      rules=["2N1", "2N3", "1N3", "3N1"], 
                      lib_path="/Users/bei/CLionProjects/cdl", 
                      result_path="./results")
trs = cd.init_trs()

es.static_search(trs,
                 cutoff=16,
                 threshold=es.sf.score_function(cd.init_trs_by_scheme(Fishburn_scheme)),
                 n_complete=5,
                 top_n=10000)  # setting top_n to -1 negates the its effect. 

2: 100%|##########| 4/4 [00:00<00:00, 420.83it/s]
3: 100%|##########| 16/16 [00:00<00:00, 274.40it/s]
4: 100%|##########| 52/52 [00:00<00:00, 320.85it/s]
5: 100%|##########| 167/167 [00:00<00:00, 375.20it/s]
6: 100%|##########| 406/406 [00:01<00:00, 381.22it/s]
7: 100%|##########| 876/876 [00:02<00:00, 358.15it/s]
8: 100%|##########| 2687/2687 [00:08<00:00, 329.17it/s]
9: 100%|##########| 6177/6177 [00:22<00:00, 280.58it/s]
10: 100%|##########| 10000/10000 [00:35<00:00, 281.28it/s]
11: 100%|##########| 10000/10000 [00:35<00:00, 284.26it/s]
12: 100%|##########| 10000/10000 [00:38<00:00, 257.70it/s]
13: 100%|##########| 10000/10000 [00:31<00:00, 315.83it/s]
14: 100%|##########| 10000/10000 [00:32<00:00, 308.28it/s]
15: 100%|##########| 10000/10000 [00:35<00:00, 282.22it/s]
16: 100%|##########| 10000/10000 [00:31<00:00, 313.74it/s]
17: 100%|##########| 10000/10000 [00:35<00:00, 284.94it/s]
18: 100%|##########| 10000/10000 [00:34<00:00, 290.26it/s]
19: 100%|##########| 10000/10000 [00:34<0

In [None]:
es.get_size_counter()