<a id="toc"></a>
# Table of Contents
1. [Align tasks](#align_tasks)
1. [Run @yukikubo123's DSL](#run_yuki_dsl)
1. [Rollback the predictions](#rollback_the_predictions)

<a id="align_tasks"></a>
# Align tasks
[Back to Table of Contents](#toc)

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import os
import json
import numpy as np
from pathlib import Path
import random
from collections import Counter
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
import matplotlib.pyplot as plt

data_path = Path('../input/abstraction-and-reasoning-challenge')
train_path = data_path / 'training'
valid_path = data_path / 'evaluation'
test_path = data_path / 'test'

def set_seeds(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED']=str(seed)
    
set_seeds(0)

paths = {'train': train_path, 'eval': valid_path, 'test': test_path}

def get_tasks(dataset='train'):
    path = paths[dataset]
    fns = sorted(os.listdir(path))
    tasks = {}
    for idx, fn in enumerate(fns):
        fp = path / fn
        with open(fp, 'r') as f:
            task = json.load(f)
            tasks[fn.split('.')[0]] = task
    return tasks


test_tasks = get_tasks('test')
train_tasks = get_tasks('train')
valid_tasks = get_tasks('eval')

In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib import animation, rc
from IPython.display import HTML

cmap = colors.ListedColormap(
        ['#000000', '#0074D9','#FF4136','#2ECC40','#FFDC00',
         '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])
norm = colors.Normalize(vmin=0, vmax=9)
    
def plot_pictures(pictures, labels):
    fig, axs = plt.subplots(1, len(pictures), figsize=(2*len(pictures),32))
    for i, (pict, label) in enumerate(zip(pictures, labels)):
        axs[i].imshow(np.array(pict), cmap=cmap, norm=norm)
        axs[i].set_title(label)
    plt.show()
    
def plot_sample(sample, predict=None):
    if predict is None:
        plot_pictures([sample['input'], sample['output']], ['Input', 'Output'])
    else:
        plot_pictures([sample['input'], sample['output'], predict], ['Input', 'Output', 'Predict'])

norm = colors.Normalize(vmin=0, vmax=9)
# 0:black, 1:blue, 2:red, 3:greed, 4:yellow,
# 5:gray, 6:magenta, 7:orange, 8:sky, 9:brown
plt.figure(figsize=(3, 1), dpi=200)
plt.imshow([list(range(10))], cmap=cmap, norm=norm)
plt.xticks(list(range(10)))
plt.yticks([])
plt.show()

task = train_tasks["db3e9e38"]
for sample in task['train']:
    plot_sample(sample)

In [None]:
from skimage.transform import hough_line

def is_rotation(img):
    tested_angles = np.array([0, np.pi / 2])
    image = np.array(img)
    h, theta, d = hough_line(image, theta=tested_angles)
    
    rot = h[:,0].max() > h[:,1].max()
    return rot

    
def get_color_counter(a, binary=False):
    if binary:
        unique, counts = np.unique( (a>0).astype(int), return_counts=True)
    else:
        unique, counts = np.unique( a, return_counts=True)
    return dict(zip(unique, counts))

def similarity(da, db):
    total = 0
    for k, v in da.items():
        if k in db:
            total += min(v, db.get(k))
    return total
    
def is_parts_aligned(da1, da2, db1, db2):    
    def get_most_color(dab):
        c = 0
        max_c = 0
        for k, v in dab.items():
            if k > 0:
                if max_c < v:
                    max_c = v
                    c = k
        return c
    
    if True:
        c1 = get_most_color(da1)
        if (da1.get(c1, 0) >= da2.get(c1, 0)) and (db1.get(c1, 0) < db2.get(c1, 0)):
            return False
        if (da1.get(c1, 0) <= da2.get(c1, 0)) and (db1.get(c1, 0) > db2.get(c1, 0)):
            return False
    return True

    
def is_2images_aligned_updown(img0, img1):
    a = np.array(img0)
    a1 = a[0:a.shape[0]//2, :]
    a2 = a[a.shape[0]//2:, :]
    
    b = np.array(img1)
    b1 = b[0:b.shape[0]//2, :]
    b2 = b[b.shape[0]//2:, :]
    
    da1 = get_color_counter(a1)
    da2 = get_color_counter(a2)
    db1 = get_color_counter(b1)
    db2 = get_color_counter(b2)
    
    return is_parts_aligned(da1, da2, db1, db2)

def is_2images_aligned_leftright(img0, img1):
    a = np.array(img0)
    a1 = a[:, 0:a.shape[1]//2] # a[0:a.shape[0]//2, :]
    a2 = a[:, a.shape[1]//2:]
    
    b = np.array(img1)
    b1 = b[:, 0:b.shape[1]//2] # b[0:b.shape[0]//2, :]
    b2 = b[:, b.shape[1]//2:]
    
    da1 = get_color_counter(a1)
    da2 = get_color_counter(a2)
    db1 = get_color_counter(b1)
    db2 = get_color_counter(b2)
    
    return is_parts_aligned(da1, da2, db1, db2)
    

In [None]:
def align_task(task):
    task_aligned = task.copy()
    
    sample_trains = task['train']
    sample_tests = task['test']
    
    # Train
    sample_trains_aligned = []
    for sample in sample_trains:
        img_input = sample['input']
        img_ouput = sample['output']
        
        sample_aligned = sample.copy()
        if is_rotation(img_input):
            sample_aligned['input'] = np.rot90(np.array(img_input), k=1).tolist()
            sample_aligned['output'] = np.rot90(np.array(img_ouput), k=1).tolist()
            
        sample_trains_aligned.append(sample_aligned)
        
    sample_trains_aligned_2 = sample_trains_aligned[:1] # first element
    img0_aligned = sample_trains_aligned_2[0]['input']
    
    for sample in sample_trains_aligned[1:]:
        sample_aligned = sample.copy()
        
        if not is_2images_aligned_updown(img0_aligned, sample_aligned['input']):
            sample_aligned['input'] = np.flipud(np.array(sample_aligned['input'])).tolist()
            sample_aligned['output'] = np.flipud(np.array(sample_aligned['output'])).tolist()
            
        if not is_2images_aligned_leftright(img0_aligned, sample_aligned['input']):
            sample_aligned['input'] = np.fliplr(np.array(sample_aligned['input'])).tolist()
            sample_aligned['output'] = np.fliplr(np.array(sample_aligned['output'])).tolist()
            
        sample_trains_aligned_2.append(sample_aligned)
        
    task_aligned['train'] = sample_trains_aligned_2
    
    # Test
    sample_test_aligned = []
    
    for sample in sample_tests:
        img_input = sample['input']
        is_output_available = 'output' in sample
        
        sample_aligned = sample.copy()
        sample_aligned['rot90'] = False
        if is_rotation(img_input):
            sample_aligned['input'] = np.rot90(np.array(img_input), k=1).tolist()
            if is_output_available:
                sample_aligned['output'] = np.rot90(np.array(sample_aligned['output']), k=1).tolist()
            sample_aligned['rot90'] = True
            
        sample_test_aligned.append(sample_aligned)
        
    sample_test_aligned_v2 = []
    for sample in sample_test_aligned:
        
        sample_aligned = sample.copy()
        
        sample_aligned['flipud'] = False
        if not is_2images_aligned_updown(img0_aligned, sample_aligned['input']):
            sample_aligned['input'] = np.flipud(np.array(sample_aligned['input'])).tolist()
            if is_output_available:
                sample_aligned['output'] = np.flipud(np.array(sample_aligned['output'])).tolist()
            sample_aligned['flipud'] = True
            
        sample_aligned['fliplr'] = False
        if not is_2images_aligned_leftright(img0_aligned, sample_aligned['input']):
            sample_aligned['input'] = np.fliplr(np.array(sample_aligned['input'])).tolist()
            if is_output_available:
                sample_aligned['output'] = np.fliplr(np.array(sample_aligned['output'])).tolist()
            sample_aligned['fliplr'] = True
            
        sample_test_aligned_v2.append(sample_aligned)
        
    task_aligned['test'] = sample_test_aligned_v2
        
    return task_aligned
        
single_task = train_tasks["db3e9e38"] 
single_task = valid_tasks["103eff5b"] 
# single_task = valid_tasks["05a7bcf2"]
task_aligned = align_task(single_task)
for sample in task_aligned['train']:
#     print(sample['flipud'], sample['rot90'])
    plot_sample(sample)
for sample in task_aligned['test']:
    print(sample['fliplr'], sample['flipud'], sample['rot90'])
    plot_sample(sample)

In [None]:
!mkdir -p test_aligned

test_aligned_path = Path("test_aligned")
test_tasks = get_tasks('test')

for task_id, task in tqdm(test_tasks.items()):
    task_aligned = align_task(task)
    task_filename = '{}.json'.format(task_id)
    
    with open(test_aligned_path / task_filename, 'w') as outfile:
        json.dump(task_aligned, outfile)

In [None]:
paths['test_aligned'] = test_aligned_path
test_aligned_tasks = get_tasks("test_aligned")
print(len(test_aligned_tasks))

<a id="run_yuki_dsl"></a>
# Run @yukikubo123's DSL
[Back to Table of Content](#toc)

In [None]:
""" This file was auto_generated by kernel_generator.py """

from typing import Set
from deap.tools import selNSGA2
from lightgbm import LGBMClassifier
from joblib import delayed
from scipy.ndimage import binary_erosion
from enum import auto
from collections import defaultdict
from scipy.ndimage import maximum_filter
from itertools import groupby
from skimage.measure import label
from sklearn.neural_network import MLPClassifier
from scipy.ndimage import binary_fill_holes
import json
import shutil
from typing import List
from enum import IntEnum
from pandas import DataFrame
from enum import unique
import cv2
import pandas as pd
from copy import deepcopy
from typing import Tuple
from itertools import product
from skimage.filters import try_all_threshold
from pathlib import Path
from heapq import heapify
from scipy.ndimage import generate_binary_structure
from sklearn.linear_model import LogisticRegression
from functools import partial
import copy
from typing import Any
from typing import Optional
from heapq import heappush
from category_encoders import OrdinalEncoder
import numpy as np
from typing import Dict
from tqdm import tqdm
from matplotlib import colors
import time
import random
from heapq import heappushpop
from typing import Iterable
from enum import Enum
import pickle
from matplotlib import pyplot as plt
from joblib import Parallel
from heapq import heappop
from itertools import chain
from dataclasses import asdict
from skimage.filters import threshold_minimum
from sklearn.linear_model import RidgeClassifier
from scipy.ndimage import binary_dilation
import optuna
from dataclasses import dataclass
from typing import Union
from typing import TypeVar
from optuna import Trial
import category_encoders
from sklearn.model_selection import KFold
from operator import itemgetter
# from ruamel import yaml
from collections import Counter


@dataclass
class OperationInconsistencyException(Exception):
    message: str = ''


class Timer:
    def __init__(self):
        pass

    def __enter__(self):
        self.start_sec = time.perf_counter()
        return self

    def second(self):
        return time.perf_counter() - self.start_sec

    def __exit__(self, *exc):
        return


class StrNameEnum(Enum):
    def __str__(self):
        return self.name

    def __repr__(self):
        return str(f'{self.__class__.__name__}.{self.name}')


class StrNameIntEnum(IntEnum):
    def __str__(self):
        return self.name

    def __repr__(self):
        return str(f'{self.__class__.__name__}.{self.name}')


@unique
class RunMode(Enum):
    LOCAL_RUN_ALL = auto()
    LOCAL_RUN = auto()
    TREE_BASE_SEARCH_OPTIMIZATION = auto()
    NODE_BASE_SEARCH_OPTIMIZATION = auto()
    LOCAL_DATA_GENERATION = auto()
    LOCAL_ML_TRAIN = auto()
    TRAIN_OPERATION_ELEMENT_INCLUSION_PREDICTION = auto()
    KERNEL = auto()
    KERNEL_EMULATION = auto()


@unique
class TaskRange(Enum):
    ALL = auto()
    CAN_ANSWER_ONLY = auto()
    EXCLUDE_GIVE_UPS = auto()


@unique
class FlipMode(StrNameEnum):
    UD = auto()
    LR = auto()
    UL_DR = auto()
    UR_DL = auto()


@unique
class EngineSchedulePattern(Enum):
    DRY_RUN = auto()
    HAND_MADE = auto()
    ML = auto()


@unique
class EngineType(Enum):
    NODE_BASED_SEARCH_ENGINE = auto()
    TREE_BASED_SEARCH_ENGINE = auto()


class RunConfig:
    RUN_MODE = RunMode.KERNEL  # Usually, use "LOCAL_RUN" or "KERNEL"
    TASK_RANGE = TaskRange.ALL  # Limit the range to save time.
    ENGINE_TYPE = EngineType.NODE_BASED_SEARCH_ENGINE
    ENGINE_SCHEDULE_PATTERN = EngineSchedulePattern.HAND_MADE
    USE_ML_GUIDE = False  # DeepCoder-like strategy. Calculate the probability of inclusion of each DSL elements.
    RUN_ONLY_PRIVATE_LB = False  # Skip public kernel run to save time.

    _KERNEL_N_JOB = 4
    _LOCAL_N_JOB = 5
    N_JOB = _KERNEL_N_JOB if RUN_MODE == RunMode.KERNEL else _LOCAL_N_JOB


@unique
class DepthSearchPattern(Enum):
    BREADTH_FIRST = auto()
    NORMAL = auto()
    DEPTH_FIRST = auto()


@unique
class TrueOrFalse(StrNameEnum):
    TRUE = auto()
    FALSE = auto()


@unique
class Color(StrNameIntEnum):
    BLACK = 0
    BLUE = 1
    RED = 2
    GREEN = 3
    YELLOW = 4
    GRAY = 5
    MAGENTA = 6
    ORANGE = 7
    SKY = 8
    BROWN = 9
    MASK_TAG = 10  # very special color. TODO unused?

    @classmethod
    def prepare(cls):
        cls.mapping = {c.value: c for c in Color}

    @classmethod
    def of(cls, value: int) -> 'Color':
        try:
            return cls.mapping[value]
        except AttributeError:
            cls.mapping = {c.value: c for c in Color}
            return cls.mapping[value]


@unique
class Direction(StrNameEnum):
    TOP = auto()
    BOTTOM = auto()
    RIGHT = auto()
    LEFT = auto()


@unique
class PaddingMode(StrNameEnum):
    REPEAT = auto()
    MIRROR_1 = auto()  # line-symmetric at the edge
    MIRROR_2 = auto()  # line-symmetric at the edge-pixel-line
    EDGE = auto()


@unique
class Axis(StrNameEnum):
    VERTICAL = auto()
    HORIZONTAL = auto()
    BOTH = auto()


@unique
class MultiColorSelectionMode(StrNameEnum):
    # ANY_WITHOUT_FIXED_COLOR = auto()  # TODO should define?
    ANY_WITHOUT_MOST_COMMON = auto()  # TODO ANY_WITHOUT_TOP2_MOST_COMMON
    ANY_WITHOUT_LEAST_COMMON = auto()


@unique
class MaxOrMin(StrNameEnum):
    MAX = max
    MIN = min

    @property
    def func(self):
        return self.value


@unique
class FillType(StrNameEnum):
    NotOverride = auto()
    Override = auto()


@unique
class LineEdgeType(StrNameEnum):
    EdgeExclude = auto()
    EdgeInclude = auto()


@unique
class ImageEdgeType(StrNameEnum):
    EDGE_EXCLUDE = auto()
    EDGE_INCLUDE = auto()


@unique
class ObjectFeature(StrNameEnum):
    AREA = auto()
    # PERIMETER_LEN = auto() # TODO difficult to implement?
    HORIZONTAL_LEN = auto()
    VERTICAL_LEN = auto()


@unique
class PixelConnectivity(StrNameEnum):
    FOUR_DIRECTION = 1
    EIGHT_DIRECTION = 2

    @property
    def value_for_skimage(self) -> int:
        return self.value

    @property
    def structure_for_skimage(self) -> np.ndarray:
        if self == PixelConnectivity.EIGHT_DIRECTION:
            return generate_binary_structure(2, 2)
        if self == PixelConnectivity.FOUR_DIRECTION:
            return generate_binary_structure(2, 1)

        raise NotImplementedError()


@unique
class HoleInclude(StrNameEnum):
    INCLUDE = auto()
    EXCLUDE = auto()


@unique
class SingleColorSelectionMode(StrNameEnum):
    MOST_COMMON = auto()
    SECOND_MOST_COMMON = auto()
    LEAST_COMMON = auto()


@dataclass(frozen=True)
class ColorSelection:
    def __call__(self, arr: np.ndarray) -> np.ndarray:
        raise NotImplementedError()


@dataclass(frozen=True)
class MaskConversion:
    def __call__(self, mask: np.ndarray) -> np.ndarray:
        raise NotImplementedError()


@dataclass(frozen=True)
class NoMaskConversion(MaskConversion):
    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        return color_mask


@dataclass(frozen=True)
class MaskOperation:
    def __call__(self, arr: np.ndarray, mask: np.ndarray) -> np.ndarray:
        raise NotImplementedError()


@dataclass(frozen=True)
class ColorChannelSelection:
    def __call__(self, arr: np.ndarray) -> List[Tuple[Color, np.ndarray]]:
        raise NotImplementedError()


@dataclass(frozen=True)
class ChannelMergeOperation:
    def __call__(self, arr: np.ndarray, original_color_mask_pairs: List[Tuple[Color, np.ndarray]], color_mask_pairs: List[Tuple[Color, np.ndarray]]) -> np.ndarray:
        raise NotImplementedError()


@dataclass(frozen=True)
class ColorOperation:
    color_selection: ColorSelection
    mask_conversions: MaskConversion
    mask_operation: MaskOperation


@dataclass(frozen=True)
class MultiColorChannelOperation:
    channel_selection: ColorChannelSelection
    mask_conversions: MaskConversion
    channel_merge_operation: ChannelMergeOperation


@dataclass(frozen=True)
class PartitionedArraySelection:
    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]]) -> List[List[bool]]:
        raise NotImplementedError()


@dataclass(frozen=True)
class PartitionOperation:
    partition_selection: 'PartitionSelection'
    # partition_uniform_operation: PartitionUniformOperation # TODO implement
    partition_merge_operation: 'PartitionMergeOperation'


@dataclass(frozen=True)
class PartitionSelection:
    # array -> (2d_partitioned_array, 2d_original_location_mask)
    def __call__(self, arr: np.ndarray) -> Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]:
        raise NotImplementedError()


@dataclass(frozen=True)
class PartitionMergeOperation:
    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]], original_location_masks: List[List[np.ndarray]]) -> np.ndarray:
        raise NotImplementedError()


@dataclass
class DistanceEvaluatorParameter:
    same_h_w_dim_between_input_output: float = 1500
    all_dim_h_w_integer_multiple: float = 650
    mean_lack_color_num: float = 30
    mean_excess_color_num: float = 50
    mean_hit_and_miss_histogram_diff: float = 50
    mean_h_v_diff_input_arr_line_num: float = 40
    mean_h_v_diff_output_arr_line_num: float = 60
    mean_h_v_edge_sum_diff: float = 2
    mean_h_v_edge_sum_diff_ratio: float = 0.5
    mean_diff_color_cell_ratio: int = 1  # 基準
    mean_diff_cell_where_no_need_to_change_count_ratio: float = 100
    mean_wrong_change_cell_where_need_to_change_count_ratio: float = 100


@dataclass
class NodeBaseSearchEngineParameter:
    breadth_first_cost: float = 3500
    normal_first_cost: float = 400
    depth_first_cost: float = 1.2
    breadth_first_exp_cost: float = 0
    normal_exp_cost: float = 0
    depth_first_exp_cost: float = 0
    element_inclusion_prob_factor: float = 0
    pq_pop_mins_or_as_least_n: int = 20


@dataclass
class TreeBaseSearchEngineParameter:
    population_num: int = 26
    max_depth: int = 8
    operation_mutation_prob: float = 0.19
    operation_component_mutation_prob: float = 0.1
    operation_param_mutation_prob: float = 0.0048
    extend_mutation_prob: float = 0.044
    shrink_mutation_prob: float = 0.0012


@dataclass
class AllParameter:
    distance_evaluator_param: DistanceEvaluatorParameter = DistanceEvaluatorParameter()
    node_base_engine_param: Optional[NodeBaseSearchEngineParameter] = NodeBaseSearchEngineParameter()
    tree_base_engine_param: Optional[TreeBaseSearchEngineParameter] = TreeBaseSearchEngineParameter()


@dataclass()
class InputOutput:
    input_arr: np.ndarray
    output_arr: Optional[np.ndarray]

    @staticmethod
    def of(json_dict: dict) -> 'InputOutput':
        return InputOutput(np.array(json_dict['input'], dtype=np.uint8),
                           np.array(json_dict['output'], dtype=np.uint8) if 'output' in json_dict else None)

    def get_all_arr(self) -> List[np.ndarray]:
        if self.output_arr is None:
            return [self.input_arr]
        else:
            return [self.input_arr, self.output_arr]

    def candidate_color_mapping(self) -> List[Tuple[Color, Color]]:
        input_colors = list(np.unique(self.input_arr)) + [Color.ANY_WITHOUT_MOST, Color.MOST, Color.SECOND_MOST, Color.LEAST]
        output_colors = np.unique(self.output_arr)
        return [(Color.of(i), Color.of(o)) for i, o in product(input_colors, output_colors) if i != o]


@dataclass(frozen=True)
class UniformOperation:
    def __call__(self, arr: np.ndarray) -> np.ndarray:
        raise NotImplementedError()


@dataclass(frozen=True)
class OperationSet:
    operations: List[Union[UniformOperation, ColorOperation, MultiColorChannelOperation, PartitionOperation]]

    def __str__(self):
        return repr(self)

    def types(self):
        results = []
        for o in self.operations:
            if isinstance(o, UniformOperation):
                results.append(UniformOperation)
            elif isinstance(o, ColorOperation):
                results.append(ColorOperation)
            elif isinstance(o, MultiColorChannelOperation):
                results.append(MultiColorChannelOperation)
            elif isinstance(o, PartitionOperation):
                results.append(PartitionOperation)
            else:
                raise NotImplementedError()
        return results

    def elements(self) -> List[Union[UniformOperation, ColorSelection, MaskConversion, MaskOperation, PartitionOperation]]:
        res = []
        for o in self.operations:
            if isinstance(o, UniformOperation):
                res.append(o)
            elif isinstance(o, ColorOperation):
                res.append(o.color_selection)
                res.append(o.mask_conversions)
                res.append(o.mask_operation)
            elif isinstance(o, MultiColorChannelOperation):
                res.append(o.channel_selection)
                res.append(o.mask_conversions)
                res.append(o.channel_merge_operation)
            elif isinstance(o, PartitionOperation):
                res.append(o.partition_selection)
                res.append(o.partition_merge_operation)
            else:
                raise NotImplementedError()
        return res


@dataclass(frozen=True)
class Task:
    name: str
    train: Tuple[InputOutput]
    test: Tuple[InputOutput]

    @staticmethod
    def of(name: str, json_dict: dict) -> 'Task':
        return Task(name,
                    tuple(InputOutput.of(io) for io in json_dict['train']),
                    tuple(InputOutput.of(io) for io in json_dict['test']))

    def get_all_arr(self) -> List[np.ndarray]:
        return self.get_train_all_arr() + self.get_test_all_arr()

    def get_train_all_arr(self) -> List[np.ndarray]:
        return list(chain.from_iterable(map(lambda io: io.get_all_arr(), self.train)))

    def get_test_all_arr(self) -> List[np.ndarray]:
        return list(chain.from_iterable(map(lambda io: io.get_all_arr(), self.test)))

    def get_input_all_arr(self) -> List[np.ndarray]:
        return list(map(lambda io: io.input_arr, self.train + self.test))

    def get_output_all_arr(self) -> List[np.ndarray]:
        return list(filter(lambda arr: arr is not None, map(lambda io: io.output_arr, self.train + self.test)))

    def test_arr_hash(self) -> int:
        return hash(self.__class__.__name__ +
                    '_'.join(map(lambda io: str(io.input_arr), self.test)))


@dataclass(frozen=True)
class ColorSelectedTask(Task):
    train_masks: List[np.ndarray]
    test_masks: List[np.ndarray]


@dataclass(frozen=True)
class MaskConvertedTask(Task):
    train_masks: List[np.ndarray]
    test_masks: List[np.ndarray]


@dataclass(frozen=True)
class ColorChannelSelectedTask(Task):
    train_color_mask_pairs: List[List[Tuple[Color, np.ndarray]]]
    test_color_mask_pairs: List[List[Tuple[Color, np.ndarray]]]


@dataclass(frozen=True)
class ColorChannelMaskConvertedTask(Task):
    train_original_color_mask_pairs: List[List[Tuple[Color, np.ndarray]]]
    train_color_mask_pairs: List[List[Tuple[Color, np.ndarray]]]
    test_original_color_mask_pairs: List[List[Tuple[Color, np.ndarray]]]
    test_color_mask_pairs: List[List[Tuple[Color, np.ndarray]]]


@dataclass(frozen=True)
class PartitionSelectionTask(Task):
    train_partitioned_arrays_original_location_masks: List[Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]]
    test_partitioned_arrays_original_location_masks: List[Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]]


@dataclass
class ImageFeature:
    height: int
    width: int
    colors: List[Color]
    hit_and_miss_histogram: List[int]
    # most_common_color: Color
    vertical_edge_num: int
    horizontal_edge_num: int


@dataclass
class ImageDiffFeature:
    input_image_feature: ImageFeature  # TODO should not define here?
    output_image_feature: ImageFeature  # TODO should not define here?
    dim_height_increase: int
    dim_width_increase: int
    dim_height_integer_multiple: bool
    dim_width_integer_multiple: bool
    dim_height_diff: int
    dim_width_diff: int
    dim_height_equal: bool
    dim_width_equal: bool
    lack_color_num: int
    excess_color_num: int
    hit_and_miss_histogram_diff: int
    # vertical_diff_input_arr_line_num: Optional[int]
    # horizontal_diff_input_arr_line_num: Optional[int]
    # vertical_diff_output_arr_line_num: Optional[int]
    # horizontal_diff_output_arr_line_num: Optional[int]
    vertical_edge_sum_diff: int
    horizontal_edge_sum_diff: int
    vertical_edge_sum_diff_ratio: float
    horizontal_edge_sum_diff_ratio: float
    diff_color_cell_ratio: Optional[float]  # None if different image size.
    diff_cell_where_no_need_to_change_count_ratio: Optional[float]  # None if different image size.
    wrong_change_cell_where_need_to_change_count_ratio: Optional[float]  # None if different image size.

    # TODO cell_diff_num_except_formost_common_color

    def same_dim(self) -> bool:
        return self.dim_height_equal and self.dim_width_equal


@dataclass
class TaskFeature:
    # image_diff_features: List[ImageDiffFeature]
    same_dim_between_input_output: bool
    same_height_dim_between_input_output: bool
    same_width_dim_between_input_output: bool
    all_dim_height_increased: bool
    all_dim_height_decreased: bool
    all_dim_width_increased: bool
    all_dim_width_decreased: bool
    all_dim_height_integer_multiple: bool
    all_dim_width_integer_multiple: bool
    mean_lack_color_num: float
    mean_excess_color_num: float
    mean_hit_and_miss_histogram_diff: float
    # mean_vertical_diff_input_arr_line_num: Optional[float]
    # mean_horizontal_diff_input_arr_line_num: Optional[float]
    # mean_vertical_diff_output_arr_line_num: Optional[float]
    # mean_horizontal_diff_output_arr_line_num: Optional[float]
    mean_vertical_edge_sum_diff: float
    mean_horizontal_edge_sum_diff: float
    mean_vertical_edge_sum_diff_ratio: float
    mean_horizontal_edge_sum_diff_ratio: float
    mean_diff_color_cell_ratio: Optional[float]  # None if different image size.
    mean_diff_cell_where_no_need_to_change_count_ratio: Optional[float]  # None if different image size.
    mean_wrong_change_cell_where_need_to_change_count_ratio: Optional[float]


@dataclass
class ColorSelectedTaskFeature:
    task_feature: TaskFeature


@dataclass
class MaskConvertedTaskFeature:
    task_feature: TaskFeature
    possible_improve_ratios: List[Optional[float]]


@dataclass
class DistanceEvaluator:
    dist_eval_param: DistanceEvaluatorParameter

    def evaluate_task_feature(self, task_feature: TaskFeature) -> float:
        return 0 \
               + self.dist_eval_param.same_h_w_dim_between_input_output * (0 if task_feature.same_height_dim_between_input_output else 1) \
               + self.dist_eval_param.same_h_w_dim_between_input_output * (0 if task_feature.same_width_dim_between_input_output else 1) \
               + self.dist_eval_param.all_dim_h_w_integer_multiple * (0 if task_feature.all_dim_height_integer_multiple else 1) \
               + self.dist_eval_param.all_dim_h_w_integer_multiple * (0 if task_feature.all_dim_width_integer_multiple else 1) \
               + self.dist_eval_param.mean_lack_color_num * task_feature.mean_lack_color_num \
               + self.dist_eval_param.mean_excess_color_num * task_feature.mean_excess_color_num \
               + self.dist_eval_param.mean_hit_and_miss_histogram_diff * task_feature.mean_hit_and_miss_histogram_diff \
               + self.dist_eval_param.mean_h_v_edge_sum_diff * (task_feature.mean_vertical_edge_sum_diff) \
               + self.dist_eval_param.mean_h_v_edge_sum_diff * (task_feature.mean_horizontal_edge_sum_diff) \
               + self.dist_eval_param.mean_h_v_edge_sum_diff_ratio * (task_feature.mean_vertical_edge_sum_diff_ratio) \
               + self.dist_eval_param.mean_h_v_edge_sum_diff_ratio * (task_feature.mean_horizontal_edge_sum_diff_ratio) \
               + self.dist_eval_param.mean_diff_color_cell_ratio * (task_feature.mean_diff_color_cell_ratio or 0) \
               + self.dist_eval_param.mean_diff_cell_where_no_need_to_change_count_ratio * (task_feature.mean_diff_cell_where_no_need_to_change_count_ratio or 0) \
               + self.dist_eval_param.mean_wrong_change_cell_where_need_to_change_count_ratio * (task_feature.mean_wrong_change_cell_where_need_to_change_count_ratio or 0)

    # + self.dist_eval_param.mean_h_v_diff_input_arr_line_num * (task_feature.mean_horizontal_diff_input_arr_line_num or 0) \
    # + self.dist_eval_param.mean_h_v_diff_input_arr_line_num * (task_feature.mean_vertical_diff_input_arr_line_num or 0) \
    # + self.dist_eval_param.mean_h_v_diff_output_arr_line_num * (task_feature.mean_horizontal_diff_output_arr_line_num or 0) \
    # + self.dist_eval_param.mean_h_v_diff_output_arr_line_num * (task_feature.mean_vertical_diff_output_arr_line_num or 0) \

    def evaluate_task_feature_element(self, task_feature: TaskFeature) -> List[float]:
        return [(0 if task_feature.same_height_dim_between_input_output else 1),
                (0 if task_feature.same_width_dim_between_input_output else 1),
                (0 if task_feature.all_dim_height_integer_multiple else 1),
                (0 if task_feature.all_dim_width_integer_multiple else 1),
                task_feature.mean_lack_color_num,
                task_feature.mean_excess_color_num,
                task_feature.mean_hit_and_miss_histogram_diff,
                (task_feature.mean_vertical_edge_sum_diff),
                (task_feature.mean_horizontal_edge_sum_diff),
                (task_feature.mean_vertical_edge_sum_diff_ratio),
                (task_feature.mean_horizontal_edge_sum_diff_ratio),
                (task_feature.mean_diff_color_cell_ratio or 0)]


class Node:

    def __repr__(self):
        return str(self)


@dataclass
class WaitingNode(Node):
    # This node will be added to priority queue.

    parent_completed_node: 'CompletedNode'
    cache_pred_distance = None

    def evaluation_features(self) -> Dict[str, Any]:
        raise NotImplementedError()

    def depth(self) -> int:
        raise NotImplementedError()

    def __le__(self, other: 'WaitingNode') -> bool:
        return self.cache_pred_distance <= other.cache_pred_distance

    def __lt__(self, other: 'WaitingNode') -> bool:
        return self.cache_pred_distance < other.cache_pred_distance


@dataclass
class UniformOperationWaitingNode(WaitingNode):
    original_task: Task
    task: Task
    task_feature: TaskFeature
    base_operation_set: OperationSet
    next_operation: UniformOperation

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}+1, class: {self.__class__.__name__}, ' \
               f'ope_set: {self.base_operation_set}, next_ope: {self.next_operation}'

    def depth(self) -> int:
        return len(self.base_operation_set.operations)

    def evaluation_features(self) -> Dict[str, Any]:
        return {
            'node_class': self.__class__.__name__,
            'depth': len(self.base_operation_set.operations),
            **asdict(self.task_feature),
            'next_operation': self.next_operation.__class__.__name__,
            **asdict(self.next_operation)
        }


@dataclass
class ColorSelectionWaitingNode(WaitingNode):
    original_task: Task
    task: Task
    task_feature: TaskFeature
    base_operation_set: OperationSet
    next_selection: ColorSelection

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}+1, class: {self.__class__.__name__}, ' \
               f'ope_set: {self.base_operation_set}, next_selection: {self.next_selection}'

    def depth(self) -> int:
        return len(self.base_operation_set.operations)

    def evaluation_features(self) -> Dict[str, Any]:
        return {
            'node_class': self.__class__.__name__,
            'depth': len(self.base_operation_set.operations),
            **asdict(self.task_feature),
            'next_selection': self.next_selection.__class__.__name__,
            **asdict(self.next_selection)
        }


@dataclass
class MaskConversionWaitingNode(WaitingNode):
    original_task: Task
    color_selected_task: ColorSelectedTask
    color_selected_task_feature: ColorSelectedTaskFeature
    base_operation_set: OperationSet
    color_selection: ColorSelection
    next_mask_conversion: MaskConversion

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}+1, class: {self.__class__.__name__}, ' \
               f'ope_set: {self.base_operation_set}, color_selection: {self.color_selection}, next_add_selection: {self.next_mask_conversion}'

    def depth(self) -> int:
        return len(self.base_operation_set.operations)

    def evaluation_features(self) -> Dict[str, Any]:
        return {
            'node_class': self.__class__.__name__,
            'depth': len(self.base_operation_set.operations),
            **asdict(self.color_selected_task_feature.task_feature),
            'next_mask_conversion': self.next_mask_conversion.__class__.__name__,
            **asdict(self.next_mask_conversion)
        }


@dataclass
class MaskOperationSelectionWaitingNode(WaitingNode):
    original_task: Task
    mask_converted_task: MaskConvertedTask
    mask_converted_task_feature: MaskConvertedTaskFeature
    base_operation_set: OperationSet
    color_selection: ColorSelection
    mask_conversion: MaskConversion
    next_mask_operation: MaskOperation

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}+1, class: {self.__class__.__name__}, ' \
               f'ope_set: {self.base_operation_set}, color_selection: {self.color_selection}, add_selection: {self.mask_conversion}, next_mask_ope: {self.next_mask_operation}'

    def depth(self) -> int:
        return len(self.base_operation_set.operations)

    def evaluation_features(self) -> Dict[str, Any]:
        return {
            'node_class': self.__class__.__name__,
            'depth': len(self.base_operation_set.operations),
            **asdict(self.mask_converted_task_feature.task_feature),
            'next_mask_operation': self.next_mask_operation.__class__.__name__,
            **asdict(self.next_mask_operation)
        }


@dataclass
class ColorChannelSelectionOperationWaitingNode(WaitingNode):
    original_task: Task
    task: Task
    task_feature: TaskFeature
    base_operation_set: OperationSet
    next_color_channel_selection: ColorChannelSelection

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}+1, class: {self.__class__.__name__}, ' \
               f'ope_set: {self.base_operation_set}, next_color_channeling: {self.next_color_channel_selection}'

    def depth(self) -> int:
        return len(self.base_operation_set.operations)

    def evaluation_features(self) -> Dict[str, Any]:
        return {
            'node_class': self.__class__.__name__,
            'depth': len(self.base_operation_set.operations),
            **asdict(self.task_feature),
            'next_operation': self.next_color_channel_selection.__class__.__name__,
            **asdict(self.next_color_channel_selection)
        }


@dataclass
class ColorChannelMaskConversionWaitingNode(WaitingNode):
    original_task: Task
    task: ColorChannelSelectedTask
    task_feature: TaskFeature
    base_operation_set: OperationSet
    color_channel_selection: ColorChannelSelection
    next_mask_conversion: MaskConversion

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}+1, class: {self.__class__.__name__}, ' \
               f'ope_set: {self.base_operation_set}, color_channel_selection: {self.color_channel_selection}, next_mask_conversion: {self.next_mask_conversion}, next_mask_ope: {self.next_mask_conversion}'

    def depth(self) -> int:
        return len(self.base_operation_set.operations)


@dataclass
class ColorChannelMergeWaitingNode(WaitingNode):
    original_task: Task
    task: ColorChannelMaskConvertedTask
    task_feature: TaskFeature
    base_operation_set: OperationSet
    color_channel_selection: ColorChannelSelection
    mask_conversion: MaskConversion
    next_merge_operation: ChannelMergeOperation

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}+1, class: {self.__class__.__name__}, ' \
               f'ope_set: {self.base_operation_set}, color_channel_selection: {self.color_channel_selection}, mask_conversion: {self.mask_conversion}, next_merge_operation: {self.next_merge_operation}'

    def depth(self) -> int:
        return len(self.base_operation_set.operations)


