In [365]:
import xml.etree.ElementTree as ET
import numpy as np
# import torch
import copy
from typing import List

In [366]:
file_path = "data/sample_haifu.xml"

In [367]:
class GameInfo:
    def __init__(self, go: int =161):
        self.go = go

    def against_human(self):
        return self.go & 1
    
    def no_red(self):
        return (self.go & 0x2) >> 1
    
    def kansaki(self):
        return (self.go & 0x4) >> 2
    
    def tonnan(self):
        return (self.go & 0x8) >> 3
    
    def three_players(self):
        return (self.go & 0x10) >> 4
    
    def fast(self):
        return (self.go & 0x40) >> 6
    
    def level(self):
        return (self.go & 0xA0) >> 5

    def __str__(self):
        return f"GameInfo: against_human={self.against_human()}, no_red={self.no_red()}, kansaki={self.kansaki()}, tonnan={self.tonnan()}, three_players={self.three_players()}, fast={self.fast()}, level={self.level()}"

In [368]:
class RoundInfo:
    def __init__(self, gameinfo: GameInfo, curr_round: int):
        self.gameinfo = gameinfo
        N_ROUNDS = 4 * (gameinfo.tonnan() + 1)
        self.remaining_rounds = N_ROUNDS - curr_round
        self.parent_rounds_remaining = [(self.remaining_rounds + i) // 4 for i in range(4)]

    def __repr__(self):
        # TODO: implement
        pass

In [369]:
with open(file_path, "rb") as file:
    tree = ET.parse(file)
    root = tree.getroot()
    # xml_str = ET.tostring(root, encoding="utf-8", method="xml") # code to convert xml to string
    # root = ET.fromstring(xml_data) # code to convert string to xml


events = []

# Define a recursive function to traverse the XML elements
def traverse(element):
    item = {"event": element.tag, "attr": element.attrib}
    events.append(item)
    for child in element:
        traverse(child)

traverse(root)

gametype = root.findall("GO")
if len(gametype) != 1:
    raise ValueError("Invalid number of game type elements")
try:
    gameinfo = GameInfo(int(gametype[0].attrib["type"]))
except:
    raise ValueError("Invalid game type element")

# print(events)
print(gameinfo)

GameInfo: against_human=1, no_red=0, kansaki=0, tonnan=1, three_players=0, fast=0, level=5


In [370]:
# sanity check
assert gameinfo.against_human() == 1
assert gameinfo.no_red() == 0
assert gameinfo.kansaki() == 0
assert gameinfo.three_players() == 0

# group events into each round
rounds = []
curr_round_events = []

for event in events:
    if event["event"] in ["mjloggm", "SHUFFLE", "UN", "GO", "TAIKYOKU"]:
        continue

    curr_round_events.append(event)

    if event["event"] in ["AGARI", "RYUUKYOKU"]:
        rounds.append(curr_round_events)
        curr_round_events = []

len(rounds), len(rounds[0])

(16, 144)

In [371]:
# constants
REMAINING_TILES = np.ones(37, dtype=np.int32) * 4
indices = [4, 13, 22]
REMAINING_TILES[indices] = 3
REMAINING_TILES[34:37] = 1

print(REMAINING_TILES)
print()

TILE2IDX = np.array([i // 4 for i in range(136)], dtype=np.int32)
TILE2IDX[16] = 34
TILE2IDX[52] = 35
TILE2IDX[88] = 36

print(TILE2IDX)

REMAINING_TSUMO = 70

[4 4 4 4 3 4 4 4 4 4 4 4 4 3 4 4 4 4 4 4 4 4 3 4 4 4 4 4 4 4 4 4 4 4 1 1 1]

[ 0  0  0  0  1  1  1  1  2  2  2  2  3  3  3  3 34  4  4  4  5  5  5  5
  6  6  6  6  7  7  7  7  8  8  8  8  9  9  9  9 10 10 10 10 11 11 11 11
 12 12 12 12 35 13 13 13 14 14 14 14 15 15 15 15 16 16 16 16 17 17 17 17
 18 18 18 18 19 19 19 19 20 20 20 20 21 21 21 21 36 22 22 22 23 23 23 23
 24 24 24 24 25 25 25 25 26 26 26 26 27 27 27 27 28 28 28 28 29 29 29 29
 30 30 30 30 31 31 31 31 32 32 32 32 33 33 33 33]


In [372]:
class Draw:
    pass

In [373]:
class Naki(Draw):
    def __init__(self, naki_code: int):
        self.naki_code = naki_code
    
    def from_who(self):
        return self.naki_code & 3

    def is_chi(self):
        return (self.naki_code & 0x4) >> 2

    def is_pon(self):
        return not self.is_chi() and (self.naki_code & 0x8)

    def is_kakan(self):
        return not self.is_chi() and (self.naki_code & 0x10)
    
    def is_minkan(self):
        return self.naki_code & 0b111100 == 0 and self.from_who()
    
    def is_ankan(self):
        return self.naki_code & 0b111100 == 0 and not self.from_who()

    def pattern_chi(self):
        pattern = (self.naki_code & 0xFC00) >> 10
        which = pattern % 3 
        pattern //= 3
        color = pattern // 7
        number = pattern % 7
        has_red = False
        codes = [(self.naki_code & mask) >> shift for mask, shift in zip([0x0018, 0x0060, 0x0180], [3, 5, 7])]
        exposed = [
            (9 * color + number + i) * 4 + code for i, code in enumerate(codes)
        ]
        has_red = color * 36 + 16 in exposed
        obtained = exposed.pop(which)
        return (color, number, which, has_red, exposed, obtained)
    
    def pattern_pon(self):
        pattern = (self.naki_code & 0xFE00) >> 9
        which = pattern % 3
        pattern //= 3
        color = pattern // 9
        number = pattern % 9
        code = (self.naki_code & 0x0060) >> 5
        exposed = [(9 * color + number) * 4 + c for c in range(4) if c != code]
        has_red = color < 3 and color * 36 + 16 in exposed
        obtained = exposed.pop(which)
        return (color, number, which, has_red, exposed, obtained)
    
    def pattern_kakan(self):
        pattern = (self.naki_code & 0xFE00) >> 9
        which = pattern % 3
        pattern //= 3
        color = pattern // 9
        number = pattern % 9
        has_red = self.number == 5 and color != 3
        exposed = [(9 * color + number) * 4 + c for c in range(4)]
        obtained = exposed.pop(which)
        return (color, number, which, has_red, exposed, obtained)

    def pattern_minkan(self):
        pattern = (self.naki_code & 0xFF00) >> 8
        which = pattern % 4
        pattern //= 4
        color = pattern // 9
        number = pattern % 9
        has_red = self.number == 5 and color != 3
        exposed = [(9 * color + number) * 4 + c for c in range(4)]
        obtained = exposed.pop(which)
        return (color, number, which, has_red, exposed, obtained)
    
    def pattern_ankan(self):
        pattern = (self.naki_code & 0xFF00) >> 8
        which = pattern % 4
        pattern //= 4
        color = pattern // 9
        number = pattern % 9
        has_red = self.number == 5 and color != 3
        exposed = [(9 * color + number) * 4 + c for c in range(4)]
        obtained = None
        return (color, number, which, has_red, exposed, obtained)

    def get_exposed(self):
        exposed, obtained = None, None

        if self.is_chi():
            _, _, _, _, exposed, obtained = self.pattern_chi()
        
        elif self.is_pon():
            _, _, _, _, exposed, obtained = self.pattern_pon()

        elif self.is_kakan():
            _, _, _, _, exposed, obtained = self.pattern_kakan()
        
        elif self.is_minkan():
            _, _, _, _, exposed, obtained = self.pattern_minkan()

        elif self.is_ankan():
            _, _, _, _, exposed, obtained = self.pattern_ankan()
        
        else:
            raise ValueError("Invalid naki code")
        
        return exposed, obtained

In [374]:
class Tsumo(Draw):
    def __init__(self, tile: int):
        self.tile = tile # 0-136

In [375]:
class Discard:
    def __init__(self, tile: int):
        self.tile = tile # 0-136

In [376]:
class StateObject:
    def __init__(
            self,
            remaining_turns: int,
            hand_tensor: List[int], 
            remaining_tiles: np.ndarray,
            remaining_tiles_pov: np.ndarray,
            reaches: List[int],
            melds: List[List[Naki]],
            scores: List[int],
            kyotaku: int,
            honba: int,
            dora: List[int],
    ):
        self.remaining_turns = remaining_turns
        self.hand_tensor = hand_tensor
        self.remaining_tiles = remaining_tiles
        self.remaining_tiles_pov = remaining_tiles_pov
        self.reaches = reaches
        self.melds = melds
        self.scores = scores
        self.kyotaku = kyotaku
        self.honba = honba
        self.dora = dora

In [377]:
class Turn:
    TSUMO = 0
    NAKI = 1

    def __init__(self, player: int, type_: int, draw: Draw, stateObj: StateObject = None, discard: Discard = None):
        self.player = player
        self.type = type_
        self.draw = draw
        self.stateObj = stateObj
        self.discard = discard

    def is_tsumogiri(self):
        if self.draw is None or self.discard is None:
            raise ValueError("draw and discard must be set")


class TsumoTurn(Turn):
    def __init__(self, player: int, draw: Draw = None, stateObj: StateObject = None, discard: Discard = None):
        super().__init__(player, Turn.TSUMO, draw, stateObj, discard)
        self.pre_decisions = []
        self.post_decisions = []

    def is_tsumogiri(self):
        super().is_tsumogiri()
        return self.draw.tile == self.discard.tile


class NakiTurn(Turn):
    def __init__(self, player: int, naki: Naki, stateObj: StateObject = None, discard: Discard = None):
        super().__init__(player, Turn.NAKI, naki, stateObj, discard)
        self.post_decisions = []

    def is_tsumogiri(self):
        super().is_tsumogiri()
        return False

In [378]:
class Decision:
    NAKI = 1
    REACH = 2
    AGARI = 3

    def __init__(self, player: int):
        self.player = player

class NakiDecision(Decision):
    def __init__(self, player: int, naki: Naki, executed: bool):
        super().__init__(player)
        self.naki = naki
        self.executed = executed
    
class ReachDecision(Decision):
    def __init__(self, player: int, executed: bool):
        super().__init__(player)
        self.executed = executed

class AgariDecision(Decision):
    def __init__(self, player: int, executed: bool):
        super().__init__(player)
        self.executed = executed

class PassDecision(Decision):
    def __init__(self, player: int, executed: bool):
        super().__init__(player)
        self.executed = executed

In [379]:
class RoundResult:
    def __init__(self, sc: List[int]):
        self.sc = sc
        # TODO: define reward based on sc

class AgariResult(RoundResult):
    def __init__(self, *args):
        super().__init__(*args)

    def __repr__(self) -> str:
        return f"AgariResult: sc={self.sc}"

class RyukyokuResult(RoundResult):
    def __init__(self, *args):
        super().__init__(*args)

    def __repr__(self) -> str:
        return f"RyukyokuResult: sc={self.sc}"

In [387]:
# constants
N_ROUNDS = 4 * (gameinfo.tonnan() + 1)

kyoku_info = rounds[3]

# sanity check
assert kyoku_info[0]["event"] == "INIT"

remaining_tiles = REMAINING_TILES.copy()
remaining_tiles_pov = [REMAINING_TILES.copy() for _ in range(4)]
remaining_tsumo = REMAINING_TSUMO
scores = kyoku_info[0]["attr"]["ten"].split(",")
scores = [int(score) for score in scores]
parent = kyoku_info[0]["attr"]["oya"]

hand_indices = [list(map(int, kyoku_info[0]["attr"][f'hai{player}'].split(","))) for player in range(4)]
hand_indices_grouped = TILE2IDX[hand_indices]

hand_tensors = [np.zeros(37, dtype=np.float32) for _ in range(4)]
unique_indices_counts = [np.unique(row, return_counts=True) for row in hand_indices_grouped]
for indices, (idx, count), ht, pov in zip(hand_indices, unique_indices_counts, hand_tensors, remaining_tiles_pov):
    ht[idx] += count
    pov[idx] -= count
    remaining_tiles[idx] -= count
assert list(map(np.sum, hand_tensors)) == [13., 13., 13., 13.]

hands = np.zeros((4, 136), dtype=np.float32)
player_indices = np.arange(4)[:, np.newaxis]
hands[player_indices, hand_indices] = 1.0
assert np.allclose(hands.sum(axis=1), [13., 13., 13., 13.])
assert np.max(hands) == 1.0

curr_round, honba, kyotaku, _, _, dora = list(map(int, kyoku_info[0]["attr"]["seed"].split(",")))
doras = [dora]
remaining_tiles[TILE2IDX[int(dora)]] -= 1
for pov in remaining_tiles_pov:
    pov[TILE2IDX[int(dora)]] -= 1

remaining_rounds = N_ROUNDS - int(curr_round)
parent_rounds_remaining = [(remaining_rounds + i) // 4 for i in range(4)]

roundinfo = RoundInfo(gameinfo, curr_round)

print(f"Scores: {scores}")
print(f"Parent: {parent}")
# print(f"Hands: {hands}")
print(f"Current round: {curr_round}")
print(f"Honba: {honba}")
print(f"Kyotaku: {kyotaku}")
print(f"Dora: {doras}")
print(f"Remaining rounds: {remaining_rounds}")
print(f"Parent rounds remaining: {parent_rounds_remaining}")
print(f"remaining_tiles: {remaining_tiles}")
print("\n" + "=" * 20 + "\n")

assert sum(remaining_tiles) == 83

reaches = []
melds = [[] for _ in range(4)]
turns = []
curr_turn = None

for event in kyoku_info[1:-1]:
    eventtype = event["event"]

    if eventtype == "DORA":
        doras = doras[:] + [int(event["attr"]["hai"])]
        remaining_tiles[TILE2IDX[int(event["attr"]["hai"])]] -= 1
        for pov in remaining_tiles_pov:
            pov[TILE2IDX[int(dora)]] -= 1
        remaining_tsumo -= 1

    elif eventtype == "REACH":
        player = int(event["attr"]["who"])

        if event["attr"]["step"] == "1":
            assert curr_turn is not None
            curr_turn.pre_decisions = [ReachDecision(player, executed=True), PassDecision(player, executed=False)]
        else: # step == 2
            reaches = reaches[:] + [player]
            kyotaku += 1
            scores = scores.copy()
            scores[player] -= 1000

    elif eventtype == "N":
        naki = Naki(int(event["attr"]["m"]))
        player = int(event["attr"]["who"])
        melds[player] = melds[player][:] + [naki]

        if not naki.is_chi() and not naki.is_pon():
            remaining_tsumo -= 1
        exposed, obtained = naki.get_exposed()
        exposed_idx = TILE2IDX[exposed]
        for i, pov in enumerate(remaining_tiles_pov):
            if i == player:
                continue
            for e in exposed_idx:
                pov[e] -= 1
        print(exposed, obtained, player, [idx for idx, r in enumerate(hands[player]) if r == 1.0], naki.naki_code)
        for e in exposed:
            assert hands[player, e] == 1.0
            hands[player, e] -= 1.0
        stateObj = StateObject(
            remaining_turns=remaining_tsumo,
            hand_tensor=hand_tensors[player],
            remaining_tiles=remaining_tiles,
            remaining_tiles_pov=remaining_tiles_pov[player],
            reaches=reaches, # share same reach list
            melds=melds,
            scores=scores,
            kyotaku=kyotaku,
            honba=honba,
            dora=doras, # share same reach list
        )
        curr_turn = NakiTurn(player=player, naki=naki, stateObj=stateObj)

    elif eventtype[0] in ["T", "U", "V", "W"]:
        player = ["T", "U", "V", "W"].index(eventtype[0])
        tile = int(eventtype[1:])
        tile_idx = TILE2IDX[tile]
        remaining_tiles[tile_idx] -= 1
        remaining_tiles_pov[player][tile_idx] -= 1
        remaining_tsumo -= 1
        assert hands[player, tile] == 0.0
        hands[player, tile] = 1.0
        stateObj = StateObject(
            remaining_turns=remaining_tsumo,
            hand_tensor=hand_tensors[player],
            remaining_tiles=remaining_tiles,
            remaining_tiles_pov=remaining_tiles_pov[player],
            reaches=reaches, # share same reach list
            melds=melds,
            scores=scores,
            kyotaku=kyotaku,
            honba=honba,
            dora=doras, # share same reach list
        )
        curr_turn = TsumoTurn(player=player, draw=Tsumo(tile), stateObj=stateObj)

    elif eventtype[0] in ["D", "E", "F", "G"]:
        player = ["D", "E", "F", "G"].index(eventtype[0])
        assert curr_turn.player == player
        tile = int(eventtype[1:])
        tile_idx = TILE2IDX[tile]
        assert hands[player, tile] == 1.0
        hands[player, tile] = 0.0
        for pov in remaining_tiles_pov:
            pov[tile_idx] -= 1
        remaining_tiles_pov[player][tile_idx] += 1
        curr_turn.discard = Discard(tile)
        turns.append(curr_turn)
        curr_turn = None

    else:
        raise ValueError(f"Invalid event type: {eventtype}")
    
event = kyoku_info[-1]
assert event["event"] in ["AGARI", "RYUUKYOKU"]

if event["event"] == "AGARI":
    result = AgariResult(list(map(int, event["attr"]["sc"].split(","))))

elif event["event"] == "RYUUKYOKU":
    result = RyukyokuResult(list(map(int, event["attr"]["sc"].split(","))))

print(f"len(turns): {len(turns)}")
print(result)
print(f"remaining_tiles: {remaining_tiles}")
print(f"remaining_tiles_pov: {remaining_tiles_pov}")
assert sum([1 for r in remaining_tiles if r < 0]) == 0
assert sum([1 for rem in remaining_tiles_pov for r in rem if r < 0]) == 0

Scores: [255, 172, 333, 240]
Parent: 1
Current round: 1
Honba: 1
Kyotaku: 0
Dora: [66]
Remaining rounds: 7
Parent rounds remaining: [1, 2, 2, 2]
remaining_tiles: [1 2 3 3 3 2 1 2 3 3 4 3 3 1 3 3 1 3 2 3 4 3 1 2 1 1 3 3 3 2 3 1 3 3 0 0 1]


[89, 99] 92 3 [8, 9, 15, 17, 24, 26, 48, 52, 78, 81, 89, 99, 125] 56719
len(turns): 71
RyukyokuResult: sc=[255, -10, 162, 30, 333, -10, 240, -10]
remaining_tiles: [1 1 0 0 1 1 0 0 1 1 0 1 0 0 0 0 0 0 0 0 1 2 0 0 0 0 0 1 0 1 0 0 1 0 0 0 0]
remaining_tiles_pov: [array([2, 2, 3, 2, 1, 2, 3, 0, 1, 2, 3, 3, 0, 3, 2, 1, 0, 2, 1, 4, 1, 3,
       0, 0, 1, 0, 2, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0], dtype=int32), array([2, 1, 3, 3, 1, 3, 3, 0, 1, 1, 2, 3, 1, 2, 1, 0, 0, 1, 1, 4, 1, 3,
       1, 1, 3, 0, 3, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0], dtype=int32), array([1, 2, 4, 4, 1, 4, 2, 0, 1, 2, 1, 3, 1, 1, 1, 1, 0, 3, 1, 2, 1, 2,
       1, 1, 2, 0, 2, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0], dtype=int32), array([2, 2, 2, 3, 1, 4, 1, 0, 1, 2, 3, 4, 1, 3, 2, 1, 0, 3, 0, 2, 1, 3,
      

In [385]:
hands.sum(axis=1)

array([10.,  7., 13., 13.], dtype=float32)