In [None]:
import os
import sys

import numpy as np
import polars as pl
import pandas as pd
from sklearn.model_selection import KFold, GroupKFold
import lightgbm as lgb
from lightgbm import early_stopping, log_evaluation
from sklearn.metrics import mean_squared_error

import kaggle_evaluation.mcts_inference_server

import warnings
warnings.filterwarnings('ignore')

# Baseline
It's a neat base to work from, without unnecessary code that can make it difficult to understand key concepts.

Inspired by: https://www.kaggle.com/code/jaejohn/mcts-strength-relevant-baseline-lb-0-422

Processed Agents columns and added futures with the most influential columns. Removed correlated columns and this is one option to improve the prediction - they can be replaced. Because they behave differently in groups with other attributes.

# Feature Engineering

In [None]:
class DataProcessor:
    def __init__(self, dropped_cols, agent_cols):
        self.dropped_cols = dropped_cols
        self.agent_cols = agent_cols

    def process_data(self, df):
        df = df.drop(filter(lambda x: x in df.columns, self.dropped_cols))
        if Config.split_agent_features:
            for col in self.agent_cols:
                df = df.with_columns(
                    pl.col(col).str.split(by="-").list.to_struct(fields=lambda idx: f"{col}_{idx}")
                ).unnest(col).drop(f"{col}_0")
        df = df.with_columns([pl.col(col).cast(pl.Categorical) for col in df.columns if col[:6] in self.agent_cols])
        df = df.with_columns([pl.col(col).cast(pl.Float32) for col in df.columns if col[:6] not in self.agent_cols])
        print(f'Data shape after processing agents: {df.shape}')
        return df.to_pandas()

    def feature_engineering(self, df):
        df['Playouts/Moves'] = df['PlayoutsPerSecond'] / (df['MovesPerSecond'] + 1e-15)
        df['EfficiencyPerPlayout'] = df['MovesPerSecond'] / (df['PlayoutsPerSecond'] + 1e-15)
        df['TurnsDurationEfficiency'] = df['DurationActions'] / (df['DurationTurnsStdDev'] + 1e-15)
        df['AdvantageBalanceRatio'] = df['AdvantageP1'] / (df['Balance'] + 1e-15)
        df['ActionTimeEfficiency'] = df['DurationActions'] / (df['MovesPerSecond'] + 1e-15)
        df['StandardizedTurnsEfficiency'] = df['DurationTurnsStdDev'] / (df['DurationActions'] + 1e-15)
        df['AdvantageTimeImpact'] = df['AdvantageP1'] / (df['DurationActions'] + 1e-15)
        df['DurationToComplexityRatio'] = df['DurationActions'] / (df['StateTreeComplexity'] + 1e-15)
        df['NormalizedGameTreeComplexity'] =  df['GameTreeComplexity'] /  (df['StateTreeComplexity'] + 1e-15)
        df['ComplexityBalanceInteraction'] =  df['Balance'] *  df['GameTreeComplexity']
        df['OverallComplexity'] =  df['StateTreeComplexity'] +  df['GameTreeComplexity']
        df['ComplexityPerPlayout'] =  df['GameTreeComplexity'] /  (df['PlayoutsPerSecond'] + 1e-15)
        df['TurnsNotTimeouts/Moves'] = df['DurationTurnsNotTimeouts'] / (df['MovesPerSecond'] + 1e-15)
        df['Timeouts/DurationActions'] = df['Timeouts'] / (df['DurationActions'] + 1e-15)
        df['OutcomeUniformity/AdvantageP1'] = df['OutcomeUniformity'] / (df['AdvantageP1'] + 1e-15)
        df['ComplexDecisionRatio'] = df['StepDecisionToEnemy'] + df['SlideDecisionToEnemy'] + df['HopDecisionMoreThanOne']
        df['AggressiveActionsRatio'] = df['StepDecisionToEnemy'] + df['HopDecisionEnemyToEnemy'] + df['HopDecisionFriendToEnemy'] + df['SlideDecisionToEnemy']
        
        new_features = [
            'Playouts/Moves', 
            'EfficiencyPerPlayout', 
            'TurnsDurationEfficiency', 
            'AdvantageBalanceRatio', 
            'ActionTimeEfficiency', 
            'StandardizedTurnsEfficiency', 
            'AdvantageTimeImpact', 
            'DurationToComplexityRatio',
            'NormalizedGameTreeComplexity', 
            'ComplexityBalanceInteraction',
            'OverallComplexity',
            'ComplexityPerPlayout',
            'TurnsNotTimeouts/Moves',
            'Timeouts/DurationActions',
            'OutcomeUniformity/AdvantageP1',
            'ComplexDecisionRatio',
            'AggressiveActionsRatio'
        ]
        
        print(f'Data shape after new_features: {df.shape}')
        return df, new_features

    def process_and_engineer(self, df):
        df = self.process_data(df)
        df, new_features = self.feature_engineering(df)
        return df, new_features

