In [None]:
{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-20T16:41:30.336439Z",
     "start_time": "2025-02-20T16:41:30.315293Z"
    }
   },
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from src.probly.representation.bayesian import Bayesian\n",
    "import sklearn.datasets as sd\n",
    "from torch.utils.data import DataLoader, TensorDataset"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "execution_count": 28
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-20T16:41:30.449816Z",
     "start_time": "2025-02-20T16:41:30.432116Z"
    }
   },
   "source": [
    "# make a simple 2d dataset\n",
    "X, y = sd.make_moons(n_samples=200, noise=0.1)\n",
    "X_train, y_train, X_test, y_test = X[:100], y[:100], X[100:], y[100:]\n",
    "\n",
    "# convert to torch tensors and make dataloader\n",
    "train = TensorDataset(\n",
    "    torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)\n",
    ")\n",
    "train_loader = DataLoader(train, batch_size=32, shuffle=True)\n",
    "test = TensorDataset(\n",
    "    torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long)\n",
    ")\n",
    "test_loader = DataLoader(test, batch_size=32, shuffle=False)\n",
    "\n",
    "\n",
    "# small fully connected neural network\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(2, 50)\n",
    "        self.fc2 = nn.Linear(50, 50)\n",
    "        self.fc3 = nn.Linear(50, 2)\n",
    "        self.act = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.act(self.fc1(x))\n",
    "        x = self.act(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "net = Net()"
   ],
   "outputs": [],
   "execution_count": 29
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-02-20T16:41:30.845757Z",
     "start_time": "2025-02-20T16:41:30.513498Z"
    }
   },
   "source": [
    "model = Bayesian(net)\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "for _ in range(100):\n",
    "    model.train()\n",
    "    print(model.model[0].weight_mu[0,0].item())\n",
    "    for x, y in train_loader:\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = model(x, 1)\n",
    "        y_pred = torch.squeeze(y_pred)\n",
    "        loss = criterion(y_pred, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "# compute accuracy on test set\n",
    "correct = 0\n",
    "total = 0\n",
    "for x, y in test_loader:\n",
    "    model.eval()\n",
    "    y_pred = model(x, 100).mean(axis=1)\n",
    "    y_pred = y_pred.argmax(axis=1)\n",
    "    correct += (y_pred == y).sum().item()\n",
    "    total += y.size(0)\n",
    "print(f\"Accuracy: {correct / total}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2622973918914795\n",
      "0.2911498248577118\n",
      "0.3001435399055481\n",
      "0.3061884641647339\n",
      "0.3100169897079468\n",
      "0.31634020805358887\n",
      "0.3191973865032196\n",
      "0.3026317358016968\n",
      "0.2873613238334656\n",
      "0.2808489501476288\n",
      "0.2790534198284149\n",
      "0.2750765383243561\n",
      "0.27342697978019714\n",
      "0.27872350811958313\n",
      "0.286266565322876\n",
      "0.3014010190963745\n",
      "0.31093505024909973\n",
      "0.3174271881580353\n",
      "0.32149747014045715\n",
      "0.31916511058807373\n",
      "0.3126566410064697\n",
      "0.3080095052719116\n",
      "0.30203521251678467\n",
      "0.29645809531211853\n",
      "0.2943454086780548\n",
      "0.2931782305240631\n",
      "0.2923352122306824\n",
      "0.2823910117149353\n",
      "0.2686168849468231\n",
      "0.25816547870635986\n",
      "0.25165632367134094\n",
      "0.24588578939437866\n",
      "0.24139894545078278\n",
      "0.2401825487613678\n",
      "0.2404068261384964\n",
      "0.2406255155801773\n",
      "0.24242015182971954\n",
      "0.2432687133550644\n",
      "0.24184399843215942\n",
      "0.240226149559021\n",
      "0.24004517495632172\n",
      "0.24223890900611877\n",
      "0.24642835557460785\n",
      "0.2576693892478943\n",
      "0.26653531193733215\n",
      "0.27401480078697205\n",
      "0.2773234248161316\n",
      "0.2789287269115448\n",
      "0.2788134515285492\n",
      "0.27806270122528076\n",
      "0.2828178107738495\n",
      "0.2858785092830658\n",
      "0.28907763957977295\n",
      "0.2911354899406433\n",
      "0.29233598709106445\n",
      "0.2934093475341797\n",
      "0.29427510499954224\n",
      "0.29490557312965393\n",
      "0.29748931527137756\n",
      "0.2981387972831726\n",
      "0.29756906628608704\n",
      "0.2953011095523834\n",
      "0.2923262119293213\n",
      "0.2897656261920929\n",
      "0.28807300329208374\n",
      "0.28693994879722595\n",
      "0.2861272990703583\n",
      "0.2843549847602844\n",
      "0.2839517593383789\n",
      "0.2821124196052551\n",
      "0.27742549777030945\n",
      "0.27454131841659546\n",
      "0.2697497010231018\n",
      "0.26571378111839294\n",
      "0.2625587582588196\n",
      "0.2605189383029938\n",
      "0.26018133759498596\n",
      "0.25968557596206665\n",
      "0.25877895951271057\n",
      "0.25743839144706726\n",
      "0.2568942606449127\n",
      "0.2572053074836731\n",
      "0.25803738832473755\n",
      "0.25931254029273987\n",
      "0.2596209943294525\n",
      "0.25979554653167725\n",
      "0.26024261116981506\n",
      "0.2594029903411865\n",
      "0.25891539454460144\n",
      "0.2594295144081116\n",
      "0.2602952718734741\n",
      "0.26087990403175354\n",
      "0.26103389263153076\n",
      "0.2609519958496094\n",
      "0.25979286432266235\n",
      "0.2589784562587738\n",
      "0.259869247674942\n",
      "0.2609672248363495\n",
      "0.26124122738838196\n",
      "0.2604442536830902\n",
      "Accuracy: 0.76\n"
     ]
    }
   ],
   "execution_count": 30
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
