In [1]:
from collections import deque
from pprint import pprint


class Solution:

    def minimalSteps(self, maze) -> int:
        def neighbors(i, j):
            for ni, nj in [(i+1, j), (i-1, j), (i, j+1), (i, j-1)]:
                if 0 <= ni < M and 0 <= nj < N:
                    yield ni, nj

        def calc_dist(i, j):
            dist = [[float('inf')]*N for _ in range(M)]
            dist[i][j] = 0
            q = deque([(i, j, 0)])
            while q:
                x, y, d = q.popleft()
                for ni, nj in neighbors(x, y):
                    if maze[ni][nj] != '#' and d + 1 < dist[ni][nj]:
                        dist[ni][nj] = d + 1
                        q.append((ni, nj, d+1))
            return dist

        M, N = len(maze), len(maze[0])
        stones = []
        machines = []
        for i in range(M):
            for j in range(N):
                ch = maze[i][j]
                if ch == 'S':
                    start = (i, j)
                elif ch == 'T':
                    end = (i, j)
                elif ch == 'O':
                    stones.append((i, j))
                elif ch == 'M':
                    machines.append((i, j))

        # regard start as a normal machine
        machines = [start] + machines
        dists = [calc_dist(i, j) for i, j in machines]
        end_dist = calc_dist(*end)

        machine_num = len(machines)

        # shortest path between machines
        edges = [[float('inf')]*machine_num for _ in range(machine_num)]
        for cur, cur_md in enumerate(dists):
            for nxt in range(cur+1, machine_num):
                nxt_md = dists[nxt]
                ci, cj = machines[cur]
                ni, nj = machines[nxt]

                nd = float('inf')
                for si, sj in stones:
                    nd = min(nd, cur_md[si][sj] + nxt_md[si][sj])

                edges[cur][nxt] = edges[nxt][cur] = nd

        state_num = 1 << machine_num
        dp = [[float('inf')] * state_num for _ in range(machine_num)]
        dp[0][1] = 0

        for s in range(1, state_num):
            for cur in range(machine_num):
                if dp[cur][s] == float('inf'):
                    continue
                for nxt in range(machine_num):
                    if s >> nxt & 1:
                        continue
                    dp[nxt][s | 1 << nxt] = min(dp[nxt][s | 1 << nxt],
                                                dp[cur][s] + edges[cur][nxt])

        ans = min(end_dist[i][j] + dp[cur][state_num-1]
                  for cur, (i, j) in enumerate(machines))
        return -1 if ans == float('inf') else ans

In [2]:
from typing import List
from collections import deque
from functools import lru_cache


