Skip to content

Commit af77656

Browse files
committed
code formatting with black
1 parent eb06829 commit af77656

File tree

4 files changed

+748
-45
lines changed

4 files changed

+748
-45
lines changed

pytorch-lightning_ipynb/kfold/baseline-light-cnn-mnist.ipynb

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -305,18 +305,18 @@
305305
"\n",
306306
" # Save settings and hyperparameters to the log directory\n",
307307
" # but skip the model parameters\n",
308-
" self.save_hyperparameters(ignore=['model'])\n",
308+
" self.save_hyperparameters(ignore=[\"model\"])\n",
309309
"\n",
310310
" # Set up attributes for computing the accuracy\n",
311311
" self.train_acc = torchmetrics.Accuracy()\n",
312312
" self.valid_acc = torchmetrics.Accuracy()\n",
313313
" self.test_acc = torchmetrics.Accuracy()\n",
314-
" \n",
315-
" # Defining the forward method is only necessary \n",
314+
"\n",
315+
" # Defining the forward method is only necessary\n",
316316
" # if you want to use a Trainer's .predict() method (optional)\n",
317317
" def forward(self, x):\n",
318318
" return self.model(x)\n",
319-
" \n",
319+
"\n",
320320
" # A common forward step to compute the loss and labels\n",
321321
" # this is used for training, validation, and testing below\n",
322322
" def _shared_step(self, batch):\n",
@@ -330,7 +330,7 @@
330330
" def training_step(self, batch, batch_idx):\n",
331331
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
332332
" self.log(\"train_loss\", loss)\n",
333-
" \n",
333+
"\n",
334334
" # To account for Dropout behavior during evaluation\n",
335335
" self.model.eval()\n",
336336
" with torch.no_grad():\n",
@@ -344,8 +344,9 @@
344344
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
345345
" self.log(\"valid_loss\", loss)\n",
346346
" self.valid_acc(predicted_labels, true_labels)\n",
347-
" self.log(\"valid_acc\", self.valid_acc,\n",
348-
" on_epoch=True, on_step=False, prog_bar=True)\n",
347+
" self.log(\n",
348+
" \"valid_acc\", self.valid_acc, on_epoch=True, on_step=False, prog_bar=True\n",
349+
" )\n",
349350
"\n",
350351
" def test_step(self, batch, batch_idx):\n",
351352
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
@@ -376,13 +377,12 @@
376377
"torch.manual_seed(RANDOM_SEED)\n",
377378
"pytorch_model = PyTorchCNN(num_classes=NUM_CLASSES)\n",
378379
"\n",
379-
"lightning_model = LightningModel(\n",
380-
" model=pytorch_model, learning_rate=LEARNING_RATE)\n",
380+
"lightning_model = LightningModel(model=pytorch_model, learning_rate=LEARNING_RATE)\n",
381381
"\n",
382-
"callbacks = [ModelCheckpoint(\n",
383-
" save_top_k=1, mode='max', monitor=\"valid_acc\"), # save top 1 model\n",
384-
" pl.callbacks.progress.TQDMProgressBar(refresh_rate=50)\n",
385-
"] \n",
382+
"callbacks = [\n",
383+
" ModelCheckpoint(save_top_k=1, mode=\"max\", monitor=\"valid_acc\"), # save top 1 model\n",
384+
" pl.callbacks.progress.TQDMProgressBar(refresh_rate=50),\n",
385+
"]\n",
386386
"\n",
387387
"logger = CSVLogger(save_dir=\"logs/\", name=\"my-model\")"
388388
]
@@ -593,24 +593,21 @@
593593
"source": [
594594
"import time\n",
595595
"\n",
596-
"\n",
597596
"trainer = pl.Trainer(\n",
598597
" max_epochs=NUM_EPOCHS,\n",
599598
" callbacks=callbacks,\n",
600599
" accelerator=\"auto\", # Uses GPUs or TPUs if available\n",
601600
" devices=\"auto\", # Uses all available GPUs/TPUs if applicable\n",
602601
" logger=logger,\n",
603-
" log_every_n_steps=100\n",
602+
" log_every_n_steps=100,\n",
604603
")\n",
605604
"\n",
606605
"start_time = time.time()\n",
607606
"trainer.fit(\n",
608-
" model=lightning_model,\n",
609-
" train_dataloaders=train_loader,\n",
610-
" val_dataloaders=valid_loader\n",
607+
" model=lightning_model, train_dataloaders=train_loader, val_dataloaders=valid_loader\n",
611608
")\n",
612609
"\n",
613-
"runtime = (time.time() - start_time)/60\n",
610+
"runtime = (time.time() - start_time) / 60\n",
614611
"print(f\"Training took {runtime:.2f} min in total.\")"
615612
]
616613
},
@@ -664,7 +661,6 @@
664661
"source": [
665662
"import pandas as pd\n",
666663
"\n",
667-
"\n",
668664
"metrics = pd.read_csv(f\"{trainer.logger.log_dir}/metrics.csv\")\n",
669665
"\n",
670666
"aggreg_metrics = []\n",
@@ -676,9 +672,11 @@
676672
"\n",
677673
"df_metrics = pd.DataFrame(aggreg_metrics)\n",
678674
"df_metrics[[\"train_loss\", \"valid_loss\"]].plot(\n",
679-
" grid=True, legend=True, xlabel='Epoch', ylabel='Loss')\n",
675+
" grid=True, legend=True, xlabel=\"Epoch\", ylabel=\"Loss\"\n",
676+
")\n",
680677
"df_metrics[[\"train_acc\", \"valid_acc\"]].plot(\n",
681-
" grid=True, legend=True, xlabel='Epoch', ylabel='ACC')"
678+
" grid=True, legend=True, xlabel=\"Epoch\", ylabel=\"ACC\"\n",
679+
")"
682680
]
683681
},
684682
{
@@ -731,7 +729,7 @@
731729
}
732730
],
733731
"source": [
734-
"trainer.test(model=lightning_model, dataloaders=test_loader, ckpt_path='best')"
732+
"trainer.test(model=lightning_model, dataloaders=test_loader, ckpt_path=\"best\")"
735733
]
736734
},
737735
{
@@ -782,7 +780,7 @@
782780
"name": "python",
783781
"nbconvert_exporter": "python",
784782
"pygments_lexer": "ipython3",
785-
"version": "3.8.12"
783+
"version": "3.9.7"
786784
},
787785
"toc": {
788786
"nav_menu": {},

pytorch-lightning_ipynb/kfold/kfold-light-cnn-mnist.ipynb

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@
139139
"outputs": [],
140140
"source": [
141141
"import os\n",
142-
"from git import Repo\n",
143142
"\n",
143+
"from git import Repo\n",
144144
"\n",
145145
"if not os.path.exists(\"pl_cross\"):\n",
146146
" Repo.clone_from(\"https://github.com/SkafteNicki/pl_cross.git\", \"pl_cross\")"
@@ -377,18 +377,18 @@
377377
"\n",
378378
" # Save settings and hyperparameters to the log directory\n",
379379
" # but skip the model parameters\n",
380-
" self.save_hyperparameters(ignore=['model'])\n",
380+
" self.save_hyperparameters(ignore=[\"model\"])\n",
381381
"\n",
382382
" # Set up attributes for computing the accuracy\n",
383383
" self.train_acc = torchmetrics.Accuracy()\n",
384384
" self.valid_acc = torchmetrics.Accuracy()\n",
385385
" self.test_acc = torchmetrics.Accuracy()\n",
386-
" \n",
387-
" # Defining the forward method is only necessary \n",
386+
"\n",
387+
" # Defining the forward method is only necessary\n",
388388
" # if you want to use a Trainer's .predict() method (optional)\n",
389389
" def forward(self, x):\n",
390390
" return self.model(x)\n",
391-
" \n",
391+
"\n",
392392
" # A common forward step to compute the loss and labels\n",
393393
" # this is used for training, validation, and testing below\n",
394394
" def _shared_step(self, batch):\n",
@@ -402,7 +402,7 @@
402402
" def training_step(self, batch, batch_idx):\n",
403403
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
404404
" self.log(\"train_loss\", loss)\n",
405-
" \n",
405+
"\n",
406406
" # To account for Dropout behavior during evaluation\n",
407407
" self.model.eval()\n",
408408
" with torch.no_grad():\n",
@@ -416,8 +416,9 @@
416416
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
417417
" self.log(\"valid_loss\", loss)\n",
418418
" self.valid_acc(predicted_labels, true_labels)\n",
419-
" self.log(\"valid_acc\", self.valid_acc,\n",
420-
" on_epoch=True, on_step=False, prog_bar=True)\n",
419+
" self.log(\n",
420+
" \"valid_acc\", self.valid_acc, on_epoch=True, on_step=False, prog_bar=True\n",
421+
" )\n",
421422
"\n",
422423
" def test_step(self, batch, batch_idx):\n",
423424
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
@@ -449,28 +450,26 @@
449450
"metadata": {},
450451
"outputs": [],
451452
"source": [
452-
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
453-
"\n",
454453
"# OLD:\n",
455454
"# from pytorch_lightning.loggers import CSVLogger\n",
456455
"# NEW:\n",
457456
"from pl_cross.loggers import CSVLogger\n",
457+
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
458458
"\n",
459459
"torch.manual_seed(RANDOM_SEED)\n",
460460
"pytorch_model = PyTorchCNN(num_classes=NUM_CLASSES)\n",
461461
"\n",
462-
"lightning_model = LightningModel(\n",
463-
" model=pytorch_model, learning_rate=LEARNING_RATE)\n",
462+
"lightning_model = LightningModel(model=pytorch_model, learning_rate=LEARNING_RATE)\n",
464463
"\n",
465464
"\n",
466465
"# NEW: Kfold currently doesn't call model.validation_step()\n",
467466
"# so the valid_acc tracking does unfortunately not work, and\n",
468467
"# we have to change valid_acc to train_acc.\n",
469468
"\n",
470-
"callbacks = [ModelCheckpoint(\n",
471-
" save_top_k=1, mode='max', monitor=\"train_acc\"), # save top 1 model\n",
472-
" pl.callbacks.progress.TQDMProgressBar(refresh_rate=50)\n",
473-
"] \n",
469+
"callbacks = [\n",
470+
" ModelCheckpoint(save_top_k=1, mode=\"max\", monitor=\"train_acc\"), # save top 1 model\n",
471+
" pl.callbacks.progress.TQDMProgressBar(refresh_rate=50),\n",
472+
"]\n",
474473
"\n",
475474
"logger = CSVLogger(save_dir=\"logs/\", name=\"my-kfold-model\")"
476475
]
@@ -715,11 +714,10 @@
715714
"# NEW:\n",
716715
"from pl_cross import Trainer\n",
717716
"\n",
718-
"\n",
719717
"trainer = Trainer(\n",
720718
" num_folds=5,\n",
721719
" shuffle=True,\n",
722-
" #stratified=True,\n",
720+
" # stratified=True,\n",
723721
" max_epochs=NUM_EPOCHS,\n",
724722
" callbacks=callbacks,\n",
725723
" accelerator=\"auto\", # Uses GPUs or TPUs if available\n",
@@ -736,7 +734,7 @@
736734
" train_dataloader=train_loader,\n",
737735
")\n",
738736
"\n",
739-
"runtime = (time.time() - start_time)/60\n",
737+
"runtime = (time.time() - start_time) / 60\n",
740738
"print(f\"Training took {runtime:.2f} min in total.\")"
741739
]
742740
},
@@ -852,7 +850,7 @@
852850
"name": "python",
853851
"nbconvert_exporter": "python",
854852
"pygments_lexer": "ipython3",
855-
"version": "3.8.12"
853+
"version": "3.9.7"
856854
},
857855
"toc": {
858856
"nav_menu": {},

pytorch_ipynb/kfold/baseline-cnn-mnist.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@
494494
"name": "python",
495495
"nbconvert_exporter": "python",
496496
"pygments_lexer": "ipython3",
497-
"version": "3.8.12"
497+
"version": "3.9.7"
498498
},
499499
"toc": {
500500
"nav_menu": {},

0 commit comments

Comments
 (0)