@dataclass
class PartitionSelectionWaitingNode(WaitingNode):
    original_task: Task
    task: Task
    task_feature: TaskFeature
    base_operation_set: OperationSet
    next_partition_selection: PartitionSelection

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}+1, class: {self.__class__.__name__}, ' \
               f'ope_set: {self.base_operation_set}, next_partition_sel: {self.next_partition_selection}'

    def depth(self) -> int:
        return len(self.base_operation_set.operations)


@dataclass
class PartitionMergeWaitingNode(WaitingNode):
    original_task: Task
    task: PartitionSelectionTask
    task_feature: TaskFeature
    base_operation_set: OperationSet
    partition_selection: PartitionSelection
    next_partition_merge_operation: PartitionMergeOperation

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}+1, class: {self.__class__.__name__}, ' \
               f'ope_set: {self.base_operation_set}, partition_sel: {self.partition_selection}, partition_merge: {self.next_partition_merge_operation}'

    def depth(self) -> int:
        return len(self.base_operation_set.operations)


@dataclass()
class CompletedNode(Node):
    # This node won't be added to priority queue. This is processed immediately and converted to next List[WaitingNode].

    parent_waiting_node: 'WaitingNode'

    def train_arr_hash(self) -> int:
        raise NotImplementedError()

    def all_arr_hash(self) -> int:
        raise NotImplementedError()


@dataclass
class UniformOperationCompletedNode(CompletedNode):
    original_task: Task
    task: Task
    task_feature: TaskFeature
    base_operation_set: OperationSet

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}, class: {self.__class__.__name__}, ope_set: {self.base_operation_set}'

    def train_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(map(lambda io: np_to_str(io.input_arr), self.task.train)))

    def all_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(map(lambda io: np_to_str(io.input_arr), self.task.train + self.task.test)))


@dataclass
class ColorSelectionCompletedNode(CompletedNode):
    original_task: Task
    color_selected_task: ColorSelectedTask
    color_selected_task_feature: ColorSelectedTaskFeature
    base_operation_set: OperationSet
    color_selection: ColorSelection

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}, class: {self.__class__.__name__}, ' \
               f'base_ope: {self.base_operation_set}, color_sele: {self.color_selection}'

    def train_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(chain(
                        map(lambda io: np_to_str(io.input_arr), self.color_selected_task.train),
                        map(lambda t: np_to_str(t), self.color_selected_task.train_masks)
                    )))

    def all_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(chain(
                        map(lambda io: np_to_str(io.input_arr), self.color_selected_task.train + self.color_selected_task.test),
                        map(lambda t: np_to_str(t), self.color_selected_task.train_masks + self.color_selected_task.test_masks)
                    )))


@dataclass
class MaskConversionCompletedNode(CompletedNode):
    original_task: Task
    mask_converted_task: MaskConvertedTask
    mask_converted_task_feature: MaskConvertedTaskFeature
    base_operation_set: OperationSet
    color_selection: ColorSelection
    mask_conversion: MaskConversion

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}, class: {self.__class__.__name__}, ' \
               f'base_ope: {self.base_operation_set}, color_sele: {self.color_selection}, add_sele: {self.mask_conversion}'

    def train_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(chain(
                        map(lambda io: np_to_str(io.input_arr), self.mask_converted_task.train),
                        map(lambda t: np_to_str(t), self.mask_converted_task.train_masks)
                    )))

    def all_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(chain(
                        map(lambda io: np_to_str(io.input_arr), self.mask_converted_task.train + self.mask_converted_task.test),
                        map(lambda t: np_to_str(t), self.mask_converted_task.train_masks + self.mask_converted_task.test_masks)
                    )))


@dataclass
class ColorChannelSelectionCompletedNode(CompletedNode):
    original_task: Task
    task: ColorChannelSelectedTask
    feature: TaskFeature
    base_operation_set: OperationSet
    color_channel_selection: ColorChannelSelection

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}, class: {self.__class__.__name__}, ' \
               f'base_ope: {self.base_operation_set}, color_sele: {self.color_channel_selection}'

    def train_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(chain(
                        map(lambda io: np_to_str(io.input_arr), self.task.train),
                        chain.from_iterable([(to_bytes(c), np_to_str(m)) for p_l in self.task.train_color_mask_pairs for c, m in p_l]),
                    )))

    def all_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(chain(
                        map(lambda io: np_to_str(io.input_arr), self.task.test),
                        chain.from_iterable([(to_bytes(c), np_to_str(m)) for p_l in self.task.train_color_mask_pairs + self.task.test_color_mask_pairs for c, m in p_l]),
                    )))


@dataclass
class ColorChannelMaskConversionCompletedNode(CompletedNode):
    original_task: Task
    task: ColorChannelMaskConvertedTask
    feature: TaskFeature
    base_operation_set: OperationSet
    color_selection: ColorChannelSelection
    mask_conversion: MaskConversion

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}, class: {self.__class__.__name__}, ' \
               f'base_ope: {self.base_operation_set}, mask_conversion: {self.mask_conversion}'

    def train_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(chain(
                        map(lambda io: np_to_str(io.input_arr), self.task.train),
                        chain.from_iterable([(to_bytes(c), np_to_str(m)) for p_l in self.task.train_original_color_mask_pairs for c, m in p_l]),
                        chain.from_iterable([(to_bytes(c), np_to_str(m)) for p_l in self.task.train_color_mask_pairs for c, m in p_l]),
                    )))

    def all_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(chain(
                        map(lambda io: np_to_str(io.input_arr), self.task.test),
                        chain.from_iterable([(to_bytes(c), np_to_str(m)) for p_l in self.task.train_original_color_mask_pairs + self.task.test_original_color_mask_pairs for c, m in p_l]),
                        chain.from_iterable([(to_bytes(c), np_to_str(m)) for p_l in self.task.train_color_mask_pairs + self.task.test_color_mask_pairs for c, m in p_l]),
                    )))


@dataclass
class PartitionSelectionCompletedNode(CompletedNode):
    original_task: Task
    task: PartitionSelectionTask
    feature: TaskFeature
    base_operation_set: OperationSet
    partition_selection: PartitionSelection

    def __str__(self):
        return f'depth: {len(self.base_operation_set.operations)}, class: {self.__class__.__name__}, ' \
               f'base_ope: {self.base_operation_set}, partition_selection: {self.partition_selection}'

    def train_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(chain(
                        map(lambda io: np_to_str(io.input_arr), self.task.train),
                        [to_bytes(v) for v in self.task.train_partitioned_arrays_original_location_masks]),
                    ))

    def all_arr_hash(self) -> int:
        return hash(bytes(self.__class__.__name__, encoding='utf-8') +
                    b'_'.join(chain(
                        map(lambda io: np_to_str(io.input_arr), self.task.test),
                        [to_bytes(v) for v in self.task.train_partitioned_arrays_original_location_masks + self.task.test_partitioned_arrays_original_location_masks]),
                    ))


@dataclass(frozen=True)
class NodeTree:
    completed_nodes: List[CompletedNode]

    def __str__(self):
        return '\n'.join(map(str, self.completed_nodes))

    @classmethod
    def of(cls, completed_node: CompletedNode) -> 'NodeTree':
        completed_nodes = []

        current_node = completed_node
        while True:
            if isinstance(current_node, CompletedNode):
                completed_nodes.append(current_node)
                current_node = current_node.parent_waiting_node
            elif isinstance(current_node, WaitingNode):
                current_node = current_node.parent_completed_node
            elif current_node is None:
                # TODO root node
                break
            else:
                raise NotImplementedError()
        return cls(list(reversed(completed_nodes)))

    @classmethod
    def replaced_new_node_tree(cls, node_tree: 'NodeTree', node_depth: int, node: CompletedNode) -> 'NodeTree':
        copied_list = copy.copy(node_tree.completed_nodes)
        copied_list[node_depth] = node
        return cls(copied_list)

    def to_operation_set(self) -> OperationSet:
        # TODO found a bug related to MultiColorChannelOperation.
        try:
            operations = []
            temp_color_selection = None
            temp_mask_conversion = None
            temp_color_channel_selection = None
            temp_partition_selection = None

            # TODO too dirty.

            assert len(self.completed_nodes[0].base_operation_set.operations) == 0, self.completed_nodes[0]
            for n in self.completed_nodes[1:]:  # first element is root.
                if isinstance(n, UniformOperationCompletedNode):
                    if isinstance(n.base_operation_set.operations[-1], UniformOperation):
                        operations.append(n.base_operation_set.operations[-1])
                    else:
                        if temp_color_selection is not None:
                            operations.append(ColorOperation(temp_color_selection, temp_mask_conversion, n.base_operation_set.operations[-1].mask_operation))
                            temp_color_selection = None
                            temp_mask_conversion = None
                            temp_color_channel_selection = None
                            temp_partition_selection = None
                        elif temp_color_channel_selection is not None:
                            operations.append(MultiColorChannelOperation(temp_color_channel_selection, temp_mask_conversion, n.base_operation_set.operations[-1].channel_merge_operation))
                            temp_color_selection = None
                            temp_mask_conversion = None
                            temp_color_channel_selection = None
                            temp_partition_selection = None
                        elif temp_partition_selection is not None:
                            operations.append(PartitionOperation(temp_partition_selection, n.base_operation_set.operations[-1].partition_merge_operation))
                            temp_color_selection = None
                            temp_mask_conversion = None
                            temp_color_channel_selection = None
                            temp_partition_selection = None
                        else:
                            raise NotImplementedError()
                elif isinstance(n, ColorSelectionCompletedNode):
                    temp_color_selection = n.color_selection
                elif isinstance(n, MaskConversionCompletedNode):
                    temp_mask_conversion = n.mask_conversion
                elif isinstance(n, ColorChannelMaskConversionCompletedNode):
                    temp_mask_conversion = n.mask_conversion
                elif isinstance(n, ColorChannelSelectionCompletedNode):
                    temp_color_channel_selection = n.color_channel_selection
                elif isinstance(n, PartitionSelectionCompletedNode):
                    temp_partition_selection = n.partition_selection
                else:
                    raise ValueError()

            return OperationSet(operations)
        except Exception as e:
            print(f'error: {e}')
        return OperationSet([])

    def waiting_nodes(self) -> List[WaitingNode]:
        return list(filter(lambda n: n is not None, map(lambda n: n.parent_waiting_node, self.completed_nodes)))


class NodeEvaluator:

    def evaluate(self, node: WaitingNode):
        raise NotImplementedError()

    def evaluate_nodes(self, nodes: List[WaitingNode]):
        raise NotImplementedError()


class RandomNodeEvaluator(NodeEvaluator):

    def evaluate(self, node: WaitingNode):
        node.cache_pred_distance = random.uniform(0, 1) * node.depth()

    def evaluate_nodes(self, nodes: List[WaitingNode]):
        for n in nodes:
            self.evaluate(n)


@dataclass
class AnswerStorageElement:
    task_name: str
    correct: bool
    depth: int
    operation_set: OperationSet

    def __post_init__(self):
        self.depth = len(self.operation_set.operations)

    def validate(self):
        task = TaskLoader().get_task(self.task_name)
        try:
            a = AnswerMatcher.is_train_test_all_match_if_operated(task, self.operation_set)
            if a != self.correct:
                print(f'{self.task_name} correct inconsistency. {self.correct}_{a}')
                return False
        except OperationInconsistencyException as e:
            print(f'{self.task_name} OperationInconsistencyException')
            return False

        return True

    def __hash__(self):
        return hash(repr(self))


@dataclass
class AnsweredSearchResult:
    operation_set: OperationSet
    test_output_arr: Tuple[np.ndarray] = None
    test_correct: Optional[bool] = None


@dataclass
class AnsweredSearchResults:
    task: Task
    results: List[AnsweredSearchResult]
    zero_depth_search_time: float
    spent_time: float
    searched_total_node: int

    def summary(self):
        summary_elements = [f'{self.task.name}_{i}, '
                            f'correct: {str(r.test_correct):>5}, '
                            f'node: {self.searched_total_node:>6}, ' \
                            f'zero_depth_sec: {int(self.zero_depth_search_time):>5}, sec: {int(self.spent_time):>5}, '
                            f'depth: {len(r.operation_set.operations)}, operation_set: {r.operation_set}'
                            for i, r in enumerate(self.results)]

        return '\n'.join(summary_elements)

    def final_test_correct(self):
        return any(map(lambda r: r.test_correct, self.results))

    def to_answer_storage_elements(self) -> List[AnswerStorageElement]:
        return [AnswerStorageElement(self.task.name, r.test_correct, len(r.operation_set.operations), r.operation_set) for r in self.results]


@dataclass
class NotAnsweredSearchResult:
    task: Task
    exception: Exception
    spent_time: float
    searched_total_node: int

    def final_test_correct(self):
        return None

    def summary(self):
        return f'{self.task.name}__, ' \
               f'correct:  None, ' \
               f'node: {self.searched_total_node:>6}, sec: {int(self.spent_time):>5}, ' \
               f'exception: {self.exception.__class__.__name__}'


@dataclass
class AnswerStorage:
    elements: Set[AnswerStorageElement]

    def validate(self):
        self.elements = set(filter(lambda e: e.validate(), self.elements))

    def add(self, element: AnswerStorageElement):
        self.elements.add(element)

    def get_text(self) -> str:
        return '\n'.join(repr(e) for e in sorted(self.elements, key=lambda e: (e.task_name, not e.correct, e.depth)))

    def get_only_correct_answer_storage(self) -> 'AnswerStorage':
        return AnswerStorage({e for e in self.elements if e.correct})

    def get_task_grouped_elements(self) -> List[Tuple[str, List[AnswerStorageElement]]]:
        elements = list(self.elements)
        elements = sorted(elements, key=lambda e: e.task_name)
        return [(k, list(g)) for k, g in groupby(elements, key=lambda e: e.task_name)]


def load_answer_storage() -> AnswerStorage:
    if not PathConfig.OPERATION_ANSWER_STORAGE.exists():
        return AnswerStorage(set())

    elements: List[AnswerStorageElement] = []
    with open(str(PathConfig.OPERATION_ANSWER_STORAGE), mode='r', encoding='utf-8') as f:
        for l in f.readlines():
            try:
                elements.append(str_to_AnswerStorageElement(l))
            except:
                pass
    storage = AnswerStorage(set(elements))
    storage.validate()
    return storage


def save_answer_storage(storage: AnswerStorage):
    PathConfig.OPERATION_ANSWER_STORAGE.unlink()
    with open(str(PathConfig.OPERATION_ANSWER_STORAGE), mode='w', encoding='utf-8') as f:
        f.write(storage.get_text())


def update_answer_storage(elements: List[AnswerStorageElement], verbose: bool = False):
    if verbose:
        print('load_answer storage')
    storage = load_answer_storage()
    if verbose:
        print(storage.get_text())

        print('add answer storage')
    for e in elements:
        e.validate()
        storage.add(e)
    if verbose:
        print('save answer storage')
        print(storage.get_text())
    save_answer_storage(storage)


@dataclass
class AnswerFoundException(Exception):
    operation_set: OperationSet


class NoImprovementException(Exception):
    MESSAGE = 'No improve'


class MaxDepthExceededException(Exception):
    MESSAGE = 'Max depth'


class MaxNodeExceededException(Exception):
    MESSAGE = 'Max node'


class TimeoutException(Exception):
    MESSAGE = 'Timeout'


def get_all_operation_classes():
    return [UniformOperation, ColorOperation, MultiColorChannelOperation, PartitionOperation]


def get_all_operation_element_classes():
    classes = [UniformOperation, ColorSelection, MaskConversion, MaskOperation, ColorChannelSelection, ChannelMergeOperation, PartitionSelection, PartitionMergeOperation]
    return chain.from_iterable([c.__subclasses__() for c in classes])


@unique
class BackGroundColorSelectionMode(StrNameEnum):
    BLACK = auto()
    MOST_COMMON = auto()


@unique
class AxisV2(StrNameEnum):
    VERTICAL = auto()
    HORIZONTAL = auto()
    VERTICAL_HORIZONTAL = auto()
    MAIN_DIAGONAL = auto()
    ANTI_DIAGONAL = auto()
    BOTH_DIAGONAL = auto()


@unique
class Corner(StrNameEnum):
    TOP_LEFT = auto()
    TOP_RIGHT = auto()
    BOTTOM_RIGHT = auto()
    BOTTOM_LEFT = auto()


@unique
class SpiralDirection(StrNameEnum):
    CLOCKWISE = auto()
    ANTICLOCKWISE = auto()


class DebugConfig:
    OPERATION_DEBUG_TASK_NAME = ''  # dae9d2b5
    OPERATION_DEBUG_OPERATION_SET = ''
    # solve debug
    SOLVE_DEBUG_TASK_NAME = ''  # dae9d2b5

    # train_data_generator debug
    TRAIN_DATA_GENERATION_DEBUG_TASK_NAME = ''


class PathConfig:
    ROOT: Path = Path('') if RunConfig.RUN_MODE == RunMode.KERNEL else Path(__file__).parent

    # input
    INPUT_ROOT: Path = ROOT / 'input'
    TRAIN_ROOT: Path = INPUT_ROOT / 'training'  # training_and_evaluation
    EVALUATION_ROOT: Path = INPUT_ROOT / 'evaluation'
    TEST_ROOT: Path = INPUT_ROOT / 'test'
    SAMPLE_SUBMISSION: Path = INPUT_ROOT / 'sample_submission.csv'

    # output
    OUTPUT_SUBMISSION: Path = ROOT / 'output' / 'submission.csv'

    # answer_memo
    OPERATION_ANSWER_MEMO_ROOT: Path = ROOT / 'answer_memo'
    OPERATION_ANSWER_TAXONOMY_YAML: Path = OPERATION_ANSWER_MEMO_ROOT / 'answer_taxonomy.yaml'
    OPERATION_ANSWER_TAXONOMY_IMAGE_ROOT: Path = OPERATION_ANSWER_MEMO_ROOT / 'answer_taxonomy'
    OPERATION_ANSWER_STORAGE: Path = OPERATION_ANSWER_MEMO_ROOT / 'answer_storage.txt'
    WRONG_ANSWERS_ROOT: Path = OPERATION_ANSWER_MEMO_ROOT / 'wrong_answers'

    # kernel
    KERNEL_SCRIPT_PATH: Path = ROOT / 'kernel' / 'kernel_script.py'

    # run
    LOG_ROOT: Path = ROOT / 'log'

    # ml_model
    SAVED_MODEL: Path = ROOT / 'saved_model'
    NODE_EVALUATOR_FEATURES = SAVED_MODEL / 'features.pkl'
    NODE_EVALUATOR_CATEGORICAL_FEATURES = SAVED_MODEL / 'categorical_features.pkl'
    NODE_EVALUATOR_MODEL = SAVED_MODEL / 'model.pkl'
    NODE_EVALUATOR_ORDINAL_ENCODER = SAVED_MODEL / 'ordinal_encoder.pkl'
    NODE_EVALUATOR_SAMPLE_DF = SAVED_MODEL / 'sample_df.pkl'

    OPERATION_ELEMENT_INCLUSION_MODEL_ROOT = SAVED_MODEL / 'operation_element_inclusion'
    OPERATION_ELEMENT_INCLUSION_MODEL = OPERATION_ELEMENT_INCLUSION_MODEL_ROOT / 'model.pkl'
    OPERATION_ELEMENT_INCLUSION_MODEL_TARGET_COLUMNS = OPERATION_ELEMENT_INCLUSION_MODEL_ROOT / 'target_columns.pkl'
    OPERATION_ELEMENT_INCLUSION_MODEL_FEATURE_COLUMNS = OPERATION_ELEMENT_INCLUSION_MODEL_ROOT / 'feature_columns.pkl'

    # ml_training_data
    LABELED_TRAINING_DATA_ROOT = ROOT / 'training'


class KernelPathConfig:
    INPUT_ROOT = Path('/kaggle/input/abstraction-and-reasoning-challenge/')
    TRAIN_ROOT: Path = INPUT_ROOT / 'training'
    EVALUATION_ROOT: Path = INPUT_ROOT / 'evaluation'
#     TEST_ROOT: Path = INPUT_ROOT / 'test'
    TEST_ROOT: Path = Path('test_aligned')
    SAMPLE_SUBMISSION: Path = INPUT_ROOT / 'sample_submission.csv'
    SUBMISSION = 'submission_yuki_alignment.csv'


def create_submission(engine_results: List[Union[AnsweredSearchResults, NotAnsweredSearchResult]]):
    submission_df = DataFrame(columns=['output_id', 'output'])

    for result in engine_results:
        test_arr_num = len(result.task.test)
        for i in range(test_arr_num):
            if isinstance(result, AnsweredSearchResults):
                answers = [r.test_output_arr[i] for r in result.results]
                answers += [None for _ in range(3 - len(answers))]
            elif isinstance(result, NotAnsweredSearchResult):
                answers = [None] * 3
            else:
                raise NotImplementedError()
            output_str = ' '.join(map(lambda a: parse_str(a), answers)) + ' '
            d = {
                'output_id': f'{result.task.name}_{i}',
                'output': output_str
            }
            submission_df = submission_df.append([d])
    return submission_df


def parse_str(arr: np.ndarray) -> str:
    if arr is None:
        return '|0|'

    return '|' + '|'.join(map(lambda row: ''.join(str(v) for v in row), arr)) + '|'


def save_submission_df(submission_df: DataFrame):
    if RunConfig.RUN_MODE == RunMode.KERNEL:
        submission_df.to_csv(KernelPathConfig.SUBMISSION, index=False)
    else:
        PathConfig.OUTPUT_SUBMISSION.parent.mkdir(parents=True, exist_ok=True)
        submission_df.to_csv(PathConfig.OUTPUT_SUBMISSION, index=False)


def plot_one(ax, arr: np.ndarray, i, train_or_test, input_or_output):
    cmap = colors.ListedColormap(
        ['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
         '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])
    norm = colors.Normalize(vmin=0, vmax=9)

    ax.imshow(arr, cmap=cmap, norm=norm)
    ax.grid(True, which='both', color='lightgrey', linewidth=0.5)
    ax.set_yticks([x - 0.5 for x in range(1 + len(arr))])
    ax.set_xticks([x - 0.5 for x in range(1 + len(arr[0]))])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(train_or_test + ' ' + input_or_output)


def plot_task(task: Task, show: bool, save_path: Optional[Path]):
    input_output_num = len(task.train + task.test)
    total_row = 2
    fig, axs = plt.subplots(total_row, input_output_num, figsize=(2 * input_output_num, 2 * total_row))
    for i, (input_output, tag) in enumerate(zip(task.train + task.test, ['train'] * len(task.train) + ['test'] * len(task.test))):
        plot_one(axs[0, i], input_output.input_arr, i, tag, 'input')
        plot_one(axs[1, i], input_output.output_arr, i, tag, 'output')
    plt.tight_layout()

    if save_path:
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path)
    if show:
        plt.show()
    plt.close()


def plot_task_with_operation_set(task: Task, operation_set: OperationSet, show: bool, save_path: Optional[Path]):
    input_output_num = len(task.train + task.test)
    total_row = 3

    applied_task = TaskOperationSetExecutor().execute(task, operation_set)
    fig, axs = plt.subplots(total_row, input_output_num, figsize=(3 * input_output_num, 3 * total_row))
    for i, (raw_io, applied_io) in enumerate(zip(task.train + task.test, applied_task.train + applied_task.test)):
        plot_one(axs[0, i], raw_io.input_arr, i, 'train?', 'input')
        plot_one(axs[1, i], raw_io.output_arr, i, 'train?', 'output')
        plot_one(axs[2, i], applied_io.input_arr, i, 'train?', 'operated')
    plt.tight_layout()

    if save_path:
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path)
    if show:
        plt.show()
    plt.close()


