In [None]:
{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
        "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:03.498148Z",
          "iopub.status.busy": "2025-01-03T16:55:03.497721Z",
          "iopub.status.idle": "2025-01-03T16:55:03.502661Z",
          "shell.execute_reply": "2025-01-03T16:55:03.501571Z",
          "shell.execute_reply.started": "2025-01-03T16:55:03.498078Z"
        },
        "trusted": true
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "from typing import Optional, List, Callable, Any, Union, Dict\n",
        "from itertools import product\n",
        "from statistics import mean\n",
        "from pathlib import Path\n",
        "import gzip\n",
        "import os"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {

      },
      "source": [
        "### Read datasets\n",
        "Use the gzip function is files ar gzipped"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:03.504167Z",
          "iopub.status.busy": "2025-01-03T16:55:03.503845Z",
          "iopub.status.idle": "2025-01-03T16:55:03.523023Z",
          "shell.execute_reply": "2025-01-03T16:55:03.521924Z",
          "shell.execute_reply.started": "2025-01-03T16:55:03.504128Z"
        },
        "trusted": true
      },
      "outputs": [],
      "source": [
        "def read_ds_gzip(path: Optional[Path]=None, ds: str = \"TRAIN\") -\u003E pd.DataFrame:\n",
        "    \"\"\"Args:\n",
        "        path (Optional[Path], optional): the path to read the dataset file. Defaults to /kaggle/input/the-insa-starcraft-2-player-prediction-challenge/{ds}.CSV.gz.\n",
        "        ds (str, optional): the part to read (TRAIN or TEST), to use when path is None. Defaults to \"TRAIN\".\n",
        "\n",
        "    Returns:\n",
        "        pd.DataFrame:\n",
        "    \"\"\"\n",
        "    with gzip.open(f'/kaggle/input/the-insa-starcraft-2-player-prediction-challenge/{ds}.CSV.gz' if path is None else path) as f:\n",
        "        max_actions = max(( len( str(c).split(\",\")) for c in f.readlines() ))\n",
        "        f.seek(0)\n",
        "        _names = [\"battleneturl\", \"played_race\"] if \"TRAIN\" in ds else [\"played_race\"]\n",
        "        _names.extend(range(max_actions - len(_names)))\n",
        "        return pd.read_csv(f, names=_names, dtype= str)\n",
        "\n",
        "def read_ds(path: Optional[Path]=None, ds: str = \"TRAIN\"):\n",
        "    \"\"\"Args:\n",
        "        path (Optional[Path], optional): the path to read the dataset file. Defaults to /kaggle/input/the-insa-starcraft-2-player-prediction-challenge/{ds}.CSV.gz.\n",
        "        ds (str, optional): the part to read (TRAIN or TEST), to use when path is None. Defaults to \"TRAIN\".\n",
        "\n",
        "    Returns:\n",
        "        pd.DataFrame:\n",
        "    \"\"\"\n",
        "    with open(f'/kaggle/input/train-sc2-keystrokes/{ds}.CSV' if path is None else path) as f:\n",
        "        max_actions = max(( len( str(c).split(\",\")) for c in f.readlines() ))\n",
        "        f.seek(0)\n",
        "        _names = [\"battleneturl\", \"played_race\"] if \"TRAIN\" in ds else [\"played_race\"]\n",
        "        _names.extend(range(max_actions - len(_names)))\n",
        "        return pd.read_csv(f, names=_names, dtype= str)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:03.525005Z",
          "iopub.status.busy": "2025-01-03T16:55:03.524636Z",
          "iopub.status.idle": "2025-01-03T16:55:12.650412Z",
          "shell.execute_reply": "2025-01-03T16:55:12.649188Z",
          "shell.execute_reply.started": "2025-01-03T16:55:03.524971Z"
        },
        "trusted": true
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(3052, 10539)"
            ]
          },
          "execution_count": 32,
          "metadata": {

          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "features_train = read_ds(Path(os.path.abspath('')) / \"data/TRAIN.CSV\") # Replace with correct path \n",
        "# features_test = read_ds(\"TEST\")\n",
        "features_train.shape #, features_test.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {

      },
      "source": [
        "### Dependent Variable\n",
        "Our dependent variable is a categorical string; we can convert it to categories codes (number) with pd.Categorical\n",
        "\n",
        "pd.Categorical doesn't directly modify the battleneturl to a number, instead it adds a cat.codes attribute to it. We can create a little function to convert the dependent variable from string to its category ID:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 33,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:12.652394Z",
          "iopub.status.busy": "2025-01-03T16:55:12.652048Z",
          "iopub.status.idle": "2025-01-03T16:55:12.657336Z",
          "shell.execute_reply": "2025-01-03T16:55:12.655950Z",
          "shell.execute_reply.started": "2025-01-03T16:55:12.652364Z"
        },
        "trusted": true
      },
      "outputs": [],
      "source": [
        "def to_categories(df: pd.DataFrame, col: str=\"battleneturl\") -\u003E None:\n",
        "    \"\"\"Convert col of df to a categorical column\"\"\"\n",
        "    df[\"battleneturl\"] = pd.Categorical(df[\"battleneturl\"])\n",
        "    df[[col]] = df[[col]].apply(lambda x: x.cat.codes)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {

      },
      "source": [
        "#### Removing outliers\n",
        "YOUR IDEAS / APPROACHES HERE.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 34,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:12.658769Z",
          "iopub.status.busy": "2025-01-03T16:55:12.658394Z",
          "iopub.status.idle": "2025-01-03T16:55:13.712063Z",
          "shell.execute_reply": "2025-01-03T16:55:13.710917Z",
          "shell.execute_reply.started": "2025-01-03T16:55:12.658728Z"
        },
        "trusted": true
      },
      "outputs": [],
      "source": [
        "#TODO\n",
        "# YOUR CODE HERE"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {

      },
      "source": [
        "### Getting features...\n",
        "\n",
        "Building a mini framework to read our Dataframe and convert it to features.\n",
        "\n",
        "Now we will create features out of the dataset.\n",
        "\n",
        "FeaturesGetter iterates over an ActionsDataLoader (yield every actions between two 't[xx]') and apply a set of Feature contained in a FeaturePool. At the end, it gets metrics over the values registered by each features in the feature pool."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 35,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:13.713615Z",
          "iopub.status.busy": "2025-01-03T16:55:13.713229Z",
          "iopub.status.idle": "2025-01-03T16:55:13.718317Z",
          "shell.execute_reply": "2025-01-03T16:55:13.717234Z",
          "shell.execute_reply.started": "2025-01-03T16:55:13.713573Z"
        },
        "trusted": true
      },
      "outputs": [],
      "source": [
        "class CancelBatchException(Exception):\n",
        "    \"\"\"Used to cancel processing of a batch of data (when the keystroke sequence is fully read)\"\"\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 36,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:13.768867Z",
          "iopub.status.busy": "2025-01-03T16:55:13.768507Z",
          "iopub.status.idle": "2025-01-03T16:55:13.792121Z",
          "shell.execute_reply": "2025-01-03T16:55:13.790899Z",
          "shell.execute_reply.started": "2025-01-03T16:55:13.768832Z"
        },
        "trusted": true
      },
      "outputs": [],
      "source": [
        "class Feature:\n",
        "    def __init__(\n",
        "        self, name: str, \n",
        "        lambda_: Callable[[List[str]], Union[int, float]]=None, \n",
        "        val_count: int=None, \n",
        "        max_iter: int=None, \n",
        "        predicate: Callable[[List[str]], bool]=None, \n",
        "        metric: Callable[[List[str]], Union[int, float]]=mean, \n",
        "        div: bool=True\n",
        "    ):\n",
        "        \"\"\"If neither lambda_, val_count nor predicate are defined, the _lambda will just be the length of the given action range.\n",
        "\n",
        "        Args:\n",
        "            name (str): feature name\n",
        "            lambda_ (Callable[[List[str]], Union[int, float]], optional): \n",
        "                lambda that'll be applied to compute metric value over action ranges. Defaults to None.\n",
        "            val_count (int, optional): set feature's lambda to be the count of this value (if lambda_ is None). Defaults to None.\n",
        "            max_iter (int, optional): when exceeding this iteration, the feature will no longer be computed. Defaults to None.\n",
        "            predicate (Callable[[List[str]], bool], optional): define a predicate to compute lambda across one \n",
        "                action range (if lambda_ and val_count is None). Defaults to None.\n",
        "            metric (Callable[[List[str]], Union[int, float]], optional): the metric used to aggregate feature's \n",
        "                values across all ranges. Defaults to mean.\n",
        "            div (bool, optional): whether to divide the aggregated metric value. Defaults to True.\n",
        "        \"\"\"\n",
        "        self.name, self.metric, self.max_iter, self.div = name, metric, max_iter, div\n",
        "        self.reset()\n",
        "        self._lambda: Callable[[List[str]], Union[int, float]]\n",
        "        if   lambda_   is not None: \n",
        "            self._lambda = lambda_\n",
        "        elif val_count is not None: \n",
        "            self._lambda = lambda x: x.count(val_count)\n",
        "        elif predicate is not None: \n",
        "            self._lambda = lambda x: sum(1 for o in x if predicate(o))\n",
        "        else: \n",
        "            self._lambda = lambda x: len(x)\n",
        "    \n",
        "    def reset(self):\n",
        "        \"\"\"Resets the value of the feature\n",
        "        \"\"\"\n",
        "        self.vals: List[Union[int, float]] = []\n",
        "        self.val, self.i = 0, 0\n",
        "        \n",
        "    def __call__(self, rng: List[str], *args):\n",
        "        \"\"\"Compute feature's value according to _lambda, for given action range. Extra *args are given to _lambda\n",
        "\n",
        "        Args:\n",
        "            rng (List[str]): range of action (given by ActionDataLoader)\n",
        "        \"\"\"\n",
        "        if self.max_iter is None or self.i \u003C self.max_iter:\n",
        "            self.val = self._lambda(rng, *args)\n",
        "            self.vals.append(self.val)\n",
        "            self.i += 1\n",
        "            \n",
        "    @property\n",
        "    def value(self) -\u003E int | float:\n",
        "        \"\"\"Returns:\n",
        "            int | float: the aggregated feature's value across all action ranges read until now\n",
        "        \"\"\"\n",
        "        return self.metric(self.vals)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 54,
      "metadata": {

      },
      "outputs": [],
      "source": [
        "class ActionsDataLoader:\n",
        "    \"\"\"Iterates over two 'tXX', yielding actions between each time steps\n",
        "    \"\"\"\n",
        "    def __init__(self, actions: pd.DataFrame, do_range: bool = True, max_t: Optional[int]=None):\n",
        "        \"\"\"Args:\n",
        "            actions (List[Feature]): The raw dataframe\n",
        "            do_range (bool): whether the data loader should iterate and yield each range \n",
        "                between two 'tXX', or just yield the whole sequence once then return. Defaults to True.\n",
        "            max_t (Optional[int], optional): the t max to stop yielding. Defaults to None.\n",
        "        \"\"\"\n",
        "        self.t_indx = [0] + [j for j, val in enumerate(actions) if isinstance(val, str) and val[0] == \"t\"]\n",
        "        self.do_range = do_range\n",
        "        if max_t and max_t \u003C len(self.t_indx):\n",
        "            self.t_indx = self.t_indx[:max_t]\n",
        "            self.values = actions.values[: self.t_indx[max_t - 1]]\n",
        "        else: \n",
        "            self.values = actions.values\n",
        "        self.n_t = len(self.t_indx)\n",
        "    \n",
        "    def __len__(self): return 1 if self.do_range else (self.n_t or 1)\n",
        "    \n",
        "    def __iter__(self):\n",
        "        if self.n_t == 0 or not self.do_range:\n",
        "            self.start_indx = 0\n",
        "            self.end_indx = self._get_first_nan_indx()\n",
        "            yield self.values[self.start_indx:self.end_indx].tolist()\n",
        "            return \n",
        "        for self.i in range(self.n_t):\n",
        "            try:\n",
        "                self._get_actions_range()\n",
        "                yield self.values[self.start_indx:self.end_indx].tolist()\n",
        "            except CancelBatchException: \n",
        "                return\n",
        "\n",
        "    def _get_actions_range(self):\n",
        "        \"\"\"Computes the action range until a 'tXX' is met. If there are no more 'tXX', \n",
        "            it means we reached the end of the game, and the sequences finish with NaN \n",
        "            (or for the longest game, the full row is read).\n",
        "\n",
        "        Raises:\n",
        "            CancelBatchException: indicates that there is no more action to be read (next action is NaN).\n",
        "        \"\"\"\n",
        "        self.start_indx = self.t_indx[self.i] + (1 if self.i \u003E0 else 0)\n",
        "        if  self.start_indx \u003E= len(self.values) or pd.isna(self.values[self.start_indx]): \n",
        "            raise CancelBatchException\n",
        "        self.end_indx = self.t_indx[self.i + 1] if (self.i + 1) \u003C self.n_t else self._get_first_nan_indx()\n",
        "    \n",
        "    def _get_first_nan_indx(self) -\u003E int:\n",
        "        \"\"\"Returns:\n",
        "            int: the first index in values that is not NaN\n",
        "        \"\"\"\n",
        "        nans = np.argwhere(pd.isna(self.values[self.start_indx:]))\n",
        "        return len(self.values) if len(nans) == 0 else nans[0][0]\n",
        "    \n",
        "    def get_max_t(self):\n",
        "        \"\"\"Gets the last 'tXX' defined. If this data loader was defined with max_t not None, it returns this max_t\n",
        "        \"\"\"\n",
        "        if self.n_t - 1 == 0:\n",
        "            return 0\n",
        "        return int(self.values[self.t_indx[self.n_t - 1]][1:]) if self.n_t \u003E 0 else 0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:13.721365Z",
          "iopub.status.busy": "2025-01-03T16:55:13.721010Z",
          "iopub.status.idle": "2025-01-03T16:55:13.742731Z",
          "shell.execute_reply": "2025-01-03T16:55:13.741384Z",
          "shell.execute_reply.started": "2025-01-03T16:55:13.721334Z"
        },
        "trusted": true
      },
      "outputs": [],
      "source": [
        "class FeaturesGetter:\n",
        "    def __init__(self, features: List[Feature], n_rows: int=3052, log: bool=False, **kwargs_dataloader):\n",
        "        \"\"\"Args:\n",
        "            features (List[Feature]): the list of features to compute\n",
        "            n_rows (int, optional): the number of row (used only in log). Defaults to 3052.\n",
        "            log (bool, optional): whether to output log information when processing the df. Defaults to False.\n",
        "\n",
        "            Accepts extra kwargs_dataloader that'll be passed to the dataloader\n",
        "        \"\"\"\n",
        "        self.feature_pool, self.n_rows, self.log, self.kwargs_dataloader = features, n_rows, log, kwargs_dataloader\n",
        "        self.game_l: int # game length\n",
        "        self.reset()\n",
        "        \n",
        "    def reset(self):\n",
        "        \"\"\"Resets the value of each feature in the feature pool\"\"\"\n",
        "        for feature in self.feature_pool: \n",
        "            feature.reset()\n",
        "        self.game_l = 0\n",
        "    \n",
        "    def _log(self):\n",
        "        \"\"\"Print to stdout the current % of the df that have been processed\"\"\"\n",
        "        global cnt\n",
        "        cnt += 1\n",
        "        print(f\"{cnt * 100 / self.n_rows:.2f} %\", end=\"\\r\")\n",
        "    \n",
        "    def _one_update(self):\n",
        "        \"\"\"Compute each feature's value for one batch (one action range yielded by the ActionDataLoader)\"\"\"\n",
        "        for feature in self.feature_pool: \n",
        "            feature(self.actions_rng)\n",
        "        \n",
        "    def __call__(self, actions: pd.DataFrame) -\u003E pd.Series:\n",
        "        \"\"\"Computes all features' values for each of the given actions, iterating over ADL with parameters defined in __init__\n",
        "        \n",
        "        Returns:\n",
        "            pd.Series: the features' values as a Series. \n",
        "                Adds an extra feature which is the game length is max_t is not in __init__ kwargs\n",
        "        \"\"\"\n",
        "        self.reset()\n",
        "        if self.log:\n",
        "            self._log()\n",
        "        adl = ActionsDataLoader(actions, **self.kwargs_dataloader)\n",
        "        for self.actions_rng in adl:\n",
        "            self._one_update()\n",
        "        activs = [f.value / len(adl) if f.div else f.value for f in self.feature_pool]\n",
        "        self.game_l = (max_t := self.kwargs_dataloader.get(\"max_t\", None)) or adl.get_max_t()\n",
        "        return pd.Series( activs + ([self.game_l] if max_t is None else []) )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {

      },
      "source": [
        "Defining lambdas to convert dataset to features\n",
        "We create basic features, corresponding to the mean of each action played per timestamp plus the mean of all actions together"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 46,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:13.793923Z",
          "iopub.status.busy": "2025-01-03T16:55:13.793524Z",
          "iopub.status.idle": "2025-01-03T16:55:13.819750Z",
          "shell.execute_reply": "2025-01-03T16:55:13.818657Z",
          "shell.execute_reply.started": "2025-01-03T16:55:13.793889Z"
        },
        "trusted": true
      },
      "outputs": [],
      "source": [
        "FEATURES_NAMES = [\"s_mean\", \"base_mean\", \"mineral_mean\", \"hotkeys_mean\", \"actions_mean\"]\n",
        "ACTIONS = [ \"s\", \"Base\", \"SingleMineral\", \"hotkey\" ]\n",
        "\n",
        "def get_base_features() -\u003E List[Feature]:\n",
        "    \"\"\"Defines base features (mean of count of each action / hotkeys)\n",
        "    \"\"\"\n",
        "    features = []\n",
        "    for i, action in enumerate(ACTIONS[:-1]):\n",
        "        features.append(Feature(FEATURES_NAMES[i], val_count=action))\n",
        "    features.append(Feature(FEATURES_NAMES[-2], predicate=lambda x: x.startswith(ACTIONS[-1]))) # hotkeys\n",
        "    features.append(Feature(FEATURES_NAMES[-1])) # all actions combined (no lambda_ means lambda_ is just the length)\n",
        "    for i, j in product(range(10), range(3)):\n",
        "        pass\n",
        "    #TODO\n",
        "        # OTHER FEATURES HERE\n",
        "    # OTHER FEATURES HERE\n",
        "    # Guess what would be useful ?\n",
        "    # set div, metric and lambda_ accordingly\n",
        "    return features"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {

      },
      "source": [
        "Now it's ready to be put into a function that'll get all the features from the initial dataframe and return a new dataframe containing only those features. FeaturesGetter gets one extra feature from that we created, which is max_time, corresponding to the \"xx\" of the last \"txx\" seen."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 47,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:13.821292Z",
          "iopub.status.busy": "2025-01-03T16:55:13.820913Z",
          "iopub.status.idle": "2025-01-03T16:55:13.844551Z",
          "shell.execute_reply": "2025-01-03T16:55:13.842925Z",
          "shell.execute_reply.started": "2025-01-03T16:55:13.821256Z"
        },
        "trusted": true
      },
      "outputs": [],
      "source": [
        "features_getter = None\n",
        "def create_features(\n",
        "    df: pd.DataFrame, \n",
        "    min_: int, \n",
        "    max_: int, \n",
        "    drop: bool=False, \n",
        "    features: List[Feature]=get_base_features(), \n",
        "    **kwargs\n",
        ") -\u003E pd.DataFrame:\n",
        "    \"\"\"Compute features on given dataframe\n",
        "\n",
        "    Args:\n",
        "        df (pd.DataFrame)\n",
        "        min_ (int): index of the first action to pass to the feature\n",
        "        max_ (int): index of the last action to pass to the feature\n",
        "        drop (bool, optional): whether to drop original columns of the dataframe. Defaults to False.\n",
        "        features (List[Feature], optional). Defaults to get_base_features().\n",
        "\n",
        "    Returns:\n",
        "        pd.DataFrame: a dataframe containing features' values for each row\n",
        "    \"\"\"\n",
        "    global features_getter\n",
        "    features_getter = FeaturesGetter(features, **kwargs)\n",
        "    final_df = df.loc[:,min_:max_].apply(features_getter, axis=1, result_type='expand')\n",
        "    final_df.columns = [f.name for f in features_getter.feature_pool] + ([\"max_time\"] if not kwargs.get(\"max_t\") else [])\n",
        "    if drop:\n",
        "        df = df.drop(columns=[i for i in range(min_, max_ + 1)])\n",
        "    final_df = pd.concat([df, final_df], axis=1)\n",
        "    features_getter.reset()\n",
        "    return final_df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {

      },
      "source": [
        "### Handling string\n",
        "The race_played column can only take three values; instead of converting it to categorical as we did with our dependent variable, we will instead convert it to dummy variables: we one-hot encode each race. It will not add many columns to our dataframe (only three) but will allow the decision trees to split much faster on the race (on only one binary split)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 41,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:13.845825Z",
          "iopub.status.busy": "2025-01-03T16:55:13.845491Z",
          "iopub.status.idle": "2025-01-03T16:55:13.874662Z",
          "shell.execute_reply": "2025-01-03T16:55:13.873395Z",
          "shell.execute_reply.started": "2025-01-03T16:55:13.845795Z"
        },
        "trusted": true
      },
      "outputs": [],
      "source": [
        "def get_dummies(df: pd.DataFrame):\n",
        "    \"\"\"Converts textual columns to one-hot encoded vectors (one column per possible value)\"\"\"\n",
        "    df = pd.get_dummies(df, columns=[\"played_race\"])\n",
        "    return df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {

      },
      "source": [
        "Function preprocess creates a pipeline of all the function we just implemented: it create the features, converts the race to dummy variables and the dependent variable to category codes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 48,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:13.876319Z",
          "iopub.status.busy": "2025-01-03T16:55:13.875931Z",
          "iopub.status.idle": "2025-01-03T16:55:13.895928Z",
          "shell.execute_reply": "2025-01-03T16:55:13.894611Z",
          "shell.execute_reply.started": "2025-01-03T16:55:13.876270Z"
        },
        "trusted": true
      },
      "outputs": [],
      "source": [
        "def preprocess(df: pd.DataFrame, min_: int, max_: int, is_train: bool=True, convert_race: bool=True, **kwargs):\n",
        "    \"\"\"Calls FeatureGetter on the dataframe, applying preprocessing steps before\n",
        "    Args:\n",
        "        df (pd.DataFrame)\n",
        "        min_ (int)\n",
        "        max_ (int)\n",
        "        is_train (bool, optional): whether the current dataframe contains training data \n",
        "            (to preprocess dependent variable or not). Defaults to True.\n",
        "        convert_race (bool, optional): whether to convert race attribute to dummies. Defaults to True.\n",
        "\n",
        "    Returns:\n",
        "        _type_: _description_\n",
        "    \"\"\"\n",
        "    df = create_features(df, min_, max_, **kwargs)\n",
        "    if convert_race: \n",
        "        df = get_dummies(df)\n",
        "    df.columns = df.columns.astype(str)\n",
        "    if is_train:\n",
        "        to_categories(df)\n",
        "    return df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 56,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2025-01-03T16:55:13.897589Z",
          "iopub.status.busy": "2025-01-03T16:55:13.897163Z",
          "iopub.status.idle": "2025-01-03T16:55:43.137764Z",
          "shell.execute_reply": "2025-01-03T16:55:43.136556Z",
          "shell.execute_reply.started": "2025-01-03T16:55:13.897548Z"
        },
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "CPU times: user 44.3 s, sys: 181 ms, total: 44.5 s\n",
            "Wall time: 44.3 s\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003Cdiv\u003E\n",
              "\u003Cstyle scoped\u003E\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "\u003C/style\u003E\n",
              "\u003Ctable border=\"1\" class=\"dataframe\"\u003E\n",
              "  \u003Cthead\u003E\n",
              "    \u003Ctr style=\"text-align: right;\"\u003E\n",
              "      \u003Cth\u003E\u003C/th\u003E\n",
              "      \u003Cth\u003Ebattleneturl\u003C/th\u003E\n",
              "      \u003Cth\u003Es_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Ebase_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Emineral_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Ehotkeys_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Eactions_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Ehotkey00_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Ehotkey01_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Ehotkey02_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Ehotkey10_mean\u003C/th\u003E\n",
              "      \u003Cth\u003E...\u003C/th\u003E\n",
              "      \u003Cth\u003Ehotkey80_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Ehotkey81_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Ehotkey82_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Ehotkey90_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Ehotkey91_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Ehotkey92_mean\u003C/th\u003E\n",
              "      \u003Cth\u003Emax_time\u003C/th\u003E\n",
              "      \u003Cth\u003Eplayed_race_Protoss\u003C/th\u003E\n",
              "      \u003Cth\u003Eplayed_race_Terran\u003C/th\u003E\n",
              "      \u003Cth\u003Eplayed_race_Zerg\u003C/th\u003E\n",
              "    \u003C/tr\u003E\n",
              "  \u003C/thead\u003E\n",
              "  \u003Ctbody\u003E\n",
              "    \u003Ctr\u003E\n",
              "      \u003Cth\u003E0\u003C/th\u003E\n",
              "      \u003Ctd\u003E53\u003C/td\u003E\n",
              "      \u003Ctd\u003E2.036254\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.199396\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.015106\u003C/td\u003E\n",
              "      \u003Ctd\u003E4.492447\u003C/td\u003E\n",
              "      \u003Ctd\u003E6.743202\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.015106\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.123867\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.190332\u003C/td\u003E\n",
              "      \u003Ctd\u003E...\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.012085\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E1655.0\u003C/td\u003E\n",
              "      \u003Ctd\u003ETrue\u003C/td\u003E\n",
              "      \u003Ctd\u003EFalse\u003C/td\u003E\n",
              "      \u003Ctd\u003EFalse\u003C/td\u003E\n",
              "    \u003C/tr\u003E\n",
              "    \u003Ctr\u003E\n",
              "      \u003Cth\u003E1\u003C/th\u003E\n",
              "      \u003Ctd\u003E29\u003C/td\u003E\n",
              "      \u003Ctd\u003E1.620482\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.036145\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E4.596386\u003C/td\u003E\n",
              "      \u003Ctd\u003E6.253012\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.006024\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.250000\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.195783\u003C/td\u003E\n",
              "      \u003Ctd\u003E...\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.003012\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.048193\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.003012\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.054217\u003C/td\u003E\n",
              "      \u003Ctd\u003E1655.0\u003C/td\u003E\n",
              "      \u003Ctd\u003ETrue\u003C/td\u003E\n",
              "      \u003Ctd\u003EFalse\u003C/td\u003E\n",
              "      \u003Ctd\u003EFalse\u003C/td\u003E\n",
              "    \u003C/tr\u003E\n",
              "    \u003Ctr\u003E\n",
              "      \u003Cth\u003E2\u003C/th\u003E\n",
              "      \u003Ctd\u003E53\u003C/td\u003E\n",
              "      \u003Ctd\u003E2.128713\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.232673\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.014851\u003C/td\u003E\n",
              "      \u003Ctd\u003E4.297030\u003C/td\u003E\n",
              "      \u003Ctd\u003E6.673267\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.014851\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.089109\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.084158\u003C/td\u003E\n",
              "      \u003Ctd\u003E...\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.009901\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.009901\u003C/td\u003E\n",
              "      \u003Ctd\u003E1010.0\u003C/td\u003E\n",
              "      \u003Ctd\u003ETrue\u003C/td\u003E\n",
              "      \u003Ctd\u003EFalse\u003C/td\u003E\n",
              "      \u003Ctd\u003EFalse\u003C/td\u003E\n",
              "    \u003C/tr\u003E\n",
              "    \u003Ctr\u003E\n",
              "      \u003Cth\u003E3\u003C/th\u003E\n",
              "      \u003Ctd\u003E29\u003C/td\u003E\n",
              "      \u003Ctd\u003E1.965347\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.103960\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E4.787129\u003C/td\u003E\n",
              "      \u003Ctd\u003E6.856436\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.009901\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.188119\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.158416\u003C/td\u003E\n",
              "      \u003Ctd\u003E...\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.004950\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.069307\u003C/td\u003E\n",
              "      \u003Ctd\u003E1005.0\u003C/td\u003E\n",
              "      \u003Ctd\u003ETrue\u003C/td\u003E\n",
              "      \u003Ctd\u003EFalse\u003C/td\u003E\n",
              "      \u003Ctd\u003EFalse\u003C/td\u003E\n",
              "    \u003C/tr\u003E\n",
              "    \u003Ctr\u003E\n",
              "      \u003Cth\u003E4\u003C/th\u003E\n",
              "      \u003Ctd\u003E53\u003C/td\u003E\n",
              "      \u003Ctd\u003E1.925926\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.018519\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E3.787037\u003C/td\u003E\n",
              "      \u003Ctd\u003E5.731481\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.009259\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.092593\u003C/td\u003E\n",
              "      \u003Ctd\u003E...\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.0\u003C/td\u003E\n",
              "      \u003Ctd\u003E0.000000\u003C/td\u003E\n",
              "      \u003Ctd\u003E540.0\u003C/td\u003E\n",
              "      \u003Ctd\u003ETrue\u003C/td\u003E\n",
              "      \u003Ctd\u003EFalse\u003C/td\u003E\n",
              "      \u003Ctd\u003EFalse\u003C/td\u003E\n",
              "    \u003C/tr\u003E\n",
              "  \u003C/tbody\u003E\n",
              "\u003C/table\u003E\n",
              "\u003Cp\u003E5 rows × 40 columns\u003C/p\u003E\n",
              "\u003C/div\u003E"
            ],
            "text/plain": [
              "   battleneturl    s_mean  base_mean  mineral_mean  hotkeys_mean  \\\n",
              "0            53  2.036254   0.199396      0.015106      4.492447   \n",
              "1            29  1.620482   0.036145      0.000000      4.596386   \n",
              "2            53  2.128713   0.232673      0.014851      4.297030   \n",
              "3            29  1.965347   0.103960      0.000000      4.787129   \n",
              "4            53  1.925926   0.018519      0.000000      3.787037   \n",
              "\n",
              "   actions_mean  hotkey00_mean  hotkey01_mean  hotkey02_mean  hotkey10_mean  \\\n",
              "0      6.743202       0.015106            0.0       0.123867       0.190332   \n",
              "1      6.253012       0.006024            0.0       0.250000       0.195783   \n",
              "2      6.673267       0.014851            0.0       0.089109       0.084158   \n",
              "3      6.856436       0.009901            0.0       0.188119       0.158416   \n",
              "4      5.731481       0.009259            0.0       0.000000       0.092593   \n",
              "\n",
              "   ...  hotkey80_mean  hotkey81_mean  hotkey82_mean  hotkey90_mean  \\\n",
              "0  ...       0.000000            0.0       0.000000       0.012085   \n",
              "1  ...       0.003012            0.0       0.048193       0.003012   \n",
              "2  ...       0.000000            0.0       0.000000       0.009901   \n",
              "3  ...       0.000000            0.0       0.000000       0.004950   \n",
              "4  ...       0.000000            0.0       0.000000       0.000000   \n",
              "\n",
              "   hotkey91_mean  hotkey92_mean  max_time  played_race_Protoss  \\\n",
              "0            0.0       0.000000    1655.0                 True   \n",
              "1            0.0       0.054217    1655.0                 True   \n",
              "2            0.0       0.009901    1010.0                 True   \n",
              "3            0.0       0.069307    1005.0                 True   \n",
              "4            0.0       0.000000     540.0                 True   \n",
              "\n",
              "   played_race_Terran  played_race_Zerg  \n",
              "0               False             False  \n",
              "1               False             False  \n",
              "2               False             False  \n",
              "3               False             False  \n",
              "4               False             False  \n",
              "\n",
              "[5 rows x 40 columns]"
            ]
          },
          "execution_count": 56,
          "metadata": {

          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "%%time\n",
        "cnt=0\n",
        "processed_df = preprocess(features_train, 0, features_train.shape[1]-3, drop=True, n_rows=3052, log=True) \n",
        "processed_df.head()"
      ]
    }
  ],
  "metadata": {
    "kaggle": {
      "accelerator": "none",
      "dataSources": [
        {
          "datasetId": 6399625,
          "sourceId": 10335289,
          "sourceType": "datasetVersion"
        }
      ],
      "dockerImageVersionId": 30822,
      "isGpuEnabled": false,
      "isInternetEnabled": true,
      "language": "python",
      "sourceType": "notebook"
    },
    "kernelspec": {
      "display_name": ".venv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}