In [None]:
    def train_context_ras_model(self):
        print("Training RAS")
        def flip(series: pd.Series) -> pd.Series:
            flipped = -series
            flipped[series == 0] = 0.0
            return flipped

        sql_query = """
            SELECT 
                mi.match_id,
                mi.home_team_id,
                mi.away_team_id,
                mi.home_elevation_dif,
                mi.away_elevation_dif,
                mi.away_travel,
                mi.home_rest_days,
                mi.away_rest_days,
                mi.temperature_c,
                mi.is_raining,
                mi.date,
                md.teamA_pdras,
                md.teamB_pdras,
                md.minutes_played,
                md.match_state,
                md.match_segment,
                md.player_dif,
                (md.teamA_headers + md.teamA_footers) AS home_shots,
                (md.teamB_headers + md.teamB_footers) AS away_shots
            FROM match_info mi
            JOIN match_detail md ON mi.match_id = md.match_id
        """
        context_df = DB.select(sql_query)
        context_df['date'] = pd.to_datetime(context_df['date'])
        context_df['match_state'] = pd.to_numeric(context_df['match_state'], errors='raise').astype(float)
        context_df['player_dif']  = pd.to_numeric(context_df['player_dif'],  errors='raise').astype(float)

        home_df = pd.DataFrame({
            'shots'              : context_df['home_shots'],
            'total_ras'          : context_df['teamA_pdras'],
            'minutes_played'     : context_df['minutes_played'],
            'team_is_home'       : 1,
            'team_elevation_dif' : context_df['home_elevation_dif'],
            'opp_elevation_dif'  : context_df['away_elevation_dif'],
            'team_travel'        : 0,
            'opp_travel'         : context_df['away_travel'],
            'team_rest_days'     : context_df['home_rest_days'],
            'opp_rest_days'      : context_df['away_rest_days'],
            'match_state'        : context_df['match_state'],
            'match_segment'      : context_df['match_segment'],
            'player_dif'         : context_df['player_dif'],
            'temperature_c'      : context_df['temperature_c'],
            'is_raining'         : context_df['is_raining'],
            'match_time'         : context_df['date'].apply(_bucket_time)
        })

        away_df = pd.DataFrame({
            'shots'              : context_df['away_shots'],
            'total_ras'          : context_df['teamB_pdras'],
            'minutes_played'     : context_df['minutes_played'],
            'team_is_home'       : 0,
            'team_elevation_dif' : context_df['away_elevation_dif'],
            'opp_elevation_dif'  : context_df['home_elevation_dif'],
            'team_travel'        : context_df['away_travel'],
            'opp_travel'         : 0,
            'team_rest_days'     : context_df['away_rest_days'],
            'opp_rest_days'      : context_df['home_rest_days'],
            'match_state'        : flip(context_df['match_state']),
            'match_segment'      : context_df['match_segment'],
            'player_dif'         : flip(context_df['player_dif']),
            'temperature_c'      : context_df['temperature_c'],
            'is_raining'         : context_df['is_raining'],
            'match_time'         : context_df['date'].apply(_bucket_time)
        })
        
        df = pd.concat([home_df, away_df], ignore_index=True)

        df['shots_per_min']     = df['shots']      / df['minutes_played']
        df['ras_per_min']       = df['total_ras']  / df['minutes_played']

        cat_cols  = ['match_state', 'player_dif', 'match_time']
        bool_cols = ['team_is_home', 'is_raining']
        num_cols  = ['team_elevation_dif', 'opp_elevation_dif', 'team_travel', 'opp_travel', 'team_rest_days', 'opp_rest_days', 'temperature_c', 'match_segment']

        df['match_segment'] = df['match_segment'].astype(int)         
        required_cols = cat_cols + bool_cols + num_cols + ['shots', 'total_ras']
        missing_cols  = [c for c in ['shots', 'total_ras'] if c not in df.columns]
        if missing_cols:
            raise ValueError(f'Missing expected columns: {missing_cols}')

        df = df.dropna(subset=[c for c in required_cols if c in df.columns])

        for c in cat_cols:
            df[c] = df[c].astype(str).str.lower()

        df[bool_cols] = df[bool_cols].astype(int)

        X_cat = pd.get_dummies(df[cat_cols], prefix=cat_cols)
        X     = pd.concat([df[num_cols], df[bool_cols], X_cat], axis=1)

        y           = df['shots_per_min']
        base_margin = np.log(df['ras_per_min'].clip(lower=1e-6))

        dtrain = xgb.DMatrix(X, label=y, base_margin=base_margin)

        params = dict(objective='count:poisson',
                        tree_method='hist',
                        max_depth=6,
                        eta=0.03,
                        subsample=1.0,
                        colsample_bytree=1.0,
                        min_child_weight=10,
                        gamma=2,
                        reg_alpha=1,
                        reg_lambda=2)  
        
        cv_results = xgb.cv(
            params,
            dtrain,
            num_boost_round=500,
            nfold=5,
            early_stopping_rounds=100,
            metrics='poisson-nloglik',
            verbose_eval=False
        )
        
        optimal_rounds = len(cv_results)
        booster = xgb.train(params, dtrain, num_boost_round=optimal_rounds)

        _save(f'cras', booster, X.columns.tolist())