def plot_task_with_result_set(task: Task, search_results: AnsweredSearchResults, show: bool, save_path: Optional[Path]):
    input_output_num = len(task.train + task.test)
    total_row = 2 + len(search_results.results)

    applied_tasks = [TaskOperationSetExecutor().execute(task, r.operation_set) for r in search_results.results]
    fig, axs = plt.subplots(total_row, input_output_num, figsize=(3 * input_output_num, 3 * total_row))
    for i, input_output in enumerate(task.train + task.test):
        plot_one(axs[0, i], input_output.input_arr, i, 'train?', 'input')
        plot_one(axs[1, i], input_output.output_arr, i, 'train?', 'output')
    for i, t in enumerate(applied_tasks):
        for j, input_output in enumerate(t.train + t.test):
            plot_one(axs[i + 2, j], input_output.input_arr, i, 'train?', 'input')
    plt.tight_layout()

    if save_path:
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path)
    if show:
        plt.show()
    plt.close()


@dataclass(frozen=True)
class Padding(UniformOperation):
    padding_mode: PaddingMode
    direction: Direction
    k: int

    def __call__(self, arr: np.ndarray) -> np.ndarray:
        if self.padding_mode == PaddingMode.REPEAT:
            np_pad_mode = 'wrap'
        elif self.padding_mode == PaddingMode.MIRROR_1:
            np_pad_mode = 'symmetric'
        elif self.padding_mode == PaddingMode.MIRROR_2:
            np_pad_mode = 'reflect'
        elif self.padding_mode == PaddingMode.EDGE:
            np_pad_mode = 'edge'
        else:
            raise ValueError(self.padding_mode)

        h, w = arr.shape
        if self.padding_mode == PaddingMode.MIRROR_2:
            h, w = h - 1, w - 1

        if self.direction == Direction.TOP:
            pad_width = ((self.k * h, 0), (0, 0))
        elif self.direction == Direction.BOTTOM:
            pad_width = ((0, self.k * h), (0, 0))
        elif self.direction == Direction.LEFT:
            pad_width = ((0, 0), (self.k * w, 0))
        elif self.direction == Direction.RIGHT:
            pad_width = ((0, 0), (0, self.k * w))
        else:
            raise ValueError(self.direction)

        return np.pad(arr, pad_width, mode=np_pad_mode)


@dataclass(frozen=True)
class Resize(UniformOperation):
    axis: Axis
    ratio: int  # TODO int? How to resize 3/2?

    def __call__(self, arr: np.ndarray) -> np.ndarray:
        if self.axis == Axis.VERTICAL:
            return np.repeat(arr, self.ratio, axis=0)
        elif self.axis == Axis.HORIZONTAL:
            return np.repeat(arr, self.ratio, axis=1)
        elif self.axis == Axis.BOTH:
            temp = np.repeat(arr, self.ratio, axis=0)
            return np.repeat(temp, self.ratio, axis=1)
        else:
            raise ValueError(self.axis)


@dataclass(frozen=True)
class Flip(UniformOperation):
    flip_mode: FlipMode

    def __call__(self, arr: np.ndarray) -> np.ndarray:
        if self.flip_mode == FlipMode.UD:
            return np.flipud(arr)
        elif self.flip_mode == FlipMode.LR:
            return np.fliplr(arr)
        elif self.flip_mode == FlipMode.UL_DR:
            return arr.T
        elif self.flip_mode == FlipMode.UR_DL:
            return np.flipud(np.flipud(arr.T))
        else:
            raise ValueError(self.flip_mode)


