In [1026]:
import copy
from IPython.display import clear_output
import time
from collections import defaultdict
import numpy as np

In [1044]:
class Amp():
    def __init__(self, group, col, row):
        self.group = group
        self.col = col        
        self.row = row
        self.moves = 0      
    
    def position(self):
        return (self.col, self.row)
    
    def moveTo(self, position):
        distanceMoved = abs(self.col - position[0]) + self.row + position[1]
        self.col = position[0]
        self.row = position[1]
        self.moves += 1
        return 10**(ord(self.group)-65) * distanceMoved
    
    def moveEnergy(self, position):
        distanceMoved = abs(self.col - position[0]) + self.row + position[1]
        return 10**(ord(self.group)-65) * distanceMoved    
    
    def homeCol(self):
        return (ord(self.group)-64)*2    
    
    def __eq__(self, other):
        return self.position() == other.position() and self.group == other.group and self.moves == other.moves
    
    def __hash__(self):
        return self.position().__hash__() + self.group.__hash__()
    
    def __repr__(self):
        return self.group + str(self.position()) + str(self.moves)
    
def printAmps(amps):
    for row in range(5):    
        for col in range(11):
            ampAtPos = [a for a in amps if (col, row) == a.position()]
            if len(ampAtPos) > 0:
                print(ampAtPos[0].group, end="")
            else:
                print(" ", end="")
        print("")    

In [1053]:
testAmpsFirst = [Amp(group, int(i/2)*2+2, i%2+1) for i, group in enumerate('BACDBCDA')]
testAmpsSecond = [Amp(group, int(i/4)*2+2, i%4+1) for i, group in enumerate('BDDACCBDBBACDACA')]

originalAmpsFirst = [Amp(group, int(i/2)*2+2, i%2+1) for i, group in enumerate('DCDCABAB')]
originalAmpsSecond = [Amp(group, int(i/4)*2+2, i%4+1) for i, group in enumerate('DDDCDCBCABABAACB')]

originalAmps = tuple(originalAmpsSecond)

In [1054]:
solutions = []
queue = defaultdict(lambda: ([], np.inf))
queue[originalAmps] = ([], 0)
i = 0
maxRow = max([amp.row for amp in originalAmps])
while len(queue) > 0:
    i += 1             
        
    amps = next(iter(queue))
    (moves, energy) = queue.pop(amps)
    qLen = len(queue)              
    
    for ampIndex, amp in enumerate(amps):
        if amp.moves >= 2:
            continue
                  
        possiblePos = []    
        for step, limit in [(1,10), (-1, 0)]:            
            if len(possiblePos) > 0 and possiblePos[0][1] > 0: #already can go to home bay
                continue
            
            col = amp.col
            while col != limit:
                col += step        
                if (col, 0) in [otherA.position() for otherA in amps]:
                    break #There is an amphipod at this spot
                if len([otherA for otherA in amps if otherA.col == amp.col and otherA.row < amp.row]) > 0:            
                    break #There is an amphipod in the way out of the hole

                if col == amp.homeCol(): #move to home bay
                    sameGroupAmpsInCol = [otherAmp.homeCol() == amp.homeCol() for otherAmp in amps if otherAmp.col == amp.homeCol()]              
                    if all(sameGroupAmpsInCol):       
                        #if it can move to home bay, don't try anything else
                        possiblePos = [(col, maxRow - len(sameGroupAmpsInCol))]
                        break
                elif col not in [2,4,6,8] and amp.moves == 0: #move to row 0
                    possiblePos.append((col, 0))  
                
        for pos in possiblePos:
            ampsCpy = copy.deepcopy(amps)
            move = [ampsCpy[ampIndex].position(), pos]
            moveEnergy = ampsCpy[ampIndex].moveTo(pos)
            (_, otherEnergy) = queue[ampsCpy]
            if energy+moveEnergy < otherEnergy:
                queue[ampsCpy] = (moves + [move], energy+moveEnergy)
            
    if i%100 == 0:
        clear_output(wait=True)             
        print("Moves: "+ str(len(moves)))
        print(len(queue))
        print(len(solutions))   
        printAmps(ampsCpy)                
    
    if qLen == len(queue):    
        if all([amp.homeCol() == amp.col for amp in amps]):
            solutions.append((moves, energy))

Moves: 30
3
4298
           
  A B C D  
  A B C D  
  A B C D  
  A B C D  


In [1055]:
min([energy for (moves, energy) in solutions])

43413