In [None]:
from sympy.parsing.sympy_parser import null

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-16T14:41:01.216326Z",
     "start_time": "2025-02-16T14:41:01.198735Z"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from probly.representation.dropout import Dropout\n",
    "import sklearn.datasets as sd\n",
    "from torch.utils.data import DataLoader, TensorDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-16T14:41:01.355324Z",
     "start_time": "2025-02-16T14:41:01.336726Z"
    }
   },
   "outputs": [],
   "source": [
    "# make a simple 2d dataset\n",
    "X, y = sd.make_moons(n_samples=200, noise=0.1)\n",
    "X_train, y_train, X_test, y_test = X[:100], y[:100], X[100:], y[100:]\n",
    "\n",
    "# convert to torch tensors and make dataloader\n",
    "train = TensorDataset(\n",
    "    torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)\n",
    ")\n",
    "train_loader = DataLoader(train, batch_size=32, shuffle=True)\n",
    "test = TensorDataset(\n",
    "    torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long)\n",
    ")\n",
    "test_loader = DataLoader(test, batch_size=32, shuffle=False)\n",
    "\n",
    "\n",
    "# small fully connected neural network\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(2, 50)\n",
    "        self.fc2 = nn.Linear(50, 50)\n",
    "        self.fc3 = nn.Linear(50, 2)\n",
    "        self.act = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.act(self.fc1(x))\n",
    "        x = self.act(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "net = Net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-16T14:41:01.900972Z",
     "start_time": "2025-02-16T14:41:01.805901Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8\n"
     ]
    }
   ],
   "source": [
    "model = Dropout(net, 0.5)\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "for _ in range(10):\n",
    "    model.train()\n",
    "    for x, y in train_loader:\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = model(x, 1)\n",
    "        y_pred = torch.squeeze(y_pred)\n",
    "        loss = criterion(y_pred, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "# compute accuracy on test set\n",
    "correct = 0\n",
    "total = 0\n",
    "for x, y in test_loader:\n",
    "    model.eval()\n",
    "    y_pred = model(x, 100)\n",
    "    y_pred = model(x, 100).mean(axis=1)\n",
    "    y_pred = y_pred.argmax(axis=1)\n",
    "    correct += (y_pred == y).sum().item()\n",
    "    total += y.size(0)\n",
    "print(f\"Accuracy: {correct / total}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