@dataclass(frozen=True)
class Rotate(UniformOperation):
    angle: int

    def __call__(self, arr: np.ndarray) -> np.ndarray:
        if self.angle not in [90, 180, 270]:
            raise ValueError(self.angle)

        return np.rot90(arr, self.angle // 90)


@dataclass(frozen=True)
class LineDeletion(UniformOperation):
    line_color: Color

    def __call__(self, arr: np.ndarray) -> np.ndarray:
        if arr.size == 1:
            raise OperationInconsistencyException('size == 1')

        if 1 in arr.shape:
            raise OperationInconsistencyException('can not separate')

        color_hit: np.ndarray = arr == self.line_color
        line_v_indices = np.where(color_hit.all(axis=1))[0]
        line_h_indices = np.where(color_hit.all(axis=0))[0]

        if len(line_v_indices) == len(line_h_indices) == 0:
            raise OperationInconsistencyException('not line found')

        arr = np.delete(arr, line_h_indices, axis=1)
        arr = np.delete(arr, line_v_indices, axis=0)

        if 0 in arr.shape:
            raise OperationInconsistencyException('0 size')

        return arr


@dataclass(frozen=True)
class FFTCompletion(UniformOperation):

    # TODO GIVE UP implement.
    #  This is just a poc for "SYMMETRY" or "REPEAT" pattern tasks.
    #  If you're interested in this function, let me know. I'll translate it.

    def __call__(self, arr: np.ndarray) -> np.ndarray:
        revs = []
        for color in Color:

            print(color)
            color_hit = arr == color

            if color == Color.BLACK or not color_hit.any():
                revs.append(np.full_like(color_hit, fill_value=False))
                continue

            rev_arr_int = self.complete_symmetric(color_hit)
            rev_arr_int = self.complete_symmetric(rev_arr_int)
            rev_arr_int = self.complete_symmetric(rev_arr_int)
            revs.append(rev_arr_int)

        for color, rev in zip(Color, revs):
            arr[rev] = color
        return arr

    def complete_symmetric(self, hit_arr, verbose=False):
        h, w = hit_arr.shape

        if verbose:
            print(hit_arr)
        f = np.fft.fftshift(np.fft.fft2(hit_arr))
        if verbose:
            print(f)
        amp = np.abs(f)
        if verbose:
            print(amp)
        amp = amp / h / w * 2
        if verbose:
            print(amp)
            print(f'sum {amp.sum()}')
            print(f'mean {amp.mean()}')
            print(f'max {amp.max()}')
        # TODO detect peakのパラメータ調整
        flags = np.array(self.detect_not_peaks_mask(amp))
        if verbose:
            print(flags)
        f[flags] = 0
        filtered_amp = amp.copy()
        filtered_amp[f == 0] = 0
        # F3_abs = np.abs(f)  # 複素数を絶対値に変換
        # F3_abs_amp = F3_abs / h / w * 2  # 交流成分はデータ数で割って2倍
        # F3_abs_amp[0] = F3_abs_amp[0] / 2  # 直流成分（今回は扱わないけど）は2倍不要
        F3_ifft = np.fft.ifft2(np.fft.ifftshift(f))  # IFFT
        F3_ifft_real = F3_ifft.real  # 実数部の取得
        # TODO 2値化アルゴリズム検討
        rev_arr_int = F3_ifft_real > threshold_minimum(F3_ifft_real)
        if verbose:
            fig, ax = try_all_threshold(F3_ifft_real, figsize=(10, 8), verbose=False)
            plt.show()

            # visualize
            plt.subplot(171)
            plt.imshow(hit_arr, cmap='gray')
            plt.title('Input Image'), plt.xticks([]), plt.yticks([])
            plt.subplot(172)
            plt.hist(amp.ravel(), bins=100)
            plt.title('Input Image'), plt.xticks([]), plt.yticks([])
            plt.subplot(173)
            plt.imshow(amp, cmap='gray')
            plt.title('Magnitude Spectrum'), plt.xticks([]), plt.yticks([])
            plt.subplot(174)
            plt.imshow(filtered_amp, cmap='gray')
            plt.title('Magnitude Spectrum'), plt.xticks([]), plt.yticks([])
            plt.subplot(175)
            plt.hist(F3_ifft_real.ravel(), bins=100)
            plt.subplot(176)
            plt.imshow(rev_arr_int, cmap='gray')
            plt.title('rev'), plt.xticks([]), plt.yticks([])
            plt.subplot(177)
            plt.imshow(rev_arr_int | hit_arr, cmap='gray')
            plt.title('and'), plt.xticks([]), plt.yticks([])
            plt.show()
        return rev_arr_int

    def detect_not_peaks_mask(self, image, filter_size=3, order=0.05):
        local_max = maximum_filter(image, footprint=np.ones((filter_size, filter_size)), mode='constant')
        detected_peaks = np.ma.array(image, mask=~(image == local_max))

        # 小さいピーク値を排除（最大ピーク値のorder倍のピークは排除）
        temp = np.ma.array(detected_peaks, mask=~(detected_peaks >= detected_peaks.max() * order))
        return temp.mask


@dataclass(frozen=True)
class FixedColorMaskFill(MaskOperation):
    color: Color

    def __call__(self, arr: np.ndarray, mask: np.ndarray) -> np.ndarray:
        arr[mask] = self.color
        return arr


@dataclass(frozen=True)
class SingleColorMaskFill(MaskOperation):
    single_color_selection_mode: SingleColorSelectionMode

    def __call__(self, arr: np.ndarray, mask: np.ndarray) -> np.ndarray:
        color = ColorSelectionUtil().select_single_color(arr, self.single_color_selection_mode)
        arr[mask] = color
        return arr


@dataclass(frozen=True)
class MaskCoordsCrop(MaskOperation):

    def __call__(self, arr: np.ndarray, mask: np.ndarray) -> np.ndarray:
        # TODO raise OperationInconsistencyException?
        if not mask.any():
            return arr

        coords = np.argwhere(mask)

        x_min, y_min = coords.min(axis=0)
        x_max, y_max = coords.max(axis=0)

        return arr[x_min:x_max + 1, y_min:y_max + 1]


@dataclass(frozen=True)
class FixedSingleColorSelection(ColorSelection):
    color: Color

    def __call__(self, arr: np.ndarray) -> np.ndarray:
        return arr == self.color


@dataclass(frozen=True)
class SingleColorSelection(ColorSelection):
    single_color_selection_mode: SingleColorSelectionMode

    def __call__(self, arr: np.ndarray) -> np.ndarray:
        color = ColorSelectionUtil().select_single_color(arr, self.single_color_selection_mode)
        return arr == color


@dataclass(frozen=True)
class MultiColorSelection(ColorSelection):
    multi_color_selection_mode: MultiColorSelectionMode

    def __call__(self, arr: np.ndarray) -> np.ndarray:
        if self.multi_color_selection_mode == MultiColorSelectionMode.ANY_WITHOUT_MOST_COMMON:
            most_common_color = ColorSelectionUtil().select_single_color(arr, SingleColorSelectionMode.MOST_COMMON)
            return arr != most_common_color
        elif self.multi_color_selection_mode == MultiColorSelectionMode.ANY_WITHOUT_LEAST_COMMON:
            least_common_color = ColorSelectionUtil().select_single_color(arr, SingleColorSelectionMode.LEAST_COMMON)
            return arr != least_common_color
        else:
            raise NotImplementedError()


class TaskLoader:

    def get_task(self, name: str) -> Task:
        try:
            return self._get_task(PathConfig.TRAIN_ROOT / f'{name}.json')
        except FileNotFoundError:
            return self._get_task(PathConfig.EVALUATION_ROOT / f'{name}.json')

    def get_training_tasks(self):
        if RunConfig.RUN_MODE == RunMode.KERNEL:
            return self._get_tasks(KernelPathConfig.TRAIN_ROOT)
        else:
            return self._get_tasks(PathConfig.TRAIN_ROOT)

    def get_evaluation_tasks(self):
        if RunConfig.RUN_MODE == RunMode.KERNEL:
            return self._get_tasks(KernelPathConfig.EVALUATION_ROOT)
        else:
            return self._get_tasks(PathConfig.EVALUATION_ROOT)

    def get_test_tasks(self):
        if RunConfig.RUN_MODE == RunMode.KERNEL:
            return self._get_tasks(KernelPathConfig.TEST_ROOT)
        else:
            return self._get_tasks(PathConfig.TEST_ROOT)

    def _get_tasks(self, root_path: Path) -> List[Task]:
        return [self._get_task(json_path) for json_path in root_path.iterdir()]

    def _get_task(self, path: Path) -> Task:
        with open(str(path), 'r') as f:
            return Task.of(path.stem, json.load(f))

    def is_private_lb_run(self) -> bool:
        eval_tasks = self._get_tasks(KernelPathConfig.EVALUATION_ROOT)
        test_tasks = self._get_tasks(KernelPathConfig.TEST_ROOT)

        eval_names = [t.name for t in eval_tasks]
        if any(filter(lambda t: t.name in eval_names, test_tasks)):
            return False
        else:
            return True


def create_image_feature(arr: np.ndarray) -> ImageFeature:
    return ImageFeature(
        height=arr.shape[0],
        width=arr.shape[1],
        colors=[Color.of(v) for v in ColorSelectionUtil().get_colors(arr)],
        hit_and_miss_histogram=calculate_hit_and_miss_histogram(arr),
        # most_common_color=ColorSelectionUtil().select_single_color(arr, SingleColorSelectionMode.MOST_COMMON),
        vertical_edge_num=np.count_nonzero(arr[1:] - arr[:-1]),  # faster than np.diff(arr, axis=0)
        horizontal_edge_num=np.count_nonzero(arr[:, 1:] - arr[:, :-1]),  # faster than np.diff(arr, axis=1)
    )


def create_image_diff_feature(original_input_arr: np.ndarray, input_arr: np.ndarray, output_arr: np.ndarray) -> ImageDiffFeature:
    util = FeatureUtil()
    in_feature = create_image_feature(input_arr)
    out_feature = create_image_feature(output_arr)
    return ImageDiffFeature(
        input_image_feature=in_feature,
        output_image_feature=out_feature,
        dim_height_increase=out_feature.height - in_feature.height,
        dim_width_increase=out_feature.width - in_feature.width,
        dim_height_integer_multiple=(out_feature.height / in_feature.height).is_integer() or (in_feature.height / out_feature.height).is_integer(),
        dim_width_integer_multiple=(out_feature.width / in_feature.width).is_integer() or (in_feature.width / out_feature.width).is_integer(),
        dim_height_diff=abs(out_feature.height - in_feature.height),
        dim_width_diff=abs(out_feature.width - in_feature.width),
        dim_height_equal=out_feature.height == in_feature.height,
        dim_width_equal=out_feature.width == in_feature.width,
        lack_color_num=len(set(out_feature.colors) - set(in_feature.colors)),
        excess_color_num=len(set(in_feature.colors) - set(out_feature.colors)),
        hit_and_miss_histogram_diff=sum(abs(i_c - o_c) for i_c, o_c in zip(in_feature.hit_and_miss_histogram, out_feature.hit_and_miss_histogram)),
        # vertical_diff_input_arr_line_num=util._vertical_diff_input_arr_line_num(input_arr, output_arr),
        # horizontal_diff_input_arr_line_num=util._horizontal_diff_input_arr_line_num(input_arr, output_arr),
        # vertical_diff_output_arr_line_num=util._vertical_diff_output_arr_line_num(input_arr, output_arr),
        # horizontal_diff_output_arr_line_num=util._horizontal_diff_output_arr_line_num(input_arr, output_arr),
        vertical_edge_sum_diff=abs(out_feature.vertical_edge_num - in_feature.vertical_edge_num),
        horizontal_edge_sum_diff=abs(out_feature.horizontal_edge_num - in_feature.horizontal_edge_num),
        vertical_edge_sum_diff_ratio=abs(out_feature.vertical_edge_num - in_feature.vertical_edge_num) / in_feature.width,
        horizontal_edge_sum_diff_ratio=abs(out_feature.horizontal_edge_num - in_feature.horizontal_edge_num) / in_feature.height,
        diff_color_cell_ratio=util._diff_cell_count_ratio(input_arr, output_arr),
        diff_cell_where_no_need_to_change_count_ratio=util._diff_cell_where_no_need_to_change_count_ratio(original_input_arr, input_arr, output_arr),
        wrong_change_cell_where_need_to_change_count_ratio=util._wrong_change_cell_where_need_to_change_count_ratio(original_input_arr, input_arr, output_arr)
    )


def create_task_feature(original_task: Task, task: Task) -> TaskFeature:
    diff_features = [create_image_diff_feature(o_io.input_arr, io.input_arr, io.output_arr) for o_io, io in zip(original_task.train, task.train)]
    return TaskFeature(
        # image_diff_features=image_diff_features,
        same_dim_between_input_output=all(f.same_dim() for f in diff_features),
        same_height_dim_between_input_output=all(f.dim_height_equal for f in diff_features),
        same_width_dim_between_input_output=all(f.dim_width_equal for f in diff_features),
        all_dim_height_increased=all(f.dim_height_increase > 0 for f in diff_features),
        all_dim_height_decreased=all(f.dim_height_increase < 0 for f in diff_features),
        all_dim_width_increased=all(f.dim_width_increase > 0 for f in diff_features),
        all_dim_width_decreased=all(f.dim_width_increase < 0 for f in diff_features),
        all_dim_height_integer_multiple=all(f.dim_height_integer_multiple for f in diff_features),
        all_dim_width_integer_multiple=all(f.dim_width_integer_multiple for f in diff_features),
        mean_lack_color_num=mean([f.lack_color_num for f in diff_features]),
        mean_excess_color_num=mean([f.excess_color_num for f in diff_features]),
        mean_hit_and_miss_histogram_diff=mean([f.hit_and_miss_histogram_diff for f in diff_features]),
        # mean_vertical_diff_input_arr_line_num=nan_mean(f.vertical_diff_input_arr_line_num for f in diff_features),
        # mean_horizontal_diff_input_arr_line_num=nan_mean(f.horizontal_diff_input_arr_line_num for f in diff_features),
        # mean_vertical_diff_output_arr_line_num=nan_mean(f.vertical_diff_output_arr_line_num for f in diff_features),
        # mean_horizontal_diff_output_arr_line_num=nan_mean(f.horizontal_diff_output_arr_line_num for f in diff_features),
        mean_vertical_edge_sum_diff=mean([f.vertical_edge_sum_diff for f in diff_features]),
        mean_horizontal_edge_sum_diff=mean([f.horizontal_edge_sum_diff for f in diff_features]),
        mean_vertical_edge_sum_diff_ratio=mean([f.vertical_edge_sum_diff_ratio for f in diff_features]),
        mean_horizontal_edge_sum_diff_ratio=mean([f.horizontal_edge_sum_diff_ratio for f in diff_features]),
        mean_diff_color_cell_ratio=nan_mean(f.diff_color_cell_ratio for f in diff_features),
        mean_diff_cell_where_no_need_to_change_count_ratio=nan_mean(f.diff_cell_where_no_need_to_change_count_ratio for f in diff_features),
        mean_wrong_change_cell_where_need_to_change_count_ratio=nan_mean(f.wrong_change_cell_where_need_to_change_count_ratio for f in diff_features),
    )


def create_color_selected_task_feature(original_task: Task, color_selected_task: ColorSelectedTask, task_feature: TaskFeature = None) -> ColorSelectedTaskFeature:
    if task_feature is None:
        task_feature = create_task_feature(original_task, color_selected_task)
    return ColorSelectedTaskFeature(task_feature)


def create_mask_conversion_task_feature(original_task: Task, mask_converted_task: MaskConvertedTask, task_feature: TaskFeature = None) -> MaskConvertedTaskFeature:
    if task_feature is None:
        task_feature = create_task_feature(original_task, mask_converted_task)
    possible_improve_ratios = [_calculate_possible_improve_ratio(io.input_arr, io.output_arr, m) for io, m in zip(mask_converted_task.train, mask_converted_task.train_masks)]
    return MaskConvertedTaskFeature(
        task_feature=task_feature,
        possible_improve_ratios=possible_improve_ratios,
    )


def _calculate_possible_improve_ratio(input_arr: np.ndarray, output_arr: np.ndarray, mask: np.ndarray) -> Optional[float]:
    if input_arr.shape != output_arr.shape:
        return None
    diff_arr = np.not_equal(input_arr, output_arr)
    if not diff_arr.all():
        return 1.0
    else:
        selected_diff_arr = np.logical_and(diff_arr, mask)
        return 1 - selected_diff_arr.sum() / diff_arr.sum()


class FeatureUtil:
    def _horizontal_diff_input_arr_line_num(self, input_arr: np.ndarray, output_arr: np.ndarray) -> Optional[float]:
        if input_arr.shape[1] != output_arr.shape[1]:
            return None
        return abs(input_arr.shape[0] - np.array([(output_arr == h_l).all(axis=1) for h_l in input_arr]).any(axis=1).sum())

    def _horizontal_diff_output_arr_line_num(self, input_arr: np.ndarray, output_arr: np.ndarray) -> Optional[float]:
        if input_arr.shape[1] != output_arr.shape[1]:
            return None
        return abs(output_arr.shape[0] - np.array([(output_arr == h_l).all(axis=1) for h_l in input_arr]).any(axis=0).sum())

    def _vertical_diff_input_arr_line_num(self, input_arr: np.ndarray, output_arr: np.ndarray) -> Optional[float]:
        if input_arr.shape[0] != output_arr.shape[0]:
            return None
        return self._horizontal_diff_input_arr_line_num(input_arr.T, output_arr.T)

    def _vertical_diff_output_arr_line_num(self, input_arr: np.ndarray, output_arr: np.ndarray) -> Optional[float]:
        if input_arr.shape[0] != output_arr.shape[0]:
            return None
        return self._horizontal_diff_output_arr_line_num(input_arr.T, output_arr.T)

    def _diff_cell_count_ratio(self, input_arr: np.ndarray, output_arr: np.ndarray) -> Optional[float]:
        if input_arr.shape != output_arr.shape:
            return None
        diff_arr = np.not_equal(input_arr, output_arr)
        return diff_arr.sum() / diff_arr.size

    def _diff_cell_where_no_need_to_change_count_ratio(self, original_input_arr: np.ndarray, input_arr: np.ndarray, output_arr: np.ndarray) -> Optional[float]:
        if not original_input_arr.shape == input_arr.shape == output_arr.shape:
            return None
        no_need_to_change_mask = np.equal(original_input_arr, output_arr)
        diff_cell = np.not_equal(input_arr, output_arr)
        diff_cell_where_no_need_to_change = diff_cell[no_need_to_change_mask]
        return diff_cell_where_no_need_to_change.sum() / original_input_arr.size

    def _wrong_change_cell_where_need_to_change_count_ratio(self, original_input_arr: np.ndarray, input_arr: np.ndarray, output_arr: np.ndarray) -> Optional[float]:
        if not original_input_arr.shape == input_arr.shape == output_arr.shape:
            return None
        need_to_change_mask = np.not_equal(original_input_arr, output_arr)
        change_mask = np.not_equal(original_input_arr, input_arr)
        wrong_mask = np.not_equal(input_arr, output_arr)

        wrong_change_cell_where_need_to_change_mask = need_to_change_mask & change_mask & wrong_mask

        return wrong_change_cell_where_need_to_change_mask.sum() / original_input_arr.size


def get_hit_and_miss_kernels():
    return [
        # right top
        np.array([
            [0, -1, -1],
            [1, 1, -1],
            [0, 1, 0],
        ], dtype=np.int8),
        # right bottom
        np.array([
            [0, 1, 0],
            [1, 1, -1],
            [0, -1, -1],
        ], dtype=np.int8),
        # left bottom
        np.array([
            [0, 1, 0],
            [-1, 1, 1],
            [-1, -1, 0],
        ], dtype=np.int8),
        # left top
        np.array([
            [-1, -1, 0],
            [-1, 1, 1],
            [0, 1, 0],
        ], dtype=np.int8),

        # right protrusion
        np.array([
            [0, -1, -1],
            [0, 1, -1],
            [0, -1, -1],
        ], dtype=np.int8),
        # bottom protrusion
        np.array([
            [0, 0, 0],
            [-1, 1, -1],
            [-1, -1, -1],
        ], dtype=np.int8),
        # left protrusion
        np.array([
            [-1, -1, 0],
            [-1, 1, 0],
            [-1, -1, 0],
        ], dtype=np.int8),
        # top protrusion
        np.array([
            [-1, -1, -1],
            [-1, 1, -1],
            [0, 0, 0],
        ], dtype=np.int8),
        # TODO implement others?
    ]


def calculate_hit_and_miss_histogram(arr: np.ndarray):
    kernels = get_hit_and_miss_kernels()

    exist_colors = np.unique(arr)

    counts = []
    for color in range(10):
        if color not in exist_colors:
            for k in kernels:
                counts.append(0)
        else:
            for k in kernels:
                color_hit = (arr == color).astype(np.uint8)
                hit_and_miss_result = cv2.morphologyEx(color_hit, cv2.MORPH_HITMISS, k)
                counts.append(int(hit_and_miss_result.sum()))

    # counts = []
    # for k in kernels:
    #     for color in range(10):
    #         if color not in exist_colors:
    #             counts.append(0)
    #         else:
    #             color_hit = (arr == color).astype(np.uint8)
    #             hit_and_miss_result = cv2.morphologyEx(color_hit, cv2.MORPH_HITMISS, k)
    #             counts.append(int(hit_and_miss_result.sum()))

    return counts


def summary_engine_results(results: List[Union[AnsweredSearchResults, NotAnsweredSearchResult]]):
    if len(results) == 0:
        return '0 result'
    counts = Counter(r.final_test_correct() for r in results)
    total_spent_time = np.sum([r.spent_time for r in results]) / 60
    mean_spent_time = np.sum([r.spent_time for r in results])
    max_spent_time = np.max([r.spent_time for r in results])

    result_message = f'--- stats --- \n' \
                     f'correct_count: {counts} \n' \
                     f'total_spent_time: {total_spent_time} min \n' \
                     f'mean_spent_time: {mean_spent_time} sec \n' \
                     f'max_spent_time: {max_spent_time} sec \n\n'

    result_message += '--- answered --- \n'
    result_message += '\n'.join(r.summary() for r in results if isinstance(r, AnsweredSearchResults))
    result_message += '\n--- all --- \n'
    result_message += '\n'.join(r.summary() for r in results)

    return result_message


class ColorSelectionUtil:

    def select_single_color(self, arr: np.ndarray, mode: SingleColorSelectionMode) -> Color:
        if mode == SingleColorSelectionMode.MOST_COMMON:
            color_counts = self.get_color_counts(arr)
            if len(color_counts) <= 0:
                raise OperationInconsistencyException('color <= 0')
            try:
                if color_counts[-1][1] == color_counts[-2][1]:  # Two maximums.
                    raise OperationInconsistencyException('duplicated max color')
            except IndexError:
                pass
            return Color.of(color_counts[-1][0])

        elif mode == SingleColorSelectionMode.SECOND_MOST_COMMON:
            color_counts = self.get_color_counts(arr)
            if len(color_counts) <= 1:
                raise OperationInconsistencyException('color <= 1')
            if color_counts[-1][1] == color_counts[-2][1]:  # Two maximums.
                raise OperationInconsistencyException('duplicated max color')
            try:
                if color_counts[-2][1] == color_counts[-3][1]:  # Two 2nd maximums.
                    raise OperationInconsistencyException('duplicated 2nd max color')
            except IndexError:
                pass
            return Color.of(color_counts[-2][0])
        elif mode == SingleColorSelectionMode.LEAST_COMMON:
            color_counts = self.get_color_counts(arr)
            if len(color_counts) <= 1:
                raise OperationInconsistencyException('color <= 1')
            if color_counts[0][1] == color_counts[1][1]:  # Two minimum.
                raise OperationInconsistencyException('duplicated 2nd max color')

            return Color.of(color_counts[0][0])
        else:
            raise NotImplementedError()

    def get_background_color(self, arr: np.ndarray, mode: BackGroundColorSelectionMode) -> Color:
        if mode == BackGroundColorSelectionMode.BLACK:
            return Color.BLACK
        elif mode == BackGroundColorSelectionMode.MOST_COMMON:
            return self.select_single_color(arr, SingleColorSelectionMode.MOST_COMMON)
        else:
            raise NotImplementedError()

    def get_color_counts(self, arr: np.ndarray) -> List[Tuple[int, int]]:
        color_counts = [(color, count) for color, count in enumerate(np.bincount(arr.ravel(), minlength=10)) if count != 0]
        return sorted(color_counts, key=itemgetter(1))

    def get_colors(self, arr: np.ndarray) -> List[Color]:
        return sorted(set(arr.ravel().tolist()))

    def select_multi_color(self):
        # TODO imple
        raise


@dataclass(frozen=True)
class SplitLineSelection(MaskConversion):
    axis: Axis

    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        result_mask = np.full_like(color_mask, fill_value=False)
        if self.axis in [Axis.VERTICAL, Axis.BOTH]:
            vertical_line_hits = color_mask.all(axis=0)
            result_mask[:, vertical_line_hits] = True

        if self.axis in [Axis.HORIZONTAL, Axis.BOTH]:
            horizontal_line_hits = color_mask.all(axis=1)
            result_mask[horizontal_line_hits] = True

        return result_mask


@dataclass(frozen=True)
class DotExistLineSelection(MaskConversion):
    axis: Axis

    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        result_mask = np.full_like(color_mask, fill_value=False)
        if self.axis in [Axis.VERTICAL, Axis.BOTH]:
            vertical_line_hits = color_mask.any(axis=0)
            result_mask[:, vertical_line_hits] = True

        if self.axis in [Axis.HORIZONTAL, Axis.BOTH]:
            horizontal_line_hits = color_mask.any(axis=1)
            result_mask[horizontal_line_hits] = True

        return result_mask


@dataclass(frozen=True)
class ObjectsTouchingEdgeSelection(MaskConversion):
    # TODO Direction or Axis property?
    true_or_false: TrueOrFalse
    connectivity: PixelConnectivity

    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        label_array, max_label_index = label(color_mask, connectivity=self.connectivity.value_for_skimage,
                                             background=False, return_num=True)

        if max_label_index == 0:
            return np.full_like(color_mask, order='C', fill_value=False)

        target_indices = [i for i in range(1, max_label_index + 1) if self._is_target(label_array == i)]

        return np.isin(label_array, target_indices)

    def _is_target(self, arr: np.ndarray) -> bool:
        top_line = arr[0]
        bottom_line = arr[-1]
        left_line = arr[:, 0]
        right_line = arr[:, -1]

        if self.true_or_false == TrueOrFalse.TRUE:
            return any([top_line.any(), bottom_line.any(), left_line.any(), right_line.any()])
        else:
            return not any([top_line.any(), bottom_line.any(), left_line.any(), right_line.any()])


@dataclass(frozen=True)
class ObjectsMaxMinSelection(MaskConversion):
    """ Create a mask with max/min feature objects """
    true_or_false: TrueOrFalse
    max_or_min: MaxOrMin
    object_feature: ObjectFeature
    connectivity: PixelConnectivity

    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        label_array, max_label_index = label(color_mask, connectivity=self.connectivity.value_for_skimage,
                                             background=False, return_num=True)

        if max_label_index == 0:
            return np.full_like(color_mask, order='C', fill_value=False)

        label_indices = list(range(1, max_label_index + 1))
        label_index_feature_value_pairs = list(map(lambda l: (l, self._calculate_object_feature(label_array, l)), label_indices))

        target_feature_value = self.max_or_min.func(label_index_feature_value_pairs, key=itemgetter(1))[1]
        if self.true_or_false == TrueOrFalse.TRUE:
            target_indices = [l_i for l_i, f in label_index_feature_value_pairs if f == target_feature_value]
        else:
            target_indices = [l_i for l_i, f in label_index_feature_value_pairs if f != target_feature_value]

        return np.isin(label_array, target_indices)

    def _calculate_object_feature(self, label_array: np.ndarray, label_index: int) -> int:
        if self.object_feature == ObjectFeature.AREA:
            return self._label_array_to_area(label_array, label_index)
        if self.object_feature == ObjectFeature.HORIZONTAL_LEN:
            return self._label_array_to_horizontal_len(label_array, label_index)
        if self.object_feature == ObjectFeature.VERTICAL_LEN:
            return self._label_array_to_vertical_len(label_array, label_index)
        else:
            raise NotImplementedError()

    def _label_array_to_area(self, label_array: np.ndarray, label_index: int) -> int:
        label_hit = label_array == label_index
        return label_hit.sum()

    def _label_array_to_horizontal_len(self, label_array: np.ndarray, label_index: int) -> int:
        label_hit = label_array == label_index
        horizontal_label_hit = label_hit.any(axis=0)
        coords = np.where(horizontal_label_hit)[0]
        return max(coords) - min(coords)

    def _label_array_to_vertical_len(self, label_array: np.ndarray, label_index: int) -> int:
        label_hit = label_array == label_index
        vertical_label_hit = label_hit.any(axis=1)
        coords = np.where(vertical_label_hit)[0]
        return max(coords) - min(coords)


@dataclass(frozen=True)
class OldObjectsMaxMinSelection(MaskConversion):
    # similar to ObjectsMaxMinSelection
    # TODO Without this function, LB will be 0.97 -> 0.98
    # TODO why???

    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        # TODO should variate hierarchy?
        contours, hierarchy = cv2.findContours(np.ascontiguousarray(color_mask).astype(np.uint8),
                                               cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        if len(contours) == 0:
            return np.full_like(color_mask, order='C', fill_value=False)

        max_area_contour = max(contours, key=lambda c: cv2.contourArea(c))

        mask = np.full_like(color_mask, order='C', fill_value=False)

        mask = cv2.drawContours(mask.astype(np.uint8), max_area_contour, contourIdx=-1, color=1)
        if isinstance(mask, cv2.UMat):  # mask sometimes becomes cv2.UMat class... I don't know why.
            mask = mask.get()

        return mask.astype(bool)


@dataclass(frozen=True)
class SquareObjectsSelection(MaskConversion):
    """ Create a mask with only square objects """

    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        color_mask = color_mask.astype(np.uint8)
        max_square_len = min(color_mask.shape)
        # TODO
        # if max_square_len == 1:
        #     return arr

        square_hit = np.full_like(color_mask, fill_value=False, dtype=bool)
        for l in range(1, max_square_len):
            hit_and_miss_kernel = self._square_hit_and_miss_kenel(l)
            filter_kernel = self._filter_kenel(l)
            temp_square_hit = cv2.morphologyEx(color_mask, cv2.MORPH_HITMISS, hit_and_miss_kernel, anchor=(1, 1))
            temp_square_hit = cv2.filter2D(temp_square_hit, -1, filter_kernel, anchor=(l - 1, l - 1), borderType=cv2.BORDER_CONSTANT)

            square_hit = np.logical_or(square_hit, temp_square_hit.astype(bool))

        return square_hit

    def _square_hit_and_miss_kenel(self, l: int) -> np.ndarray:
        kernel = np.full((l + 2, l + 2), fill_value=1, dtype=np.int8)
        kernel[0, :] = -1
        kernel[-1, :] = -1
        kernel[:, 0] = -1
        kernel[:, -1] = -1
        return kernel

    def _filter_kenel(self, l: int) -> np.ndarray:
        return np.full((l, l), fill_value=1, dtype=np.int8)


@dataclass(frozen=True)
class HolesSelection(MaskConversion):
    """ Select only the empty hole inside. """

    connectivity: PixelConnectivity

    # TODO Lack of consideration of the edges of the image?

    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        filled = binary_fill_holes(color_mask, structure=self.connectivity.structure_for_skimage)
        return filled ^ color_mask


@dataclass(frozen=True)
class ObjectInnerSelection(MaskConversion):
    connectivity: PixelConnectivity
    image_edge_type: ImageEdgeType

    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        if self.image_edge_type == ImageEdgeType.EDGE_EXCLUDE:
            border_value = 0
        elif self.image_edge_type == ImageEdgeType.EDGE_INCLUDE:
            border_value = 1
        else:
            raise NotImplementedError()

        return binary_erosion(color_mask, structure=self.connectivity.structure_for_skimage, border_value=border_value)


@dataclass(frozen=True)
class ContourSelection(MaskConversion):
    """ Create a contour mask """

    connectivity: PixelConnectivity
    image_edge_type: ImageEdgeType

    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        inner_mask = ObjectInnerSelection(self.connectivity, self.image_edge_type)(color_mask)
        return np.logical_xor(inner_mask, color_mask)


@dataclass(frozen=True)
class ContourOuterSelection(MaskConversion):
    """ Create a mask one pixel outside the contour """

    connectivity: PixelConnectivity
    hole_include: HoleInclude

    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        if self.hole_include == HoleInclude.INCLUDE:
            dilated = binary_dilation(color_mask, structure=self.connectivity.structure_for_skimage, border_value=0)
            return np.logical_xor(dilated, color_mask)
        elif self.hole_include == HoleInclude.EXCLUDE:
            dilated = binary_dilation(color_mask, structure=self.connectivity.structure_for_skimage, border_value=0)
            holes = HolesSelection(self.connectivity)(color_mask)
            return np.logical_and(np.logical_xor(dilated, color_mask), ~holes)
        else:
            raise NotImplementedError()


@dataclass(frozen=True)
class ConnectDotSelection(MaskConversion):
    # TODO This function spends much time.
    axis: Axis
    edge_type: LineEdgeType
    fill_type: FillType

    def __call__(self, color_mask: np.ndarray) -> np.ndarray:
        result_mask = np.full_like(color_mask, fill_value=False)
        coords = np.argwhere(color_mask)
        if self.axis in [Axis.HORIZONTAL, Axis.BOTH]:

            # Calculate the min and max coordinates of the horizontal
            horizontal_group = {k: itemgetter(0, -1)(tuple(map(itemgetter(1), g))) for k, g in groupby(coords, key=itemgetter(0))}

            # filter
            horizontal_group = {k: (v[0], v[1]) for k, v in horizontal_group.items() if (v[1] - v[0]) >= 2}

            # maskを計算
            for y, (x_min, x_max) in horizontal_group.items():
                if self.edge_type == LineEdgeType.EdgeInclude:
                    pass
                elif self.edge_type == LineEdgeType.EdgeExclude:
                    x_min += 1
                    x_max -= 1
                else:
                    raise NotImplementedError()

                result_mask[y, x_min:x_max + 1] = True

        if self.axis in [Axis.VERTICAL, Axis.BOTH]:

            # Calculate the min and max coordinates of the vertical
            vertical_group = {k: itemgetter(0, -1)(tuple(map(itemgetter(0), g))) for k, g in groupby(sorted(coords, key=itemgetter(1)), key=itemgetter(1))}

            # calculate mask
            vertical_group = {k: (v[0], v[1]) for k, v in vertical_group.items() if (v[1] - v[0]) >= 2}

            # generate mask
            for x, (y_min, y_max) in vertical_group.items():
                if self.edge_type == LineEdgeType.EdgeInclude:
                    pass
                elif self.edge_type == LineEdgeType.EdgeExclude:
                    y_min += 1
                    y_max -= 1
                else:
                    raise NotImplementedError()

                result_mask[y_min:y_max + 1, x] = True

        if self.fill_type == FillType.NotOverride:
            result_mask = np.logical_xor(result_mask, color_mask)
        return result_mask


class TaskOperationSetExecutor:

    def execute(self, task: Task, operation_set: OperationSet) -> Task:
        arrays = OperationSetExecutor.apply_operation_set([io.input_arr for io in task.train + task.test], operation_set)

        return Task(task.name,
                    tuple([InputOutput(a, io.output_arr) for a, io in zip(arrays[:len(task.train)], task.train)]),
                    tuple([InputOutput(a, io.output_arr) for a, io in zip(arrays[len(task.train):], task.test)]))


class ColorSelectionExecutor:

    @staticmethod
    def execute(task: Task, color_selection: ColorSelection) -> ColorSelectedTask:
        masks = OperationSetExecutor.apply_color_selection([io.input_arr for io in task.train + task.test], color_selection)

        return ColorSelectedTask(task.name, task.train, task.test, masks[:len(task.train)], masks[len(task.train):])


class MaskConversionExecutor:

    @staticmethod
    def execute(task: ColorSelectedTask, mask_conversion: MaskConversion) -> MaskConvertedTask:
        masks = OperationSetExecutor.apply_mask_conversion(task.train_masks + task.test_masks, mask_conversion)

        return MaskConvertedTask(task.name, task.train, task.test, masks[:len(task.train_masks)], masks[len(task.train_masks):])


class MaskOperationExecutor:

    @staticmethod
    def execute(task: MaskConvertedTask, mask_operation: MaskOperation) -> Task:
        new_arrays = OperationSetExecutor.apply_mask_operation(
            [io.input_arr for io in task.train + task.test], task.train_masks + task.test_masks, mask_operation)
        train_io = tuple([InputOutput(n, io.output_arr) for n, io, in zip(new_arrays[:len(task.train)], task.train)])
        test_io = tuple([InputOutput(n, io.output_arr) for n, io, in zip(new_arrays[len(task.train):], task.train)])

        return Task(task.name, train_io, test_io)


class ColorChannelSelectionExecutor:

    @staticmethod
    def execute(task: Task, color_channel_selection: ColorChannelSelection) -> ColorChannelSelectedTask:
        color_mask_pairs_list = OperationSetExecutor.apply_channel_selection([io.input_arr for io in task.train + task.test], color_channel_selection)

        return ColorChannelSelectedTask(task.name, task.train, task.test,
                                        color_mask_pairs_list[:len(task.train)],
                                        color_mask_pairs_list[len(task.train):])


class ColorChannelMaskConversionSelectionExecutor:

    @staticmethod
    def execute(task: ColorChannelSelectedTask, mask_conversion: MaskConversion) -> ColorChannelMaskConvertedTask:
        color_mask_pairs_list = OperationSetExecutor.apply_color_channel_mask_conversion(
            task.train_color_mask_pairs + task.test_color_mask_pairs, mask_conversion)

        return ColorChannelMaskConvertedTask(task.name, task.train, task.test,
                                             task.train_color_mask_pairs,
                                             color_mask_pairs_list[:len(task.train_color_mask_pairs)],
                                             task.test_color_mask_pairs,
                                             color_mask_pairs_list[len(task.train_color_mask_pairs):])


class ColorChannelMergeExecutor:

    @staticmethod
    def execute(task: ColorChannelMaskConvertedTask, merge_operation: ChannelMergeOperation) -> Task:
        new_arrays = OperationSetExecutor.apply_channel_merge([io.input_arr for io in task.train + task.test],
                                                              task.train_original_color_mask_pairs + task.test_original_color_mask_pairs,
                                                              task.train_color_mask_pairs + task.test_color_mask_pairs,
                                                              merge_operation)

        train_io = tuple([InputOutput(n, io.output_arr) for n, io, in zip(new_arrays[:len(task.train)], task.train)])
        test_io = tuple([InputOutput(n, io.output_arr) for n, io, in zip(new_arrays[len(task.train):], task.train)])

        return Task(task.name, train_io, test_io)


class PartitionSelectionExecutor:

    @staticmethod
    def execute(task: Task, partition_selection: PartitionSelection) -> PartitionSelectionTask:
        array_mask_list = OperationSetExecutor.apply_partition_selection([io.input_arr for io in task.train + task.test], partition_selection)

        return PartitionSelectionTask(task.name, task.train, task.test, array_mask_list[:len(task.train)], array_mask_list[len(task.train):])


class PartitionMergeExecutor:

    @staticmethod
    def execute(task: PartitionSelectionTask, partition_merge_operation: PartitionMergeOperation) -> Task:
        new_arrays = OperationSetExecutor.apply_partition_merge_operation([io.input_arr for io in task.train + task.test],
                                                                          task.train_partitioned_arrays_original_location_masks + task.test_partitioned_arrays_original_location_masks,
                                                                          partition_merge_operation)

        train_io = tuple([InputOutput(n, io.output_arr) for n, io, in zip(new_arrays[:len(task.train)], task.train)])
        test_io = tuple([InputOutput(n, io.output_arr) for n, io, in zip(new_arrays[len(task.train):], task.train)])

        return Task(task.name, train_io, test_io)


class CompletedNodeProcessor:

    @staticmethod
    def process(node: CompletedNode) -> List[WaitingNode]:
        mapping = {
            UniformOperationCompletedNode: OperationCompletedNodeProcessor,
            ColorSelectionCompletedNode: ColorSelectionCompletedNodeProcessor,
            MaskConversionCompletedNode: MaskConversionCompletedNodeProcessor,
            ColorChannelSelectionCompletedNode: ColorChannelSelectionCompletedNodeProcessor,
            ColorChannelMaskConversionCompletedNode: ColorChannelMaskConversionCompletedNodeProcessor,
            PartitionSelectionCompletedNode: PartitionSelectionCompletedNodeProcessor,
        }

        processor = mapping[node.__class__]
        return processor.process(node)


class OperationCompletedNodeProcessor:

    @classmethod
    def process(cls, node: UniformOperationCompletedNode) -> List[Union[UniformOperationWaitingNode, ColorSelectionWaitingNode, ColorChannelSelectionOperationWaitingNode]]:
        res = [
            *[UniformOperationWaitingNode(node, node.original_task, node.task, node.task_feature, node.base_operation_set, new_operation)
              for new_operation in cls._candidate_operations(node.task, node.task_feature)],
            *[ColorSelectionWaitingNode(node, node.original_task, node.task, node.task_feature, node.base_operation_set, color_selection)
              for color_selection in cls._candidate_color_selections(node.task)],
            *[ColorChannelSelectionOperationWaitingNode(node, node.original_task, node.task, node.task_feature, node.base_operation_set, color_channel_selection)
              for color_channel_selection in cls._candidate_color_channel_selection(node.task)],
        ]

        # first operation only
        if len(node.base_operation_set.operations) == 0:
            res.extend([PartitionSelectionWaitingNode(node, node.original_task, node.task, node.task_feature, node.base_operation_set, partition_selection)
                        for partition_selection in cls._candidate_partition_selection(node.task)])

        return res

    @staticmethod
    def _candidate_operations(task: Task, task_feature: TaskFeature):
        input_colors = list(map(lambda v: Color.of(v), set(chain.from_iterable(chain.from_iterable(task.get_input_all_arr())))))

        candidates = []

        if task_feature.all_dim_height_increased:
            candidates += [Resize(Axis.VERTICAL, r) for r in range(2, 5)]
            candidates += [Padding(m, d, k) for m, d, k in product(PaddingMode, [Direction.TOP, Direction.BOTTOM], range(1, 4))]

        if task_feature.all_dim_width_increased:
            candidates += [Resize(Axis.HORIZONTAL, r) for r in range(2, 5)]
            candidates += [Padding(m, d, k) for m, d, k in product(PaddingMode, [Direction.LEFT, Direction.RIGHT], range(1, 4))]

        if task_feature.all_dim_height_decreased or task_feature.all_dim_width_decreased:
            candidates += [LineDeletion(c) for c in input_colors]

        candidates += [
            *[Flip(m) for m in FlipMode],
            *[Rotate(a) for a in [90, 180, 270]],
        ]

        return candidates

    @staticmethod
    def _candidate_color_selections(task: Task) -> List[ColorSelection]:
        input_colors = list(map(lambda v: Color.of(v), set(chain.from_iterable(chain.from_iterable(task.get_input_all_arr())))))
        return [
            *[FixedSingleColorSelection(c) for c in input_colors],
            *[SingleColorSelection(m) for m in SingleColorSelectionMode],
            *[MultiColorSelection(m) for m in MultiColorSelectionMode],
        ]

    @staticmethod
    def _candidate_color_channel_selection(task: Task) -> List[ColorChannelSelection]:
        return [
            *[WithOutMostCommonColorChannelSelection(m) for m in BackGroundColorSelectionMode]
        ]

    @staticmethod
    def _candidate_partition_selection(task: Task) -> List[PartitionSelection]:
        input_colors = list(map(lambda v: Color.of(v), set(chain.from_iterable(chain.from_iterable(task.get_input_all_arr())))))

        return [
            *[ColorNumIntegerDivisionPartition(axis=a) for a in Axis],
            *[IntegerDivisionPartition(axis=a, n_split=n) for a, n in product(Axis, range(2, 5))],
            *[GeneralizedLinePartition(m) for m in BackGroundColorSelectionMode],
            *[LinePartition(line_color=c) for c in input_colors],
        ]


class ColorSelectionCompletedNodeProcessor:

    @classmethod
    def process(cls, node: ColorSelectionCompletedNode) -> List[MaskConversionWaitingNode]:
        return [MaskConversionWaitingNode(node, node.original_task, node.color_selected_task, node.color_selected_task_feature,
                                          node.base_operation_set, node.color_selection, mask_conversion)
                for mask_conversion in cls._candidate_mask_conversions()]

    @staticmethod
    def _candidate_mask_conversions() -> List[MaskConversion]:
        return [
            NoMaskConversion(),
            SquareObjectsSelection(),
            *[ObjectsTouchingEdgeSelection(tf, c) for tf, c in product(TrueOrFalse, PixelConnectivity)],
            *[ObjectsMaxMinSelection(tf, m, t, c) for tf, m, t, c in product(TrueOrFalse, MaxOrMin, ObjectFeature, PixelConnectivity)],
            OldObjectsMaxMinSelection(),
            *[SplitLineSelection(a) for a in Axis],
            *[DotExistLineSelection(a) for a in Axis],
            *[HolesSelection(c) for c in PixelConnectivity],
            *[ObjectInnerSelection(c, e) for c, e in product(PixelConnectivity, ImageEdgeType)],
            *[ContourSelection(c, e) for c, e in product(PixelConnectivity, ImageEdgeType)],
            *[ContourOuterSelection(c, h) for c, h in product(PixelConnectivity, HoleInclude)],
            *[ConnectDotSelection(a, e, f) for a, e, f in product(Axis, LineEdgeType, FillType)],
        ]


class MaskConversionCompletedNodeProcessor:

    @classmethod
    def process(cls, node: MaskConversionCompletedNode) -> List[MaskOperationSelectionWaitingNode]:
        return [MaskOperationSelectionWaitingNode(node, node.original_task, node.mask_converted_task, node.mask_converted_task_feature,
                                                  node.base_operation_set, node.color_selection, node.mask_conversion, mask_operation)
                for mask_operation in cls._candidate(node)]

    @staticmethod
    def _candidate(node: MaskConversionCompletedNode) -> List[MaskOperation]:
        # TODO use
        # color_mappings = set(chain.from_iterable(t.candidate_color_mapping() for t in task.train))

        output_colors = list(map(lambda v: Color.of(v), set(chain.from_iterable(chain.from_iterable(node.mask_converted_task.get_output_all_arr())))))

        candidates = []

        if not node.mask_converted_task_feature.task_feature.same_dim_between_input_output:
            candidates += [MaskCoordsCrop()]

        candidates += [
            *[FixedColorMaskFill(c) for c in output_colors],
            *[SingleColorMaskFill(m) for m in SingleColorSelectionMode],
        ]

        return candidates


class ColorChannelSelectionCompletedNodeProcessor:

    @classmethod
    def process(cls, node: ColorChannelSelectionCompletedNode) -> List[ColorChannelMaskConversionWaitingNode]:
        return [ColorChannelMaskConversionWaitingNode(node, node.original_task, node.task, node.feature,
                                                      node.base_operation_set, node.color_channel_selection, mask_conversion)
                for mask_conversion in cls._candidate_mask_conversions()]

    @staticmethod
    def _candidate_mask_conversions() -> List[MaskConversion]:
        return [
            NoMaskConversion(),
            SquareObjectsSelection(),
            *[ObjectsTouchingEdgeSelection(tf, c) for tf, c in product(TrueOrFalse, PixelConnectivity)],
            *[ObjectsMaxMinSelection(tf, m, t, c) for tf, m, t, c in product(TrueOrFalse, MaxOrMin, ObjectFeature, PixelConnectivity)],
            OldObjectsMaxMinSelection(),
            *[SplitLineSelection(a) for a in Axis],
            *[DotExistLineSelection(a) for a in Axis],
            *[HolesSelection(c) for c in PixelConnectivity],
            *[ObjectInnerSelection(c, e) for c, e in product(PixelConnectivity, ImageEdgeType)],
            *[ContourSelection(c, e) for c, e in product(PixelConnectivity, ImageEdgeType)],
            *[ContourOuterSelection(c, h) for c, h in product(PixelConnectivity, HoleInclude)],
            *[ConnectDotSelection(a, e, f) for a, e, f in product(Axis, LineEdgeType, FillType)],
        ]


class ColorChannelMaskConversionCompletedNodeProcessor:

    @classmethod
    def process(cls, node: ColorChannelMaskConversionCompletedNode) -> List[ColorChannelMergeWaitingNode]:
        return [ColorChannelMergeWaitingNode(node, node.original_task, node.task, node.feature,
                                             node.base_operation_set, node.color_selection, node.mask_conversion, merge_operation)
                for merge_operation in cls._candidate()]

    @staticmethod
    def _candidate() -> List[ChannelMergeOperation]:
        return [ColorChannelOverrideOperation()]


class PartitionSelectionCompletedNodeProcessor:

    @classmethod
    def process(cls, node: PartitionSelectionCompletedNode) -> List[PartitionMergeWaitingNode]:
        return [PartitionMergeWaitingNode(node, node.original_task, node.task, node.feature,
                                          node.base_operation_set, node.partition_selection, c) for c in cls._candidate(node)]

    @staticmethod
    def _candidate(node: PartitionSelectionCompletedNode) -> List[PartitionMergeOperation]:
        output_colors = list(map(lambda v: Color.of(v), set(chain.from_iterable(chain.from_iterable(node.task.get_output_all_arr())))))

        selections = [
            *[UniqueColorNumberSelection(m) for m in MaxOrMin],
            *[ColoredCellNumberSelection(m, bg) for m, bg in product(MaxOrMin, BackGroundColorSelectionMode)],
            *[SameShapeNumSelection(m) for m in MaxOrMin],
            *[SymmetrySelection(a, tf) for a, tf in product(AxisV2, TrueOrFalse)],
        ]

        return [
            *[AnySelectionMerge(m, c) for m, c in product(BackGroundColorSelectionMode, output_colors)],
            *[NotSelectionMerge(m, c) for m, c in product(BackGroundColorSelectionMode, output_colors)],
            *[AllSelectionMerge(m, c) for m, c in product(BackGroundColorSelectionMode, output_colors)],
            *[ModifiedXorSelectionMerge(m, c) for m, c in product(BackGroundColorSelectionMode, output_colors)],
            *[NaturalArrayOrderedOverrideMerge(m, c, a) for m, c, a in product(BackGroundColorSelectionMode, Corner, [Axis.VERTICAL, Axis.HORIZONTAL])],
            *[DiagonalArrayOrderedOverrideMerge(m, c, a) for m, c, a in product(BackGroundColorSelectionMode, Corner, [Axis.VERTICAL, Axis.HORIZONTAL])],
            *[SpiralArrayOrderedOverrideMerge(m, c, d) for m, c, d in product(BackGroundColorSelectionMode, Corner, SpiralDirection)],
            *[UniquelySelectedArrayExtraction(s) for s in selections],
            *[RestoreOnlySelectedArray(m, s) for m, s in product(BackGroundColorSelectionMode, selections)],
            ExtractOneValueFromPartitionedArray(),
        ]


class AnswerMatcher:

    @staticmethod
    def is_match_arr(arr1: np.ndarray, arr2: np.ndarray) -> bool:
        return np.array_equal(arr1, arr2)

    @classmethod
    def is_train_all_match_if_operated(cls, task: Task, operation_set: OperationSet) -> bool:
        try:
            applied_task = TaskOperationSetExecutor().execute(task, operation_set)
            return all(cls.is_match_arr(io.input_arr, io.output_arr) for io in applied_task.train)
        except OperationInconsistencyException:
            return False

    @classmethod
    def is_train_test_all_match_if_operated(cls, task: Task, operation_set: OperationSet) -> bool:
        try:
            applied_task = TaskOperationSetExecutor().execute(task, operation_set)
            return all(cls.is_match_arr(io.input_arr, io.output_arr) for io in applied_task.train + applied_task.test)
        except OperationInconsistencyException:
            return False

    # TODO ？
    @classmethod
    def is_train_all_match(cls, task: Task) -> bool:
        return all(map(lambda io: cls.is_match_arr(io.input_arr, io.output_arr), task.train))


def setup_df_display_options():
    np.set_printoptions(threshold=10000)
    np.set_printoptions(linewidth=10000)
    pd.set_option('display.max_columns', 1000)
    pd.set_option('display.max_rows', 1000)
    pd.set_option('display.width', 800)
    pd.set_option('display.max_colwidth', 300)


def mean(values: List[float]) -> float:
    return sum(values) / len(values)


def nan_mean(val_iter: Iterable[Union[int, float]]) -> Optional[float]:
    nan_filtered = [v for v in val_iter if v is not None]
    if not nan_filtered:
        return None
    return mean(nan_filtered)


def initialize_path():
    if RunConfig.RUN_MODE in [RunMode.LOCAL_RUN_ALL, RunMode.LOCAL_RUN]:
        shutil.rmtree(PathConfig.WRONG_ANSWERS_ROOT, ignore_errors=True)
        PathConfig.OUTPUT_SUBMISSION.unlink() if PathConfig.OUTPUT_SUBMISSION.exists() else None


@dataclass
class HandMadeNodeEvaluator(NodeEvaluator):
    pattern: DepthSearchPattern
    operation_element_prob_dict: Dict[str, float]
    node_search_engine_param: NodeBaseSearchEngineParameter
    dist_eval_param: DistanceEvaluatorParameter

    def __post_init__(self):
        self.class_mapping = {
            UniformOperationWaitingNode: OperationWaitingNodeEvaluator(self.operation_element_prob_dict),
            ColorSelectionWaitingNode: ColorSelectionWaitingNodeEvaluator(self.operation_element_prob_dict),
            MaskConversionWaitingNode: MaskConversionWaitingNodeEvaluator(self.operation_element_prob_dict),
            MaskOperationSelectionWaitingNode: MaskOperationSelectionWaitingNodeEvaluator(self.operation_element_prob_dict),
            ColorChannelSelectionOperationWaitingNode: ColorChannelSelectionOperationWaitingNodeEvaluator(self.operation_element_prob_dict),
            ColorChannelMaskConversionWaitingNode: ColorChannelMaskConversionWaitingNodeEvaluator(self.operation_element_prob_dict),
            ColorChannelMergeWaitingNode: ColorChannelMergeWaitingNodeEvaluator(self.operation_element_prob_dict),
            PartitionSelectionWaitingNode: PartitionSelectionWaitingNodeEvaluator(self.operation_element_prob_dict),
            PartitionMergeWaitingNode: PartitionMergeWaitingNodeEvaluator(self.operation_element_prob_dict),
        }

        self.dist_evaluator = DistanceEvaluator(self.dist_eval_param)

    def evaluate_nodes(self, nodes: List[WaitingNode]):
        for n in nodes:
            self.evaluate(n)

    def evaluate(self, node: WaitingNode):
        evaluator = self.class_mapping[node.__class__]
        task_feature = evaluator.get_task_feature(node)
        base_distance = self.dist_evaluator.evaluate_task_feature(task_feature)
        element_including_prob = evaluator.get_element_inclusion_prob(node)
        node.cache_pred_distance = self.calculate_final_distance(base_distance, element_including_prob, node.depth())

    def evaluate_base_distance_for_completed_node(self, node: CompletedNode):
        return self.dist_evaluator.evaluate_task_feature(node.task_feature)

    def calculate_final_distance(self, base_distance: float, element_inclusion_prob: float, depth: int) -> float:
        prob_cost = self.node_search_engine_param.element_inclusion_prob_factor * (1 - element_inclusion_prob)
        if self.pattern == DepthSearchPattern.BREADTH_FIRST:
            return base_distance ** (1 + depth * self.node_search_engine_param.breadth_first_exp_cost) + prob_cost + self.node_search_engine_param.breadth_first_cost * depth
        elif self.pattern == DepthSearchPattern.NORMAL:
            return base_distance ** (1 + depth * self.node_search_engine_param.normal_exp_cost) + prob_cost + self.node_search_engine_param.normal_first_cost * depth
        elif self.pattern == DepthSearchPattern.DEPTH_FIRST:
            return base_distance ** (1 + depth * self.node_search_engine_param.depth_first_exp_cost) + prob_cost + self.node_search_engine_param.depth_first_cost * depth
        else:
            raise NotImplementedError()


@dataclass
class HandmadeNodeEvaluatorBase:
    operation_element_prob_dict: Dict[str, float]

    def get_task_feature(self, node) -> TaskFeature:
        raise NotImplementedError()

    def get_element_inclusion_prob(self, node) -> float:
        raise NotImplementedError()

    def calculate_dist_factor(self, node) -> float:
        raise NotImplementedError()


class OperationWaitingNodeEvaluator(HandmadeNodeEvaluatorBase):

    def get_task_feature(self, operation_waiting_node: UniformOperationWaitingNode) -> TaskFeature:
        return operation_waiting_node.task_feature

    def get_element_inclusion_prob(self, operation_waiting_node: UniformOperationWaitingNode) -> float:
        return self.operation_element_prob_dict[operation_waiting_node.next_operation.__class__.__name__]

    def calculate_dist_factor(self, operation_waiting_node: UniformOperationWaitingNode) -> float:
        # TODO use height_integer_multiple?
        operation = operation_waiting_node.next_operation
        if isinstance(operation, (Flip, Rotate)):
            if operation_waiting_node.task_feature.same_dim_between_input_output:
                dist_factor = 0.8
            else:
                dist_factor = 1.2
        elif isinstance(operation, (Resize, Padding)):
            if operation_waiting_node.task_feature.same_dim_between_input_output:
                dist_factor = 1.2
            else:
                dist_factor = 0.8
        else:
            raise NotImplementedError(operation)
        return dist_factor


class ColorSelectionWaitingNodeEvaluator(HandmadeNodeEvaluatorBase):

    def get_task_feature(self, color_selection_waiting_node: ColorSelectionWaitingNode) -> TaskFeature:
        return color_selection_waiting_node.task_feature

    def get_element_inclusion_prob(self, color_selection_waiting_node: ColorSelectionWaitingNode) -> float:
        return self.operation_element_prob_dict[color_selection_waiting_node.next_selection.__class__.__name__]

    def calculate_dist_factor(self, color_selection_waiting_node: ColorSelectionWaitingNode) -> float:
        return 1.0


class MaskConversionWaitingNodeEvaluator(HandmadeNodeEvaluatorBase):

    def get_task_feature(self, mask_conversion_waiting_node: MaskConversionWaitingNode) -> TaskFeature:
        return mask_conversion_waiting_node.color_selected_task_feature.task_feature

    def get_element_inclusion_prob(self, mask_conversion_waiting_node: MaskConversionWaitingNode) -> float:
        return self.operation_element_prob_dict[mask_conversion_waiting_node.next_mask_conversion.__class__.__name__]

    def calculate_dist_factor(self, mask_conversion_waiting_node: MaskConversionWaitingNode) -> float:
        return 1.0


class MaskOperationSelectionWaitingNodeEvaluator(HandmadeNodeEvaluatorBase):

    def get_task_feature(self, mask_operation_waiting_node: MaskOperationSelectionWaitingNode) -> TaskFeature:
        return mask_operation_waiting_node.mask_converted_task_feature.task_feature

    def get_element_inclusion_prob(self, mask_operation_waiting_node: MaskOperationSelectionWaitingNode) -> float:
        return self.operation_element_prob_dict[mask_operation_waiting_node.next_mask_operation.__class__.__name__]

    def calculate_dist_factor(self, mask_operation_waiting_node: MaskOperationSelectionWaitingNode) -> float:
        return 1.0


class ColorChannelSelectionOperationWaitingNodeEvaluator(HandmadeNodeEvaluatorBase):

    def get_task_feature(self, node: ColorChannelSelectionOperationWaitingNode) -> TaskFeature:
        return node.task_feature

    def get_element_inclusion_prob(self, node: ColorChannelSelectionOperationWaitingNode) -> float:
        return self.operation_element_prob_dict[node.next_color_channel_selection.__class__.__name__]


class ColorChannelMaskConversionWaitingNodeEvaluator(HandmadeNodeEvaluatorBase):

    def get_task_feature(self, node: ColorChannelMaskConversionWaitingNode) -> TaskFeature:
        return node.task_feature

    def get_element_inclusion_prob(self, node: ColorChannelMaskConversionWaitingNode) -> float:
        return self.operation_element_prob_dict[node.next_mask_conversion.__class__.__name__]


class ColorChannelMergeWaitingNodeEvaluator(HandmadeNodeEvaluatorBase):

    def get_task_feature(self, node: ColorChannelMergeWaitingNode) -> TaskFeature:
        return node.task_feature

    def get_element_inclusion_prob(self, node: ColorChannelMergeWaitingNode) -> float:
        return self.operation_element_prob_dict[node.next_merge_operation.__class__.__name__]


class PartitionSelectionWaitingNodeEvaluator(HandmadeNodeEvaluatorBase):

    def get_task_feature(self, node: PartitionSelectionWaitingNode) -> TaskFeature:
        return node.task_feature

    def get_element_inclusion_prob(self, node: PartitionSelectionWaitingNode) -> float:
        return self.operation_element_prob_dict[node.next_partition_selection.__class__.__name__]


class PartitionMergeWaitingNodeEvaluator(HandmadeNodeEvaluatorBase):

    def get_task_feature(self, node: PartitionMergeWaitingNode) -> TaskFeature:
        return node.task_feature

    def get_element_inclusion_prob(self, node: PartitionMergeWaitingNode) -> float:
        return self.operation_element_prob_dict[node.next_partition_merge_operation.__class__.__name__]


def np_to_str(arr: np.ndarray) -> bytes:
    return arr.tostring()


def to_bytes(obj):
    return bytes(str(obj), encoding='utf-8')


def train_operation_element_inclusion_prediction():
    storage = load_answer_storage()
    save_answer_storage(storage)
    storage = storage.get_only_correct_answer_storage()
    print(storage.get_text())

    type_classes = [c.__name__ for c in get_all_operation_classes()]
    subclasses = [c.__name__ for c in get_all_operation_element_classes()]
    record_dicts = []

    for task_name, elements in storage.get_task_grouped_elements():
        pseudo_operation_set = OperationSet(list(chain.from_iterable([e.operation_set.operations for e in elements])))

        task = TaskLoader().get_task(task_name)
        task_feature = create_task_feature(task)

        operation_type_classes = [o_s_t.__name__ for o_s_t in pseudo_operation_set.types()]
        type_answer_dict = {c: c in operation_type_classes for c in type_classes}

        operation_element_classes = [o_s_e.__class__.__name__ for o_s_e in pseudo_operation_set.elements()]
        element_answer_dict = {c: c in operation_element_classes for c in subclasses}

        record_dicts.append({**asdict(task_feature), **type_answer_dict, **element_answer_dict})

    df = DataFrame(record_dicts)
    df = df.fillna(10)  # TODO do not use magic number

    target_columns = type_classes + subclasses
    feature_columns = list(set(df.columns) - set(target_columns))

    x = df[feature_columns]
    y = df[target_columns]

    print(x)
    print(y)

    model = MLPClassifier(
        # early_stopping=True, validation_fraction=0.3, n_iter_no_change=50,
        hidden_layer_sizes=(50,), solver='sgd', learning_rate_init=0.003, max_iter=40,
        verbose=True,
    )
    model.fit(x, y)

    shutil.rmtree(PathConfig.OPERATION_ELEMENT_INCLUSION_MODEL_ROOT, ignore_errors=True)
    PathConfig.OPERATION_ELEMENT_INCLUSION_MODEL_ROOT.mkdir(parents=True, exist_ok=True)
    pickle.dump(model, PathConfig.OPERATION_ELEMENT_INCLUSION_MODEL.open(mode='wb'))
    pickle.dump(feature_columns, PathConfig.OPERATION_ELEMENT_INCLUSION_MODEL_FEATURE_COLUMNS.open(mode='wb'))
    pickle.dump(target_columns, PathConfig.OPERATION_ELEMENT_INCLUSION_MODEL_TARGET_COLUMNS.open(mode='wb'))

    temp_dicts = []
    for e in sorted(storage.elements, key=lambda e: e.task_name):
        task = TaskLoader().get_task(e.task_name)
        task_feature = create_task_feature(task)
        pred_dict = predict_operation_element_inclusion(task_feature)

        operation_type_classes = [o_s_t.__name__ for o_s_t in e.operation_set.types()]
        operation_element_classes = [o_s_e.__class__.__name__ for o_s_e in e.operation_set.elements()]
        element_answer_dict = {c: c in operation_type_classes + operation_element_classes for c in type_classes + subclasses}
        temp_dicts.append(element_answer_dict)
        temp_dicts.append(pred_dict)

    temp_df = DataFrame(temp_dicts)
    print(temp_df)


def predict_operation_element_inclusion(task_feature: TaskFeature) -> Dict[str, float]:
    model: MLPClassifier = pickle.load(PathConfig.OPERATION_ELEMENT_INCLUSION_MODEL.open(mode='rb'))
    feature_columns = pickle.load(PathConfig.OPERATION_ELEMENT_INCLUSION_MODEL_FEATURE_COLUMNS.open(mode='rb'))
    target_columns = pickle.load(PathConfig.OPERATION_ELEMENT_INCLUSION_MODEL_TARGET_COLUMNS.open(mode='rb'))

    df = DataFrame([asdict(task_feature)])
    df = df.fillna(10)

    x = df[feature_columns]
    y = model.predict_proba(x)[0]

    return {c: p for c, p in zip(target_columns, y)}


@dataclass
class NodeBaseSearchEngine:
    MAX_NODE = 100000000
    answer_limit_num: int = 3

    def search(self, task: Task, params: AllParameter, verbose: bool = False) -> Union[AnsweredSearchResults, NotAnsweredSearchResult]:
        task_feature = create_task_feature(task, task)
        if RunConfig.USE_ML_GUIDE:
            operation_element_prob_dict = predict_operation_element_inclusion(task_feature)
        else:
            operation_element_prob_dict = defaultdict(lambda: 1)

        schedules: NodeEvaluatorSchedules = get_schedule(operation_element_prob_dict, params.node_base_engine_param, params.distance_evaluator_param)
        node_evaluator = schedules.pop_evaluator()

        root_node = UniformOperationCompletedNode(None, task, task, task_feature, OperationSet([]))
        first_waiting_nodes = CompletedNodeProcessor.process(root_node)
        node_evaluator.evaluate_nodes(first_waiting_nodes)
        zero_depth_pq = PriorityQueue([*first_waiting_nodes])
        pq = PriorityQueue([])
        zero_depth_completed_nodes = []
        zero_depth_completed_node_eval_map = {}
        visited_node_hashes = defaultdict(dict)  # If same array is found, cache to save time.

        if verbose:
            print('search zero depth nodes')
        with Timer() as timer:
            for node_i in range(self.MAX_NODE):
                if len(zero_depth_pq) == 0:
                    break

                waiting_new_nodes = []
                for same_cost_node_i, waiting_node in enumerate(zero_depth_pq.pop_mins_or_as_least_n(params.node_base_engine_param.pq_pop_mins_or_as_least_n)):
                    completed_node = WaitingNodeProcessor().process(waiting_node)

                    if isinstance(completed_node, Exception):
                        if verbose:
                            print(f'skipped: {completed_node}')
                        continue

                    if isinstance(completed_node, UniformOperationCompletedNode):
                        zero_depth_completed_nodes.append(completed_node)
                        zero_depth_completed_node_eval_map[completed_node.base_operation_set.operations[0]] = node_evaluator.evaluate_base_distance_for_completed_node(completed_node)
                        continue

                    temp_waiting_new_nodes = CompletedNodeProcessor.process(completed_node)
                    node_evaluator.evaluate_nodes(temp_waiting_new_nodes)

                    waiting_new_nodes += temp_waiting_new_nodes

                for n in waiting_new_nodes:
                    zero_depth_pq.push(n)

            one_depth_answer_nodes = [k for k, v in zero_depth_completed_node_eval_map.items() if v == 0]
            if one_depth_answer_nodes:
                answers = []
                result_applied_tasks = []
                for o in one_depth_answer_nodes:
                    try:
                        applied_task = TaskOperationSetExecutor().execute(task, OperationSet([o]))
                    except OperationInconsistencyException:
                        continue
                    if any(applied_task.test_arr_hash() == t.test_arr_hash() for t in result_applied_tasks):
                        continue
                    result_applied_tasks.append(applied_task)
                    answers.append(AnsweredSearchResult(OperationSet([o])))
                answers = answers[:3]
                return AnsweredSearchResults(task, answers, timer.second(), 0, node_i)
            zero_depth_search_time = timer.second()

        for completed_node in zero_depth_completed_nodes:
            train_node_hash = completed_node.train_arr_hash()
            all_node_hash = completed_node.all_arr_hash()
            if train_node_hash in visited_node_hashes:
                if verbose:
                    print(f'hash skipped. same node: {"_".join(map(str, (f"{k}:{v}" for k, v in visited_node_hashes[train_node_hash].items())))}')
                visited_node_hashes[train_node_hash][all_node_hash] = completed_node
                continue
            visited_node_hashes[train_node_hash][all_node_hash] = completed_node

            temp_waiting_new_nodes = CompletedNodeProcessor.process(completed_node)
            node_evaluator.evaluate_nodes(temp_waiting_new_nodes)
            for n in temp_waiting_new_nodes:
                pq.push(n)

        # TODO 1 depthものを使って評価関数をいい感じに
        # TODO 同じoperationを含むとマイナスな補正をかけないと、まずいかも？

        if verbose:
            print('search none-zero depth nodes')
        searched_total_node = 0
        with Timer() as timer:
            for node_i in range(self.MAX_NODE):
                if len(pq) == 0:
                    return NotAnsweredSearchResult(task, NoImprovementException(), timer.second(), searched_total_node)

                waiting_new_nodes = []
                for same_cost_node_i, waiting_node in enumerate(pq.pop_mins_or_as_least_n(params.node_base_engine_param.pq_pop_mins_or_as_least_n)):
                    if verbose:
                        print(f'total_node: {searched_total_node}, node: {node_i}_{same_cost_node_i}, pq_len: {len(pq)}, cost: {waiting_node.cache_pred_distance}, {waiting_node}')

                    searched_total_node += 1
                    completed_node = WaitingNodeProcessor().process(waiting_node)

                    if isinstance(completed_node, Exception):
                        if verbose:
                            print(f'skipped: {completed_node}')
                        continue

                    if isinstance(completed_node, UniformOperationCompletedNode):
                        if AnswerMatcher.is_train_all_match(completed_node.task):
                            answers = []
                            for t in get_alternative_operation_sets(task, completed_node, visited_node_hashes, verbose):
                                answers.append(AnsweredSearchResult(t.to_operation_set()))
                                if len(answers) == 3:
                                    break
                            return AnsweredSearchResults(task, answers, zero_depth_search_time, timer.second(), searched_total_node)

                    train_node_hash = completed_node.train_arr_hash()
                    all_node_hash = completed_node.all_arr_hash()
                    if train_node_hash in visited_node_hashes:
                        if verbose:
                            print(f'hash skipped. same node: {"_".join(map(str, (f"{k}:{v}" for k, v in visited_node_hashes[train_node_hash].items())))}')
                        visited_node_hashes[train_node_hash][all_node_hash] = completed_node
                        continue
                    visited_node_hashes[train_node_hash][all_node_hash] = completed_node

                    temp_waiting_new_nodes = CompletedNodeProcessor.process(completed_node)
                    node_evaluator.evaluate_nodes(temp_waiting_new_nodes)

                    waiting_new_nodes += temp_waiting_new_nodes

                    if timer.second() > schedules.timeout_sec():
                        return NotAnsweredSearchResult(task, TimeoutException(), timer.second(), searched_total_node)

                for n in waiting_new_nodes:
                    pq.push(n)

                if timer.second() > schedules.next_timing():
                    if verbose:
                        print('=========================== evaluator switch!!! ===========================')
                    node_evaluator = schedules.pop_evaluator()
                    if node_evaluator is None:
                        return NotAnsweredSearchResult(task, TimeoutException(), timer.second(), searched_total_node)
                    node_evaluator.evaluate_nodes(pq.heap)
                    pq.refresh()

        return NotAnsweredSearchResult(task, MaxNodeExceededException(), timer.second(), searched_total_node)


class WaitingNodeProcessor:

    def process(self, node: WaitingNode) -> Union[CompletedNode, OperationInconsistencyException]:
        mapping = {
            UniformOperationWaitingNode: UniformOperationWaitingNodeProcessor(),
            ColorSelectionWaitingNode: ColorSelectionWaitingNodeProcessor(),
            MaskConversionWaitingNode: MaskConversionWaitingNodeProcessor(),
            MaskOperationSelectionWaitingNode: MaskOperationSelectionWaitingNodeProcessor(),
            ColorChannelSelectionOperationWaitingNode: ColorChannelSelectionOperationWaitingNodeProcessor(),
            ColorChannelMaskConversionWaitingNode: ColorChannelMaskConversionWaitingNodeProcessor(),
            ColorChannelMergeWaitingNode: ColorChannelMergeWaitingNodeProcessor(),
            PartitionSelectionWaitingNode: PartitionSelectionWaitingNodeProcessor(),
            PartitionMergeWaitingNode: PartitionMergeWaitingNodeProcessor(),
        }

        try:
            processor = mapping[node.__class__]
            return processor.process(node)
        except OperationInconsistencyException as e:
            return e


class UniformOperationWaitingNodeProcessor:

    def process(self, node: UniformOperationWaitingNode) -> UniformOperationCompletedNode:
        new_task = TaskOperationSetExecutor().execute(node.task, OperationSet([node.next_operation]))
        if self.can_skip(node.task, new_task):
            raise OperationInconsistencyException(f'can skip')
        new_task_feature = create_task_feature(node.original_task, new_task)
        new_base_operation_set = OperationSet(node.base_operation_set.operations + [node.next_operation])

        return UniformOperationCompletedNode(node, node.original_task, new_task, new_task_feature, new_base_operation_set)

    def can_skip(self, prev_task: Task, next_task: Task) -> bool:
        # TODO use OperationInconsistencyException?
        if all(AnswerMatcher.is_match_arr(prev_io.input_arr, next_io.input_arr) for prev_io, next_io in zip(prev_task.train, next_task.train)):
            # no effect
            return True
        else:
            return False


class ColorSelectionWaitingNodeProcessor:

    def process(self, node: ColorSelectionWaitingNode) -> ColorSelectionCompletedNode:
        color_selected_task = ColorSelectionExecutor.execute(node.task, node.next_selection)
        if self.can_skip(color_selected_task):
            raise OperationInconsistencyException(f'can skip')
        color_selected_task_feature = create_color_selected_task_feature(node.original_task, color_selected_task, node.task_feature)
        return ColorSelectionCompletedNode(node, node.original_task, color_selected_task,
                                           color_selected_task_feature, node.base_operation_set, node.next_selection)

    def can_skip(self, color_selected_task: ColorSelectedTask) -> bool:
        # TODO use OperationInconsistencyException?
        if not any(m.any() for m in color_selected_task.train_masks):
            # if no mask was generated, skip.
            return True
        elif all(m.all() for m in color_selected_task.train_masks):
            # if mask covers all region, skip.
            return True
        else:
            return False


class MaskConversionWaitingNodeProcessor:

    def process(self, node: MaskConversionWaitingNode) -> MaskConversionCompletedNode:
        mask_converted_task = MaskConversionExecutor.execute(node.color_selected_task, node.next_mask_conversion)
        if self.can_skip(mask_converted_task):
            raise OperationInconsistencyException(f'can skip')
        mask_converted_task_feature = create_mask_conversion_task_feature(node.original_task, mask_converted_task, node.color_selected_task_feature.task_feature)

        return MaskConversionCompletedNode(node, node.original_task, mask_converted_task, mask_converted_task_feature,
                                           node.base_operation_set, node.color_selection, node.next_mask_conversion)

    def can_skip(self, mask_converted_task: MaskConvertedTask) -> bool:
        if not any(m.any() for m in mask_converted_task.train_masks):
            # if no mask was generated, skip.
            return True
        elif all(m.all() for m in mask_converted_task.train_masks):
            # if mask covers all region, skip.
            return True
        else:
            return False


class MaskOperationSelectionWaitingNodeProcessor:

    def process(self, node: MaskOperationSelectionWaitingNode) -> UniformOperationCompletedNode:
        new_task = MaskOperationExecutor.execute(node.mask_converted_task, node.next_mask_operation)
        if self.can_skip(node.mask_converted_task, new_task):
            raise OperationInconsistencyException(f'can skip')
        new_task_feature = create_task_feature(node.original_task, new_task)
        new_base_operation_set = OperationSet(node.base_operation_set.operations +
                                              [ColorOperation(node.color_selection, node.mask_conversion, node.next_mask_operation)])

        return UniformOperationCompletedNode(node, node.original_task, new_task, new_task_feature, new_base_operation_set)

    def can_skip(self, prev_task: Task, next_task: Task) -> bool:
        if all(AnswerMatcher.is_match_arr(prev_io.input_arr, next_io.input_arr) for prev_io, next_io in zip(prev_task.train, next_task.train)):
            # no effect
            return True
        else:
            return False


class ColorChannelSelectionOperationWaitingNodeProcessor:

    def process(self, node: ColorChannelSelectionOperationWaitingNode) -> ColorChannelSelectionCompletedNode:
        new_task = ColorChannelSelectionExecutor().execute(node.task, node.next_color_channel_selection)
        if self.can_skip(node.task, new_task):
            raise OperationInconsistencyException(f'can skip')

        # reuse old feature.
        return ColorChannelSelectionCompletedNode(node, node.original_task, new_task, node.task_feature, node.base_operation_set, node.next_color_channel_selection)

    def can_skip(self, prev_task: Task, next_task: Task) -> bool:
        # TODO imple
        return False


class ColorChannelMaskConversionWaitingNodeProcessor:

    def process(self, node: ColorChannelMaskConversionWaitingNode) -> ColorChannelMaskConversionCompletedNode:
        new_task = ColorChannelMaskConversionSelectionExecutor().execute(node.task, node.next_mask_conversion)
        if self.can_skip(node.task, new_task):
            raise OperationInconsistencyException(f'can skip')

        # reuse old feature.
        return ColorChannelMaskConversionCompletedNode(node, node.original_task, new_task, node.task_feature, node.base_operation_set, node.color_channel_selection, node.next_mask_conversion)

    def can_skip(self, prev_task: Task, next_task: Task) -> bool:
        # TODO imple
        return False


class ColorChannelMergeWaitingNodeProcessor:

    def process(self, node: ColorChannelMergeWaitingNode) -> UniformOperationCompletedNode:
        new_task = ColorChannelMergeExecutor.execute(node.task, node.next_merge_operation)
        if self.can_skip(node.task, new_task):
            raise OperationInconsistencyException(f'can skip')

        new_task_feature = create_task_feature(node.original_task, new_task)
        new_base_operation_set = OperationSet(node.base_operation_set.operations +
                                              [MultiColorChannelOperation(node.color_channel_selection, node.mask_conversion, node.next_merge_operation)])

        return UniformOperationCompletedNode(node, node.original_task, new_task, new_task_feature, new_base_operation_set)

    def can_skip(self, prev_task: Task, next_task: Task) -> bool:
        # TODO imple
        return False


class PartitionSelectionWaitingNodeProcessor:

    def process(self, node: PartitionSelectionWaitingNode) -> PartitionSelectionCompletedNode:
        new_task = PartitionSelectionExecutor().execute(node.task, node.next_partition_selection)
        if self.can_skip(node.task, new_task):
            raise OperationInconsistencyException(f'can skip')

        # reuse old feature.
        return PartitionSelectionCompletedNode(node, node.original_task, new_task, node.task_feature, node.base_operation_set, node.next_partition_selection)

    def can_skip(self, prev_task: Task, next_task: Task) -> bool:
        # TODO imple
        return False


class PartitionMergeWaitingNodeProcessor:
    def process(self, node: PartitionMergeWaitingNode) -> UniformOperationCompletedNode:
        new_task = PartitionMergeExecutor().execute(node.task, node.next_partition_merge_operation)

        if self.can_skip(node.task, new_task):
            raise OperationInconsistencyException(f'can skip')

        new_task_feature = create_task_feature(node.original_task, new_task)
        new_base_operation_set = OperationSet(node.base_operation_set.operations + [PartitionOperation(node.partition_selection, node.next_partition_merge_operation)])

        return UniformOperationCompletedNode(node, node.original_task, new_task, new_task_feature, new_base_operation_set)

    def can_skip(self, prev_task: Task, next_task: Task) -> bool:
        # TODO imple
        return False


@dataclass(frozen=True)
class WithOutMostCommonColorChannelSelection(ColorChannelSelection):
    bg_selection_mode: BackGroundColorSelectionMode

    def __call__(self, arr: np.ndarray) -> List[Tuple[Color, np.ndarray]]:
        bg = ColorSelectionUtil().get_background_color(arr, self.bg_selection_mode)

        colors = ColorSelectionUtil().get_colors(arr)

        results = [(c, arr == c) for c in colors if c != bg]

        if len(results) <= 1:
            raise OperationInconsistencyException('can not devide')

        return results


@dataclass
class OperationSetExecutionResultHolder:
    raw_task: Task
    cache: Dict[str, Tuple[Task, TaskFeature]]

    def get_result(self, operation_set: OperationSet) -> Tuple[Task, TaskFeature]:
        if str(operation_set) in self.cache:
            return self.cache[str(operation_set)]

        for i in reversed(range(1, len(operation_set.operations))):
            prev_o_s = OperationSet(operation_set.operations[:i])
            post_o_s = OperationSet(operation_set.operations[i:])
            assert len(prev_o_s.operations) + len(post_o_s.operations) == len(operation_set.operations)

            if str(prev_o_s) in self.cache:
                prev_task, _ = self.cache[str(prev_o_s)]
                post_o_s_applied_task = TaskOperationSetExecutor().execute(prev_task, post_o_s)
                post_o_s_applied_task_feature = create_task_feature(post_o_s_applied_task)
                self.cache[str(operation_set)] = (post_o_s_applied_task, post_o_s_applied_task_feature)
                return post_o_s_applied_task, post_o_s_applied_task_feature

        applied_task = TaskOperationSetExecutor().execute(self.raw_task, operation_set)
        applied_task_feature = create_task_feature(applied_task)
        self.cache[str(operation_set)] = (applied_task, applied_task_feature)
        return applied_task, applied_task_feature


@dataclass
class OperationSetMutator:
    # TODO uniform(0, 1) is redundant. There must be easier way to do it.
    # TODO I'd like to define procedure of probability. like albumentation.
    # TODO should increase max_depth dynamically?

    holder: OperationSetExecutionResultHolder
    operation_element_prob_dict: Dict[str, float]

    def mutate(self, operation_set: OperationSet):
        new_operations = []
        for o in operation_set.operations:
            task, task_feature = self.holder.get_result(OperationSet(new_operations))
            if random.uniform(0, 1) < TreeBaseSearchEngineParameter.operation_mutation_prob:
                new_operations.append(self.get_random_one_operation(task, task_feature))
            elif random.uniform(0, 1) < TreeBaseSearchEngineParameter.operation_component_mutation_prob:
                if isinstance(o, UniformOperation):
                    new_operations.append(self._uniform_operation_candidates(task_feature))
                elif isinstance(o, ColorOperation):
                    color_sel, add_sels, mask_ope = o.color_selection, o.mask_conversions, o.mask_operation
                    if random.uniform(0, 1) < 1 / 3:
                        color_sel = self._color_selection_candidates(task)
                    if random.uniform(0, 1) < 1 / 3:
                        add_sels = [self._mask_conversions()]
                    if random.uniform(0, 1) < 1 / 3:
                        mask_ope = self._mask_operation_candidates(task)
                    new_operations.append(ColorOperation(color_sel, add_sels, mask_ope))
                else:
                    raise NotImplementedError()
            elif random.uniform(0, 1) < TreeBaseSearchEngineParameter.operation_param_mutation_prob:
                if isinstance(o, UniformOperation):
                    new_operations.append(self._mutate_parameter(o, task))
                elif isinstance(o, ColorOperation):
                    color_sel, add_sels, mask_ope = o.color_selection, o.mask_conversions, o.mask_operation
                    if random.uniform(0, 1) < 1 / 3:
                        color_sel = self._mutate_parameter(color_sel, task)
                    if random.uniform(0, 1) < 1 / 3:
                        add_sels = [self._mutate_parameter(add_sels[0], task)]
                    if random.uniform(0, 1) < 1 / 3:
                        mask_ope = self._mutate_parameter(mask_ope, task)
                    new_operations.append(ColorOperation(color_sel, add_sels, mask_ope))
                else:
                    raise NotImplementedError()
            elif random.uniform(0, 1) < TreeBaseSearchEngineParameter.shrink_mutation_prob:
                continue
            else:
                new_operations.append(o)

        if len(new_operations) < TreeBaseSearchEngineParameter.max_depth:
            if random.uniform(0, 1) < TreeBaseSearchEngineParameter.extend_mutation_prob:
                temp_new_set = OperationSet(new_operations)
                task, task_feature = self.holder.get_result(temp_new_set)
                new_operations.append(self.get_random_one_operation(task, task_feature))

        return OperationSet(new_operations)

    def get_random_one_operation(self, task: Task, task_feature: TaskFeature):
        classes = [UniformOperation, ColorOperation]

        class_probs = [self.operation_element_prob_dict[c.__name__] for c in classes]
        total_prob = sum(class_probs)
        class_probs = [p / total_prob for p in class_probs]

        chosen_class = np.random.choice(classes, p=class_probs)

        if chosen_class == UniformOperation:
            operation = self._uniform_operation_candidates(task_feature)
        elif chosen_class == ColorOperation:
            color_sel = self._color_selection_candidates(task)
            add_sels = self._mask_conversions()
            mask_ope = self._mask_operation_candidates(task)
            operation = ColorOperation(color_sel, add_sels, mask_ope)
        else:
            raise NotImplementedError()

        return operation

    def _uniform_operation_candidates(self, task_feature: TaskFeature):
        classes = [Resize, Padding, Flip, Rotate]

        class_probs = [self.operation_element_prob_dict[c.__name__] for c in classes]
        total_prob = sum(class_probs)
        class_probs = [p / total_prob for p in class_probs]

        chosen_class = np.random.choice(classes, p=class_probs)

        if chosen_class == Resize:
            return random.choice([Resize(a, r) for a, r in product(Axis, range(2, 5))])
        elif chosen_class == Padding:
            return random.choice([Padding(m, d, k) for m, d, k in product(PaddingMode, Direction, range(1, 4))])
        elif chosen_class == Flip:
            return random.choice([Flip(m) for m in FlipMode])
        elif chosen_class == Rotate:
            return random.choice([Rotate(a) for a in [90, 180, 270]])
        else:
            raise NotImplementedError()

    def _color_selection_candidates(self, task: Task):
        classes = [FixedSingleColorSelection, SingleColorSelection, MultiColorSelection]

        class_probs = [self.operation_element_prob_dict[c.__name__] for c in classes]
        total_prob = sum(class_probs)
        class_probs = [p / total_prob for p in class_probs]

        chosen_class = np.random.choice(classes, p=class_probs)

        if chosen_class == FixedSingleColorSelection:
            input_colors = list(map(lambda v: Color.of(v), set(chain.from_iterable(chain.from_iterable(task.get_input_all_arr())))))
            return random.choice([FixedSingleColorSelection(c) for c in input_colors])
        elif chosen_class == SingleColorSelection:
            return random.choice([SingleColorSelection(m) for m in SingleColorSelectionMode])
        elif chosen_class == MultiColorSelection:
            return random.choice([MultiColorSelection(m) for m in MultiColorSelectionMode])
        else:
            raise NotImplementedError()

    def _mask_conversions(self):
        classes = [NoMaskConversion, SquareObjectsSelection, ObjectsMaxMinSelection, SplitLineSelection, DotExistLineSelection,
                   HolesSelection, ObjectInnerSelection, ContourSelection, ContourOuterSelection, ConnectDotSelection]

        class_probs = [self.operation_element_prob_dict[c.__name__] for c in classes]
        total_prob = sum(class_probs)
        class_probs = [p / total_prob for p in class_probs]

        chosen_class = np.random.choice(classes, p=class_probs)

        if chosen_class == NoMaskConversion:
            return NoMaskConversion()
        elif chosen_class == SquareObjectsSelection:
            return SquareObjectsSelection()
        elif chosen_class == ObjectsMaxMinSelection:
            return random.choice([ObjectsMaxMinSelection(m, t, c) for m, t, c in product(MaxOrMin, ObjectFeature, PixelConnectivity)])
        elif chosen_class == SplitLineSelection:
            return random.choice([SplitLineSelection(a) for a in Axis])
        elif chosen_class == DotExistLineSelection:
            return random.choice([DotExistLineSelection(a) for a in Axis])
        elif chosen_class == HolesSelection:
            return random.choice([HolesSelection(c) for c in PixelConnectivity])
        elif chosen_class == ObjectInnerSelection:
            return random.choice([ObjectInnerSelection(c, e) for c, e in product(PixelConnectivity, ImageEdgeType)])
        elif chosen_class == ContourSelection:
            return random.choice([ContourSelection(c, e) for c, e in product(PixelConnectivity, ImageEdgeType)])
        elif chosen_class == ContourOuterSelection:
            return random.choice([ContourOuterSelection(c, h) for c, h in product(PixelConnectivity, HoleInclude)])
        elif chosen_class == ConnectDotSelection:
            return random.choice([ConnectDotSelection(a, e, f) for a, e, f in product(Axis, LineEdgeType, FillType)])
        else:
            raise NotImplementedError()

    def _mask_operation_candidates(self, task: Task):
        classes = [MaskCoordsCrop, FixedColorMaskFill, SingleColorMaskFill]

        class_probs = [self.operation_element_prob_dict[c.__name__] for c in classes]
        total_prob = sum(class_probs)
        class_probs = [p / total_prob for p in class_probs]

        chosen_class = np.random.choice(classes, p=class_probs)

        if chosen_class == MaskCoordsCrop:
            return MaskCoordsCrop()
        elif chosen_class == FixedColorMaskFill:
            output_colors = list(map(lambda v: Color.of(v), set(chain.from_iterable(chain.from_iterable(task.get_output_all_arr())))))
            return random.choice([FixedColorMaskFill(c) for c in output_colors])
        elif chosen_class == SingleColorMaskFill:
            return random.choice([SingleColorMaskFill(m) for m in SingleColorSelectionMode])
        else:
            raise NotImplementedError()

    def _mutate_parameter(self, operation_element, task: Task):
        # TODO Should mutate one property of operation_element.
        if isinstance(operation_element, Resize):
            return random.choice([Resize(Axis.VERTICAL, r) for r in range(2, 5)])
        elif isinstance(operation_element, Padding):
            return random.choice([Padding(m, d, k) for m, d, k in product(PaddingMode, [Direction.TOP, Direction.BOTTOM], range(1, 4))])
        elif isinstance(operation_element, Flip):
            return random.choice([Flip(m) for m in FlipMode])
        elif isinstance(operation_element, Rotate):
            return random.choice([Rotate(a) for a in [90, 180, 270]])
        elif isinstance(operation_element, FixedSingleColorSelection):
            input_colors = list(map(lambda v: Color.of(v), set(chain.from_iterable(chain.from_iterable(task.get_input_all_arr())))))
            return random.choice([FixedSingleColorSelection(c) for c in input_colors])
        elif isinstance(operation_element, SingleColorSelection):
            return random.choice([SingleColorSelection(m) for m in SingleColorSelectionMode])
        elif isinstance(operation_element, MultiColorSelection):
            return random.choice([MultiColorSelection(m) for m in MultiColorSelectionMode])
        elif isinstance(operation_element, NoMaskConversion):
            return random.choice([NoMaskConversion()])
        elif isinstance(operation_element, SquareObjectsSelection):
            return random.choice([SquareObjectsSelection()])
        elif isinstance(operation_element, ObjectsMaxMinSelection):
            return random.choice([ObjectsMaxMinSelection(m, t, c) for m, t, c in product(MaxOrMin, ObjectFeature, PixelConnectivity)])
        elif isinstance(operation_element, SplitLineSelection):
            return random.choice([SplitLineSelection(a) for a in Axis])
        elif isinstance(operation_element, DotExistLineSelection):
            return random.choice([DotExistLineSelection(a) for a in Axis])
        elif isinstance(operation_element, HolesSelection):
            return random.choice([HolesSelection(c) for c in PixelConnectivity])
        elif isinstance(operation_element, ObjectInnerSelection):
            return random.choice([ObjectInnerSelection(c, e) for c, e in product(PixelConnectivity, ImageEdgeType)])
        elif isinstance(operation_element, ContourSelection):
            return random.choice([ContourSelection(c, e) for c, e in product(PixelConnectivity, ImageEdgeType)])
        elif isinstance(operation_element, ContourOuterSelection):
            return random.choice([ContourOuterSelection(c, h) for c, h in product(PixelConnectivity, HoleInclude)])
        elif isinstance(operation_element, ConnectDotSelection):
            return random.choice([ConnectDotSelection(a, e, f) for a, e, f in product(Axis, LineEdgeType, FillType)])
        elif isinstance(operation_element, MaskCoordsCrop):
            return random.choice([MaskCoordsCrop()])
        elif isinstance(operation_element, FixedColorMaskFill):
            output_colors = list(map(lambda v: Color.of(v), set(chain.from_iterable(chain.from_iterable(task.get_output_all_arr())))))
            return random.choice([FixedColorMaskFill(c) for c in output_colors])
        elif isinstance(operation_element, SingleColorMaskFill):
            return random.choice([SingleColorMaskFill(m) for m in SingleColorSelectionMode])
        else:
            raise NotImplementedError(operation_element)


@dataclass
class Individual:
    operation_set: OperationSet
    distance: float
    task_feature: TaskFeature

    def __str__(self):
        return f'depth: {len(self.operation_set.operations)}, dist: {self.distance:.5f}, ope: {self.operation_set}'


@dataclass
class Population:
    strategy: str
    individuals: List[Individual]

    def show(self):
        self.sort()
        for i in self.individuals:
            print(i)

    def sort(self):
        random.shuffle(self.individuals)
        self.individuals = sorted(self.individuals, key=lambda i: i.distance)

    def get_elite(self):
        # TODO Lack of consideration when there were multiple elites.
        self.sort()
        return self.individuals[0]

    def get_dist0_if_exists(self) -> Optional[OperationSet]:
        elite = min(self.individuals, key=lambda i: i.distance)
        if elite.distance == 0:
            return elite.operation_set
        else:
            return None

    def mutate(self, mutator, holder, evaluator):
        # 変異
        self.sort()
        mutated_individuals = [self.get_elite()]
        for i in self.individuals[1:]:
            for _ in range(1000000000000):
                try:
                    mutated_operation_set = mutator.mutate(i.operation_set)
                    # チェックする。
                    task, task_feature = holder.get_result(mutated_operation_set)
                    break
                except OperationInconsistencyException:
                    continue

            applied_task, applied_task_feature = holder.get_result(mutated_operation_set)
            mutation_distance = evaluator.evaluate_task_feature(applied_task_feature)
            mutated_individuals.append(Individual(mutated_operation_set, mutation_distance, applied_task_feature))

        self.individuals = mutated_individuals

    def select(self):
        if self.strategy == 'simple':
            self.individuals = self.select_simple()
        elif self.strategy == 'nsga2':
            self.individuals = self.select_nsga2()
        else:
            raise NotImplementedError()

    def select_nsga2(self):
        raw_len = len(self.individuals)
        selected = selNSGA2(self.individuals, raw_len)

        simple_selection = self.select_simple(include_elite=True)

        return selected + simple_selection[:raw_len - len(selected)]

    def select_simple(self, include_elite: bool = True):
        # 選択
        self.sort()
        if include_elite:
            next_individuals = [self.get_elite()]
        else:
            next_individuals = []
        # score = 1 / distance # TODO Handle 0 division
        score_sum = sum(map(lambda i: 1 / i.distance, self.individuals))
        score_ratios = [1 / i.distance / score_sum for i in self.individuals]
        score_roulette = np.cumsum(score_ratios)

        for _ in range(len(self.individuals) - 1):
            roulette_prob_hit = random.uniform(0, 1)
            for i, roulette_prob in enumerate(score_roulette):
                if roulette_prob_hit < roulette_prob:
                    next_individuals.append(self.individuals[i])
                    break

        return next_individuals


@dataclass
class TreeBaseSearchEngine:
    time_out: int = 60  # TODO

    def get_first_individual(self, evaluator, mutator, holder, task, root_task_feature):
        try:
            operation = mutator.get_random_one_operation(task, root_task_feature)
            operation_set = OperationSet([operation])
            _, task_feature = holder.get_result(operation_set)
            distance = evaluator.evaluate_task_feature(task_feature)
            return Individual(operation_set, distance, task_feature)
        except OperationInconsistencyException:
            return self.get_first_individual(evaluator, mutator, holder, task, root_task_feature)

    def search(self, task: Task, params: AllParameter, verbose: bool = False) -> Union[AnsweredSearchResults, NotAnsweredSearchResult]:
        evaluator = DistanceEvaluator(params.distance_evaluator_param)
        holder = OperationSetExecutionResultHolder(task, {})
        root_operation_set = OperationSet([])
        _, root_task_feature = holder.get_result(root_operation_set)

        if RunConfig.USE_ML_GUIDE:
            operation_element_prob_dict = predict_operation_element_inclusion()
        else:
            operation_element_prob_dict = defaultdict(lambda: 1)

        if verbose:
            print(operation_element_prob_dict)

        mutator: OperationSetMutator = OperationSetMutator(holder, operation_element_prob_dict)

        individuals = [self.get_first_individual(evaluator, mutator, holder, task, root_task_feature) for _ in range(TreeBaseSearchEngineParameter.population_num)]

        population = Population('simple', individuals)
        with Timer() as timer:

            for i in range(10000000):

                if verbose:
                    print(f'============== generation: {i} population')
                    population.show()

                population.mutate(mutator, holder, evaluator)

                if verbose:
                    print(f'============== generation: {i}, mutation population')
                    population.show()

                answer_operation_set = population.get_dist0_if_exists()
                if answer_operation_set is not None:
                    if AnswerMatcher.is_train_all_match_if_operated(task, answer_operation_set):
                        return AnsweredSearchResults(task, [AnsweredSearchResult(answer_operation_set)], timer.second(), i)
                    else:
                        raise NotImplementedError()

                population.select()

                if timer.second() > self.time_out:
                    return NotAnsweredSearchResult(task, TimeoutException(), timer.second(), i)

            return NotAnsweredSearchResult(task, MaxNodeExceededException(), timer.second(), i)


T = TypeVar('T')


class PriorityQueue:
    def __init__(self, heap: List[T]):
        self.heap = heap
        heapify(self.heap)

    def refresh(self):
        heapify(self.heap)

    def push(self, item: T):
        heappush(self.heap, item)

    def pop_min(self) -> T:
        return heappop(self.heap)

    def pop_mins(self) -> List[T]:
        min_item = self.pop_min()
        results = [min_item]
        for _ in range(len(self.heap)):
            item = self.pop_min()
            if item <= min_item:
                results.append(item)
            else:
                self.push(item)
                return results

        return results

    def pop_mins_or_as_least_n(self, n: int) -> List[T]:
        results = []

        while len(results) < n:
            if len(self) == 0:
                break
            results += self.pop_mins()

        return results

    def push_pop(self, item: T) -> T:
        return heappushpop(self.heap, item)

    def __len__(self) -> int:
        return len(self.heap)

    def sorted_list(self) -> List[T]:
        return sorted(self.heap)


def str_to_operation_set(s: str) -> OperationSet:
    # DSL string -> DSL object
    return eval(s)


def str_to_AnswerStorageElement(s: str):
    # noinspection PyUnresolvedReferences
    # from abstraction_and_reasoning_challenge.src.answer_storage.answer_storage import AnswerStorageElement
    return eval(s)


@dataclass(frozen=True)
class UniqueColorNumberSelection(PartitionedArraySelection):
    max_or_min: MaxOrMin

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]]) -> List[List[bool]]:
        color_nums = _apply(self._color_num, partitioned_arrays)

        if self.max_or_min == MaxOrMin.MAX:
            target_color_num = max(map(max, color_nums))
        elif self.max_or_min == MaxOrMin.MIN:
            target_color_num = min(map(min, color_nums))
        else:
            raise NotImplementedError()

        return _apply(lambda n: n == target_color_num, color_nums)

    def _color_num(self, array: np.ndarray):
        return len(np.unique(array))


