|
305 | 305 | "\n", |
306 | 306 | " # Save settings and hyperparameters to the log directory\n", |
307 | 307 | " # but skip the model parameters\n", |
308 | | - " self.save_hyperparameters(ignore=['model'])\n", |
| 308 | + " self.save_hyperparameters(ignore=[\"model\"])\n", |
309 | 309 | "\n", |
310 | 310 | " # Set up attributes for computing the accuracy\n", |
311 | 311 | " self.train_acc = torchmetrics.Accuracy()\n", |
312 | 312 | " self.valid_acc = torchmetrics.Accuracy()\n", |
313 | 313 | " self.test_acc = torchmetrics.Accuracy()\n", |
314 | | - " \n", |
315 | | - " # Defining the forward method is only necessary \n", |
| 314 | + "\n", |
| 315 | + " # Defining the forward method is only necessary\n", |
316 | 316 | " # if you want to use a Trainer's .predict() method (optional)\n", |
317 | 317 | " def forward(self, x):\n", |
318 | 318 | " return self.model(x)\n", |
319 | | - " \n", |
| 319 | + "\n", |
320 | 320 | " # A common forward step to compute the loss and labels\n", |
321 | 321 | " # this is used for training, validation, and testing below\n", |
322 | 322 | " def _shared_step(self, batch):\n", |
|
330 | 330 | " def training_step(self, batch, batch_idx):\n", |
331 | 331 | " loss, true_labels, predicted_labels = self._shared_step(batch)\n", |
332 | 332 | " self.log(\"train_loss\", loss)\n", |
333 | | - " \n", |
| 333 | + "\n", |
334 | 334 | " # To account for Dropout behavior during evaluation\n", |
335 | 335 | " self.model.eval()\n", |
336 | 336 | " with torch.no_grad():\n", |
|
344 | 344 | " loss, true_labels, predicted_labels = self._shared_step(batch)\n", |
345 | 345 | " self.log(\"valid_loss\", loss)\n", |
346 | 346 | " self.valid_acc(predicted_labels, true_labels)\n", |
347 | | - " self.log(\"valid_acc\", self.valid_acc,\n", |
348 | | - " on_epoch=True, on_step=False, prog_bar=True)\n", |
| 347 | + " self.log(\n", |
| 348 | + " \"valid_acc\", self.valid_acc, on_epoch=True, on_step=False, prog_bar=True\n", |
| 349 | + " )\n", |
349 | 350 | "\n", |
350 | 351 | " def test_step(self, batch, batch_idx):\n", |
351 | 352 | " loss, true_labels, predicted_labels = self._shared_step(batch)\n", |
|
376 | 377 | "torch.manual_seed(RANDOM_SEED)\n", |
377 | 378 | "pytorch_model = PyTorchCNN(num_classes=NUM_CLASSES)\n", |
378 | 379 | "\n", |
379 | | - "lightning_model = LightningModel(\n", |
380 | | - " model=pytorch_model, learning_rate=LEARNING_RATE)\n", |
| 380 | + "lightning_model = LightningModel(model=pytorch_model, learning_rate=LEARNING_RATE)\n", |
381 | 381 | "\n", |
382 | | - "callbacks = [ModelCheckpoint(\n", |
383 | | - " save_top_k=1, mode='max', monitor=\"valid_acc\"), # save top 1 model\n", |
384 | | - " pl.callbacks.progress.TQDMProgressBar(refresh_rate=50)\n", |
385 | | - "] \n", |
| 382 | + "callbacks = [\n", |
| 383 | + " ModelCheckpoint(save_top_k=1, mode=\"max\", monitor=\"valid_acc\"), # save top 1 model\n", |
| 384 | + " pl.callbacks.progress.TQDMProgressBar(refresh_rate=50),\n", |
| 385 | + "]\n", |
386 | 386 | "\n", |
387 | 387 | "logger = CSVLogger(save_dir=\"logs/\", name=\"my-model\")" |
388 | 388 | ] |
|
593 | 593 | "source": [ |
594 | 594 | "import time\n", |
595 | 595 | "\n", |
596 | | - "\n", |
597 | 596 | "trainer = pl.Trainer(\n", |
598 | 597 | " max_epochs=NUM_EPOCHS,\n", |
599 | 598 | " callbacks=callbacks,\n", |
600 | 599 | " accelerator=\"auto\", # Uses GPUs or TPUs if available\n", |
601 | 600 | " devices=\"auto\", # Uses all available GPUs/TPUs if applicable\n", |
602 | 601 | " logger=logger,\n", |
603 | | - " log_every_n_steps=100\n", |
| 602 | + " log_every_n_steps=100,\n", |
604 | 603 | ")\n", |
605 | 604 | "\n", |
606 | 605 | "start_time = time.time()\n", |
607 | 606 | "trainer.fit(\n", |
608 | | - " model=lightning_model,\n", |
609 | | - " train_dataloaders=train_loader,\n", |
610 | | - " val_dataloaders=valid_loader\n", |
| 607 | + " model=lightning_model, train_dataloaders=train_loader, val_dataloaders=valid_loader\n", |
611 | 608 | ")\n", |
612 | 609 | "\n", |
613 | | - "runtime = (time.time() - start_time)/60\n", |
| 610 | + "runtime = (time.time() - start_time) / 60\n", |
614 | 611 | "print(f\"Training took {runtime:.2f} min in total.\")" |
615 | 612 | ] |
616 | 613 | }, |
|
664 | 661 | "source": [ |
665 | 662 | "import pandas as pd\n", |
666 | 663 | "\n", |
667 | | - "\n", |
668 | 664 | "metrics = pd.read_csv(f\"{trainer.logger.log_dir}/metrics.csv\")\n", |
669 | 665 | "\n", |
670 | 666 | "aggreg_metrics = []\n", |
|
676 | 672 | "\n", |
677 | 673 | "df_metrics = pd.DataFrame(aggreg_metrics)\n", |
678 | 674 | "df_metrics[[\"train_loss\", \"valid_loss\"]].plot(\n", |
679 | | - " grid=True, legend=True, xlabel='Epoch', ylabel='Loss')\n", |
| 675 | + " grid=True, legend=True, xlabel=\"Epoch\", ylabel=\"Loss\"\n", |
| 676 | + ")\n", |
680 | 677 | "df_metrics[[\"train_acc\", \"valid_acc\"]].plot(\n", |
681 | | - " grid=True, legend=True, xlabel='Epoch', ylabel='ACC')" |
| 678 | + " grid=True, legend=True, xlabel=\"Epoch\", ylabel=\"ACC\"\n", |
| 679 | + ")" |
682 | 680 | ] |
683 | 681 | }, |
684 | 682 | { |
|
731 | 729 | } |
732 | 730 | ], |
733 | 731 | "source": [ |
734 | | - "trainer.test(model=lightning_model, dataloaders=test_loader, ckpt_path='best')" |
| 732 | + "trainer.test(model=lightning_model, dataloaders=test_loader, ckpt_path=\"best\")" |
735 | 733 | ] |
736 | 734 | }, |
737 | 735 | { |
|
782 | 780 | "name": "python", |
783 | 781 | "nbconvert_exporter": "python", |
784 | 782 | "pygments_lexer": "ipython3", |
785 | | - "version": "3.8.12" |
| 783 | + "version": "3.9.7" |
786 | 784 | }, |
787 | 785 | "toc": { |
788 | 786 | "nav_menu": {}, |
|
0 commit comments