In [None]:
import pickle
from cdl import *
from collections import defaultdict, Counter, OrderedDict
from tqdm import tqdm
import os
import shutil
import random


class StaticFeature5:
    def __init__(self, cd, n_rules=4):
        self.cd = cd
        self.cd.init_trs()
        self.cd.init_subset(5)
        if n_rules <= 4:
            with open('../databases/database_5.pkl', "rb") as f:
                self.dataset_5 = pickle.load(f)
        else:
            with open('../database/database_5_six_rules.pkl', "rb") as f:
                self.dataset_5 = pickle.load(f)
    
    def fetch_feature(self, trs):
        sizes = []
        states = self.cd.subset_states(trs)
        for state in states:
            sizes.append(self.dataset_5[tuple(state)])
        
        return sizes
    
    def score_function(self, trs):
        sizes = self.fetch_feature(trs)
        
        if min(sizes) < 16:
            return -1
        
        counter = Counter(sizes)
        num_16 = counter.get(16, 0)
        num_17 = counter.get(17, 0)
        num_18 = counter.get(18, 0)
        num_19 = counter.get(19, 0)
        num_20 = counter.get(20, 0)

        score = (0 * num_16) + (1 * num_17) + (2 * num_18) + (3 * num_19) + (4 * num_20)
        return score

In [None]:
class ExhaustiveSearch:
    def __init__(self, cd, rules):
        self.cd = cd
        self.sf = StaticFeature5(cd, n_rules=len(rules))
        self.rules = rules

    def expand_trs(self,
                   trs, 
                   use_value=True):
        
        triple = self.cd.unassigned_triples(trs)[0]
        trs_value_list = []
        for rule in self.rules:
            trs = self.cd.assign_rule(trs, triple, rule)
            if use_value:
                value = self.sf.score_function(trs)
                if value > -1:
                    trs_value_list.append((trs, value))
            else:
                trs_value_list.append((trs, -1))

        return trs_value_list

        
    def static_search(self,
                      trs,
                      interval=1,
                      n_complete_triple=20,
                      top_n=100000):
        
        folder_name = f"{interval}_{n_complete_triple}_{top_n}"

        num_unassigned = len(self.cd.unassigned_triples(trs))

        trs_score_list = self.expand_trs(trs)
        self.save_trs_list([trs_score[0] for trs_score in trs_score_list],
                           folder_name,
                           f"{1}_{num_unassigned}",
                           "0.pkl")
        
        for n_iter in range(2, num_unassigned + 1):
            next_trs_value_list = []
            
            trs_list = self.load_trs_list(folder_name, f"{n_iter-1}_{num_unassigned}", "0.pkl")

            for trs in tqdm(trs_list, ascii=True, desc=f'{n_iter}'):
                if n_iter >= n_complete_triple and ((n_iter - n_complete_triple) % interval) == 0:
                    trs_value_list = self.expand_trs(trs, 
                                                     use_value=True)
                else:
                    trs_value_list = self.expand_trs(trs, 
                                                     use_value=False)
                    
                next_trs_value_list.extend(trs_value_list)
            
            if n_iter >= n_complete_triple and ((n_iter - n_complete_triple) % interval) == 0:
                next_trs_value_list.sort(key=lambda trs_score: trs_score[1])
                next_trs_value_list = next_trs_value_list[-top_n:]
                    
            self.save_trs_list([trs_value[0] for trs_value in next_trs_value_list],
                               folder_name,
                               f"{n_iter}_{num_unassigned}",
                               "0.pkl")
            
            
    def resume_static_search(self,
                             folder_name,
                             sub_folder_name,
                             interval,
                             n_complete_triple,
                             top_n):
        
        name = sub_folder_name.split("_")
        n_iter, num_unassigned = int(name[0]), int(name[1])
        
        for n_iter in range(n_iter+1, num_unassigned + 1):
            next_trs_score_list = []
            
            trs_list = self.load_trs_list(folder_name, f"{n_iter-1}_{num_unassigned}", "0.pkl")

            for trs in tqdm(trs_list, ascii=True, desc=f'{n_iter}'):
                trs_score_list = self.expand_trs(trs)
                next_trs_score_list.extend(trs_score_list)
                    
            if n_iter >= n_complete_triple and ((n_iter - n_complete_triple) % interval) == 0:
                next_trs_score_list.sort(key=lambda trs_score: trs_score[1])
                next_trs_score_list = next_trs_score_list[-top_n:]
                                
            self.save_trs_list([trs_score[0] for trs_score in next_trs_score_list],
                               folder_name,
                               f"{n_iter}_{num_unassigned}",
                               f"{0}.pkl")

            
    def save_trs_list(self,
                      trs_list,
                      folder_name,
                      sub_folder_name,
                      filename):
        
        name = sub_folder_name.split("_")
        n_iter, num_unassigned = int(name[0]), int(name[1])
        
        trs_folder_name = f'./trs_list/{self.cd.n}/{folder_name}/{n_iter}_{num_unassigned}/'
        if not os.path.exists(trs_folder_name):
            os.makedirs(trs_folder_name)

        with open(trs_folder_name + filename, "wb") as f:
            pickle.dump(trs_list, f)
        
        if n_iter > 1:
            shutil.rmtree(f"./trs_list/{self.cd.n}/{folder_name}/{n_iter-1}_{num_unassigned}")
    
    def load_trs_list(self,
                      folder_name,
                      sub_folder_name,
                      filename):
        
        folder_name = f'./trs_list/{self.cd.n}/{folder_name}/{sub_folder_name}/'
        with open(folder_name + filename, "rb") as f:
            trs_list = pickle.load(f)
        return trs_list
    
    
    def get_size_counter(self, 
                         folder_name, 
                         sub_folder_name):
        sizes = []
        name = sub_folder_name.split("_")
        n_iter, num_unassigned = int(name[0]), int(name[1])

        trs_list = self.load_trs_list(folder_name, f"{n_iter}_{num_unassigned}", "0.pkl")
        for trs in tqdm(trs_list, ascii=True, desc=f'{n_iter}'):
            sizes.append(self.cd.size(trs))
        result = Counter(sizes)
        result = OrderedDict(sorted(result.items(), key=lambda t: t[0]))
        
        return result


In [None]:
cd = CondorcetDomain(n=6)
es = ExhaustiveSearch(cd, rules=["2N3", "2N1", "1N3", "3N1"])
trs = cd.init_trs()

In [None]:
es.static_search(trs,
                 interval=1,
                 n_complete_triple=5,
                 top_n=1000)

In [None]:
es.get_size_counter(folder_name="1_5_1000", 
                    sub_folder_name="20_20")