@dataclass(frozen=True)
class ColoredCellNumberSelection(PartitionedArraySelection):
    max_or_min: MaxOrMin
    bg_selection_mode: BackGroundColorSelectionMode

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]]) -> List[List[bool]]:
        bg = ColorSelectionUtil().get_background_color(arr, self.bg_selection_mode)

        colored_cell_nums = _apply(partial(self._colored_cell_nums, bg=bg), partitioned_arrays)

        if self.max_or_min == MaxOrMin.MAX:
            target_color_num = max(map(max, colored_cell_nums))
        elif self.max_or_min == MaxOrMin.MIN:
            target_color_num = min(map(min, colored_cell_nums))
        else:
            raise NotImplementedError()

        return _apply(lambda n: n == target_color_num, colored_cell_nums)

    def _colored_cell_nums(self, array: np.ndarray, bg: Color):
        return (array != bg).sum()


@dataclass(frozen=True)
class SameShapeNumSelection(PartitionedArraySelection):
    max_or_min: MaxOrMin

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]]) -> List[List[bool]]:
        np_strings = _apply(lambda n: n.tostring(), partitioned_arrays)

        c = Counter(_flatten(np_strings))

        most_commons = c.most_common()
        if len(most_commons) < 2:
            raise OperationInconsistencyException('can not select')

        if self.max_or_min == MaxOrMin.MAX:
            if most_commons[0][1] == most_commons[1][1]:
                raise OperationInconsistencyException('duplicated max')
            target = most_commons[0][0]
        elif self.max_or_min == MaxOrMin.MIN:
            if most_commons[-1][1] == most_commons[-2][1]:
                raise OperationInconsistencyException('duplicated min')
            target = most_commons[-1][0]
        else:
            raise NotImplementedError()

        return _apply(lambda n: n == target, np_strings)