In [None]:
nan_columns = ['Behaviour', 'StateRepetition', 'Duration', 'Complexity', 'BoardCoverage', 'GameOutcome', 'StateEvaluation', 'Clarity', 'Decisiveness', 'Drama', 'MoveEvaluation', 'StateEvaluationDifference', 'BoardSitesOccupied', 'BranchingFactor', 'DecisionFactor', 'MoveDistance', 'PieceNumber', 'ScoreDifference']
zero_columns = ['Realtime', 'Simultaneous', 'HiddenInformation', 'Match', 'AsymmetricRules', 'AsymmetricPlayRules', 'AsymmetricEndRules', 'AsymmetricSetup', 'Simulation', 'Solitaire', 'Multiplayer', 'Coalition', 'Puzzle', 'DeductionPuzzle', 'PlanningPuzzle', 'PrismShape', 'ParallelogramShape', 'RectanglePyramidalShape', 'TargetShape', 'BrickTiling', 'CelticTiling', 'QuadHexTiling', 'Hints', 'DiceD3', 'BiasedDice', 'Card', 'Domino', 'SituationalTurnKo', 'SituationalSuperko', 'InitialAmount', 'InitialPot', 'BetDecision', 'BetDecisionFrequency', 'VoteDecisionFrequency', 'ChooseTrumpSuitDecision', 'ChooseTrumpSuitDecisionFrequency', 'LeapDecisionToFriend', 'LeapDecisionToFriendFrequency', 'HopDecisionEnemyToFriend', 'HopDecisionEnemyToFriendFrequency', 'HopDecisionFriendToFriend', 'FromToDecisionWithinBoard', 'FromToDecisionBetweenContainers', 'BetEffect', 'BetEffectFrequency', 'VoteEffectFrequency', 'SwapPlayersEffectFrequency', 'TakeControl', 'TakeControlFrequency', 'PassEffectFrequency', 'SetCost', 'SetCostFrequency', 'SetPhase', 'SetPhaseFrequency', 'SetTrumpSuit', 'SetTrumpSuitFrequency', 'StepEffectFrequency', 'SlideEffectFrequency', 'LeapEffectFrequency', 'HopEffectFrequency', 'FromToEffectFrequency', 'SwapPiecesEffect', 'SwapPiecesEffectFrequency', 'ShootEffect', 'ShootEffectFrequency', 'MaxCapture', 'OffDiagonalDirection', 'Information', 'HidePieceType', 'HidePieceOwner', 'HidePieceCount', 'HidePieceRotation', 'HidePieceValue', 'HidePieceState', 'InvisiblePiece', 'LineDrawFrequency', 'ConnectionDraw', 'ConnectionDrawFrequency', 'GroupLossFrequency', 'GroupDrawFrequency', 'LoopLossFrequency', 'LoopDraw', 'LoopDrawFrequency', 'PatternLoss', 'PatternLossFrequency', 'PatternDraw', 'PatternDrawFrequency', 'PathExtentEndFrequency', 'PathExtentWinFrequency', 'PathExtentLossFrequency', 'PathExtentDraw', 'PathExtentDrawFrequency', 'TerritoryLoss', 'TerritoryLossFrequency', 'TerritoryDraw', 'TerritoryDrawFrequency', 'CheckmateLoss', 'CheckmateLossFrequency', 'CheckmateDraw', 'CheckmateDrawFrequency', 'NoTargetPieceLoss', 'NoTargetPieceLossFrequency', 'NoTargetPieceDraw', 'NoTargetPieceDrawFrequency', 'NoOwnPiecesDraw', 'NoOwnPiecesDrawFrequency', 'FillLoss', 'FillLossFrequency', 'FillDraw', 'FillDrawFrequency', 'ScoringDrawFrequency', 'NoProgressWin', 'NoProgressWinFrequency', 'NoProgressLoss', 'NoProgressLossFrequency', 'SolvedEnd', 'PositionalRepetition', 'SituationalRepetition', 'Narrowness', 'Variance', 'DecisivenessMoves', 'DecisivenessThreshold', 'LeadChange', 'Stability', 'DramaAverage', 'DramaMedian', 'DramaMaximum', 'DramaMinimum', 'DramaVariance', 'DramaChangeAverage', 'DramaChangeSign', 'DramaChangeLineBestFit', 'DramaChangeNumTimes', 'DramaMaxIncrease', 'DramaMaxDecrease', 'MoveEvaluationAverage', 'MoveEvaluationMedian', 'MoveEvaluationMaximum', 'MoveEvaluationMinimum', 'MoveEvaluationVariance', 'MoveEvaluationChangeAverage', 'MoveEvaluationChangeSign', 'MoveEvaluationChangeLineBestFit', 'MoveEvaluationChangeNumTimes', 'MoveEvaluationMaxIncrease', 'MoveEvaluationMaxDecrease', 'StateEvaluationDifferenceAverage', 'StateEvaluationDifferenceMedian', 'StateEvaluationDifferenceMaximum', 'StateEvaluationDifferenceMinimum', 'StateEvaluationDifferenceVariance', 'StateEvaluationDifferenceChangeAverage', 'StateEvaluationDifferenceChangeSign', 'StateEvaluationDifferenceChangeLineBestFit', 'StateEvaluationDifferenceChangeNumTimes', 'StateEvaluationDifferenceMaxIncrease', 'StateEvaluationDifferenceMaxDecrease', 'BoardSitesOccupiedMinimum', 'BranchingFactorMinimum', 'DecisionFactorMinimum', 'MoveDistanceMinimum', 'PieceNumberMinimum', 'ScoreDifferenceMinimum', 'ScoreDifferenceChangeNumTimes', 'Roots', 'Cosine', 'Sine', 'Tangent', 'Exponential', 'Logarithm', 'ExclusiveDisjunction', 'Float', 'HandComponent', 'SetHidden', 'SetInvisible', 'SetHiddenCount', 'SetHiddenRotation', 'SetHiddenState', 'SetHiddenValue', 'SetHiddenWhat', 'SetHiddenWho']
one_columns = ['Id', 'NumPlayers', 'Properties', 'Format', 'Time', 'Discrete', 'Turns', 'Alternating', 'Players', 'TwoPlayer', 'Equipment', 'Container', 'Board', 'PlayableSites', 'Component', 'Rules', 'Play', 'End']
frequency_columns = ['BetDecisionFrequency', 'VoteDecisionFrequency', 'SwapPlayersDecisionFrequency', 'ChooseTrumpSuitDecisionFrequency', 'PassDecisionFrequency', 'ProposeDecisionFrequency', 'AddDecisionFrequency', 'PromotionDecisionFrequency', 'RemoveDecisionFrequency', 'RotationDecisionFrequency', 'StepDecisionFrequency', 'StepDecisionToEmptyFrequency', 'StepDecisionToFriendFrequency', 'StepDecisionToEnemyFrequency', 'SlideDecisionFrequency', 'SlideDecisionToEmptyFrequency', 'SlideDecisionToEnemyFrequency', 'SlideDecisionToFriendFrequency', 'LeapDecisionFrequency', 'LeapDecisionToEmptyFrequency', 'LeapDecisionToFriendFrequency', 'LeapDecisionToEnemyFrequency', 'HopDecisionFrequency', 'HopDecisionMoreThanOneFrequency', 'HopDecisionEnemyToEmptyFrequency', 'HopDecisionFriendToEmptyFrequency', 'HopDecisionEnemyToFriendFrequency', 'HopDecisionFriendToFriendFrequency', 'HopDecisionEnemyToEnemyFrequency', 'HopDecisionFriendToEnemyFrequency', 'FromToDecisionFrequency', 'FromToDecisionWithinBoardFrequency', 'FromToDecisionBetweenContainersFrequency', 'FromToDecisionEmptyFrequency', 'FromToDecisionEnemyFrequency', 'FromToDecisionFriendFrequency', 'SwapPiecesDecisionFrequency', 'ShootDecisionFrequency', 'BetEffectFrequency', 'VoteEffectFrequency', 'SwapPlayersEffectFrequency', 'TakeControlFrequency', 'PassEffectFrequency', 'RollFrequency', 'ProposeEffectFrequency', 'AddEffectFrequency', 'SowFrequency', 'SowCaptureFrequency', 'SowRemoveFrequency', 'SowBacktrackingFrequency', 'PromotionEffectFrequency', 'RemoveEffectFrequency', 'PushEffectFrequency', 'FlipFrequency', 'SetNextPlayerFrequency', 'MoveAgainFrequency', 'SetValueFrequency', 'SetCountFrequency', 'SetCostFrequency', 'SetPhaseFrequency', 'SetTrumpSuitFrequency', 'SetRotationFrequency', 'StepEffectFrequency', 'SlideEffectFrequency', 'LeapEffectFrequency', 'HopEffectFrequency', 'FromToEffectFrequency', 'SwapPiecesEffectFrequency', 'ShootEffectFrequency', 'ReplacementCaptureFrequency', 'HopCaptureFrequency', 'HopCaptureMoreThanOneFrequency', 'DirectionCaptureFrequency', 'EncloseCaptureFrequency', 'CustodialCaptureFrequency', 'InterveneCaptureFrequency', 'SurroundCaptureFrequency', 'CaptureSequenceFrequency', 'LineEndFrequency', 'LineWinFrequency', 'LineLossFrequency', 'LineDrawFrequency', 'ConnectionEndFrequency', 'ConnectionWinFrequency', 'ConnectionLossFrequency', 'ConnectionDrawFrequency', 'GroupEndFrequency', 'GroupWinFrequency', 'GroupLossFrequency', 'GroupDrawFrequency', 'LoopEndFrequency', 'LoopWinFrequency', 'LoopLossFrequency', 'LoopDrawFrequency', 'PatternEndFrequency', 'PatternWinFrequency', 'PatternLossFrequency', 'PatternDrawFrequency', 'PathExtentEndFrequency', 'PathExtentWinFrequency', 'PathExtentLossFrequency', 'PathExtentDrawFrequency', 'TerritoryEndFrequency', 'TerritoryWinFrequency', 'TerritoryLossFrequency', 'TerritoryDrawFrequency', 'CheckmateFrequency', 'CheckmateWinFrequency', 'CheckmateLossFrequency', 'CheckmateDrawFrequency', 'NoTargetPieceEndFrequency', 'NoTargetPieceWinFrequency', 'NoTargetPieceLossFrequency', 'NoTargetPieceDrawFrequency', 'EliminatePiecesEndFrequency', 'EliminatePiecesWinFrequency', 'EliminatePiecesLossFrequency', 'EliminatePiecesDrawFrequency', 'NoOwnPiecesEndFrequency', 'NoOwnPiecesWinFrequency', 'NoOwnPiecesLossFrequency', 'NoOwnPiecesDrawFrequency', 'FillEndFrequency', 'FillWinFrequency', 'FillLossFrequency', 'FillDrawFrequency', 'ReachEndFrequency', 'ReachWinFrequency', 'ReachLossFrequency', 'ReachDrawFrequency', 'ScoringEndFrequency', 'ScoringWinFrequency', 'ScoringLossFrequency', 'ScoringDrawFrequency', 'NoMovesEndFrequency', 'NoMovesWinFrequency', 'NoMovesLossFrequency', 'NoMovesDrawFrequency', 'NoProgressEndFrequency', 'NoProgressWinFrequency', 'NoProgressLossFrequency', 'NoProgressDrawFrequency', 'DrawFrequency']
component_columns = ['ComponentStyle', 'AnimalComponent', 'ChessComponent', 'KingComponent', 'QueenComponent', 'KnightComponent', 'RookComponent', 'BishopComponent', 'PawnComponent', 'FairyChessComponent', 'PloyComponent', 'ShogiComponent', 'XiangqiComponent', 'StrategoComponent', 'JanggiComponent', 'HandComponent', 'CheckersComponent', 'BallComponent', 'TaflComponent', 'DiscComponent', 'MarkerComponent']
rules_columns =  ['GameRulesetName', 'EnglishRules', 'LudRules']
# Correlation columns with correlation threshold 0.9
corr_columns = ['Asymmetric', 'AsymmetricForces', 'Cooperation', 'Shape', 'RegularShape', 'PolygonShape', 'CircleShape', 'SpiralShape', 'MancalaBoard', 'NumPlayableSitesOnBoard', 'SquarePyramidalShape', 'NumInnerSites', 'NumEdges', 'NumCells', 'NumOuterSites', 'NumTopSites', 'NumRightSites', 'Hand', 'NumVertices', 'PlayersWithDirections', 'Stochastic', 'NumComponentsType', 'OpeningContract', 'Repetition', 'NumStartComponentsBoard', 'NumStartComponentsHand', 'NumStartComponents', 'SwapOption', 'VoteDecision', 'StepDecision', 'LeapDecision', 'LeapDecisionToEmpty', 'MovesNonDecision', 'Dice', 'TrackLoop', 'Sow', 'SowWithEffect', 'SowProperties', 'SowOriginFirst', 'SetMove', 'PieceValue', 'PieceRotation', 'HopDecisionEnemyToEmpty', 'AutoMove', 'CanNotMove', 'SlideDecision', 'Track', 'Directions', 'RightwardDirection', 'RightwardsDirection', 'ForwardLeftDirection', 'BackwardLeftDirection', 'LineEnd', 'Connection', 'ConnectionEnd', 'Loop', 'PathExtent', 'PatternEnd', 'LoopLoss', 'PathExtentEnd', 'PathExtentWin', 'Territory', 'TerritoryEnd', 'Threat', 'Checkmate', 'NoPieceMover', 'NoOwnPiecesEnd', 'FillEnd', 'Scoring', 'ScoringEnd', 'NoMoves', 'ProgressCheck', 'NoProgressEnd', 'Completion', 'DurationTurns', 'BoardSitesOccupiedAverage', 'BoardSitesOccupiedChangeAverage', 'BranchingFactorAverage', 'BranchingFactorMaximum', 'BranchingFactorVariance', 'BranchingFactorMedian', 'DecisionFactorAverage', 'BranchingFactorChangeMaxDecrease', 'DecisionFactorMaximum', 'BranchingFactorChangeAverage', 'BranchingFactorChangeLineBestFit', 'BranchingFactorChangeSign', 'DecisionFactorChangeAverage', 'BranchingFactorChangeMaxIncrease', 'DecisionFactorVariance', 'MoveDistanceAverage', 'MoveDistanceChangeAverage', 'MoveDistanceMaximum', 'MoveDistanceMaxIncrease', 'PieceNumberAverage', 'PieceNumberMedian', 'BoardSitesOccupiedChangeSign', 'PieceNumberChangeAverage', 'BoardSitesOccupiedChangeNumTimes', 'ScoreDifferenceAverage', 'ScoreDifferenceMedian', 'ScoreDifferenceMaximum', 'ScoreDifferenceVariance', 'ScoreDifferenceChangeAverage', 'ScoreDifferenceChangeLineBestFit', 'ScoreDifferenceMaxIncrease', 'Arithmetic', 'Visual', 'Vertex', 'BoardStyle', 'SowCCW', 'NumLayers', 'LeapDecisionToEnemy', 'KingComponent', 'KnightComponent', 'ChessComponent', 'StackType', 'StateType', 'PieceState', 'RememberValues', 'Implementation', 'ComplexityBalanceInteraction', 'LeftwardsDirection']
#corr_columns = [AsymmetricForces, AsymmetricPiecesType, Team, RegularShape, PolygonShape, Tiling, CircleTiling, SpiralTiling, TrackLoop, NumInnerSites, NumLayers, NumEdges, NumCells, NumVertices, NumPerimeterSites, NumBottomSites, NumLeftSites, NumContainers, NumPlayableSites, PieceDirection, Dice, NumComponentsTypePerPlayer, SwapOption, PositionalSuperko, NumStartComponentsBoardPerPlayer, NumStartComponentsHandPerPlayer, NumStartComponentsPerPlayer, SwapPlayersDecision, ProposeDecision, StepDecisionToEmpty, LeapDecisionToEmpty, LeapDecisionToEnemy, MovesEffects, Roll, Sow, SowWithEffect, SowProperties, SowOriginFirst, SowCCW, MoveAgain, SetValue, SetRotation, HopCapture, PathExtent, Threat, LineOfSight, Directions, AbsoluteDirections, LeftwardDirection, LeftwardsDirection, ForwardRightDirection, BackwardRightDirection, LineWin, ConnectionEnd, ConnectionWin, LoopEnd, LoopLoss, PatternWin, PathExtentEnd, PathExtentWin, PathExtentLoss, TerritoryEnd, TerritoryWin, Checkmate, CheckmateWin, NoOwnPiecesEnd, NoOwnPiecesWin, FillWin, ScoringEnd, ScoringWin, NoMovesEnd, NoProgressEnd, NoProgressDraw, Drawishness, Timeouts, BoardSitesOccupiedMedian, BoardSitesOccupiedChangeLineBestFit, BranchingFactorMedian, BranchingFactorVariance, BranchingFactorChangeMaxDecrease, DecisionFactorAverage, DecisionFactorMedian, DecisionFactorMaximum, DecisionFactorVariance, DecisionFactorChangeAverage, DecisionFactorChangeSign, DecisionFactorChangeLineBestFit, DecisionFactorMaxIncrease, DecisionFactorMaxDecrease, MoveDistanceMedian, MoveDistanceChangeLineBestFit, MoveDistanceMaxIncrease, MoveDistanceMaxDecrease, PieceNumberMedian, PieceNumberMaximum, PieceNumberChangeSign, PieceNumberChangeLineBestFit, PieceNumberChangeNumTimes, ScoreDifferenceMedian, ScoreDifferenceMaximum, ScoreDifferenceVariance, ScoreDifferenceChangeAverage, ScoreDifferenceChangeLineBestFit, ScoreDifferenceMaxIncrease, ScoreDifferenceMaxDecrease, Comparison, Style, BoardStyle, GraphStyle, MancalaStyle, ShibumiStyle, KnightComponent, RookComponent, PawnComponent, StateType, StackState, SiteState, ForgetValues, SetInternalCounter, Efficiency, EfficiencyPerPlayout]
output_cols = ['num_wins_agent1', 'num_draws_agent1', 'num_losses_agent1']
agent_cols = ['agent1', 'agent2']

