In [1]:
import numpy as np
import matplotlib.pyplot as plt

import numpy as np
from functools import reduce
from typing import List, Tuple
from collections import namedtuple

MAXSIDE = 100
MAXAREA = 40 * 40
MAXPIXELS = 40 * 40 * 5

Point = namedtuple('Point', ['x', 'y'])

class Image:
    def __init__(self, x=0, y=0, w=0, h=0, mask=None):
        self.x = x
        self.y = y
        self.w = w
        self.h = h
        if mask is None:
            self.mask = np.zeros((h, w), dtype=np.int8)
        else:
            self.mask = np.array(mask, dtype=np.int8).reshape(h, w)

    def __getitem__(self, idx):
        i, j = idx
        return mask[i, j]

    def __setitem__(self, idx, value):
        i, j = idx
        mask[i, j] = value

    def safe(self, i, j):
        if i < 0 or j < 0 or i >=  h or j >=  w:
            return 0
        return self.mask[i, j]

    def __eq__(self, other):
        return np.array_equal(self.mask, other.mask) and self.x == other.x and self.y == other.y and self.w == other.w and self.h == other.h

    def __ne__(self, other):
        return not self.__eq__(other)

    def __lt__(self, other):
        if (self.w, self.h) != (other.w, other.h):
            return (self.w, self.h) < (other.w, self.h)
        return self.mask.flatten().tolist() < other.mask.flatten().tolist()
    
    def col_mask(self) -> int:
        mask = 0
        for i in range(self.h):
            for j in range(self.w):
                mask |= 1 << self[i, j]
        return mask
    
    def count_cols(self, include0: int = 0) -> int:
        mask = self.col_mask()
        if not include0:
            mask &= ~1
        return bin(mask).count('1')
    
    def count(self) -> int:
        ans = 0
        for i in range(self.h):
            for j in range(self.w):
                ans += self[i, j] > 0
        return ans
    
    @staticmethod
    def full(p: Point, sz: Point, filling: int = 1) -> 'Image':
        return Image(p.x, p.y, sz.x, sz.y, [[filling] * sz.x for _ in range(sz.y)])

    @staticmethod
    def empty(p: Point, sz: Point) -> 'Image':
        return Image.full(p, sz, 0)

    @staticmethod
    def is_rectangle(img: 'Image') -> bool:
        return img.count() == img.w * img.h

    def count_components_dfs(self, r: int, c: int):
        self[r, c] = 0
        for nr in range(r - 1, r + 2):
            for nc in range(c - 1, c + 2):
                if 0 <= nr < self.h and 0 <= nc < self.w and self[nr, nc]:
                    self.count_components_dfs(nr, nc)

    def count_components(self) -> int:
        ans = 0
        for i in range(self.h):
            for j in range(self.w):
                if self[i, j]:
                    self.count_components_dfs(i, j)
                    ans += 1
        return ans

    def majority_col(self, include0: int = 0) -> int:
        cnt = [0] * 10
        for i in range(self.h):
            for j in range(self.w):
                c = self[i, j]
                if 0 <= c < 10:
                    cnt[c] += 1
        if not include0:
            cnt[0] = 0
        ret, ma = max(enumerate(cnt), key=lambda x: x[1])
        return ret

    def sub_image(self, p: Point, sz: Point) -> 'Image':
        assert p.x >= 0 and p.y >= 0 and p.x + sz.x <= self.w and p.y + sz.y <= self.h and sz.x >= 0 and sz.y >= 0
        ret = Image(p.x, p.y, sz.x, sz.y)
        for i in range(sz.y):
            for j in range(sz.x):
                ret[i, j] = self[p.y + i, p.x + j]
        return ret

    def split_cols(self, include0: int = 0) -> List[Tuple['Image', int]]:
        ret = []
        mask = self.col_mask()
        for c in range(1 if not include0 else 0, 10):
            if mask >> c & 1:
                s = Image(self.x, self.y, self.w, self.h, self.mask.copy())
                for i in range(self.h):
                    for j in range(self.w):
                        s[i, j] = int(s[i, j] == c)
                ret.append((s, c))
        return ret

    def hash_image(self):
        base = 137
        r = 1543
        r = (r * base + self.w) % 2**64
        r = (r * base + self.h) % 2**64
        r = (r * base + self.x) % 2**64
        r = (r * base + self.y) % 2**64
        for c in self.mask.flatten():
            r = (r * base + int(c)) % 2**64
        return r

class Piece:
    def __init__(self, imgs=None, node_prob=0.0, keepi=0, knowi=0):
        if imgs is None:
            imgs = []
        self.imgs = imgs
        self.node_prob = node_prob
        self.keepi = keepi
        self.knowi = knowi

def check_all(v, f):
    return all(f(it) for it in v)

def all_equal(v, f):
    needed = f(v[0])
    return all(f(it) == needed for it in v)


def visualize_center(original: Image, centered: Image):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    
    ax1.imshow(original.mask, cmap='viridis')
    ax1.set_title(f"Original ({original.w}x{original.h})")
    ax1.axis('off')
    
    full_size = np.zeros((original.h, original.w))
    y_start, x_start = centered.y - original.y, centered.x - original.x
    full_size[y_start:y_start+centered.h, x_start:x_start+centered.w] = centered.mask
    
    ax2.imshow(full_size, cmap='viridis')
    ax2.set_title(f"Centered ({centered.w}x{centered.h})")
    ax2.axis('off')
    
    plt.tight_layout()
    plt.show()

def center(img: Image) -> Image:
    sz_x = (img.w + 1) % 2 + 1
    sz_y = (img.h + 1) % 2 + 1
    center_x = img.x + (img.w - sz_x) // 2
    center_y = img.y + (img.h - sz_y) // 2
    
    center_img = Image(center_x, center_y, sz_x, sz_y)
    for i in range(sz_y):
        for j in range(sz_x):
            center_img[i, j] = img[center_y - img.y + i, center_x - img.x + j]
    
    return center_img
# Test cases
def test_center():
    test_cases = [
        Image(0, 0, 5, 5, np.random.randint(0, 5, (5, 5))),  # Odd x Odd
        Image(0, 0, 4, 4, np.random.randint(0, 5, (4, 4))),  # Even x Even
        Image(0, 0, 3, 4, np.random.randint(0, 5, (4, 3))),  # Odd x Even
        Image(0, 0, 4, 3, np.random.randint(0, 5, (3, 4))),  # Even x Odd
        Image(0, 0, 1, 1, [[2]]),  # 1x1
        Image(0, 0, 2, 1, [[1, 2]]),  # 2x1
        Image(0, 0, 1, 2, [[1], [2]])  # 1x2
    ]

    for i, img in enumerate(test_cases):
        centered = center(img)
        print(f"\nTest case {i+1}:")
        print(f"Original: {img.w}x{img.h}, Centered: {centered.w}x{centered.h}")
        print(f"Center position: ({centered.x}, {centered.y})")
        visualize_center(img, centered)

        # Assertions
        assert centered.w in (1, 2), f"Centered width should be 1 or 2, got {centered.w}"
        assert centered.h in (1, 2), f"Centered height should be 1 or 2, got {centered.h}"
        assert centered.x == img.x + (img.w - centered.w) // 2, "Incorrect x position"
        assert centered.y == img.y + (img.h - centered.h) // 2, "Incorrect y position"
        assert np.all(centered.mask == img.majority_col()), "Centered image should be filled with majority color"



In [2]:
test_center()

NameError: name 'mask' is not defined