@dataclass(frozen=True)
class SymmetrySelection(PartitionedArraySelection):
    axis: AxisV2
    true_or_false: TrueOrFalse

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]]) -> List[List[bool]]:
        return _apply(partial(self._is_symmetry, axis=self.axis, true_or_false=self.true_or_false), partitioned_arrays)

    def _is_symmetry(self, array: np.ndarray, axis: AxisV2, true_or_false: TrueOrFalse) -> bool:
        if axis == AxisV2.VERTICAL:
            res = np.array_equal(array, Flip(FlipMode.UD)(array))
        elif axis == AxisV2.HORIZONTAL:
            res = np.array_equal(array, Flip(FlipMode.LR)(array))
        elif axis == AxisV2.VERTICAL_HORIZONTAL:
            res = self._is_symmetry(array, AxisV2.VERTICAL, TrueOrFalse.TRUE) and self._is_symmetry(array, AxisV2.HORIZONTAL, TrueOrFalse.TRUE)
        elif axis == AxisV2.MAIN_DIAGONAL:
            res = np.array_equal(array, Flip(FlipMode.UL_DR)(array))
        elif axis == AxisV2.ANTI_DIAGONAL:
            res = np.array_equal(array, Flip(FlipMode.UR_DL)(array))
        elif axis == AxisV2.BOTH_DIAGONAL:
            res = self._is_symmetry(array, AxisV2.MAIN_DIAGONAL, TrueOrFalse.TRUE) and self._is_symmetry(array, AxisV2.ANTI_DIAGONAL, TrueOrFalse.TRUE)
        else:
            raise NotImplementedError()

        if true_or_false == TrueOrFalse.TRUE:
            return res
        else:
            return not res


def _apply(func, partitioned_arrays: List[List[np.ndarray]]) -> List[List[Any]]:
    results = []
    for h_arrays in partitioned_arrays:
        temp_masks = []
        for array in h_arrays:
            temp_masks.append(func(array))
        results.append(temp_masks)
    return results


def _flatten(partitioned: List[List[Any]]) -> List[Any]:
    return list(chain.from_iterable(partitioned))


class OperationSetEvaluator:
    # Evaluation function to choose three answers by ranking the OperationSet.
    # Smaller is better.

    def evaluate(self, operation_set: OperationSet) -> float:
        score_map = {
            FixedSingleColorSelection: 0.5,
        }

        return sum(score_map.get(e.__class__, 1) for e in operation_set.elements())


def get_alternative_operation_sets(raw_task: Task, last_completed_node: UniformOperationCompletedNode, visited_node_hashes: Dict[int, Dict[int, Any]], verbose: bool) -> Iterable[NodeTree]:
    if verbose:
        print('original_answer')
        print(NodeTree.of(last_completed_node).to_operation_set())
        print('===search other answers===')

    node_tree = NodeTree.of(last_completed_node)

    depth_alternative_nodes_pairs: List[Tuple[int, List[CompletedNode]]] = []
    for i, node in enumerate(node_tree.completed_nodes):
        if i == 0:
            # no alternative for root node
            continue
        if node.train_arr_hash() in visited_node_hashes:
            same_hash_node_dicts = visited_node_hashes[node.train_arr_hash()]
            alternative_nodes = [n for all_hash, n in same_hash_node_dicts.items() if all_hash != node.all_arr_hash()]
            depth_alternative_nodes_pairs.append((i, alternative_nodes))

    if verbose:
        print(f'alternative_nodes:')
        for i, alternative_nodes in depth_alternative_nodes_pairs:
            for n in alternative_nodes:
                print(f'node_depth: {i}, {n}')

    candidate_node_trees = [node_tree]
    for i, alternative_nodes in depth_alternative_nodes_pairs:
        if len(candidate_node_trees) > 1000:
            break  # TODO Too many candidate_node_trees causes Memory Error.
        for n in alternative_nodes:
            candidate_node_trees += [NodeTree.replaced_new_node_tree(t, i, n) for t in candidate_node_trees]

    if verbose:
        print('node_tree:')
        print(node_tree)
        print('candidate_node_trees:')
        for c in candidate_node_trees:
            print('===')
            print(c)

    # TODO unnecessary filter?
    candidate_node_trees = [t for t in candidate_node_trees if AnswerMatcher.is_train_all_match_if_operated(raw_task, t.to_operation_set())]
    candidate_node_trees = sorted(candidate_node_trees, key=lambda t: OperationSetEvaluator().evaluate(t.to_operation_set()))

    result_applied_tasks = []
    for t in candidate_node_trees:
        try:
            applied_task = TaskOperationSetExecutor().execute(raw_task, t.to_operation_set())
        except OperationInconsistencyException:
            continue

        if any(applied_task.test_arr_hash() == t.test_arr_hash() for t in result_applied_tasks):
            continue

        result_applied_tasks.append(applied_task)
        yield t


class ColorChannelOverrideOperation(ChannelMergeOperation):

    def __call__(self, arr: np.ndarray, original_color_mask_paris: List[Tuple[Color, np.ndarray]], color_mask_pairs: List[Tuple[Color, np.ndarray]]) -> np.ndarray:
        diff_mask_paris = [(c1, np.logical_and(np.logical_xor(o_m, c_m), c_m))
                           for (c1, o_m), (c2, c_m) in zip(sorted(original_color_mask_paris, key=itemgetter(0)),
                                                           sorted(color_mask_pairs, key=itemgetter(0)))]  # TODO should groupby color?

        check_mask = np.full_like(diff_mask_paris[0][1], fill_value=False)

        # If duplicated, InconsistencyException
        for _, m in diff_mask_paris:
            if check_mask[m].any():
                raise OperationInconsistencyException('failed channel merge')
            check_mask[m] = True

        for c, m in diff_mask_paris:
            arr[m] = c

        return arr


@dataclass
class NodeEvaluatorSchedule:
    start_sec: int
    evaluator: Optional[NodeEvaluator]


@dataclass
class NodeEvaluatorSchedules:
    schedules: List[NodeEvaluatorSchedule]

    def pop_evaluator(self) -> Optional[NodeEvaluator]:
        evaluator = self.schedules[0].evaluator
        self.schedules = self.schedules[1:]
        return evaluator

    def next_timing(self):
        return self.schedules[0].start_sec

    def timeout_sec(self):
        return self.schedules[-1].start_sec


def get_schedule(operation_element_prob_dict: Dict[str, float], node_search_engine_param, dist_eval_param) -> NodeEvaluatorSchedules:
    if RunConfig.RUN_MODE == RunMode.KERNEL:
        if RunConfig.ENGINE_SCHEDULE_PATTERN == EngineSchedulePattern.HAND_MADE:
            return NodeEvaluatorSchedules([
                NodeEvaluatorSchedule(0, HandMadeNodeEvaluator(DepthSearchPattern.BREADTH_FIRST, operation_element_prob_dict, node_search_engine_param, dist_eval_param)),
                NodeEvaluatorSchedule(60 * 1, HandMadeNodeEvaluator(DepthSearchPattern.NORMAL, operation_element_prob_dict, node_search_engine_param, dist_eval_param)),
                NodeEvaluatorSchedule(60 * 2, HandMadeNodeEvaluator(DepthSearchPattern.DEPTH_FIRST, operation_element_prob_dict, node_search_engine_param, dist_eval_param)),
                NodeEvaluatorSchedule(60 * 3, None),
            ])
        if RunConfig.ENGINE_SCHEDULE_PATTERN == EngineSchedulePattern.DRY_RUN:
            return NodeEvaluatorSchedules([
                NodeEvaluatorSchedule(0, HandMadeNodeEvaluator(DepthSearchPattern.BREADTH_FIRST, operation_element_prob_dict, node_search_engine_param, dist_eval_param)),
                NodeEvaluatorSchedule(3, None),
            ])
    else:
        if RunConfig.ENGINE_SCHEDULE_PATTERN == EngineSchedulePattern.HAND_MADE:
            return NodeEvaluatorSchedules([
                NodeEvaluatorSchedule(0, HandMadeNodeEvaluator(DepthSearchPattern.BREADTH_FIRST, operation_element_prob_dict, node_search_engine_param, dist_eval_param)),
                NodeEvaluatorSchedule(20, HandMadeNodeEvaluator(DepthSearchPattern.NORMAL, operation_element_prob_dict, node_search_engine_param, dist_eval_param)),
                NodeEvaluatorSchedule(40, HandMadeNodeEvaluator(DepthSearchPattern.DEPTH_FIRST, operation_element_prob_dict, node_search_engine_param, dist_eval_param)),
                NodeEvaluatorSchedule(60, None),
            ])
        if RunConfig.ENGINE_SCHEDULE_PATTERN == EngineSchedulePattern.ML:
            return NodeEvaluatorSchedules([
                NodeEvaluatorSchedule(0, MLNodeEvaluator(DepthSearchPattern.BREADTH_FIRST)),
                NodeEvaluatorSchedule(20, MLNodeEvaluator(DepthSearchPattern.NORMAL)),
                NodeEvaluatorSchedule(40, MLNodeEvaluator(DepthSearchPattern.DEPTH_FIRST)),
                NodeEvaluatorSchedule(60, None),
            ])
        if RunConfig.ENGINE_SCHEDULE_PATTERN == EngineSchedulePattern.DRY_RUN:
            return NodeEvaluatorSchedules([
                NodeEvaluatorSchedule(0, HandMadeNodeEvaluator(DepthSearchPattern.BREADTH_FIRST, operation_element_prob_dict, node_search_engine_param, dist_eval_param)),
                NodeEvaluatorSchedule(3, None),
            ])
    raise NotImplementedError()


def optimize_node_base_search(tasks: List[Task]):
    assert RunConfig.ENGINE_TYPE == EngineType.NODE_BASED_SEARCH_ENGINE

    def objective(trial: Trial):
        param = AllParameter(
            # distance_evaluator_param=DistanceEvaluatorParameter(
            #     same_h_w_dim_between_input_output=trial.suggest_loguniform('same_h_w_dim_between_input_output', 100, 10000),
            #     all_dim_h_w_integer_multiple=trial.suggest_loguniform('all_dim_h_w_integer_multiple', 10, 1000),
            #     mean_lack_color_num=trial.suggest_loguniform('mean_lack_color_num', 1, 100),
            #     mean_excess_color_num=trial.suggest_loguniform('mean_excess_color_num', 1, 100),
            #     mean_hit_and_miss_histogram_diff=trial.suggest_loguniform('mean_hit_and_miss_histogram_diff', 1, 100),
            #     mean_h_v_diff_input_arr_line_num=trial.suggest_loguniform('mean_h_v_diff_input_arr_line_num', 1, 100),
            #     mean_h_v_diff_output_arr_line_num=trial.suggest_loguniform('mean_h_v_diff_output_arr_line_num', 1, 100),
            #     mean_h_v_edge_sum_diff=trial.suggest_discrete_uniform('mean_h_v_edge_sum_diff', 0, 2, 0.5),
            #     mean_h_v_edge_sum_diff_ratio=trial.suggest_discrete_uniform('mean_h_v_edge_sum_diff_ratio', 0, 2, 0.5),
            # mean_diff_cell_where_no_need_to_change_count_ratio=trial.suggest_loguniform('mean_diff_cell_where_no_need_to_change_count_ratio', 1, 100000),
            # ),

            node_base_engine_param=NodeBaseSearchEngineParameter(
                # breadth_first_cost=trial.suggest_loguniform('breadth_first_cost', 1000, 100000),
                normal_first_cost=trial.suggest_loguniform('normal_first_cost', 10, 1000),
                depth_first_cost=trial.suggest_loguniform('depth_first_cost', 0.1, 10),
                # breadth_first_exp_cost=trial.suggest_loguniform('exp_cost', 0.001, 3),
                # normal_exp_cost=trial.params['exp_cost'],
                # depth_first_exp_cost=trial.params['exp_cost'],
                pq_pop_mins_or_as_least_n=trial.suggest_int('pq_pop_mins_or_as_least_n', 1, 10),
                #     element_inclusion_prob_factor=trial.suggest_loguniform('element_inclusion_prob_factor', 0.001, 10000000),
            )
        )

        print(trial.params)

        engine_results = solve_tasks(tasks, param, add_answer_storage=True)
        answered_results = [r for r in engine_results if isinstance(r, AnsweredSearchResults)]
        true_results = [r for r in engine_results if r.final_test_correct()]

        all_len = len(engine_results)
        true_len = len(true_results)
        false_len = len(answered_results) - len(true_results)
        none_len = len(engine_results) - len(answered_results)

        print(trial.params)
        print(f'true: {true_len}, false: {false_len}, none: {none_len}, all: {all_len}')
        return all_len - true_len - false_len / 2

    study = optuna.create_study()
    study.optimize(objective, n_trials=1000)

    print(study.best_params)


def optimize_tree_base_search(tasks: List[Task]):
    assert RunConfig.ENGINE_TYPE == EngineType.TREE_BASED_SEARCH_ENGINE

    def objective(trial: Trial):
        all_parameter = AllParameter(
            tree_base_engine_param=TreeBaseSearchEngineParameter(
                population_num=trial.suggest_int('population_num', 20, 80),
                max_depth=trial.suggest_int('max_depth', 6, 10),
                operation_mutation_prob=trial.suggest_loguniform('operation_mutation_prob', 0.01, 0.5),
                operation_component_mutation_prob=trial.suggest_loguniform('operation_component_mutation_prob', 0.005, 0.5),
                operation_param_mutation_prob=trial.suggest_loguniform('operation_param_mutation_prob', 0.001, 0.5),
                extend_mutation_prob=trial.suggest_loguniform('extend_mutation_prob', 0.01, 1),
                shrink_mutation_prob=trial.suggest_loguniform('shrink_mutation_prob', 0.001, 0.1),
            ))
        print(trial.params)

        engine_results = solve_tasks(tasks, all_parameter, add_answer_storage=True)
        answered_results = [r for r in engine_results if isinstance(r, AnsweredSearchResults)]
        true_results = [r for r in engine_results if r.final_test_correct()]

        all_len = len(engine_results)
        true_len = len(true_results)
        false_len = len(answered_results) - len(true_results)
        none_len = len(engine_results) - len(answered_results)

        print(f'true: {true_len}, false: {false_len}, none: {none_len}, all: {all_len}')
        return all_len - true_len - false_len / 2

    study = optuna.create_study()
    study.optimize(objective, n_trials=1000)

    print(study.best_params)


def solve_tasks(tasks: List[Task],
                params: AllParameter,
                output_summary_path: Optional[Path] = None,
                save_submission: bool = False,
                copy_wrong_answers_root_tag: Optional[str] = None,
                add_answer_storage: bool = False,
                verbose: bool = False) \
        -> List[Union[AnsweredSearchResults, NotAnsweredSearchResult]]:
    print('===== start parallel solve tasks =====\n\n')
    if RunConfig.N_JOB == 1 or len(tasks) == 1:
        engine_results = [solve_task(task, params, verbose) for task in tqdm(tasks, miniters=0, mininterval=None, maxinterval=None)]
    else:
        # with Pool(processes=RunConfig.N_JOB) as pool:
        #     args = ((task, verbose) for task in tqdm(tasks, miniters=0, mininterval=None, maxinterval=None))
        #     engine_results = pool.starmap(solve_task, args)

        # 'multiprocessing' or 'threading'
        engine_results = Parallel(n_jobs=RunConfig.N_JOB, backend='multiprocessing') \
            (delayed(solve_task)(task, params, verbose) for task in tqdm(tasks, miniters=0, mininterval=None, maxinterval=None))

    print('===== end parallel solve tasks =====\n\n')

    summary = summary_engine_results(engine_results)
    print(summary)

    if output_summary_path:
        output_summary_path.write_text(summary)

    if save_submission:
        print('start save submission')
        submission_df = create_submission(engine_results)
        save_submission_df(submission_df)

    if add_answer_storage:
        storage_elements = list(chain.from_iterable([r.to_answer_storage_elements() for r in engine_results if isinstance(r, AnsweredSearchResults)]))
        update_answer_storage(storage_elements)

    if copy_wrong_answers_root_tag:
        print('start copy wrong answers')
        for r in engine_results:
            if not r.final_test_correct():
                plot_task(r.task, show=False, save_path=PathConfig.WRONG_ANSWERS_ROOT / copy_wrong_answers_root_tag / f'{r.task.name}.png')

    return engine_results


def solve_task(task: Task, params: AllParameter, verbose: bool = False) -> Union[AnsweredSearchResults, NotAnsweredSearchResult]:
    try:
        engine = get_engine(RunConfig.ENGINE_TYPE)
        engine_result = engine.search(task, params, verbose)
    except Exception as e:
        print(f'unknown error {task.name}')
        raise e

    if isinstance(engine_result, NotAnsweredSearchResult):
        return engine_result
    elif isinstance(engine_result, AnsweredSearchResults):
        # calculate operation_set-executed task.
        for result in engine_result.results:
            applied_task = TaskOperationSetExecutor().execute(task, result.operation_set)
            result.test_output_arr = [io.input_arr for io in applied_task.test]
            result.test_correct = AnswerMatcher.is_train_test_all_match_if_operated(task, result.operation_set)

        engine_result.results = sorted(engine_result.results, key=lambda r: r.test_correct, reverse=True)
        print(engine_result.summary())
        return engine_result
    else:
        raise NotImplementedError()