dropped_cols = output_cols + nan_columns + zero_columns + one_columns + frequency_columns + component_columns + rules_columns + corr_columns
processor = DataProcessor(dropped_cols, agent_cols)

# Model

In [None]:
class Config:
    train_path = '/kaggle/input/um-game-playing-strength-of-mcts-variants/train.csv'
    early_stop = 100
    n_splits = 5
    seed = 1212
    split_agent_features = True
    lgbm_params = {
                    'learning_rate': 0.05,
                    'num_leaves': 63,
                    'max_depth': 8,
                    'num_boost_round': 10_000,
                    'reg_lambda': 2.0,
                    'reg_alpha': 1.0,
                    'verbose': -1,
                   }

In [None]:
def train_lgb(processor, data, group_col):
    X = data.drop(['utility_agent1'], axis=1)
    y = data['utility_agent1']
    
    group_kfold = GroupKFold(n_splits=Config.n_splits)
    
    models = []
    fold_scores = []

    for fi, (train_idx, valid_idx) in enumerate(group_kfold.split(X, y, groups=group_col)):
        print(f'Fold {fi+1}/{Config.n_splits} ...')
        
        model = lgb.LGBMRegressor(**Config.lgbm_params)
        
        model.fit(
            X.iloc[train_idx], y.iloc[train_idx],
            eval_set=[(X.iloc[valid_idx], y.iloc[valid_idx])],
            eval_metric='rmse',
            callbacks=[lgb.early_stopping(Config.early_stop)]
        )
        
        models.append(model)
        
        y_pred = model.predict(X.iloc[valid_idx])
        fold_score = mean_squared_error(y.iloc[valid_idx], y_pred, squared=False)
        fold_scores.append(fold_score)
        
        print(f'Fold {fi+1} RMSE: {fold_score:.4f}')
    
    avg_score = np.mean(fold_scores)
    print(f'\nTotal RMSE over {Config.n_splits} folds: {avg_score:.4f}')
    
    return models

def infer_lgb(data, models):
    return np.mean([model.predict(data) for model in models], axis=0)

# Submission

In [None]:
run_i = 0
models = None

def predict(test_data, submission):
    global run_i, models
    if run_i == 0:
        train_df = pl.read_csv(Config.train_path)
        processed_train_df, _ = processor.process_and_engineer(train_df)
        
        group_col = train_df['GameRulesetName']
        
        models = train_lgb(processor, processed_train_df, group_col)
    
    run_i += 1
    
    processed_test_df, _ = processor.process_and_engineer(test_data)
    
    return submission.with_columns(
        pl.Series('utility_agent1', infer_lgb(processed_test_df, models))
    )

inference_server = kaggle_evaluation.mcts_inference_server.MCTSInferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(
        (
            '/kaggle/input/um-game-playing-strength-of-mcts-variants/test.csv',
            '/kaggle/input/um-game-playing-strength-of-mcts-variants/sample_submission.csv'
        )
    )