In [10]:
import os
from enum import IntEnum
import import_ipynb
from maze import Maze
import numpy as np
from typing import Union

class Puddle(IntEnum):
    Dry, Small, Large = range(3)
    
class GridBase():
    
    # 미로가 정의되어 있는지
    maze = None
    # 웅덩이
    Puddles = None
    # base에 존재하는 모든 영역
    base_areas = []
    # 그리드에 존재하는 영역
    grid_areas = []
    
    debug_maze = False
    
    # 전체 그리드에 대한 계산된 보장
    grid_rewards = []
    
    def __init__(self, workong_directory: str=".", **kwargs:dict ):

        self.working_directory = working_directory

        # 현재 Google Colab에서 실행 여부 확인
        self.drawmode = kwargs.get('drawmode','colab' if 'COLAB_GPU' in os.environ else "")

        self.width = kwargs.get('width', 3)
        self.height = kwarfs.get('height', 3)

        #시작과 끝지점 -> 왼쪽상단과 오른쪽 하단
        self.start = kwargs.get('start',[0,0])
        self.end = kwargs.get('end', [self.width-1, self.height -1])

        # 웅덩이 설정
        self.puddles = kwargs.get('puddles', None)

        # 웅덩이에 대해 정의된 속성
        puddle_props = kwargs.get('puddles_props',{})
        self.large_puddle_reward = puddle_props.get('large_reward', -4)
        self.small_puddle_reward = puddle_props.get('small_reward', -2)
        self.large_puddle_probability = puddle_props.get('large_prob', 0.4)
        self.small_puddle_probability = puddle_props.get('small_prob', 0.6)

        self.base_areas = kwargs.get('base_areas',[])

        self.grid_areas = kwargs.get('grid_areas',[])

        # 모든 미로와 벽 설정
        self.add_maze = kwargs.get('add_maze', False)
        self.maze_seed = kwargs.get('maze_seed', 0)
        self.make_maze()
        self.toggle_walls(kwargs.get('wall',[]))

        # 보상 설정
        self.grid_rewards = self.get_reward()

    def make_maze(self):
        if self.add_maze:
            if self.maze is None:
                self.maze = Maze(self.width, self.height, self.start[0], self.start[1], seed =self.maze_seed)
                self.maze.make_maze()
            if self.debug_maze:
                self.maze.write_svg(os.path.join(self.working_directory, 'maze.svg'))

    # 지정된 벽을 추가하거나 제거한다.
    def toggle_walls(self, walls):

        if self.maze is None:
            self.maze = Maze(self.width, self.height, self.start[0], self.start[1], no_walls = True)
            self.add_maze

        for (loc), direction in walls:
            x = loc[0]
            y = loc[1]
            if len(loc) == 3:
                num_cells = loc[2]
            else:
                num_cells = 1

        for n in range(num_cells):

            if x >= self.width or y >= self.height:
                break

            current_cell = self.maze.cell_at(x,y)

            if direction == 'E': 
                next_cell = self.maze.cell_at(x+1,y)
            elif direction == 'W': 
                next_cell = self.maze.cell_at(x-1,y)
            elif direction == 'N': 
                next_cell = self.maze.cell_at(x,y-1)
            elif direction == 'S': 
                next_cell = self.maze.cell_at(x,y+1)

            # 이미 벽이 없으면 새 벽을 추가하고 그렇지 않으면 제거한다.
            current_cell.toggle_wall(next_cell, direction)

            # 여러 셀에서 반복되는 벽의 다음 셀로 이동
            if direction == 'E' or direction =='W': 
                y += 1
            else:
                x+= 1

    def get_puddle_size(self,x,y):

        # 지정한 위치에서 웅덩이의 크기를 얻는다.
        if self.puddles is not None:
            # Puddle에서 list가 True면 반환
            if isinstance(self.puddles[0], list):
                return Puddle(self.puddles[y][x])
            else:
                for (px,py), puddle_size in self.puddles:
                    if x == px and y == py:
                        return Puddle(puddle_size)
        return Puddle.Dry

    # (x,y) 현재 상태에서 시작할 때 대상 상태로 이동할 확률를 얻음.
    def get_transition_probability(self, x ,y):
        puddle_size = self.get_puddle_size(x,y)

        if puddle_size == Puddle.Large:
            return self.large_puddle_probability
        if puddle_size == Puddle.Small:
            return self.small_puddle_probability

        # 웅덩이가 없으면 1을 반환
        return 1.

    def get_reward(self, x:int = None, y:int = None) -> Union[int,np.ndarray]:

        # 지정된 그리드 셀에 대한 보상을 반환
        if (x is None) or (y is None):
            return self.get_reward_array()
        else:
            return self.get_reward_value(x,y)

    def get_reward_array(self) -> np.ndarray:

        if len(self.grid_rewards) == 0:
            height = self.height
            width = self.width
            reward_arr = np.zeros((height, width)).astype(int)
            for y in range(height):
                for x in range(width):
                    reward_arr[y][x] = self.get_reward_value(x,y)
            return reward_arr

        return self.grid_rewards

    def test_for_base_area(self,x,y):

        for area in self.base_areas:
            if type(area[0]).__name__ == 'int':
                ax,ay,aw,ah = self.get_area_defn(area)
            else:
                ax,ay,aw,ah = self.get_area_defn(area[0])
            return self.in_area(x,y,ax,ay,aw,ah)

    def get_reward_value(self,x,y):

        # 지정된 상태로 이동하여 얻은 보상을 반환한다.
        # 웅덩이 없음 -> -1, 작은 웅덩이 -> -2, 큰 웅덩이 -> -4

        if len(self.grid_rewards) > 0:
            return self.grid_rewards[y,x]

        puddle_size = self.get_puddle_size(x,y)

        if puddle_size == Puddle.Large:
            return self.large_puddle_reward
        elif puddle_size == Puddle.Small:
            return self.small_puddle_reward

    # 리워드가 존재하지 않으면 0을 반환
        if self.test_for_base_area(x,y):
            return 0

        cell_reward = -1
        for area in self.grid_area:
            if len(area) > 2:
                try:
                    ax,ay,aw,ah = self.get_area_defn(area[0])
                    if self.in_area(x,y,ax,ay,aw,ah):
                        cell_reward = area[2]
                except:  
                    pass

        return cell_reward
 
    def get_area_defn(self, area):
        x,y,*args = area
        wd,ht = arfs if args else (1,1)
        return x,y,wd,ht
    
    def in_area(self,x,y,ax,ay,aw,ah):
        if ( x>=ax and x < (ax+aw)) and ((y >= ay and y < (ay+ah))):
            return True
        return False
 