class OperationSetExecutor:

    @classmethod
    def apply_operation_set(cls, arrays: List[np.ndarray], operation_set: OperationSet) -> List[np.ndarray]:
        for o in operation_set.operations:
            arrays = cls.apply_operation(arrays, o)

        return arrays

    @classmethod
    def apply_operation(cls, arrays: List[np.ndarray], operation: Union[UniformOperation, ColorOperation, MultiColorChannelOperation]) -> List[np.ndarray]:
        if isinstance(operation, UniformOperation):
            return cls.apply_uniform_operation(arrays, operation)
        elif isinstance(operation, ColorOperation):
            masks = cls.apply_color_selection(arrays, operation.color_selection)
            masks = cls.apply_mask_conversion(masks, operation.mask_conversions)
            return cls.apply_mask_operation(arrays, masks, operation.mask_operation)
        elif isinstance(operation, MultiColorChannelOperation):
            original_color_mask_pairs_list = cls.apply_channel_selection(arrays, operation.channel_selection)
            color_mask_pairs_list = deepcopy(original_color_mask_pairs_list)
            color_mask_pairs_list = cls.apply_color_channel_mask_conversion(color_mask_pairs_list, operation.mask_conversions)
            return cls.apply_channel_merge(arrays, original_color_mask_pairs_list, color_mask_pairs_list, operation.channel_merge_operation)
        elif isinstance(operation, PartitionOperation):
            partitioned_arrays_original_location_masks = cls.apply_partition_selection(arrays, operation.partition_selection)
            return cls.apply_partition_merge_operation(arrays, partitioned_arrays_original_location_masks, operation.partition_merge_operation)
        else:
            raise NotImplementedError()

    @classmethod
    def apply_uniform_operation(cls, arrays: List[np.ndarray], operation: UniformOperation) -> List[np.ndarray]:
        new_arrays = [cls._apply_uniform_operation(a, operation) for a in arrays]

        if all(np.array_equal(n, r) for n, r in zip(new_arrays, arrays)):
            raise OperationInconsistencyException(f'no effect. {operation}')
        return new_arrays

    @classmethod
    def _apply_uniform_operation(cls, arr: np.ndarray, operation: UniformOperation) -> np.ndarray:
        cls._check_arr(arr, operation)

        temp_arr = deepcopy(arr)
        new_arr = operation(temp_arr)
        cls._check_arr(new_arr, operation)
        return new_arr

    @classmethod
    def apply_color_selection(cls, arrays: List[np.ndarray], selection: ColorSelection) -> List[np.ndarray]:
        return [cls._apply_color_selection(a, selection) for a in arrays]

    @classmethod
    def _apply_color_selection(cls, arr: np.ndarray, selection: ColorSelection) -> np.ndarray:
        cls._check_arr(arr, None)

        temp_arr = deepcopy(arr)
        mask = selection(temp_arr)
        cls._check_mask(mask, selection)
        return mask

    @classmethod
    def apply_channel_selection(cls, arrays: List[np.ndarray], channel_selection: ColorChannelSelection) -> List[List[Tuple[Color, np.ndarray]]]:
        return [cls._apply_channel_selection(a, channel_selection) for a in arrays]

    @classmethod
    def _apply_channel_selection(cls, arr: np.ndarray, channel_selection: ColorChannelSelection) -> List[Tuple[Color, np.ndarray]]:
        cls._check_arr(arr, None)

        temp_arr = deepcopy(arr)
        color_mask_pairs = channel_selection(temp_arr)

        for c, m in color_mask_pairs:
            cls._check_mask(m, channel_selection)
        return color_mask_pairs

    @classmethod
    def apply_color_channel_mask_conversion(cls, color_mask_pairs_list: List[List[Tuple[Color, np.ndarray]]], mask_conversion: MaskConversion) -> List[List[Tuple[Color, np.ndarray]]]:
        new_color_mask_pairs_list = [cls._apply_color_channel_mask_conversion(p, mask_conversion) for p in color_mask_pairs_list]
        # TODO imple
        # if not isinstance(mask_conversion, NoMaskConversion):
        #     if all(np.array_equal(n, r) for n, r in zip(new_color_mask_pairs_list, color_mask_pairs_list)):
        #         raise OperationInconsistencyException(mask_conversion)
        return new_color_mask_pairs_list

    @classmethod
    def _apply_color_channel_mask_conversion(cls, color_mask_pairs: List[Tuple[Color, np.ndarray]], mask_conversion: MaskConversion) -> List[Tuple[Color, np.ndarray]]:
        for c, m in color_mask_pairs:
            cls._check_mask(m, None)

        temp_color_mask_pairs = deepcopy(color_mask_pairs)
        temp_color_mask_pairs = [(c, mask_conversion(m)) for c, m in temp_color_mask_pairs]

        for c, m in temp_color_mask_pairs:
            cls._check_mask(m, mask_conversion)

        return temp_color_mask_pairs

    @classmethod
    def apply_channel_merge(cls, arrays: List[np.ndarray], original_color_mask_pairs_list: List[List[Tuple[Color, np.ndarray]]], color_mask_pairs_list: List[List[Tuple[Color, np.ndarray]]], merge_operation: ChannelMergeOperation) -> List[np.ndarray]:
        new_arrays = [cls._apply_channel_merge(arr, o_p, p, merge_operation) for arr, o_p, p in zip(arrays, original_color_mask_pairs_list, color_mask_pairs_list)]

        if all(np.array_equal(n, r) for n, r in zip(new_arrays, arrays)):
            raise OperationInconsistencyException(f'no effect. {merge_operation}')
        return new_arrays

    @classmethod
    def _apply_channel_merge(cls, arr: np.ndarray, original_color_mask_pairs: List[Tuple[Color, np.ndarray]], color_mask_pairs: List[Tuple[Color, np.ndarray]], merge_operation: ChannelMergeOperation) -> np.ndarray:
        cls._check_arr(arr, None)
        for c, m in color_mask_pairs:
            cls._check_mask(m, None)

        temp_arr = deepcopy(arr)
        temp_original_color_mask_pairs = deepcopy(original_color_mask_pairs)
        temp_color_mask_pairs = deepcopy(color_mask_pairs)

        new_arr = merge_operation(temp_arr, temp_original_color_mask_pairs, temp_color_mask_pairs)
        cls._check_arr(new_arr, merge_operation)
        return new_arr

    @classmethod
    def apply_mask_conversion(cls, masks: List[np.ndarray], mask_conversion: MaskConversion) -> List[np.ndarray]:
        new_masks = [cls._mask_conversion(m, mask_conversion) for m in masks]
        if not isinstance(mask_conversion, NoMaskConversion):
            if all(np.array_equal(n, r) for n, r in zip(new_masks, masks)):
                raise OperationInconsistencyException(f'no effect. {mask_conversion}')
        return new_masks

    @classmethod
    def _mask_conversion(cls, mask: np.ndarray, mask_conversion: MaskConversion) -> np.ndarray:
        cls._check_mask(mask, None)

        temp_mask = deepcopy(mask)
        applied_mask = mask_conversion(temp_mask)
        cls._check_mask(applied_mask, mask_conversion)
        return applied_mask

    @classmethod
    def apply_mask_operation(cls, arrays: List[np.ndarray], masks: List[np.ndarray], mask_operation: MaskOperation) -> List[np.ndarray]:
        new_arrays = [cls._apply_mask_operation(a, m, mask_operation) for a, m in zip(arrays, masks)]
        if all(np.array_equal(n, r) for n, r in zip(new_arrays, arrays)):
            raise OperationInconsistencyException(f'no effect. {mask_operation}')
        return new_arrays

    @classmethod
    def _apply_mask_operation(cls, arr: np.ndarray, mask: np.ndarray, mask_operation: MaskOperation) -> np.ndarray:
        cls._check_arr(arr, None)
        cls._check_mask(mask, None)

        temp_arr, temp_mask = deepcopy(arr), deepcopy(mask)
        applied_arr = mask_operation(temp_arr, temp_mask)
        cls._check_arr(applied_arr, mask_operation)
        return applied_arr

    @staticmethod
    def _check_arr(arr: np.ndarray, operation: Optional[UniformOperation]):
        # TODO Just for assertion and debug. This function spends some time. Should remove this function at the end of competition?
        assert isinstance(arr, np.ndarray), f'operation: {operation}, type: {type(arr)}'
        assert arr.dtype == np.uint8, f'operation: {operation}, dtype: {arr.dtype}'
        assert arr.size != 0, f'operation: {operation}, operation_result: \n{arr}'
        assert 0 <= np.min(arr) <= np.max(arr) <= 10, f'operation: {operation}, operation_result: \n{arr}'
        assert len(arr.shape) == 2, f'operation: {operation}, operation_result: \n{arr}'

    @staticmethod
    def _check_mask(mask: np.ndarray, operation: Union[ColorSelection, MaskConversion, None]):
        # TODO Just for assertion and debug. This function spends some time. Should remove this function at the end of competition?
        assert isinstance(mask, np.ndarray), f'selection: {operation}, type: {type(mask)}'
        assert mask.dtype == bool, f'selection: {operation}, dtype: {mask.dtype}'
        assert len(mask.shape) == 2, f'selection: {operation}, result: \n{mask}'

    @classmethod
    def apply_partition_selection(cls, arrays: List[np.ndarray], partition_selection: PartitionSelection) -> List[Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]]:
        return [cls._apply_partition_selection(a, partition_selection) for a in arrays]

    @classmethod
    def _apply_partition_selection(cls, arr: np.ndarray, partition_selection: PartitionSelection) -> Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]:
        cls._check_arr(arr, None)
        temp_arr = deepcopy(arr)
        return partition_selection(temp_arr)

    @classmethod
    def apply_partition_merge_operation(cls, arrays: List[np.ndarray], partitioned_arrays_original_location_masks: List[Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]], partition_merge_operation: PartitionMergeOperation):
        return [cls._apply_partition_merge_operation(a, p, partition_merge_operation) for a, p in zip(arrays, partitioned_arrays_original_location_masks)]

    @classmethod
    def _apply_partition_merge_operation(cls, arr: np.ndarray, partitioned_arrays_original_location_masks: Tuple[List[List[np.ndarray]], List[List[np.ndarray]]], partition_merge_operation: PartitionMergeOperation):
        partitioned_arrays, original_location_masks = partitioned_arrays_original_location_masks

        cls._check_arr(arr, None)
        temp_arr = deepcopy(arr)
        temp_partitioned_arrays = deepcopy(partitioned_arrays)
        temp_original_location_masks = deepcopy(original_location_masks)

        res_arr = partition_merge_operation(temp_arr, temp_partitioned_arrays, temp_original_location_masks)
        cls._check_arr(res_arr, partition_merge_operation)

        return res_arr


class MLNodeEvaluator(NodeEvaluator):

    def __init__(self, pattern: DepthSearchPattern):
        self.pattern = pattern
        self.features = pickle.load(PathConfig.NODE_EVALUATOR_FEATURES.open(mode='rb'))
        self.categorical_features = pickle.load(PathConfig.NODE_EVALUATOR_CATEGORICAL_FEATURES.open(mode='rb'))
        self.sample_df = pickle.load(PathConfig.NODE_EVALUATOR_SAMPLE_DF.open(mode='rb'))
        self.model: LGBMClassifier = pickle.load(PathConfig.NODE_EVALUATOR_MODEL.open(mode='rb'))
        self.model.n_jobs = 1
        self.oe: OrdinalEncoder = pickle.load(PathConfig.NODE_EVALUATOR_ORDINAL_ENCODER.open(mode='rb'))

    def evaluate(self, node: WaitingNode) -> float:
        raise NotImplementedError()

    def evaluate_nodes(self, nodes: List[WaitingNode]):
        if len(nodes) == 0:
            return

        feature_dicts = [n.evaluation_features() for n in nodes]

        feature_dicts = [{
            **{k: v for k, v in d.items() if k in self.features},
            **{f: None for f in self.features if f not in d}
        } for d in feature_dicts]

        for d in feature_dicts:
            for c_f in self.categorical_features:
                d[c_f] = str(d[c_f])

        df = DataFrame(columns=self.features)
        df = df.append(feature_dicts)

        df[self.categorical_features] = self.oe.transform(df[self.categorical_features])
        df = df.fillna(-1)

        x = df[self.features]
        probs = self.model.predict_proba(x)[:, 0]

        for n, p in zip(nodes, probs):
            n.cache_pred_distance = self._add_cost(p, n.depth())

    def _add_cost(self, prob: float, depth: int) -> float:
        # Impose penalty. A* like algorithm.
        if self.pattern == DepthSearchPattern.BREADTH_FIRST:
            return prob ** (1 / (1 + (depth / 1))) + 0.3 * depth
        elif self.pattern == DepthSearchPattern.NORMAL:
            return prob ** (1 / (1 + (depth / 2))) + 0.1 * depth
        elif self.pattern == DepthSearchPattern.DEPTH_FIRST:
            return prob
        else:
            raise NotImplementedError()


@dataclass(frozen=True)
class LinePartition(PartitionSelection):
    line_color: Color

    def __call__(self, arr: np.ndarray) -> Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]:
        if arr.size == 1:
            raise OperationInconsistencyException('size == 1')

        if 1 in arr.shape and len(np.unique(arr)) == 1:
            raise OperationInconsistencyException('can not separate')

        color_hit: np.ndarray = arr == self.line_color
        line_v_indices = np.where(color_hit.all(axis=1))[0]
        line_h_indices = np.where(color_hit.all(axis=0))[0]

        if len(line_v_indices) == len(line_h_indices) == 0:
            raise OperationInconsistencyException('not line found')

        if 1 in np.diff(line_v_indices) or 1 in np.diff(line_h_indices):
            raise OperationInconsistencyException('line duplicated')

        partitioned_arrays = []
        partitioned_masks = []
        for start_v_i, end_v_i in zip([0] + list(line_v_indices + 1), list(line_v_indices) + [arr.shape[0]]):
            if start_v_i == end_v_i:
                continue
            partitioned_temp_arrays = []
            partitioned_temp_masks = []
            for start_h_i, end_h_i in zip([0] + list(line_h_indices + 1), list(line_h_indices) + [arr.shape[1]]):
                if start_h_i == end_h_i:
                    continue
                partitioned_temp_arrays.append(arr[start_v_i:end_v_i, start_h_i:end_h_i])

                mask = np.full_like(arr, fill_value=False, dtype=bool)
                mask[start_v_i:end_v_i, start_h_i:end_h_i] = True
                partitioned_temp_masks.append(mask)

            partitioned_arrays.append(partitioned_temp_arrays)
            partitioned_masks.append(partitioned_temp_masks)

        return partitioned_arrays, partitioned_masks


@dataclass(frozen=True)
class GeneralizedLinePartition(PartitionSelection):
    bg_selection_mode: BackGroundColorSelectionMode

    def __call__(self, arr: np.ndarray) -> Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]:
        if arr.size == 1:
            raise OperationInconsistencyException('size == 1')

        if 1 in arr.shape and len(np.unique(arr)) == 1:
            raise OperationInconsistencyException('can not separate')

        bg = ColorSelectionUtil().get_background_color(arr, self.bg_selection_mode)

        colors = [Color.of(c) for c in np.unique(arr)]

        color_lines = []
        for c in colors:
            if c == bg:
                continue
            color_hit: np.ndarray = arr == c
            line_v_indices = np.where(color_hit.all(axis=1))[0]
            line_h_indices = np.where(color_hit.all(axis=0))[0]
            color_lines.append((c, len(line_v_indices) + len(line_h_indices)))

        if len(color_lines) == 0:
            raise OperationInconsistencyException('not colored')

        target_color = max(color_lines, key=itemgetter(1))[0]

        return LinePartition(target_color)(arr)


@dataclass(frozen=True)
class IntegerDivisionPartition(PartitionSelection):
    axis: Axis
    n_split: int

    def __call__(self, arr: np.ndarray) -> Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]:
        if self.axis == Axis.HORIZONTAL:
            if arr.shape[1] % self.n_split != 0:
                raise OperationInconsistencyException('can not divide')

            masks = []
            partition_len = arr.shape[1] // self.n_split
            for i in range(self.n_split):
                mask = np.full_like(arr, fill_value=False, dtype=bool)
                start_i, end_i = i * partition_len, (i + 1) * partition_len
                mask[:, start_i:end_i] = True
                masks.append(mask)
            masks = [masks]

            partitioned_arrays = np.split(arr, self.n_split, axis=1)
            partitioned_arrays = [partitioned_arrays]

        elif self.axis == Axis.VERTICAL:
            if arr.shape[0] % self.n_split != 0:
                raise OperationInconsistencyException('can not divide')

            masks = []
            partition_len = arr.shape[0] // self.n_split
            for i in range(self.n_split):
                mask = np.full_like(arr, fill_value=False, dtype=bool)
                start_i, end_i = i * partition_len, (i + 1) * partition_len
                mask[start_i:end_i, :] = True
                masks.append(mask)
            masks = [[m] for m in masks]

            partitioned_arrays = np.split(arr, self.n_split, axis=0)
            partitioned_arrays = [[a] for a in partitioned_arrays]
        elif self.axis == Axis.BOTH:
            if arr.shape[0] % self.n_split != 0 or arr.shape[1] % self.n_split != 0:
                raise OperationInconsistencyException('can not divide')

            masks = []
            v_partition_len = arr.shape[0] // self.n_split
            h_partition_len = arr.shape[1] // self.n_split
            for i in range(self.n_split):
                temp_masks = []
                v_start_i, v_end_i = i * v_partition_len, (i + 1) * v_partition_len
                for j in range(self.n_split):
                    mask = np.full_like(arr, fill_value=False, dtype=bool)
                    h_start_i, h_end_i = j * h_partition_len, (j + 1) * h_partition_len
                    mask[v_start_i:v_end_i, h_start_i:h_end_i] = True
                    temp_masks.append(mask)
                masks.append(temp_masks)

            partitioned_arrays = np.split(arr, self.n_split, axis=0)
            partitioned_arrays = [np.split(a, self.n_split, axis=1) for a in partitioned_arrays]

        else:
            raise NotImplementedError()

        return partitioned_arrays, masks


@dataclass(frozen=True)
class ColorNumIntegerDivisionPartition(PartitionSelection):
    axis: Axis

    def __call__(self, arr: np.ndarray) -> Tuple[List[List[np.ndarray]], List[List[np.ndarray]]]:
        color_num = len(np.unique(arr))
        color_num = color_num - 1  # bg

        if color_num == 0:
            raise OperationInconsistencyException('not colored')

        return IntegerDivisionPartition(self.axis, color_num)(arr)


@dataclass
class RandomNodeTreeCreateEngine:
    MAX_NODE = 30000
    timeout_sec: int = 30
    node_evaluator = RandomNodeEvaluator()

    def search(self, task: Task, verbose: bool = False) -> List[NodeTree]:
        task_feature = TaskFeature.of(task)
        root_node = UniformOperationCompletedNode(None, task, task_feature, OperationSet([]))
        first_waiting_nodes = CompletedNodeProcessor.process(root_node)
        self.node_evaluator.evaluate_nodes(first_waiting_nodes)
        pq = PriorityQueue([*first_waiting_nodes])

        if verbose:
            print('first pq nodes')
            for n in pq.sorted_list():
                print(f'cost: {n.cache_pred_distance}, {n}')

        with Timer() as timer:
            for node_i in range(self.MAX_NODE):
                if len(pq) == 0:
                    # TODO What's the right thing to do?
                    raise NotImplementedError()

                waiting_node = pq.pop_min()

                completed_node = WaitingNodeProcessor().process(waiting_node)

                if completed_node is None:
                    if verbose:
                        print('skipped')
                    continue

                waiting_new_nodes = CompletedNodeProcessor.process(completed_node)
                self.node_evaluator.evaluate_nodes(waiting_new_nodes)

                for n in waiting_new_nodes:
                    pq.push(n)

                if timer.second() > self.timeout_sec:
                    break

        return [NodeTree.of(waiting_node.parent_completed_node) for waiting_node in pq.heap]


@dataclass(frozen=True)
class AnySelectionMerge(PartitionMergeOperation):
    bg_selection_mode: BackGroundColorSelectionMode
    fill_color: Color

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]], original_location_masks: List[List[np.ndarray]]) -> np.ndarray:
        shape = partitioned_arrays[0][0].shape
        if not all(a.shape == shape for h_a in partitioned_arrays for a in h_a):
            raise OperationInconsistencyException('not same shape')

        bg = ColorSelectionUtil().get_background_color(arr, self.bg_selection_mode)

        result_mask = np.full_like(partitioned_arrays[0][0], fill_value=False, dtype=bool)
        for horizontal_arrays in partitioned_arrays:
            for a in horizontal_arrays:
                result_mask[a != bg] = True

        result_arr = np.full_like(partitioned_arrays[0][0], fill_value=bg)
        result_arr[result_mask] = self.fill_color

        return result_arr


@dataclass(frozen=True)
class NotSelectionMerge(PartitionMergeOperation):
    bg_selection_mode: BackGroundColorSelectionMode
    fill_color: Color

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]], original_location_masks: List[List[np.ndarray]]) -> np.ndarray:
        shape = partitioned_arrays[0][0].shape
        if not all(a.shape == shape for h_a in partitioned_arrays for a in h_a):
            raise OperationInconsistencyException('not same shape')

        bg = ColorSelectionUtil().get_background_color(arr, self.bg_selection_mode)

        result_mask = np.full_like(partitioned_arrays[0][0], fill_value=True, dtype=bool)
        for horizontal_arrays in partitioned_arrays:
            for a in horizontal_arrays:
                result_mask[a != bg] = False

        result_arr = np.full_like(partitioned_arrays[0][0], fill_value=bg)
        result_arr[result_mask] = self.fill_color

        return result_arr


@dataclass(frozen=True)
class AllSelectionMerge(PartitionMergeOperation):
    bg_selection_mode: BackGroundColorSelectionMode
    fill_color: Color

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]], original_location_masks: List[List[np.ndarray]]) -> np.ndarray:
        shape = partitioned_arrays[0][0].shape
        if not all(a.shape == shape for h_a in partitioned_arrays for a in h_a):
            raise OperationInconsistencyException('not same shape')

        bg = ColorSelectionUtil().get_background_color(arr, self.bg_selection_mode)

        result_mask = np.full_like(partitioned_arrays[0][0], fill_value=True, dtype=bool)
        for horizontal_arrays in partitioned_arrays:
            for a in horizontal_arrays:
                result_mask[a == bg] = False

        result_arr = np.full_like(partitioned_arrays[0][0], fill_value=bg)
        result_arr[result_mask] = self.fill_color

        return result_arr


@dataclass(frozen=True)
class ModifiedXorSelectionMerge(PartitionMergeOperation):
    bg_selection_mode: BackGroundColorSelectionMode
    fill_color: Color

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]], original_location_masks: List[List[np.ndarray]]) -> np.ndarray:
        shape = partitioned_arrays[0][0].shape
        if not all(a.shape == shape for h_a in partitioned_arrays for a in h_a):
            raise OperationInconsistencyException('not same shape')

        bg = ColorSelectionUtil().get_background_color(arr, self.bg_selection_mode)

        any_result_mask = np.full_like(partitioned_arrays[0][0], fill_value=False, dtype=bool)
        for horizontal_arrays in partitioned_arrays:
            for a in horizontal_arrays:
                any_result_mask[a != bg] = True

        all_result_mask = np.full_like(partitioned_arrays[0][0], fill_value=True, dtype=bool)
        for horizontal_arrays in partitioned_arrays:
            for a in horizontal_arrays:
                all_result_mask[a == bg] = False

        # modified xor
        result_mask = any_result_mask
        result_mask[all_result_mask] = False

        result_arr = np.full_like(partitioned_arrays[0][0], fill_value=bg)
        result_arr[result_mask] = self.fill_color

        return result_arr


@dataclass(frozen=True)
class NaturalArrayOrderedOverrideMerge(PartitionMergeOperation):
    bg_selection_mode: BackGroundColorSelectionMode
    start_corner: Corner
    first_axis: Axis

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]], original_location_masks: List[List[np.ndarray]]) -> np.ndarray:
        shape = partitioned_arrays[0][0].shape
        if not all(a.shape == shape for h_a in partitioned_arrays for a in h_a):
            raise OperationInconsistencyException('not same shape')

        bg = ColorSelectionUtil().get_background_color(arr, self.bg_selection_mode)

        result_arr = np.full_like(partitioned_arrays[0][0], fill_value=bg)

        h, w = len(partitioned_arrays), len(partitioned_arrays[0])
        for i, j in self.natural_array(h, w, self.start_corner, self.first_axis):
            array = partitioned_arrays[i][j]
            result_arr[array != bg] = array[array != bg]

        return result_arr

    def natural_array(self, h: int, w: int, start_corner: Corner, first_axis: Axis) -> List[Tuple[int, int]]:
        start_ind = get_index(start_corner, h, w)

        vertical_start_ind, horizontal_start_ind = start_ind
        vertical_end_ind = h - 1 if vertical_start_ind == 0 else 0
        horizontal_end_ind = w - 1 if horizontal_start_ind == 0 else 0
        vertical_step = +1 if vertical_start_ind == 0 else -1
        horizontal_step = +1 if horizontal_start_ind == 0 else -1

        index_orders = []
        if first_axis == Axis.HORIZONTAL:
            for i in range_closed(vertical_start_ind, vertical_end_ind, vertical_step):
                for j in range_closed(horizontal_start_ind, horizontal_end_ind, horizontal_step):
                    index_orders.append((i, j))
        elif first_axis == Axis.VERTICAL:
            for j in range_closed(horizontal_start_ind, horizontal_end_ind, horizontal_step):
                for i in range_closed(vertical_start_ind, vertical_end_ind, vertical_step):
                    index_orders.append((i, j))
        else:
            raise NotImplementedError()

        assert len(set(index_orders)) == len(index_orders) == h * w, index_orders
        return index_orders


@dataclass(frozen=True)
class DiagonalArrayOrderedOverrideMerge(PartitionMergeOperation):
    bg_selection_mode: BackGroundColorSelectionMode
    start_corner: Corner
    first_axis: Axis

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]], original_location_masks: List[List[np.ndarray]]) -> np.ndarray:
        shape = partitioned_arrays[0][0].shape
        if not all(a.shape == shape for h_a in partitioned_arrays for a in h_a):
            raise OperationInconsistencyException('not same shape')

        bg = ColorSelectionUtil().get_background_color(arr, self.bg_selection_mode)

        result_arr = np.full_like(partitioned_arrays[0][0], fill_value=bg)

        h, w = len(partitioned_arrays), len(partitioned_arrays[0])
        for i, j in self.diagonal_array(h, w, self.start_corner, self.first_axis):
            array = partitioned_arrays[i][j]
            result_arr[array != bg] = array[array != bg]

        return result_arr

    def diagonal_array(self, h: int, w: int, start_corner: Corner, first_axis: Axis) -> List[Tuple[int, int]]:
        start_ind = get_index(start_corner, h, w)

        vertical_start_ind, horizontal_start_ind = start_ind
        vertical_end_ind = h - 1 if vertical_start_ind == 0 else 0
        horizontal_end_ind = w - 1 if horizontal_start_ind == 0 else 0
        vertical_step = +1 if vertical_start_ind == 0 else -1
        horizontal_step = +1 if horizontal_start_ind == 0 else -1

        index_orders = []
        if first_axis == Axis.HORIZONTAL:
            for i in range_closed(vertical_start_ind, vertical_end_ind, vertical_step):
                for h_num, j in enumerate(range_closed(horizontal_start_ind, horizontal_end_ind, horizontal_step)):
                    index_orders.append(((i + h_num) % h, j))
        elif first_axis == Axis.VERTICAL:
            for j in range_closed(horizontal_start_ind, horizontal_end_ind, horizontal_step):
                for v_num, i in enumerate(range_closed(vertical_start_ind, vertical_end_ind, vertical_step)):
                    index_orders.append((i, (j + v_num) % w))
        else:
            raise NotImplementedError()
        assert len(set(index_orders)) == len(index_orders) == h * w, index_orders
        return index_orders


@dataclass(frozen=True)
class SpiralArrayOrderedOverrideMerge(PartitionMergeOperation):
    bg_selection_mode: BackGroundColorSelectionMode
    start_corner: Corner
    spiral_direction: SpiralDirection

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]], original_location_masks: List[List[np.ndarray]]) -> np.ndarray:
        shape = partitioned_arrays[0][0].shape
        if not all(a.shape == shape for h_a in partitioned_arrays for a in h_a):
            raise OperationInconsistencyException('not same shape')

        bg = ColorSelectionUtil().get_background_color(arr, self.bg_selection_mode)

        result_arr = np.full_like(partitioned_arrays[0][0], fill_value=bg)

        h, w = len(partitioned_arrays), len(partitioned_arrays[0])
        for i, j in self.spiral(h, w, self.start_corner, self.spiral_direction):
            array = partitioned_arrays[i][j]
            result_arr[array != bg] = array[array != bg]

        return result_arr

    def spiral(self, h: int, w: int, start_corner: Corner, spiral_direction: SpiralDirection) -> List[Tuple[int, int]]:
        start_ind = get_index(start_corner, h, w)

        index_orders = [start_ind]

        current_ind = start_ind
        while True:
            if (start_corner in [Corner.TOP_LEFT, Corner.TOP_RIGHT, Corner.BOTTOM_RIGHT] and spiral_direction == SpiralDirection.CLOCKWISE) \
                    or (start_corner == Corner.BOTTOM_LEFT and spiral_direction == SpiralDirection.ANTICLOCKWISE):
                if valid_index((current_ind[0], current_ind[1] + 1), h, w, index_orders):
                    direction = Direction.RIGHT
                elif valid_index((current_ind[0] + 1, current_ind[1]), h, w, index_orders):
                    direction = Direction.BOTTOM
                elif valid_index((current_ind[0], current_ind[1] - 1), h, w, index_orders):
                    direction = Direction.LEFT
                elif valid_index((current_ind[0] - 1, current_ind[1]), h, w, index_orders):
                    direction = Direction.TOP
                else:
                    break
            else:
                if valid_index((current_ind[0] - 1, current_ind[1]), h, w, index_orders):
                    direction = Direction.TOP
                elif valid_index((current_ind[0], current_ind[1] - 1), h, w, index_orders):
                    direction = Direction.LEFT
                elif valid_index((current_ind[0] + 1, current_ind[1]), h, w, index_orders):
                    direction = Direction.BOTTOM
                elif valid_index((current_ind[0], current_ind[1] + 1), h, w, index_orders):
                    direction = Direction.RIGHT
                else:
                    break

            while True:
                if direction == Direction.RIGHT:
                    next_ind = (current_ind[0], current_ind[1] + 1)
                elif direction == Direction.BOTTOM:
                    next_ind = (current_ind[0] + 1, current_ind[1])
                elif direction == Direction.LEFT:
                    next_ind = (current_ind[0], current_ind[1] - 1)
                elif direction == Direction.TOP:
                    next_ind = (current_ind[0] - 1, current_ind[1])
                else:
                    raise NotImplementedError()

                if valid_index(next_ind, h, w, index_orders):
                    index_orders.append(next_ind)
                    current_ind = next_ind
                else:
                    break

        assert len(set(index_orders)) == len(index_orders) == h * w, index_orders
        return index_orders


