In [None]:
import tensorflow as tf

In [None]:
import os
import sys
sys.path.insert(0, "/Users/xinran.he/GitProjects/mahjong")

from log_parser.discard_prediction_parser import parse_discard_prediction

In [None]:
input_file = "/Users/xinran.he/GitProjects/mahjong/data/raw/20180101/2018010110gm-00a9-0000-033e3e35.txt"
games = parse_discard_prediction(open(input_file, "r").read())

In [None]:
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _int64_list_feature(values):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

In [None]:
def get_center_features(one_round):
    global_context = one_round.global_context
    center_context = one_round.center_player.context
    feature_dict = {}
    feature_dict["current_field"] = _int64_feature(global_context.field)
    feature_dict["round"] = _int64_feature(global_context.round)
    feature_dict["center_field"] = _int64_feature(center_context.field)
    feature_dict["center_oya"] = _int64_feature(int(center_context.is_dealer))
    return feature_dict

In [None]:
def is_player_riichi(player):
    for discarded_hai in player.discarded_hai:
        if discarded_hai.is_after_riichi:
            return True
    return False

In [None]:
MAX_SCORE_DIFF = 32000
SCORE_DIFF_DET = 400

def player_score_diff(center_score, player_score):
    diff = max(min(MAX_SCORE_DIFF, center_score - player_score), -MAX_SCORE_DIFF)
    diff = diff // SCORE_DIFF_DET
    return diff + MAX_SCORE_DIFF / SCORE_DIFF_DET

In [None]:
def get_player_features(one_round):
    feature_dict = {}
    center_context = one_round.center_player.context
    for pid, player in enumerate(one_round.other_player):
        player_context = player.context
        feature_dict["player%d_oya" % pid] = _int64_feature(player_context.is_dealer)
        feature_dict["player%d_field" % pid] = _int64_feature(player_context.field)
        feature_dict["player%d_riichi" % pid] = _int64_feature(int(is_player_riichi(player)))
        feature_dict["player%d_claim" % pid] = _int64_feature(len(player.claim))
        feature_dict["player%d_order" % pid] = _int64_feature(3 + center_context.order - player_context.order)
        feature_dict["player%d_score" % pid] = _int64_feature(player_score_diff(center_context.score, player_context.score))
    return feature_dict

In [None]:
CLS_TOKEN = 69
SEP_TOKEN = 70
PADDING = 0

def get_hid(hai, doras):
    if hai.id in doras or hai.is_red:
        return 2 * (hai.id + 1)
    else:
        return hai.id + 1   

def get_sequence_features(one_round):
    center_player = one_round.center_player
    doras = [h.id for h in one_round.global_context.dora]
    
    hai_seq = [CLS_TOKEN]
    pos_seq = [0] * 14
    feature_seq = [0] * 14
    
    hai_seq.extend([get_hid(h, doras) for h in center_player.hand])
    if len(center_player.hand) < 14:
        hai_seq.extend([PADDING] * (14 - center_player.hand))
    for i, player in enumerate(one_round.other_player):
        hai_seq.append(SEP_TOKEN)
        pos_seq.append(0)
        feature_seq.append(0)
        
        hai_seq.extend([get_hid(h.hai, doras) for h in player.discarded_hai])
        pos_seq.extend(range(1, 1 + len(player.discarded_hai)))
        feature_seq.extend([i + 1] * len(player.discarded_hai))
    return {
        "hai_seq": _int64_list_feature(hai_seq),
        "pos_seq": _int64_list_feature(pos_seq),
        "feature_seq": _int64_list_feature(feature_seq)
    }

In [None]:
def generate_tfexample(one_round):
    features = {}
    features.update(get_center_features(one_round))
    features.update(get_player_features(one_round))
    features.update(get_sequence_features(one_round))
    example = tf.train.Example(features=tf.train.Features(feature=features))
    return example

In [None]:
print generate_tfexample(games[0].one_round[1])

In [None]:
print hai_seq

In [None]:
print pos_seq

In [None]:
print feature_seq

In [None]:
print games[0].one_round[1]

In [None]:
temp = "/Users/xinran.he/GitProjects/mahjong/data/raw/20180101"
dirs = os.listdir(temp)

In [None]:
def generate_data(folder_path):
    for file in os.listdir(folder_path):
        file_path = "%s/%s" % (folder_path, file)
        print file_path

generate_data(temp)