diff --git a/examples/classification_with_NODE.ipynb b/examples/classification_with_NODE.ipynb index cccf343e..6b8ef3fb 100644 --- a/examples/classification_with_NODE.ipynb +++ b/examples/classification_with_NODE.ipynb @@ -3,31 +3,10 @@ { "cell_type": "code", "execution_count": 1, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], "source": [ "import os\n", "os.chdir(\"..\")\n", "from sklearn.datasets import fetch_covtype\n", - "import random\n", - "import numpy as np\n", - "import pandas as pd\n", - "import lightgbm as lgb\n", - "from sklearn.metrics import accuracy_score, f1_score\n", - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], - "source": [ "from sklearn.datasets import make_classification\n", "from sklearn.model_selection import train_test_split\n", "import random\n", @@ -35,29 +14,50 @@ "import pandas as pd\n", "import lightgbm as lgb\n", "from sklearn.metrics import accuracy_score, f1_score\n", - "import os\n", - "os.chdir(\"..\")\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "# Utility Functions" - ] + ], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", "execution_count": 2, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], "source": [ + "\n", + "def make_mixed_classification(n_samples, n_features, n_categories):\n", + " X,y = make_classification(n_samples=n_samples, n_features=n_features, random_state=42, n_informative=5)\n", + " cat_cols = random.choices(list(range(X.shape[-1])),k=n_categories)\n", + " num_cols = [i for i in range(X.shape[-1]) if i not in cat_cols]\n", + " for col in cat_cols:\n", + " X[:,col] = pd.qcut(X[:,col], q=4).codes.astype(int)\n", + " col_names = [] \n", + " num_col_names=[]\n", + " cat_col_names=[]\n", + " for i in range(X.shape[-1]):\n", + " if i in cat_cols:\n", + " col_names.append(f\"cat_col_{i}\")\n", + " cat_col_names.append(f\"cat_col_{i}\")\n", + " if i in num_cols:\n", + " col_names.append(f\"num_col_{i}\")\n", + " num_col_names.append(f\"num_col_{i}\")\n", + " X = pd.DataFrame(X, columns=col_names)\n", + " y = pd.Series(y, name=\"target\")\n", + " data = X.join(y)\n", + " return data, cat_col_names, num_col_names\n", + " \n", "def load_classification_data():\n", " dataset = fetch_covtype(data_home=\"data\")\n", " data = np.hstack([dataset.data, dataset.target.reshape(-1, 1)])\n", @@ -83,104 +83,106 @@ " val_acc = accuracy_score(y_true, y_pred)\n", " val_f1 = f1_score(y_true, y_pred)\n", " print(f\"{tag} Acc: {val_acc} | {tag} F1: {val_f1}\")" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "# Generate Synthetic Data \n", "\n", "First of all, let's create a synthetic data which is a mix of numerical and categorical features" - ] + ], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", "execution_count": 3, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], "source": [ "data, cat_col_names, num_col_names = make_mixed_classification(n_samples=10000, n_features=20, n_categories=4)\n", "train, test = train_test_split(data, random_state=42)\n", "train, val = train_test_split(train, random_state=42)" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "## Baseline\n", "\n", "Let's use the default LightGBM model as a baseline." - ] + ], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", "execution_count": 4, - "metadata": { - "Collapsed": "false" - }, + "source": [ + "clf = lgb.LGBMClassifier(random_state=42)\n", + "clf.fit(train.drop(columns='target'), train['target'], categorical_feature=cat_col_names)\n", + "val_pred = clf.predict(val.drop(columns='target'))\n", + "print_metrics(val['target'], val_pred, \"Validation\")\n", + "test_pred = clf.predict(test.drop(columns='target'))\n", + "print_metrics(test['target'], test_pred, \"Holdout\")" + ], "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stderr", "text": [ - "Validation Acc: 0.9290666666666667 | Validation F1: 0.9285330467490596\n", - "Holdout Acc: 0.9344 | Holdout F1: 0.9346613545816733\n" + "/home/fonnesbeck/anaconda3/envs/pitch_effect/lib/python3.9/site-packages/lightgbm/basic.py:1702: UserWarning: Using categorical_feature in Dataset.\n", + " _log_warning('Using categorical_feature in Dataset.')\n" ] }, { - "name": "stderr", "output_type": "stream", + "name": "stdout", "text": [ - "D:\\miniconda3\\envs\\df_encoder\\lib\\site-packages\\lightgbm\\basic.py:1551: UserWarning: Using categorical_feature in Dataset.\n", - " warnings.warn('Using categorical_feature in Dataset.')\n" + "Validation Acc: 0.9328 | Validation F1: 0.9322580645161291\n", + "Holdout Acc: 0.9328 | Holdout F1: 0.9330677290836654\n" ] } ], - "source": [ - "clf = lgb.LGBMClassifier(random_state=42)\n", - "clf.fit(train.drop(columns='target'), train['target'], categorical_feature=cat_col_names)\n", - "val_pred = clf.predict(val.drop(columns='target'))\n", - "print_metrics(val['target'], val_pred, \"Validation\")\n", - "test_pred = clf.predict(test.drop(columns='target'))\n", - "print_metrics(test['target'], test_pred, \"Holdout\")" - ] + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "# Importing the Library" - ] + ], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", - "execution_count": 3, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], + "execution_count": 18, "source": [ "from pytorch_tabular import TabularModel\n", "from pytorch_tabular.models import CategoryEmbeddingModelConfig, NodeConfig, TabNetModelConfig\n", "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig\n", - "from pytorch_tabular.category_encoders import CategoricalEmbeddingTransformer" - ] + "from pytorch_tabular.categorical_encoders import CategoricalEmbeddingTransformer\n", + "\n" + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "## Define the Configs\n", "\n", @@ -191,15 +193,14 @@ "* TrainerConfig - This let's you configure the training process by setting things like batch_size, epochs, early stopping, etc. The vast majority of parameters are directly borrowed from PyTorch Lightning and is passed to the underlying Trainer object during training\n", "* OptimizerConfig - This let's you define and use different Optimizers and LearningRate Schedulers. Standard PyTorch Optimizers and Learning RateSchedulers are supported. For custom optimizers, you can use the parameter in the fit method to overwrite this. The custom optimizer should be PyTorch compatible\n", "* ExperimentConfig - This is an optional parameter. If set, this defines the Experiment Tracking. Right now, only two experiment tracking frameworks are supported: Tensorboard and Weights&Biases. W&B experiment tracker has more features like tracking the gradients and logits across epochs." - ] + ], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", - "execution_count": 14, - "metadata": { - "Collapsed": "false" - }, - "outputs": [], + "execution_count": 11, "source": [ "data_config = DataConfig(\n", " target=['target'], #target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n", @@ -212,7 +213,8 @@ " auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate\n", " batch_size=1024,\n", " max_epochs=1000,\n", - " gpus=1, #index of the GPU to use. 0, means CPU\n", + " auto_select_gpus=False,\n", + " gpus=0, #index of the GPU to use. 0, means CPU\n", ")\n", "optimizer_config = OptimizerConfig()\n", "model_config = CategoryEmbeddingModelConfig(\n", @@ -228,276 +230,182 @@ " optimizer_config=optimizer_config,\n", " trainer_config=trainer_config,\n", ")" - ] + ], + "outputs": [], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "true" - }, "source": [ "## Training the Model \n", "Now that we have defined the configs and the TabularModel. We just need to call the `fit` method and pass the train and test dataframes. We can also pass in validation dataframe. But if omitted, TabularModel will separate 20% of the data as validation." - ] + ], + "metadata": { + "Collapsed": "true" + } }, { "cell_type": "code", - "execution_count": 15, - "metadata": { - "Collapsed": "false", - "collapsed": true, - "jupyter": { - "outputs_hidden": true - }, - "tags": [] - }, + "execution_count": 12, + "source": [ + "tabular_model.fit(train=train, test=test)" + ], "outputs": [ { - "name": "stderr", "output_type": "stream", + "name": "stderr", "text": [ - "D:\\miniconda3\\envs\\df_encoder\\lib\\site-packages\\pytorch_lightning\\utilities\\distributed.py:45: UserWarning: Checkpoint directory saved_models exists and is not empty. With save_top_k=1, all files in this directory will be deleted when a checkpoint is saved!\n", - " warnings.warn(*args, **kwargs)\n", - "GPU available: True, used: False\n", - "GPU available: True, used: False\n", - "TPU available: False, using: 0 TPU cores\n", + "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", - "D:\\miniconda3\\envs\\df_encoder\\lib\\site-packages\\pytorch_lightning\\utilities\\distributed.py:45: UserWarning: GPU available but not used. Set the --gpus flag when calling the script.\n", - " warnings.warn(*args, **kwargs)\n", - "GPU available: True, used: True\n", - "GPU available: True, used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "TPU available: False, using: 0 TPU cores\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", - " | Name | Type | Params\n", - "------------------------------------------------------------\n", - "0 | embedding_layers | ModuleList | 60 \n", - "1 | normalizing_batch_norm | BatchNorm1d | 32 \n", - "2 | linear_layers | Sequential | 19.0 M\n", - "3 | loss | CrossEntropyLoss | 0 \n", - "\n", - " | Name | Type | Params\n", - "------------------------------------------------------------\n", - "0 | embedding_layers | ModuleList | 60 \n", - "1 | normalizing_batch_norm | BatchNorm1d | 32 \n", - "2 | linear_layers | Sequential | 19.0 M\n", - "3 | loss | CrossEntropyLoss | 0 \n", - "\n", - "Finding best initial lr: 0%| | 0/100 [00:00, ?it/s]\u001b[A\n", - "Finding best initial lr: 2%|█▏ | 2/100 [00:00<00:08, 11.52it/s]\u001b[A\n", - "Finding best initial lr: 4%|██▎ | 4/100 [00:00<00:08, 11.66it/s]\u001b[A\n", - "Finding best initial lr: 6%|███▍ | 6/100 [00:00<00:08, 11.25it/s]\u001b[A\n", - "Finding best initial lr: 8%|████▌ | 8/100 [00:00<00:08, 11.25it/s]\u001b[A\n", - "Finding best initial lr: 10%|█████▌ | 10/100 [00:00<00:07, 12.50it/s]\u001b[A\n", - "Finding best initial lr: 12%|██████▋ | 12/100 [00:01<00:08, 10.55it/s]\u001b[A\n", - "Finding best initial lr: 14%|███████▊ | 14/100 [00:01<00:07, 10.78it/s]\u001b[A\n", - "Finding best initial lr: 16%|████████▉ | 16/100 [00:01<00:08, 9.55it/s]\u001b[A\n", - "Finding best initial lr: 17%|█████████▌ | 17/100 [00:01<00:09, 9.08it/s]\u001b[A\n", - "Finding best initial lr: 18%|██████████ | 18/100 [00:01<00:09, 8.97it/s]\u001b[A\n", - "Finding best initial lr: 20%|███████████▏ | 20/100 [00:01<00:07, 10.21it/s]\u001b[A\n", - "Finding best initial lr: 22%|████████████▎ | 22/100 [00:02<00:09, 8.56it/s]\u001b[A\n", - "Finding best initial lr: 23%|████████████▉ | 23/100 [00:02<00:09, 8.22it/s]\u001b[A\n", - "Finding best initial lr: 25%|██████████████ | 25/100 [00:02<00:07, 9.69it/s]\u001b[A\n", - "Finding best initial lr: 27%|███████████████ | 27/100 [00:02<00:08, 9.12it/s]\u001b[A\n", - "Finding best initial lr: 29%|████████████████▏ | 29/100 [00:02<00:07, 9.77it/s]\u001b[A\n", - "Finding best initial lr: 31%|█████████████████▎ | 31/100 [00:03<00:06, 9.95it/s]\u001b[A\n", - "Finding best initial lr: 33%|██████████████████▍ | 33/100 [00:03<00:06, 10.29it/s]\u001b[A\n", - "Finding best initial lr: 35%|███████████████████▌ | 35/100 [00:03<00:05, 11.67it/s]\u001b[A\n", - "Finding best initial lr: 37%|████████████████████▋ | 37/100 [00:03<00:06, 10.19it/s]\u001b[A\n", - "Finding best initial lr: 39%|█████████████████████▊ | 39/100 [00:03<00:05, 10.59it/s]\u001b[A\n", - "Finding best initial lr: 41%|██████████████████████▉ | 41/100 [00:04<00:05, 10.48it/s]\u001b[A\n", - "Finding best initial lr: 43%|████████████████████████ | 43/100 [00:04<00:05, 10.71it/s]\u001b[A\n", - "Finding best initial lr: 45%|█████████████████████████▏ | 45/100 [00:04<00:04, 12.02it/s]\u001b[A\n", - "Finding best initial lr: 47%|██████████████████████████▎ | 47/100 [00:04<00:05, 10.52it/s]\u001b[A\n", - "Finding best initial lr: 49%|███████████████████████████▍ | 49/100 [00:04<00:04, 10.89it/s]\u001b[A\n", - "Finding best initial lr: 51%|████████████████████████████▌ | 51/100 [00:04<00:04, 10.69it/s]\u001b[A\n", - "Finding best initial lr: 53%|█████████████████████████████▋ | 53/100 [00:05<00:04, 10.89it/s]\u001b[A\n", - "Finding best initial lr: 55%|██████████████████████████████▊ | 55/100 [00:05<00:03, 12.14it/s]\u001b[A\n", - "Finding best initial lr: 57%|███████████████████████████████▉ | 57/100 [00:05<00:04, 10.64it/s]\u001b[A\n", - "Finding best initial lr: 59%|█████████████████████████████████ | 59/100 [00:05<00:03, 11.05it/s]\u001b[A\n", - "Finding best initial lr: 61%|██████████████████████████████████▏ | 61/100 [00:05<00:03, 10.86it/s]\u001b[A\n", - "Finding best initial lr: 63%|███████████████████████████████████▎ | 63/100 [00:05<00:03, 10.94it/s]\u001b[A\n", - "Finding best initial lr: 65%|████████████████████████████████████▍ | 65/100 [00:06<00:02, 12.23it/s]\u001b[A\n", - "Finding best initial lr: 67%|█████████████████████████████████████▌ | 67/100 [00:06<00:03, 10.58it/s]\u001b[A\n", - "Finding best initial lr: 69%|██████████████████████████████████████▋ | 69/100 [00:06<00:02, 10.97it/s]\u001b[A\n", - "Finding best initial lr: 71%|███████████████████████████████████████▊ | 71/100 [00:06<00:02, 10.72it/s]\u001b[A\n", - "Finding best initial lr: 73%|████████████████████████████████████████▉ | 73/100 [00:06<00:02, 10.78it/s]\u001b[A\n", - "Finding best initial lr: 75%|██████████████████████████████████████████ | 75/100 [00:07<00:02, 12.08it/s]\u001b[A\n", - "Finding best initial lr: 77%|███████████████████████████████████████████ | 77/100 [00:07<00:02, 10.39it/s]\u001b[A\n", - "Finding best initial lr: 79%|████████████████████████████████████████████▏ | 79/100 [00:07<00:01, 10.83it/s]\u001b[A\n", - "Finding best initial lr: 81%|█████████████████████████████████████████████▎ | 81/100 [00:07<00:01, 10.68it/s]\u001b[A\n", - "Finding best initial lr: 84%|███████████████████████████████████████████████ | 84/100 [00:07<00:01, 10.64it/s]\u001b[A\n", - "LR finder stopped early due to diverging loss.\n", + " | Name | Type | Params\n", + "---------------------------------------------------------------\n", + "0 | embedding_layers | ModuleList | 45 \n", + "1 | normalizing_batch_norm | BatchNorm1d | 34 \n", + "2 | backbone | FeedForwardBackbone | 19.0 M\n", + "3 | output_layer | Linear | 1.0 K \n", + "4 | loss | CrossEntropyLoss | 0 \n", + "/home/fonnesbeck/anaconda3/envs/pitch_effect/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " warnings.warn(*args, **kwargs)\n", + "/home/fonnesbeck/anaconda3/envs/pitch_effect/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " warnings.warn(*args, **kwargs)\n", + "Finding best initial lr: 88%|████████▊ | 88/100 [04:20<00:35, 2.96s/it]\n", "LR finder stopped early due to diverging loss.\n", - "Learning rate set to 0.0003019951720402019\n", - "Learning rate set to 0.0003019951720402019\n", - "\n", - " | Name | Type | Params\n", - "------------------------------------------------------------\n", - "0 | embedding_layers | ModuleList | 60 \n", - "1 | normalizing_batch_norm | BatchNorm1d | 32 \n", - "2 | linear_layers | Sequential | 19.0 M\n", - "3 | loss | CrossEntropyLoss | 0 \n", + "Learning rate set to 0.0005248074602497723\n", "\n", - " | Name | Type | Params\n", - "------------------------------------------------------------\n", - "0 | embedding_layers | ModuleList | 60 \n", - "1 | normalizing_batch_norm | BatchNorm1d | 32 \n", - "2 | linear_layers | Sequential | 19.0 M\n", - "3 | loss | CrossEntropyLoss | 0 \n" + " | Name | Type | Params\n", + "---------------------------------------------------------------\n", + "0 | embedding_layers | ModuleList | 45 \n", + "1 | normalizing_batch_norm | BatchNorm1d | 34 \n", + "2 | backbone | FeedForwardBackbone | 19.0 M\n", + "3 | output_layer | Linear | 1.0 K \n", + "4 | loss | CrossEntropyLoss | 0 \n" ] }, { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ - "Epoch 1: 71%|▋| 5/7 [00:00<00:00, 4.60it/s, loss=11.006, train_loss=0.728, valid_loss=19.2, valid_accuracy=0.485, tra\n", - "Epoch 1: 100%|█| 7/7 [00:03<00:00, 1.15it/s, loss=11.006, train_loss=0.728, valid_loss=0.712, valid_accuracy=0.673, tr\n", - "Epoch 2: 71%|▋| 5/7 [00:00<00:00, 4.70it/s, loss=11.017, train_loss=0.65, valid_loss=0.712, valid_accuracy=0.673, tra\n", - "Epoch 2: 100%|█| 7/7 [00:03<00:00, 1.04it/s, loss=11.017, train_loss=0.65, valid_loss=0.667, valid_accuracy=0.518, tra\n", - "Epoch 3: 71%|▋| 5/7 [00:00<00:00, 4.62it/s, loss=10.940, train_loss=0.644, valid_loss=0.667, valid_accuracy=0.518, tr\n", - "Epoch 3: 100%|█| 7/7 [00:04<00:00, 1.10s/it, loss=10.940, train_loss=0.644, valid_loss=0.633, valid_accuracy=0.688, tr\n", - "Epoch 4: 71%|▋| 5/7 [00:00<00:00, 4.74it/s, loss=1.421, train_loss=0.6, valid_loss=0.633, valid_accuracy=0.688, train\n", - "Epoch 4: 100%|█| 7/7 [00:04<00:00, 1.08s/it, loss=1.421, train_loss=0.6, valid_loss=0.586, valid_accuracy=0.739, train\n", - "Epoch 5: 71%|▋| 5/7 [00:00<00:00, 4.58it/s, loss=0.654, train_loss=0.61, valid_loss=0.586, valid_accuracy=0.739, trai\n", - "Epoch 5: 100%|█| 7/7 [00:04<00:00, 1.01s/it, loss=0.654, train_loss=0.61, valid_loss=0.558, valid_accuracy=0.734, trai\n", - "Epoch 6: 71%|▋| 5/7 [00:00<00:00, 4.63it/s, loss=0.631, train_loss=0.62, valid_loss=0.558, valid_accuracy=0.734, trai\n", - "Epoch 6: 100%|█| 7/7 [00:00<00:00, 8.23it/s, loss=0.631, train_loss=0.62, valid_loss=0.568, valid_accuracy=0.736, trai\n", - "Epoch 7: 71%|▋| 5/7 [00:00<00:00, 4.87it/s, loss=0.618, train_loss=0.613, valid_loss=0.568, valid_accuracy=0.736, tra\n", - "Epoch 7: 100%|█| 7/7 [00:03<00:00, 1.06it/s, loss=0.618, train_loss=0.613, valid_loss=0.547, valid_accuracy=0.743, tra\n", - "Epoch 8: 71%|▋| 5/7 [00:00<00:00, 4.74it/s, loss=0.607, train_loss=0.624, valid_loss=0.547, valid_accuracy=0.743, tra\n", - "Epoch 14: 100%|█| 7/7 [01:57<00:00, 39.15s/it, loss=0.491, train_loss=0.541, valid_loss=0.329, valid_auroc=0.151, valid\n", - "\n", - "Epoch 8: 100%|█| 7/7 [00:05<00:00, 1.27s/it, loss=0.607, train_loss=0.624, valid_loss=0.53, valid_accuracy=0.748, trai\u001b[A\n", - "Epoch 9: 86%|▊| 6/7 [00:00<00:00, 6.90it/s, loss=0.600, train_loss=0.574, valid_loss=0.53, valid_accuracy=0.748, trai\u001b[A\n", - "Epoch 9: 100%|█| 7/7 [00:04<00:00, 1.05s/it, loss=0.600, train_loss=0.574, valid_loss=0.516, valid_accuracy=0.756, tra\n", - "Epoch 10: 86%|▊| 6/7 [00:00<00:00, 6.92it/s, loss=0.596, train_loss=0.658, valid_loss=0.516, valid_accuracy=0.756, tr\n", - "Epoch 10: 100%|█| 7/7 [00:03<00:00, 1.14it/s, loss=0.596, train_loss=0.658, valid_loss=0.51, valid_accuracy=0.753, tra\n", - "Epoch 11: 86%|▊| 6/7 [00:00<00:00, 7.02it/s, loss=0.590, train_loss=0.572, valid_loss=0.51, valid_accuracy=0.753, tra\n", - "Epoch 11: 100%|█| 7/7 [00:00<00:00, 8.37it/s, loss=0.590, train_loss=0.572, valid_loss=0.528, valid_accuracy=0.644, tr\n", - "Epoch 12: 86%|▊| 6/7 [00:00<00:00, 7.18it/s, loss=0.588, train_loss=0.581, valid_loss=0.528, valid_accuracy=0.644, tr\n", - "Epoch 12: 100%|█| 7/7 [00:03<00:00, 1.20it/s, loss=0.588, train_loss=0.581, valid_loss=0.502, valid_accuracy=0.695, tr\n", - "Epoch 13: 86%|▊| 6/7 [00:00<00:00, 7.08it/s, loss=0.582, train_loss=0.54, valid_loss=0.502, valid_accuracy=0.695, tra\n", - "Epoch 13: 100%|█| 7/7 [00:04<00:00, 1.12s/it, loss=0.582, train_loss=0.54, valid_loss=0.492, valid_accuracy=0.643, tra\n", - "Epoch 14: 86%|▊| 6/7 [00:00<00:00, 6.91it/s, loss=0.573, train_loss=0.556, valid_loss=0.492, valid_accuracy=0.643, tr\n", - "Epoch 14: 100%|█| 7/7 [00:03<00:00, 1.27it/s, loss=0.573, train_loss=0.556, valid_loss=0.487, valid_accuracy=0.609, tr\n", - "Epoch 15: 86%|▊| 6/7 [00:00<00:00, 7.07it/s, loss=0.567, train_loss=0.565, valid_loss=0.487, valid_accuracy=0.609, tr\n", - "Epoch 15: 100%|█| 7/7 [00:00<00:00, 8.40it/s, loss=0.567, train_loss=0.565, valid_loss=0.489, valid_accuracy=0.738, tr\n", - "Epoch 16: 86%|▊| 6/7 [00:00<00:00, 7.09it/s, loss=0.560, train_loss=0.542, valid_loss=0.489, valid_accuracy=0.738, tr\n", - "Epoch 16: 100%|█| 7/7 [00:03<00:00, 1.13it/s, loss=0.560, train_loss=0.542, valid_loss=0.472, valid_accuracy=0.782, tr\n", - "Epoch 17: 86%|▊| 6/7 [00:00<00:00, 6.88it/s, loss=0.556, train_loss=0.545, valid_loss=0.472, valid_accuracy=0.782, tr\n", - "Epoch 17: 100%|█| 7/7 [00:04<00:00, 1.18s/it, loss=0.556, train_loss=0.545, valid_loss=0.472, valid_accuracy=0.786, tr\n", - "Epoch 18: 86%|▊| 6/7 [00:00<00:00, 7.03it/s, loss=0.554, train_loss=0.539, valid_loss=0.472, valid_accuracy=0.786, tr\n", - "Epoch 18: 100%|█| 7/7 [00:03<00:00, 1.19it/s, loss=0.554, train_loss=0.539, valid_loss=0.46, valid_accuracy=0.798, tra\n", - "Epoch 19: 86%|▊| 6/7 [00:00<00:00, 6.88it/s, loss=0.551, train_loss=0.547, valid_loss=0.46, valid_accuracy=0.798, tra\n", - "Epoch 19: 100%|█| 7/7 [00:03<00:00, 1.21it/s, loss=0.551, train_loss=0.547, valid_loss=0.453, valid_accuracy=0.764, tr\n", - "Epoch 20: 86%|▊| 6/7 [00:00<00:00, 6.87it/s, loss=0.547, train_loss=0.558, valid_loss=0.453, valid_accuracy=0.764, tr\n", - "Epoch 20: 100%|█| 7/7 [00:00<00:00, 8.20it/s, loss=0.547, train_loss=0.558, valid_loss=0.456, valid_accuracy=0.797, tr\n", - "Epoch 21: 86%|▊| 6/7 [00:00<00:00, 7.24it/s, loss=0.543, train_loss=0.521, valid_loss=0.456, valid_accuracy=0.797, tr\n", - "Epoch 21: 100%|█| 7/7 [00:04<00:00, 1.03s/it, loss=0.543, train_loss=0.521, valid_loss=0.43, valid_accuracy=0.728, tra\n", - "Epoch 22: 86%|▊| 6/7 [00:00<00:00, 7.01it/s, loss=0.541, train_loss=0.552, valid_loss=0.43, valid_accuracy=0.728, tra\n", - "Epoch 22: 100%|█| 7/7 [00:00<00:00, 8.34it/s, loss=0.541, train_loss=0.552, valid_loss=0.44, valid_accuracy=0.629, tra\n", - "Epoch 23: 86%|▊| 6/7 [00:00<00:00, 7.11it/s, loss=0.537, train_loss=0.49, valid_loss=0.44, valid_accuracy=0.629, trai\n", - "Epoch 23: 100%|█| 7/7 [00:00<00:00, 8.48it/s, loss=0.537, train_loss=0.49, valid_loss=0.432, valid_accuracy=0.596, tra\n", - "Epoch 24: 86%|▊| 6/7 [00:00<00:00, 7.14it/s, loss=0.537, train_loss=0.53, valid_loss=0.432, valid_accuracy=0.596, tra\n", - "Epoch 24: 100%|█| 7/7 [00:04<00:00, 1.01s/it, loss=0.537, train_loss=0.53, valid_loss=0.425, valid_accuracy=0.599, tra\n", - "Epoch 25: 86%|▊| 6/7 [00:00<00:00, 6.90it/s, loss=0.534, train_loss=0.533, valid_loss=0.425, valid_accuracy=0.599, tr\n", - "Epoch 25: 100%|█| 7/7 [00:00<00:00, 8.24it/s, loss=0.534, train_loss=0.533, valid_loss=0.433, valid_accuracy=0.512, tr\n", - "Epoch 26: 86%|▊| 6/7 [00:00<00:00, 7.16it/s, loss=0.533, train_loss=0.495, valid_loss=0.433, valid_accuracy=0.512, tr\n", - "Epoch 26: 100%|█| 7/7 [00:00<00:00, 8.50it/s, loss=0.533, train_loss=0.495, valid_loss=0.426, valid_accuracy=0.527, tr\n", - "Epoch 27: 86%|▊| 6/7 [00:00<00:00, 7.06it/s, loss=0.530, train_loss=0.504, valid_loss=0.426, valid_accuracy=0.527, tr\n", - "Epoch 27: 100%|█| 7/7 [00:04<00:00, 1.11s/it, loss=0.530, train_loss=0.504, valid_loss=0.409, valid_accuracy=0.575, tr\n", - "Epoch 28: 86%|▊| 6/7 [00:00<00:00, 7.00it/s, loss=0.525, train_loss=0.527, valid_loss=0.409, valid_accuracy=0.575, tr\n", - "Epoch 28: 100%|█| 7/7 [00:04<00:00, 1.02s/it, loss=0.525, train_loss=0.527, valid_loss=0.405, valid_accuracy=0.707, tr\n", - "Epoch 29: 86%|▊| 6/7 [00:00<00:00, 6.91it/s, loss=0.523, train_loss=0.524, valid_loss=0.405, valid_accuracy=0.707, tr\n", - "Epoch 29: 100%|█| 7/7 [00:00<00:00, 8.22it/s, loss=0.523, train_loss=0.524, valid_loss=0.413, valid_accuracy=0.753, tr\n", - "Epoch 30: 86%|▊| 6/7 [00:00<00:00, 6.96it/s, loss=0.523, train_loss=0.526, valid_loss=0.413, valid_accuracy=0.753, tr\n", - "Epoch 30: 100%|█| 7/7 [00:00<00:00, 8.32it/s, loss=0.523, train_loss=0.526, valid_loss=0.42, valid_accuracy=0.716, tra\n", - "Epoch 31: 86%|▊| 6/7 [00:00<00:00, 7.03it/s, loss=0.522, train_loss=0.498, valid_loss=0.42, valid_accuracy=0.716, tra\n", - "Epoch 31: 100%|█| 7/7 [00:04<00:00, 1.07s/it, loss=0.522, train_loss=0.498, valid_loss=0.4, valid_accuracy=0.686, trai\n", - "Epoch 32: 86%|▊| 6/7 [00:00<00:00, 6.91it/s, loss=0.523, train_loss=0.548, valid_loss=0.4, valid_accuracy=0.686, trai\n", - "Epoch 32: 100%|█| 7/7 [00:00<00:00, 8.28it/s, loss=0.523, train_loss=0.548, valid_loss=0.409, valid_accuracy=0.722, tr\n", - "Epoch 33: 86%|▊| 6/7 [00:00<00:00, 7.04it/s, loss=0.524, train_loss=0.554, valid_loss=0.409, valid_accuracy=0.722, tr\n", - "Epoch 33: 100%|█| 7/7 [00:04<00:00, 1.22s/it, loss=0.524, train_loss=0.554, valid_loss=0.399, valid_accuracy=0.682, tr\n", - "Epoch 34: 86%|▊| 6/7 [00:00<00:00, 6.86it/s, loss=0.520, train_loss=0.518, valid_loss=0.399, valid_accuracy=0.682, tr\n", - "Epoch 34: 100%|█| 7/7 [00:00<00:00, 8.21it/s, loss=0.520, train_loss=0.518, valid_loss=0.412, valid_accuracy=0.596, tr\n", - "Epoch 35: 86%|▊| 6/7 [00:00<00:00, 7.01it/s, loss=0.519, train_loss=0.495, valid_loss=0.412, valid_accuracy=0.596, tr\n", - "Epoch 35: 100%|█| 7/7 [00:00<00:00, 8.36it/s, loss=0.519, train_loss=0.495, valid_loss=0.4, valid_accuracy=0.577, trai\n", - "Epoch 36: 86%|▊| 6/7 [00:00<00:00, 7.01it/s, loss=0.515, train_loss=0.516, valid_loss=0.4, valid_accuracy=0.577, trai\n", - "Epoch 36: 100%|█| 7/7 [00:04<00:00, 1.11s/it, loss=0.515, train_loss=0.516, valid_loss=0.389, valid_accuracy=0.585, tr\n", - "Epoch 37: 86%|▊| 6/7 [00:00<00:00, 6.84it/s, loss=0.509, train_loss=0.486, valid_loss=0.389, valid_accuracy=0.585, tr\n", - "Epoch 37: 100%|█| 7/7 [00:00<00:00, 8.09it/s, loss=0.509, train_loss=0.486, valid_loss=0.399, valid_accuracy=0.617, tr\n", - "Epoch 38: 86%|▊| 6/7 [00:00<00:00, 7.21it/s, loss=0.507, train_loss=0.491, valid_loss=0.399, valid_accuracy=0.617, tr\n", - "Epoch 38: 100%|█| 7/7 [00:00<00:00, 8.57it/s, loss=0.507, train_loss=0.491, valid_loss=0.398, valid_accuracy=0.614, tr\n", - "Epoch 39: 86%|▊| 6/7 [00:00<00:00, 7.08it/s, loss=0.509, train_loss=0.582, valid_loss=0.398, valid_accuracy=0.614, tr\n", - "Epoch 39: 100%|█| 7/7 [00:00<00:00, 8.44it/s, loss=0.509, train_loss=0.582, valid_loss=0.4, valid_accuracy=0.524, trai\n", - "Epoch 39: 100%|█| 7/7 [00:11<00:00, 2.89s/it, loss=0.509, train_loss=0.582, valid_loss=0.4, valid_accuracy=0.524, trai" + "Epoch 32: 100%|██████████| 7/7 [00:35<00:00, 7.09s/it, loss=0.552, train_loss=0.59, valid_loss=0.423, valid_accuracy=0.831, train_accuracy=0.698]" ] } ], - "source": [ - "tabular_model.fit(train=train, test=test)" - ] + "metadata": { + "Collapsed": "false", + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "## Evaluating the Model\n", "To evaluate the model on new data on the same metrics/loss that was used during training, we can use the `evaluate` method" - ] + ], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", - "execution_count": 16, - "metadata": { - "Collapsed": "false" - }, + "execution_count": 13, + "source": [ + "result = tabular_model.evaluate(test)\n", + "print(result)" + ], "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stderr", + "text": [ + "/home/fonnesbeck/anaconda3/envs/pitch_effect/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " warnings.warn(*args, **kwargs)\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", "text": [ - "Testing: 100%|███████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 25.28it/s]--------------------------------------------------------------------------------\n", + "Testing: 100%|██████████| 3/3 [00:15<00:00, 4.14s/it]--------------------------------------------------------------------------------\n", "DATALOADER:0 TEST RESULTS\n", - "{'test_accuracy': tensor(0.5808, device='cuda:0'),\n", - " 'train_accuracy': tensor(0.5936, device='cuda:0'),\n", - " 'train_loss': tensor(0.5825, device='cuda:0'),\n", - " 'valid_accuracy': tensor(0.5244, device='cuda:0'),\n", - " 'valid_loss': tensor(0.4004, device='cuda:0')}\n", + "{'test_accuracy': tensor(0.8416),\n", + " 'train_accuracy': tensor(0.7048),\n", + " 'train_loss': tensor(0.5903),\n", + " 'valid_accuracy': tensor(0.8311),\n", + " 'valid_loss': tensor(0.4232)}\n", "--------------------------------------------------------------------------------\n", - "Testing: 100%|███████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 21.18it/s]\n", - "[{'train_loss': 0.582484245300293, 'valid_loss': 0.40044283866882324, 'valid_accuracy': 0.5244444608688354, 'train_accuracy': 0.5935624241828918, 'test_accuracy': 0.5807999968528748}]\n" + "Testing: 100%|██████████| 3/3 [00:15<00:00, 5.22s/it]\n", + "[{'train_loss': 0.5903493762016296, 'valid_loss': 0.4231547713279724, 'valid_accuracy': 0.8311111330986023, 'train_accuracy': 0.7048248052597046, 'test_accuracy': 0.8416000008583069}]\n" ] } ], - "source": [ - "result = tabular_model.evaluate(test)\n", - "print(result)" - ] + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "markdown", - "metadata": { - "Collapsed": "false" - }, "source": [ "To get the prediction as a dataframe, we can use the `predict` method. This will add predictions to the same dataframe that was passed in. For classification problems, we get both the probabilities and the final prediction taking 0.5 as the threshold" - ] + ], + "metadata": { + "Collapsed": "false" + } }, { "cell_type": "code", - "execution_count": 17, - "metadata": { - "Collapsed": "false" - }, + "execution_count": 14, + "source": [ + "pred_df = tabular_model.predict(test)\n", + "pred_df.head()" + ], "outputs": [ { + "output_type": "stream", + "name": "stderr", + "text": [ + "Generating Predictions...: 100%|██████████| 3/3 [00:14<00:00, 4.92s/it]\n" + ] + }, + { + "output_type": "execute_result", "data": { + "text/plain": [ + " num_col_0 cat_col_1 num_col_2 num_col_3 num_col_4 num_col_5 \\\n", + "6252 -2.790932 0.0 -2.010758 3.205420 -0.356361 -0.744417 \n", + "4684 -0.139585 0.0 -1.207160 2.690514 1.072764 -3.499028 \n", + "1731 0.001421 1.0 -0.279572 0.363639 0.852329 0.089246 \n", + "4742 0.086662 3.0 0.798527 0.916448 -1.085978 0.512223 \n", + "4521 0.982186 2.0 -0.117476 -0.168583 -0.088413 -0.206658 \n", + "\n", + " num_col_6 num_col_7 cat_col_8 num_col_9 ... num_col_14 num_col_15 \\\n", + "6252 0.427836 -1.492040 0.0 1.364186 ... -0.660336 -0.705788 \n", + "4684 1.561682 0.953991 2.0 1.243788 ... -2.726836 0.944248 \n", + "1731 0.084824 0.194984 0.0 2.668561 ... -0.508633 0.508788 \n", + "4742 -0.903704 1.538725 2.0 1.518521 ... 0.326685 1.343219 \n", + "4521 -1.233511 -0.137569 3.0 -1.678887 ... -0.282845 0.458761 \n", + "\n", + " num_col_16 num_col_17 num_col_18 num_col_19 target 0_probability \\\n", + "6252 0.229519 0.060878 -0.464394 2.879481 0 0.145202 \n", + "4684 0.821184 0.368647 -1.199147 0.126323 1 0.517947 \n", + "1731 -0.097083 -0.128070 -0.282642 -0.190155 0 0.830036 \n", + "4742 -1.147619 1.795053 0.857619 0.532915 1 0.469329 \n", + "4521 1.381926 -0.566849 -0.475947 -0.400418 1 0.269307 \n", + "\n", + " 1_probability prediction \n", + "6252 0.854798 1 \n", + "4684 0.482053 0 \n", + "1731 0.169964 0 \n", + "4742 0.530671 1 \n", + "4521 0.730693 1 \n", + "\n", + "[5 rows x 24 columns]" + ], "text/html": [ "