@dataclass(frozen=True)
class UniquelySelectedArrayExtraction(PartitionMergeOperation):
    array_selection: PartitionedArraySelection

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]], original_location_masks: List[List[np.ndarray]]) -> np.ndarray:
        selections = self.array_selection(arr, partitioned_arrays)

        results = []
        for h_arrays, h_flags in zip(partitioned_arrays, selections):
            for array, flag in zip(h_arrays, h_flags):
                if flag:
                    results.append(array)

        if len(set(map(lambda a: a.tostring(), results))) == 1:
            return results[0]
        else:
            raise OperationInconsistencyException('not unique')


@dataclass(frozen=True)
class RestoreOnlySelectedArray(PartitionMergeOperation):
    bg_selection_mode: BackGroundColorSelectionMode
    array_selection: PartitionedArraySelection

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]], original_location_masks: List[List[np.ndarray]]) -> np.ndarray:
        bg = ColorSelectionUtil().get_background_color(arr, self.bg_selection_mode)

        selections = self.array_selection(arr, partitioned_arrays)

        for h_arrays, h_flags, h_locations in zip(partitioned_arrays, selections, original_location_masks):
            for array, flag, location in zip(h_arrays, h_flags, h_locations):
                if flag:
                    arr[location] = array.ravel()
                else:
                    arr[location] = bg

        return arr


@dataclass(frozen=True)
class ExtractOneValueFromPartitionedArray(PartitionMergeOperation):

    def __call__(self, arr: np.ndarray, partitioned_arrays: List[List[np.ndarray]], original_location_masks: List[List[np.ndarray]]) -> np.ndarray:
        h, w = len(partitioned_arrays), len(partitioned_arrays[0])
        result_arr = np.zeros_like(partitioned_arrays[0][0], shape=(h, w))

        for i, j in product(range(h), range(w)):
            array = partitioned_arrays[i][j]
            extracted_value = ColorSelectionUtil().select_single_color(array, SingleColorSelectionMode.MOST_COMMON)
            result_arr[i][j] = extracted_value

        return result_arr


def range_closed(start, stop, step):
    direction = 1 if (step > 0) else -1
    return range(start, stop + direction, step)


def get_index(corner: Corner, h: int, w: int) -> Tuple[int, int]:
    if corner == Corner.TOP_LEFT:
        return 0, 0
    elif corner == Corner.TOP_RIGHT:
        return 0, w - 1
    elif corner == Corner.BOTTOM_RIGHT:
        return h - 1, w - 1
    elif corner == Corner.BOTTOM_LEFT:
        return h - 1, 0
    else:
        raise NotImplementedError()


def valid_index(ind2d: Tuple[int, int], h: int, w: int, black_list: List[Tuple[int, int]]) -> bool:
    if ind2d in black_list:
        return False
    if ind2d[0] < 0 or h <= ind2d[0]:
        return False
    if ind2d[1] < 0 or w <= ind2d[1]:
        return False
    return True


def save_ml_training_data(task: Task, verbose: bool = False):
    # 正解データ
    correct_node_trees, exception, _ = NodeBaseSearchEngine(answer_limit_num=60).search(task, verbose)

    print('search engine end')
    if exception is not None:
        print('answer not found')
        return

    correct_node_trees = [t for t in correct_node_trees if AnswerMatcher.is_train_test_all_match_if_operated(task, t.to_operation_set())]

    if len(correct_node_trees) == 0:
        print('answer not found')
        return

    correct_waiting_nodes = list(chain.from_iterable([t.waiting_nodes() for t in correct_node_trees]))
    correct_feature_dicts = [n.evaluation_features() for n in correct_waiting_nodes]
    correct_feature_dict_tuples = set(tuple(sorted(d.items())) for d in correct_feature_dicts)
    correct_df = DataFrame(dict(t) for t in correct_feature_dict_tuples)

    # 不正解データ
    trees = RandomNodeTreeCreateEngine(timeout_sec=120).search(task, verbose)
    print('random tree generated')
    waiting_nodes = list(chain.from_iterable([t.waiting_nodes() for t in trees]))
    feature_dicts = [n.evaluation_features() for n in waiting_nodes]

    feature_dict_tuples = set(tuple(sorted(d.items())) for d in feature_dicts)
    feature_dict_tuples = feature_dict_tuples - correct_feature_dict_tuples
    wrong_df = DataFrame(dict(t) for t in feature_dict_tuples)

    # ラベル付け
    correct_df['label'] = 1
    wrong_df['label'] = 0
    all_df = correct_df.append(wrong_df, sort=False)

    PathConfig.LABELED_TRAINING_DATA_ROOT.mkdir(parents=True, exist_ok=True)
    pickle.dump(all_df, (PathConfig.LABELED_TRAINING_DATA_ROOT / f'{task.name}.pkl').open(mode='wb'))
    print('save')


def train_ml():
    x, y, feature_columns, categorical_features = prepare_train_data()
    train_lgbm(x, y, feature_columns, categorical_features)


def prepare_train_data():
    dfs = []
    for pickle_path in PathConfig.LABELED_TRAINING_DATA_ROOT.iterdir():
        print(pickle_path)
        dfs.append(pickle.load((pickle_path.open(mode='rb'))))
    all_df = pd.concat(dfs, ignore_index=True, sort=False)

    print(f'label1: {len(all_df[all_df["label"] == 1])}_label0: {len(all_df[all_df["label"] == 0])}')

    # TODO Should we not use dsl properties(too detailed) features?
    not_used_feature_columns = {
        'label', 'depth', 'color', 'angle', 'direction',
        'multi_color_selection_mode', 'single_color_selection_mode',
        'edge_type', 'fill_type', 'flip_mode', 'k', 'ratio', 'padding_mode'
    }
    feature_columns = sorted(set(all_df.columns) - not_used_feature_columns)
    all_df = all_df[feature_columns + ['label']]

    # process categorical
    categorical_features = list(filter(lambda s: s in feature_columns, map(str, all_df.select_dtypes(include='object').columns)))

    for c_f in categorical_features:
        all_df[c_f] = all_df[c_f].fillna('None')
        all_df[c_f] = all_df[c_f].apply(str)
        # all_df[c_f] = all_df[c_f].astype(str)
        # all_df[c_f] = all_df[c_f].apply(lambda v: str(v))
        # all_df[c_f] = all_df[c_f].astype('category')

    oe = category_encoders.OrdinalEncoder()
    all_df[categorical_features] = oe.fit_transform(all_df[categorical_features])

    # oe = category_encoders.OneHotEncoder(cols=[categorical_features])
    # all_df = oe.fit_transform(all_df)

    # all_df = all_df.fillna(-1)

    # Relabel 0-labeled data in the neighborhood of 1 to 1
    print('relabelling')
    small_is_better_features = [
        'mean_diff_color_cell_ratio',
        'mean_excess_color_num',
        'mean_lack_color_num',
        'mean_horizontal_diff_input_arr_line_num',
        'mean_horizontal_diff_output_arr_line_num',
        'mean_horizontal_edge_sum_diff',
        'mean_horizontal_edge_sum_diff_ratio',
        'mean_vertical_diff_input_arr_line_num',
        'mean_vertical_diff_output_arr_line_num',
        'mean_vertical_edge_sum_diff',
        'mean_vertical_edge_sum_diff_ratio',
    ]

    for index, r in tqdm(all_df[all_df['label'] == 1].iterrows()):
        temp_feature = sorted(set(feature_columns) - set(small_is_better_features))

        near_rows = all_df[(all_df[temp_feature] == r[temp_feature]).all(axis=1)]
        can_label_1 = near_rows[(near_rows[small_is_better_features] <= r[small_is_better_features]).all(axis=1)]

        all_df.loc[can_label_1.index.values, 'label'] = 1

    print(f'label1: {len(all_df[all_df["label"] == 1])}_label0: {len(all_df[all_df["label"] == 0])}')

    print(f'feature_columns:')
    for f in feature_columns:
        print(f)

    x = all_df[feature_columns]
    y = all_df['label']
    print(f'len(x): {len(x)}, 1_labelled_len: {len(y[y == 1])}')
    # x, y = RandomUnderSampler(sampling_strategy=0.01).fit_resample(x, y)
    # x, y = EditedNearestNeighbours(sampling_strategy=0.01, n_jobs=RunConfig.N_JOB).fit_resample(x, y)
    print(f'len(x): {len(x)}, 1_labelled_len: {len(y[y == 1])}')

    # visualize
    # scaled_x = StandardScaler().fit_transform(x)
    # x_reduced = PCA(n_components=2).fit_transform(scaled_x)
    # plt.scatter(x_reduced[y == 1, 0], x_reduced[y == 1, 1], alpha=0.1)
    # plt.scatter(x_reduced[:, 0], x_reduced[:, 1], c=y, alpha=0.1)
    # plt.show()
    # plt.close()

    pickle.dump(all_df, PathConfig.NODE_EVALUATOR_SAMPLE_DF.open(mode='wb'))
    pickle.dump(oe, PathConfig.NODE_EVALUATOR_ORDINAL_ENCODER.open(mode='wb'))
    pickle.dump(feature_columns, PathConfig.NODE_EVALUATOR_FEATURES.open(mode='wb'))
    pickle.dump(categorical_features, PathConfig.NODE_EVALUATOR_CATEGORICAL_FEATURES.open(mode='wb'))

    return x, y, feature_columns, categorical_features


def train_lg(feature_columns, x, y):
    model = LogisticRegression(class_weight='balanced', n_jobs=RunConfig.N_JOB)
    model.fit(x, y)

    pred_y = model.predict_proba(x)
    print(pred_y)
    # cb = CatBoostClassifier(loss_function='Logloss', class_weights=[0.1, 1], cat_features=categorical_features)
    # cb.fit(x, y)
    # pred_y = cb.predict_proba(x)

    PathConfig.SAVED_MODEL.mkdir(parents=True, exist_ok=True)
    pickle.dump(model, PathConfig.NODE_EVALUATOR_MODEL.open(mode='wb'))

    del model

    print(x)
    print(y)
    print(pred_y)

    # cb = CatBoostClassifier()
    # cb.load_model(str(PathConfig.NODE_EVALUATOR_MODEL), format="cbm")
    model = pickle.load(PathConfig.NODE_EVALUATOR_MODEL.open(mode='rb'))
    pred_y = model.predict_proba(x)

    print(pred_y)

    coefs = np.abs(model.coef_[0])

    for c, f in zip(coefs, feature_columns):
        print(f'{f}_{c}')


def train_lgbm(x, y, feature_columns, categorical_features):
    lgbm_params = {
        'silent': False, 'n_jobs': RunConfig.N_JOB,
        'class_weight': 'balanced', 'max_depth': 3, 'learning_rate': 0.2,
    }

    best_iterations = []
    folds = KFold(shuffle=False, n_splits=3)
    for n_fold, (train_index, valid_index) in enumerate(folds.split(x, y)):
        train_x, train_y = x.iloc[train_index], y.iloc[train_index]
        valid_x, valid_y = x.iloc[valid_index], y.iloc[valid_index]

        model = LGBMClassifier(n_estimators=1000, **lgbm_params)
        model.fit(train_x, train_y, eval_set=[(valid_x, valid_y), (train_x, train_y)],
                  early_stopping_rounds=10, categorical_feature=categorical_features,
                  verbose=True)
        best_iterations.append(model.best_iteration_)

    print(best_iterations)

    model = LGBMClassifier(n_estimators=min(best_iterations), **lgbm_params)
    model.fit(x, y, verbose=True, categorical_feature=categorical_features)

    pred_y = model.predict_proba(x)
    print(pred_y)

    PathConfig.SAVED_MODEL.mkdir(parents=True, exist_ok=True)
    pickle.dump(model, PathConfig.NODE_EVALUATOR_MODEL.open(mode='wb'))

    del model

    print(x)
    print(y)
    print(pred_y)

    model = pickle.load(PathConfig.NODE_EVALUATOR_MODEL.open(mode='rb'))
    pred_y = model.predict_proba(x)

    print(pred_y)

    importance = pd.DataFrame(model.feature_importances_, index=feature_columns, columns=['importance'])

    print(importance)


def train_test_model(feature_columns, x, y):
    try_cv = False

    if try_cv:
        folds = KFold(shuffle=True)
        for n_fold, (train_index, valid_index) in enumerate(folds.split(x, y)):
            train_x, train_y = x.iloc[train_index], y.iloc[train_index]
            valid_x, valid_y = x.iloc[valid_index], y.iloc[valid_index]

            model = LGBMClassifier(class_weight='balanced', learning_rate=0.2, n_jobs=RunConfig.N_JOB, n_estimators=1000,
                                   silent=False)
            model.fit(train_x, train_y, eval_set=[(valid_x, valid_y), (train_x, train_y)],
                      early_stopping_rounds=10,
                      verbose=True)

    # model = MLPClassifier(hidden_layer_sizes=(20, 20, 10))
    model = RidgeClassifier(class_weight='balanced')
    # model = LinearSVC(class_weight='balanced')
    # model = LGBMClassifier(class_weight='balanced', learning_rate=0.2, n_estimators=50,
    #                        silent=False)
    model.fit(x, y)

    pred_y = model.predict(x)
    print(pred_y)

    PathConfig.SAVED_MODEL.mkdir(parents=True, exist_ok=True)
    pickle.dump(model, PathConfig.NODE_EVALUATOR_MODEL.open(mode='wb'))

    del model

    print(x)
    print(y)
    print(pred_y)

    model = pickle.load(PathConfig.NODE_EVALUATOR_MODEL.open(mode='rb'))
    pred_y = model.predict_proba(x)

    print(pred_y)

    importance = pd.DataFrame(model.feature_importances_, index=feature_columns, columns=['importance'])

    print(importance)


CATEGORIES = [
    'PARTITION',
    'SYMMETRY',
    'REPEAT',
    'DENOISE',
    'SIMPLIFICATION',
    'NUMBER',
    'RANKING',
    'SHAPE',
    'FIND_FIT',
    'LINE',
    'OBJECT_TRANSFORM',
    'OBJECT_MOVE',
    'JIGSAW_PUZZLE',
    'COLOR',
    'PASTE',
    'GUIDE',
    'META',
    'OTHERS',
    'ONCE_ANSWERED',
]


GIVE_UPS = [
    'SYMMETRY',
    'REPEAT',
    'DENOISE',
    'SIMPLIFICATION',
    'NUMBER',
    'RANKING',
    'SHAPE',
    'FIND_FIT',
    'OBJECT_MOVE',
    'JIGSAW_PUZZLE',
    'COLOR',
    'PASTE',
    'GUIDE',
    'META',
]


class TaskTaxonomy:

    def __init__(self):
        with open(str(PathConfig.OPERATION_ANSWER_TAXONOMY_YAML), 'r') as f:
            yaml_dict = yaml.load(f, Loader=yaml.Loader)
        self.trains: Dict[str, List[str]] = yaml_dict['1_train']
        self.evals: Dict[str, List[str]] = yaml_dict['2_eval']
        self.check()

    def check(self):
        assert len(self.trains) == len(self.evals) == 400

        for task_name, categories in {**self.trains, **self.evals}.items():
            assert len(categories) == len(set(categories))
            for category in categories:
                assert category in CATEGORIES, category

        json_task_names = {path.stem for path in chain.from_iterable([PathConfig.TRAIN_ROOT.iterdir(), PathConfig.EVALUATION_ROOT.iterdir()])}
        df_task_names = set(list(self.trains.keys()) + list(self.evals.keys()))
        assert json_task_names - df_task_names == set(), json_task_names - df_task_names
        assert df_task_names - json_task_names == set(), df_task_names - json_task_names

    def show_stats(self):
        print('=== train stats ====')
        for c in CATEGORIES:
            num = len(list(filter(lambda v: c in v, self.trains.values())))
            print(f'{c}: {num}')

        print('\n=== eval stats ====')
        for c in CATEGORIES:
            num = len(list(filter(lambda v: c in v, self.evals.values())))
            print(f'{c}: {num}')

    def save_yaml(self):
        self.check()

        with open(str(PathConfig.OPERATION_ANSWER_TAXONOMY_YAML), 'w') as f:
            yaml.dump({'1_train': self.trains,
                       '2_eval': self.evals}, f)

    def save_categorized_fig(self):
        # from abstraction_and_reasoning_challenge.src.loader.task_loader import TaskLoader  # TODO fix local import？
        shutil.rmtree(PathConfig.OPERATION_ANSWER_TAXONOMY_IMAGE_ROOT)

        for (task_name, categories), tag in tqdm(list(zip(list(self.trains.items()) + list(self.evals.items()), ['train'] * len(self.trains) + ['evals'] * len(self.evals)))):
            if categories == []:
                task = TaskLoader().get_task(task_name)
                plot_task(task, show=False, save_path=PathConfig.OPERATION_ANSWER_TAXONOMY_IMAGE_ROOT / tag / 'not_categorized' / f'{task_name}.png')
            for c in categories:
                task = TaskLoader().get_task(task_name)
                plot_task(task, show=False, save_path=PathConfig.OPERATION_ANSWER_TAXONOMY_IMAGE_ROOT / tag / c / f'{task_name}.png')

    def get_give_up_task_names(self) -> List[str]:
        can_answers = self.get_can_answer_task_names()
        give_up_task_names = []
        for task_name, categories in {**self.trains, **self.evals}.items():
            if task_name in can_answers:
                continue
            for c in categories:
                if c in GIVE_UPS:
                    give_up_task_names.append(task_name)
                    break

        return give_up_task_names

    def get_can_answer_task_names(self) -> List[str]:
        return [task_name for task_name, categories in {**self.trains, **self.evals}.items() if 'ONCE_ANSWERED' in categories]

    def filter_tasks(self, tasks: List[Task]) -> List[Task]:
        if RunConfig.TASK_RANGE == TaskRange.ALL:
            return tasks
        elif RunConfig.TASK_RANGE == TaskRange.EXCLUDE_GIVE_UPS:
            return list(filter(lambda t: t.name not in self.get_give_up_task_names(), tasks))
        elif RunConfig.TASK_RANGE == TaskRange.CAN_ANSWER_ONLY:
            return list(filter(lambda t: t.name in self.get_can_answer_task_names(), tasks))
        else:
            raise NotImplementedError()




def get_engine(engine_type: EngineType):
    if engine_type == EngineType.NODE_BASED_SEARCH_ENGINE:
        return NodeBaseSearchEngine()
    elif engine_type == EngineType.TREE_BASED_SEARCH_ENGINE:
        return TreeBaseSearchEngine()
    else:
        raise NotImplementedError()


def run():
    if debug_run():
        return

    initialize_path()

    if RunConfig.RUN_MODE == RunMode.LOCAL_RUN:
        load_answer_storage()  # debug validate
        tt = TaskTaxonomy()
        solve_tasks(tt.filter_tasks(TaskLoader().get_training_tasks()), AllParameter(), output_summary_path=PathConfig.OPERATION_ANSWER_MEMO_ROOT / 'answer_summary_train.txt', copy_wrong_answers_root_tag='train', add_answer_storage=True, save_submission=True)
        solve_tasks(tt.filter_tasks(TaskLoader().get_evaluation_tasks()), AllParameter(), output_summary_path=PathConfig.OPERATION_ANSWER_MEMO_ROOT / 'answer_summary_eval.txt', copy_wrong_answers_root_tag='eval', add_answer_storage=False, save_submission=False)
    elif RunConfig.RUN_MODE == RunMode.LOCAL_RUN_ALL:
        solve_tasks(TaskLoader().get_training_tasks(), AllParameter(), output_summary_path=PathConfig.OPERATION_ANSWER_MEMO_ROOT / 'answer_summary_train.txt', copy_wrong_answers_root_tag='train', add_answer_storage=True, save_submission=True)
        solve_tasks(TaskLoader().get_evaluation_tasks(), AllParameter(), output_summary_path=PathConfig.OPERATION_ANSWER_MEMO_ROOT / 'answer_summary_eval.txt', copy_wrong_answers_root_tag='eval', add_answer_storage=False, save_submission=False)
        solve_tasks(TaskLoader().get_test_tasks(), AllParameter(), save_submission=True)
    elif RunConfig.RUN_MODE == RunMode.KERNEL_EMULATION:
        solve_tasks(TaskLoader().get_test_tasks(), AllParameter(), save_submission=True)
    elif RunConfig.RUN_MODE == RunMode.NODE_BASE_SEARCH_OPTIMIZATION:
        optimize_node_base_search(TaskLoader().get_training_tasks())
    elif RunConfig.RUN_MODE == RunMode.TREE_BASE_SEARCH_OPTIMIZATION:
        optimize_tree_base_search(TaskLoader().get_training_tasks())
    elif RunConfig.RUN_MODE == RunMode.LOCAL_DATA_GENERATION:
        for t in TaskLoader().get_training_tasks():
            print(t.name)
            save_ml_training_data(t)
    elif RunConfig.RUN_MODE == RunMode.LOCAL_ML_TRAIN:
        train_ml()
    elif RunConfig.RUN_MODE == RunMode.TRAIN_OPERATION_ELEMENT_INCLUSION_PREDICTION:
        train_operation_element_inclusion_prediction()
    elif RunConfig.RUN_MODE == RunMode.KERNEL:
        if RunConfig.RUN_ONLY_PRIVATE_LB and not TaskLoader().is_private_lb_run():
            print('This is kernel public run. Skipped.')
            shutil.copy(str(KernelPathConfig.SAMPLE_SUBMISSION), KernelPathConfig.SUBMISSION)
            return
        else:
            print('This is private private run. Not skipped.')
            solve_tasks(TaskLoader().get_test_tasks(), AllParameter(), save_submission=True)
    else:
        raise ValueError(RunConfig.RUN_MODE)
    print('end')


def debug_run():
    print('start')
    if DebugConfig.OPERATION_DEBUG_TASK_NAME:
        operation_set = str_to_operation_set(DebugConfig.OPERATION_DEBUG_OPERATION_SET)
        print(operation_set)
        task = TaskLoader().get_task(DebugConfig.OPERATION_DEBUG_TASK_NAME)
        applied_task = TaskOperationSetExecutor().execute(task, operation_set)

        original_task_feature = create_task_feature(task, task)
        applied_task_feature = create_task_feature(task, applied_task)

        original_df = DataFrame(asdict(original_task_feature), index=['index']).T
        applied_df = DataFrame(asdict(applied_task_feature), index=['index']).T
        merged_feature_df = pd.merge(original_df, applied_df, left_index=True, right_index=True,
                                     suffixes=['original_', 'appplied_'])

        original_waiting_node = ColorSelectionWaitingNode(None, task, task, original_task_feature, OperationSet([]), MultiColorSelection(MultiColorSelectionMode.ANY_WITHOUT_MOST_COMMON))
        # original_waiting_node2 = MaskConversionWaitingNode(None, None, task, original_task_feature, OperationSet([]), SingleColorSelection(SingleColorSelectionMode.LEAST_COMMON))
        applied_waiting_node = ColorSelectionWaitingNode(None, task, applied_task, applied_task_feature, operation_set, MultiColorSelection(MultiColorSelectionMode.ANY_WITHOUT_MOST_COMMON))
        print(merged_feature_df)

        print('distance')
        print(DistanceEvaluator(DistanceEvaluatorParameter()).evaluate_task_feature(original_task_feature))
        print(DistanceEvaluator(DistanceEvaluatorParameter()).evaluate_task_feature(applied_task_feature))

        print('breadth cost')
        HandMadeNodeEvaluator(DepthSearchPattern.BREADTH_FIRST, defaultdict(lambda: 1), NodeBaseSearchEngineParameter(), DistanceEvaluatorParameter()).evaluate_nodes([original_waiting_node, original_waiting_node])
        HandMadeNodeEvaluator(DepthSearchPattern.BREADTH_FIRST, defaultdict(lambda: 1), NodeBaseSearchEngineParameter(), DistanceEvaluatorParameter()).evaluate_nodes([original_waiting_node, applied_waiting_node])
        print(original_waiting_node.cache_pred_distance)
        print(applied_waiting_node.cache_pred_distance)

        print('normal cost')
        HandMadeNodeEvaluator(DepthSearchPattern.NORMAL, defaultdict(lambda: 1), NodeBaseSearchEngineParameter(), DistanceEvaluatorParameter()).evaluate_nodes([original_waiting_node, original_waiting_node])
        HandMadeNodeEvaluator(DepthSearchPattern.NORMAL, defaultdict(lambda: 1), NodeBaseSearchEngineParameter(), DistanceEvaluatorParameter()).evaluate_nodes([original_waiting_node, applied_waiting_node])
        print(original_waiting_node.cache_pred_distance)
        print(applied_waiting_node.cache_pred_distance)

        print('depth cost')
        HandMadeNodeEvaluator(DepthSearchPattern.DEPTH_FIRST, defaultdict(lambda: 1), NodeBaseSearchEngineParameter(), DistanceEvaluatorParameter()).evaluate_nodes([original_waiting_node, original_waiting_node])
        HandMadeNodeEvaluator(DepthSearchPattern.DEPTH_FIRST, defaultdict(lambda: 1), NodeBaseSearchEngineParameter(), DistanceEvaluatorParameter()).evaluate_nodes([original_waiting_node, applied_waiting_node])
        print(original_waiting_node.cache_pred_distance)
        print(applied_waiting_node.cache_pred_distance)

        plot_task_with_operation_set(task, operation_set, show=True, save_path=None)
        return True

    if DebugConfig.SOLVE_DEBUG_TASK_NAME:
        task = TaskLoader().get_task(DebugConfig.SOLVE_DEBUG_TASK_NAME)
        engine_result = solve_tasks([task], AllParameter(), add_answer_storage=True, verbose=True)[0]

        if isinstance(engine_result, AnsweredSearchResults):
            plot_task_with_result_set(task, engine_result, show=True, save_path=None)
        return True

    if DebugConfig.TRAIN_DATA_GENERATION_DEBUG_TASK_NAME:
        task = TaskLoader().get_task(DebugConfig.TRAIN_DATA_GENERATION_DEBUG_TASK_NAME)
        save_ml_training_data(task)
        train_ml(task)
        return True

    return False


def performance_run():
    # from line_profiler import LineProfiler
    # from python_utils.src.library.print_line_profiler import print_stats
    # from abstraction_and_reasoning_challenge import run as run_module
    # from abstraction_and_reasoning_challenge.src.domain import task_solver
    # from abstraction_and_reasoning_challenge.src.domain.search_engine.evaluation_functions import handmade_evaluator
    # from abstraction_and_reasoning_challenge.src.domain.search_engine.node import waiting_node
    # from abstraction_and_reasoning_challenge.src.domain.search_engine.node_processor import waiting_node_processor
    # from abstraction_and_reasoning_challenge.src.domain.feature import task_feature
    # from abstraction_and_reasoning_challenge.src.domain.search_engine.engine import node_base_search_engine
    # from abstraction_and_reasoning_challenge.src.domain.search_engine.engine import tree_base_search_engine
    #
    # profiler = LineProfiler()
    # profiler.add_module(run_module)
    # profiler.add_module(task_solver)
    # profiler.add_module(handmade_evaluator)
    # profiler.add_module(waiting_node)
    # profiler.add_module(waiting_node_processor)
    # profiler.add_module(task_feature)
    # profiler.add_module(node_base_search_engine)
    # profiler.add_module(tree_base_search_engine)
    #
    # profiler.runcall(run)
    # # profiler.print_stats()
    # stats = profiler.get_stats()
    # print_stats(stats, strip_seconds_limit=0., cost_sort=True)
    pass


performance_profiling_mode = False


if __name__ == '__main__':
    run()


<a id="rollback_the_predictions"></a>
# Rollback the predictions
[Back to Table of Content](#toc)

In [None]:
sub = pd.read_csv("./submission_yuki_alignment.csv")
print(sub.shape)
sub.head(3)

In [None]:
def get_string(pred):
    str_pred = str([list(row) for row in pred])
    str_pred = str_pred.replace(', ', '')
    str_pred = str_pred.replace('[[', '|')
    str_pred = str_pred.replace('][', '|')
    str_pred = str_pred.replace(']]', '|')
    return str_pred
    
def get_string_list(preds):
    return " ".join([get_string(pred) for pred in preds])

def rollback_row(r, test_aligned_tasks = test_aligned_tasks, debug=False):
    output_id = r["output_id"]
    output_aligned = str(r["output_aligned"])
    
    # |080000|808000|008088|000008| |0| |0| 
    if len(output_aligned) < 10:
        return "|00|00| |00|00| |00|00|"
    
    task_id = output_id.split("_")[0]
    order_id = int(output_id.split("_")[1])
    
    task_aligned = test_aligned_tasks[task_id]
    sample_aligned = task_aligned['test'][order_id]
    
    predictions_aligned = output_aligned.split(" ")
    def str2list(s):
        return [int(d) for d in s]
    predictions_aligned = [[str2list(s) for s in pred.split("|")[1:-1]] \
                               for pred in predictions_aligned if len(pred) > 5]
    
    predictions = []
    modified = False
    for pred_aligned in predictions_aligned:
        pred = np.array(pred_aligned)
        if sample_aligned['fliplr']:
            pred = np.fliplr(pred)
            modified = True
        if sample_aligned['flipud']:
            pred = np.flipud(pred)
            modified = True
        if sample_aligned['rot90']:
            pred = np.rot90(pred, k=3)
            modified = True
        predictions.append(pred.tolist())

    output_final = get_string_list(predictions)
    if debug and modified:
        print(task_id, order_id)
    return output_final
    
def rollback_sub(sub):
    sub2 = sub.copy()
    sub2["output_aligned"] = sub2["output"]
    sub2["output"] = sub2.apply(lambda r: rollback_row(r), axis=1) 
    sub2["is_modified"] = sub2.apply(lambda r: 1 if r["output_aligned"] != r["output"] else 0, axis=1) 
    return sub2
    
sub2 = rollback_sub(sub)
print(sub2["is_modified"].sum())
sub2.head(3)

In [None]:
sub2[["output_id", "output"]].to_csv("./submission_yuki_rollback.csv", index=None)

In [None]:
sub2 = sub2[["output_id", "output"]]
sub2.set_index('output_id', inplace=True)

In [None]:
sample_submission = pd.read_csv('/kaggle/input/abstraction-and-reasoning-challenge/sample_submission.csv', index_col='output_id')

for idx, row in sample_submission.iterrows():
    if idx in sub2.index:
        sample_submission.loc[idx, 'output'] = sub2.loc[idx, 'output']

sample_submission.to_csv('submission.csv')