In [423]:
import numpy as np
import pandas as pd
from pathlib import Path

from torch.utils.data import Dataset

from jass.game.const import card_ids
from jass.game.game_util import get_cards_encoded_from_str


class JassRoundDataset(Dataset):
    def __init__(self, data_dir: str = "./data", train=True):
        data_files = (Path(data_dir) / ("train" if train else "test")).glob("*.txt")
        self.jass_rounds = pd.read_json(next(data_files), nrows=10, lines=True)

    def __len__(self):
        return len(self.jass_rounds)

    def __getitem__(self, idx):
        jass_round = self.jass_rounds.iloc[idx].to_dict()
        
        # -- action
        action = np.zeros(36)
        action[jass_round["action"]]
        
        jass_round = jass_round["obs"]
        
        # -- reward
        jass_round_reward = (jass_round["points"]["0"] - jass_round["points"]["1"]) / 157
        if jass_round["currentPlayer"] in [1, 3]:
            jass_round_reward = jass_round_reward * -1
        
        # -- state
        others_hand = np.ones(36)
        current_hand = get_cards_encoded_from_str(
            jass_round["player"][jass_round["currentPlayer"]]["hand"]
        )
        for i in np.where(current_hand == 1):
            others_hand[i] = 0
        
        
        trick_first_player = [[0, 0, 0, 0] for _ in range(9)]
        trick_winner = [[0, 0, 0, 0] for _ in range(9)]

        tricks = np.array([[np.zeros(36), np.zeros(36), np.zeros(36), np.zeros(36)] for _ in range(9)])
        played_cards = np.array([[np.zeros(36), np.zeros(36), np.zeros(36), np.zeros(36)] for _ in range(9)])

        for game_round, trick in enumerate(jass_round["tricks"]):
            if "cards" in trick:
                for trick_index, trick_card in enumerate(trick["cards"]):
                    trick_card = card_ids[trick_card]
                    
                    current_card_player = (trick["first"] + trick_index) % 4
                    others_hand[trick_card] = 0
                    
                    trick_first_player[game_round][trick["first"]] = 1

                    if "win" in trick:
                        trick_winner[game_round][trick["win"]] = 1

                    played_cards[game_round][current_card_player][trick_card] = 1
                    tricks[game_round][trick_index][trick_card] = 1
        
        trump = np.zeros(6)
        trump[jass_round["trump"]] = 1
        
        state_obs = np.concatenate((
            current_hand,
            others_hand,
            np.concatenate(
                [np.concatenate(t) for t in tricks]
            ),
            np.concatenate(
                [np.concatenate(t) for t in played_cards]
            ),
            np.concatenate(trick_first_player),
            np.concatenate(trick_winner),
            trump,
            [jass_round["forehand"]],
            action
        ))
        
        return torch.tensor(state_obs), torch.tensor(jass_round_reward)

In [424]:
import torch
from typing import Optional

import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader


class JassDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./data"):
        super().__init__()
        self.data_dir = data_dir

    def setup(self, stage: Optional[str] = None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            jass_full = JassRoundDataset(self.data_dir, train=True)
            lengths = [int(len(jass_full) * 0.8), int(len(jass_full) * 0.2)]
            self.jass_train, self.jass_val = random_split(
                jass_full,
                lengths, 
                generator=torch.Generator().manual_seed(42)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            raise NotImplementedError("test_not_implemented")

        if stage == "predict" or stage is None:
            raise NotImplementedError("predict not implemented")

    def train_dataloader(self):
        return DataLoader(self.jass_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.jass_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.jass_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.jass_predict, batch_size=32)

In [425]:
jass_data = JassDataModule()

In [426]:
jass_data.setup("fit")

In [427]:
dl = jass_data.train_dataloader()

In [428]:
len(dl)

1

In [429]:
next(iter(dl))

[tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.]], dtype=torch.float64),
 tensor([-0.6306,  0.6306,  0.6306, -0.6306,  0.6306, -0.6306, -0.6306,  0.6306])]