In [None]:
# Day 23
#
# https://adventofcode.com/2021/day/23

from copy import *

num_pawn_types = 4

#list of pawn names with pawntype (0-3) as index
P = [chr(ord('A')+p) for p in range(num_pawn_types)]
#print("P",P)

#movement cost with pawntype (0-3) as index
C = [10**p for p in range(num_pawn_types)]
#print("C",C)

steps = { 'L':(-1,0), 'R':(1,0), 'D':(0,1), 'U':(0,-1)}



######################
#
#  Map operations
#
######################

#Insert agents in map
def addAgents(M, agents, time=None):
    m = deepcopy(M)
    for t,x,y,h,v,r in agents:
        if not time is None:
            x,y = r[time]
        m[y][x] = P[t]    
    return m
    

#Render map as text
def render(M):
    t=""
    for l in M:
        t += "".join(l) + "\n"
    return t

def p(M):
    print("".join(M))

######################
#
#  Floodfill distances to prioritize distances to goal
#
######################

def flood(m, p, d=0):
    x,y = p
    m[y][x] = str(d)
    for dx,dy in steps.values():
        xx = x+dx
        yy = y+dy
        if m[yy][xx] == ' ':
            m = flood(m, (xx, yy), d+1)
    return m


def make_costmaps(m,t):
    cms = []
    for y in reversed(range(len(m))):
        for x in range(len(m[y])):
            if t[y][x] in [str(x) for x in range(num_pawn_types)]:
                cms.append(flood(deepcopy(m), (x,y)))
        if len(cms) > 0:
            return cms

    
    
######################
#
#  Read inputdata
#
######################

#return clean_map, type_map and start state    
def parse(text):
    #print(text)
    S = []
    M = [list(l) for l in text.split("\n")]
    T = deepcopy(M)
    for y in range(len(M)):
        iii = 0
        for x in range(len(M[y])):
            #print(y,x)
            if M[y][x] == ".":
                M[y][x]=" "
                T[y][x]="h"
            
            if M[y][x] in P:
                #create pawn - type, x, y, steps, history, steps, recoord
                c=(ord(M[y][x])-ord('A'),x,y,[],[],[])
                S.append(c)
                
                #mark home tile in type map
                T[y][x] = str(iii)
                iii+=1
                #mark tile as free in clean map
                M[y][x]=" "
                
                #mark nostop zones in type map
                T[1][x] = 'n'
            
            
            
    return M, T, S

######################
#
#  State operations
#
######################

    
#Serialize part of state, so we can recognize situations 
#we have been in before and filter their proposal 
def serialize(state):
    R = []
    state.sort(key=lambda x: int(((x[0]*100+x[2])*100+x[1])))
    for t,x,y,h,v,r in state:
        st="%c,%d,%d"%(P[t],x,y)
        if len(h):
            st += h[-1]
        R.append(st)
    return ",".join(R)


######################
#
#  Unfold moves
#
######################


#'i' idle
#' ' regular
#'n' nostop
#'g' goal
#
def legal(m,T,agents):
    
    occupancy = {}
    occupancy_before = {}
    for a in agents:
        t,x,y,h,v,r = a
        occupancy[(x,y)] = a
        if len(r) > 0:
            occupancy_before[r[-1]] = a
    
    for t,x,y,h,v,r in agents:
        H = "".join(h)
        
        #V = "".join(v)
        
        #No stopping in front of chambers
        if "n " in H:
            #print("No stopping in front of chambers")
            return False
        
        #Only one stop in hallway allowed
        if H.replace("n", "h").count("h ") > 1:
            #print("Only one stop allowed in hallway")
            return False
        
        if v[-1] == "D":
            if h[-1] != str(t):
                #print("no moving down in foreign rooms")
                return False
            else:
                below = (x,y+1)
                if below in occupancy:
                    _t,_x,_y,_h,_v,_r = occupancy[below]
                    if t != _t:
                        #print("dont enter own room when strangers")
                        return False
        else:
            below = (x,y+1)
            if T[y][x] == str(t) and T[y+1][x] == str(t) and not below in occupancy:
                #print("dont miss an oppertunity to go deep")
                return False
        
        
        #dont miss an oppertunity to move out        
        #if v[-1] != "U":
        #    above = (x,y-1)
        #    if T[y-1][x] in [str(x) for x in range(num_pawn_types)] and T[y-1][x] != str(t) and above not in occupancy_before:
        #        print("dont miss an oppertunity to move forward in foreign dens")
        #        p(render(addAgents(m,agents)))
        #        return False
    return True


def cost(agents, costmaps=None):
    acc=0
    for t,x,y,h,v,r in agents:
        for z in h:
            if z != ' ':
                #every step counts with the type cost
                acc += C[t]
        if not costmaps is None:
            acc += int(costmaps[t][y][x])*C[t]
    return acc


