In [28]:
from typing import List, Dict, Optional, Set, Tuple
from collections import deque, Counter
from sortedcontainers import SortedList
import random
import math

# support for maintaining a list in sorted order
# without having to sort the list after each insertion
import bisect
from heapq import (
    heappop,
    heappush,
    heapify,
    heapreplace,
    nlargest,
    heappushpop,
)


class Solution:
    def countSubIslands(
        self, grid1: List[List[int]], grid2: List[List[int]]
    ) -> int:
        m = len(grid1)
        n = len(grid1[0])
        island_grid = [[0] * n for _ in range(m)]

        def get_forward_adjacent_islands(row: int, col: int):
            if row == col == 0:
                return []
            if row == 0:
                if grid2[row][col - 1] == 1:
                    return [island_grid[row][col - 1]]
                return []
            if col == 0:
                if grid2[row - 1][col] == 1:
                    return [island_grid[row - 1][col]]
                return []
            return [
                island_grid[i][j]
                for i, j in [(row, col - 1), (row - 1, col)]
                if grid2[i][j] == 1
            ]

        num_islands = 0
        parent_dict = {}
        is_sub_island_dict = {}

        def get_parent(elem: int):
            if parent_dict[elem] == elem:
                return elem
            parent_dict[elem] = get_parent(parent_dict[elem])
            return parent_dict[elem]

        def join_islands(island1: int, island2: int, is_grid1_island: bool):
            parent1 = get_parent(island1)
            parent2 = get_parent(island2)

            if parent1 < parent2:
                parent1, parent2 = parent2, parent1

            parent_dict[parent1] = parent2
            if not is_grid1_island or not is_sub_island_dict[parent1]:
                is_sub_island_dict[parent2] = False

        for row in range(m):
            for col in range(n):
                if grid2[row][col] == 1:
                    adjacent_islands = get_forward_adjacent_islands(row, col)
                    if len(adjacent_islands) == 0:
                        num_islands += 1
                        island_grid[row][col] = num_islands
                        parent_dict[num_islands] = num_islands
                        is_sub_island_dict[num_islands] = grid1[row][col] == 1
                    elif len(adjacent_islands) == 1:
                        island_grid[row][col] = adjacent_islands[0]
                        if grid1[row][col] == 0:
                            is_sub_island_dict[
                                get_parent(adjacent_islands[0])
                            ] = False
                    else:
                        island_grid[row][col] = min(adjacent_islands)
                        join_islands(
                            adjacent_islands[0],
                            adjacent_islands[1],
                            grid1[row][col] == 1,
                        )

        res = 0
        print(parent_dict)
        print(is_sub_island_dict)
        for island in range(1, num_islands + 1):
            parent = get_parent(island)
            if island == parent and is_sub_island_dict[parent]:
                res += 1
        print(parent_dict)
        print(is_sub_island_dict)
        return res

In [29]:
grid1 = [
    [1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1],
    [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1],
    [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0],
    [1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1],
    [0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0],
    [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1],
    [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
    [0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
    [1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1],
    [1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1],
    [0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1],
]
grid2 = [
    [1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1],
    [1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1],
    [1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1],
    [1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1],
    [0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0],
    [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1],
    [1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1],
    [0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1],
    [1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1],
    [1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0],
    [0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0],
]
Solution().countSubIslands(grid1, grid2)

{1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 3, 8: 8, 9: 5, 10: 3, 11: 3, 12: 3, 13: 12, 14: 3, 15: 3, 16: 3, 17: 3, 18: 18, 19: 3, 20: 18, 21: 3, 22: 3, 23: 3, 24: 3, 25: 25}
{1: True, 2: True, 3: False, 4: True, 5: False, 6: True, 7: False, 8: False, 9: True, 10: True, 11: False, 12: False, 13: False, 14: False, 15: False, 16: False, 17: True, 18: True, 19: True, 20: True, 21: False, 22: True, 23: True, 24: False, 25: True}
{1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 3, 8: 8, 9: 5, 10: 3, 11: 3, 12: 3, 13: 3, 14: 3, 15: 3, 16: 3, 17: 3, 18: 18, 19: 3, 20: 18, 21: 3, 22: 3, 23: 3, 24: 3, 25: 25}
{1: True, 2: True, 3: False, 4: True, 5: False, 6: True, 7: False, 8: False, 9: True, 10: True, 11: False, 12: False, 13: False, 14: False, 15: False, 16: False, 17: True, 18: True, 19: True, 20: True, 21: False, 22: True, 23: True, 24: False, 25: True}


6