class Solution:
    def minimalSteps(self, maze: List[str]) -> int:
        s_i, s_j = -1, -1
        t_i, t_j = -1, -1
        m_pos = []      # 机关
        o_pos = []      # 石堆

        m, n  = len(maze), len(maze[0])
        for i in range(m):
            for j in range(n):
                ch = maze[i][j]
                if ch == 'M':
                    m_pos.append((i, j))
                elif ch == 'O':
                    o_pos.append((i, j))
                elif ch == 'S':
                    s_i, s_j = i, j
                elif ch == 'T':
                    t_i, t_j = i, j

        s2o_dist = [self.min_dist(maze, s_i, s_j, i, j) for i, j in o_pos] # 起点到每个石堆的最短距离
        t2m_dist = [self.min_dist(maze, t_i, t_j, i, j) for i, j in m_pos] # 终点到每一个机关的最短距离
        m2o_dist = {(m_idx, o_idx): self.min_dist(maze, m_i, m_j, o_i, o_j) for m_idx, (m_i, m_j) in  enumerate(m_pos) for o_idx, (o_i, o_j) in enumerate(o_pos)} # 每一个机关到每一个石堆的最短距离
        s2t_dist = self.min_dist(maze, s_i, s_j, t_i, t_j) # 起点到终点的最短距离

        # 先提前计算好两个机关之间通过一个石堆中转的最小开销, 后面DP要用
        m2m_dist = {}
        for i in range(len(m_pos)):
            for j in range(i+1, len(m_pos)):
                min_sum = 0x7fffffff
                for o_idx in range(len(o_pos)):
                    if m2o_dist[(i, o_idx)] != -1 and m2o_dist[(j, o_idx)] != -1:
                        min_sum = min(min_sum, m2o_dist[(i, o_idx)] + m2o_dist[(j, o_idx)])
                if min_sum == 0x7fffffff:
                    min_sum = -1

                m2m_dist[(i, j)] = min_sum
                m2m_dist[(j, i)] = min_sum

        # 提前计算起点到下一个机关之间通过一个石碓中转最小开销，后面DP使用
        s2m_dist = [0 for _ in range(len(m_pos))]
        for m_idx in range(len(m_pos)):
            min_sum = 0x7fffffff
            for o_idx in range(len(o_pos)):
                if s2o_dist[o_idx] != -1 and m2o_dist[(m_idx, o_idx)] != -1:
                    min_sum = min(min_sum, s2o_dist[o_idx] + m2o_dist[(m_idx, o_idx)])
            s2m_dist[m_idx] = -1 if min_sum == 0x7fffffff else min_sum

        '''
        后面的问题是一个DP求最佳策略问题
        假设机关总共有x个
        求一个距离和最小的序列
        S O1 M1 O2 M2 ..... Ox Mx T
        其中S, T表示起点和终点位置，不可变动
        Oi Mi表示中间经过的石堆位置和机关位置, 可以从所有的可能的O和M中选择，但是M的数值不能在序列里面重复
        M总共只有16个，所以可以用位运算压缩状态的表示, 数位为1表示对应下标的M已经选择过了，反之表示没有选择过
        '''

        x = len(m_pos)

        # M 的选择状态为stat的情况下，当前对序列cur_pos位置的数值进行决策后的最小距离和, cur_idx表示当前序列位置选择的M或者O的下标
        @lru_cache(typed=False, maxsize=999000000)
        def dp(stat, cur_idx):
            one_cnt = 0
            stat_val = stat
            while stat_val:
                one_cnt += 1
                stat_val &= (stat_val - 1)
            cur_pos = one_cnt * 2

            if cur_pos == 2*x:
                # 后一个位置就是终点
                return s2t_dist if cur_pos == 0 else t2m_dist[cur_idx]

            else:

                if cur_pos != 0:
                    min_dis_sum = 0x7fffffff

                    possible_m = []
                    for m_idx in range(len(m_pos)):
                        if stat & (1 << m_idx):
                            continue

                        min_sum = m2m_dist[(cur_idx, m_idx)]
                        if min_sum == -1:
                            continue

                        possible_m.append((min_sum, m_idx))
                    possible_m.sort()

                    for min_sum, m_idx in possible_m:
                        if min_sum >= min_dis_sum:
                            continue

                        dis = dp(stat | (1 << m_idx), m_idx)
                        if dis == -1:
                            continue

                        if dis > min_dis_sum:
                            continue

                        min_dis_sum = min(min_dis_sum, dis + min_sum)


                    return -1 if min_dis_sum == 0x7fffffff else min_dis_sum

                else:
                    min_dis_sum = 0x7fffffff

                    possible_m = []
                    for m_idx in range(len(m_pos)):
                        min_sum = s2m_dist[m_idx]
                        if min_sum == -1:
                            continue
                        possible_m.append((min_sum, m_idx))
                    possible_m.sort()

                    for min_sum, m_idx in possible_m:
                        if min_sum >= min_dis_sum:
                            continue

                        dis = dp(stat | (1 << m_idx), m_idx)
                        if dis == -1:
                            continue

                        if dis > min_dis_sum:
                            continue

                        min_dis_sum = min(min_dis_sum, dis + min_sum)

                    return -1 if min_dis_sum == 0x7fffffff else min_dis_sum

        return dp(0, 0x7fffffff)

    def min_dist(self, maze, start_i, start_j, end_i, end_j):
        m, n = len(maze), len(maze[0])
        que = deque()
        que.append((0, start_i, start_j))
        best_stat = {(start_i, start_j): 0}

        while len(que) > 0:
            cost, i, j = que.popleft()
            #print(i, j, cost)

            if i == end_i and j == end_j:
                return cost

            for ii, jj in [(i-1, j), (i+1, j), (i, j-1), (i, j+1)]:
                if ii >= 0 and ii < m and jj >= 0 and jj < n and maze[ii][jj] != '#':
                    #print(ii, jj)
                    if (ii, jj) not in best_stat or cost + 1 < best_stat[(ii, jj)]:
                        que.append( (cost+1, ii, jj) )
                        best_stat[(ii, jj)] = cost + 1

        #print('end')
        return -1