def isgoal(T,agents):
    #Goal is when all players stand on a home tile of their type
    for t,x,y,h,v,r in agents:
        if T[y][x] != str(t):
            return False
    return True



#fold current state int possible 
#new states
def nextstates(M,T,agents):
    m=addAgents(M, agents)
    #p(render(T))
    #p(render(m))
    nextstates=[]
    #every agent has a chance
    for i in range(len(agents)):
        t,x,y,h,v,r=agents[i]
        #print("agent", i, P[t], x, y, h,v)
        
        #Walk through steps
        for stp in steps.keys():
            dx,dy = steps[stp]
            xx=x+dx
            yy=y+dy
            #is move available
            #print("tps:",['#']+P)
            #print("m:", m[yy][xx])
            if m[yy][xx] == ' ':
                move = (t,xx,yy,h + [T[yy][xx]], v+[stp], r + [(x,y)])
                ns = [move]
                for j in range(len(agents)):
                    if i!=j:
                        _t,_x,_y,_h,_v,_r=agents[j]
                        ns.append((_t,_x,_y,_h+[' '], _v+[' '], _r + [(_x,_y)]))
                if legal(M,T,ns):
                    assert(len(ns) == len(agents))
                    nextstates.append(ns)
    return nextstates
           
                    
######################
#
#  Playback
#
######################

def playback(M,S):
    for i in range(len(S[0][-1])):
        m = deepcopy(M)
        p(render(addAgents(m,S,i)))
    m = deepcopy(M)
    p(render(addAgents(m,S)))


######################
#
#  Solver
#
######################

def solve(fn, facit, pb=False, debug=False):
    print("starting search on ", fn)
    
    
    blacklist={}
    
    M, T, start = parse(open(fn).read())

    costmaps = make_costmaps(M,T)
    
    p(render(addAgents(M, start)))

    if debug:
        print("empty map")
        p(render(M))

        print("type map")
        p(render(T))
        
        for i in range(len(costmaps)):
            print("costmap for %c" % (P[i]) )
            p(render(costmaps[i]))


    e = [(cost(start,costmaps),start)]
    i = 0
    while len(e):
            
        c,s = e.pop(0)
        if debug and i % 10000 == 0:
            print("%d - cost:%d e:%d bl:%d"%(i, cost(s, costmaps), len(e), len(blacklist.keys())))
            p(render(addAgents(M,s)))
            
        if isgoal(T,s):
            c = cost(s, costmaps)
            print("!!!GOAL!!! (fn:%s cost:%d, steps:%d)"%(fn, c, i))
            if pb:
                playback(M,s)
                print("!!!GOAL!!! (fn:%s cost:%d, steps:%d)"%(fn, c, i))
            if not facit is None:
                if c == facit:
                    print("Which is correct!!!")
                else:
                    print("It seems incorrect")
                #    assert(0)
            
            print("-"*30)
            return c 
        i += 1
        #if i > 100000:
        #    break
        sns = serialize(s)
        blacklist[sns] = 1   
        
        for n in nextstates(M,T,s):
            if not serialize(n) in blacklist:
                e.append((cost(n, costmaps),n))
            
        
        #sort situations so lowest cost is evaluated first
        e.sort(key=lambda x: x[0])
            
    print(fn, "no result")
    print("-"*30)
    return Null    
    
solve("i23_test0.txt",    46, False, False)
solve("i23_test1.txt",   123, False, False)
solve("i23_test2.txt",   608, False, False)
solve("i23_test3.txt",   517, False, False)
solve("i23_test4.txt",  8470, False, False)
solve("i23_test5.txt", 12521, False,  False)
solve("i23.txt",           1, False,  False)

print("fin")


starting search on  i23_test0.txt
#######
#     #
##B#A##
 #####

!!!GOAL!!! (fn:i23_test0.txt cost:46, steps:27)
Which is correct!!!
------------------------------
starting search on  i23_test1.txt
#########
#       #
###A#B###
  #B#A#
  #####

!!!GOAL!!! (fn:i23_test1.txt cost:123, steps:664)
Which is correct!!!
------------------------------
starting search on  i23_test2.txt
###########
#         #
###C#B#A###
  #A#B#C#
  #######

!!!GOAL!!! (fn:i23_test2.txt cost:719, steps:51)
It seems incorrect
------------------------------
starting search on  i23_test3.txt
###########
#         #
###B#C#B###
  #A#A#C#
  #######

!!!GOAL!!! (fn:i23_test3.txt cost:628, steps:465)
It seems incorrect
------------------------------
starting search on  i23_test4.txt
#############
#           #
###D#C#B#A###
  #A#B#C#D#
  #########

!!!GOAL!!! (fn:i23_test4.txt cost:9581, steps:1521)
It seems incorrect
------------------------------
starting search on  i23_test5.txt
#############
#           #
###B#C#