Skip to content

Commit

Permalink
[BugFix] Resolve bugs in computer_vision_fine_tuning.py example (#5985)
Browse files Browse the repository at this point in the history
* update the script to use DataModule

* add message at for the frozen parameters

* add message about trainable parameters

* resolve flake8
  • Loading branch information
tchaton committed Feb 16, 2021
1 parent 6e79bef commit 141316f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 90 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,5 @@ wandb
cifar-10-batches-py
*.pt
# ctags
tags
tags
data
193 changes: 104 additions & 89 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
Note:
See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
"""

import argparse
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Union

import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.optim.lr_scheduler import MultiStepLR
Expand All @@ -55,52 +55,114 @@
import pytorch_lightning as pl
from pl_examples import cli_lightning_logo
from pytorch_lightning import _logger as log
from pytorch_lightning import LightningDataModule
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
from pytorch_lightning.utilities import rank_zero_info

DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"

# --- Finetuning Callback ---


class MilestonesFinetuningCallback(BaseFinetuning):
class MilestonesFinetuning(BaseFinetuning):

def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True):
def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False):
self.milestones = milestones
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule):
self.freeze(module=pl_module.feature_extractor, train_bn=self.train_bn)
self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn)

def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
if epoch == self.milestones[0]:
# unfreeze 5 last layers
self.unfreeze_and_add_param_group(
module=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn
modules=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn
)

elif epoch == self.milestones[1]:
# unfreeze remaing layers
self.unfreeze_and_add_param_group(
module=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn
modules=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn
)


class CatDogImageDataModule(LightningDataModule):

def __init__(
self,
dl_path: Union[str, Path],
num_workers: int = 0,
batch_size: int = 8,
):
super().__init__()

self._dl_path = dl_path
self._num_workers = num_workers
self._batch_size = batch_size

def prepare_data(self):
"""Download images and prepare images datasets."""
download_and_extract_archive(url=DATA_URL, download_root=self._dl_path, remove_finished=True)

@property
def data_path(self):
return Path(self._dl_path).joinpath("cats_and_dogs_filtered")

@property
def normalize_transform(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

@property
def train_transform(self):
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), self.normalize_transform
])

@property
def valid_transform(self):
return transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), self.normalize_transform])

def create_dataset(self, root, transform):
return ImageFolder(root=root, transform=transform)

def __dataloader(self, train: bool):
"""Train/validation loaders."""
if train:
dataset = self.create_dataset(self.data_path.joinpath("train"), self.train_transform)
else:
dataset = self.create_dataset(self.data_path.joinpath("validation"), self.valid_transform)
return DataLoader(dataset=dataset, batch_size=self._batch_size, num_workers=self._num_workers, shuffle=train)

def train_dataloader(self):
log.info("Training data loaded.")
return self.__dataloader(train=True)

def val_dataloader(self):
log.info("Validation data loaded.")
return self.__dataloader(train=False)

@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser])
parser.add_argument(
"--num-workers", default=0, type=int, metavar="W", help="number of CPU workers", dest="num_workers"
)
parser.add_argument(
"--batch-size", default=8, type=int, metavar="W", help="number of sample in a batch", dest="batch_size"
)
return parser


# --- Pytorch-lightning module ---


class TransferLearningModel(pl.LightningModule):
"""Transfer Learning with pre-trained ResNet50.
>>> with TemporaryDirectory(dir='.') as tmp_dir:
... TransferLearningModel(tmp_dir) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
TransferLearningModel(
(feature_extractor): Sequential(...)
(fc): Sequential(...)
)
"""

def __init__(
self,
dl_path: Union[str, Path],
backbone: str = "resnet50",
train_bn: bool = True,
milestones: tuple = (5, 10),
Expand All @@ -115,7 +177,6 @@ def __init__(
dl_path: Path where the data will be downloaded
"""
super().__init__()
self.dl_path = dl_path
self.backbone = backbone
self.train_bn = train_bn
self.milestones = milestones
Expand All @@ -124,7 +185,6 @@ def __init__(
self.lr_scheduler_gamma = lr_scheduler_gamma
self.num_workers = num_workers

self.dl_path = dl_path
self.__build_model()

self.train_acc = pl.metrics.Accuracy()
Expand Down Expand Up @@ -163,7 +223,7 @@ def forward(self, x):
# 2. Classifier (returns logits):
x = self.fc(x)

return F.sigmoid(x)
return torch.sigmoid(x)

def loss(self, logits, labels):
return self.loss_func(input=logits, target=labels)
Expand Down Expand Up @@ -195,60 +255,16 @@ def validation_step(self, batch, batch_idx):
self.log("val_acc", self.valid_acc(y_logits, y_true.int()), prog_bar=True)

def configure_optimizers(self):
optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr)

parameters = list(self.parameters())
trainable_parameters = list(filter(lambda p: p.requires_grad, parameters))
rank_zero_info(
f"The model will start training with only {len(trainable_parameters)} "
f"trainable parameters out of {len(parameters)}."
)
optimizer = optim.Adam(trainable_parameters, lr=self.lr)
scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma)

return [optimizer], [scheduler]

def prepare_data(self):
"""Download images and prepare images datasets."""
download_and_extract_archive(url=DATA_URL, download_root=self.dl_path, remove_finished=True)

def setup(self, stage: str):
data_path = Path(self.dl_path).joinpath("cats_and_dogs_filtered")

# 2. Load the data + preprocessing & data augmentation
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_dataset = ImageFolder(
root=data_path.joinpath("train"),
transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
)

valid_dataset = ImageFolder(
root=data_path.joinpath("validation"),
transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize,
]),
)

self.train_dataset = train_dataset
self.valid_dataset = valid_dataset

def __dataloader(self, train: bool):
"""Train/validation loaders."""

_dataset = self.train_dataset if train else self.valid_dataset
loader = DataLoader(dataset=_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=train)

return loader

def train_dataloader(self):
log.info("Training data loaded.")
return self.__dataloader(train=True)

def val_dataloader(self):
log.info("Validation data loaded.")
return self.__dataloader(train=False)

@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser])
Expand All @@ -263,7 +279,7 @@ def add_model_specific_args(parent_parser):
"--epochs", default=15, type=int, metavar="N", help="total number of epochs", dest="nb_epochs"
)
parser.add_argument("--batch-size", default=8, type=int, metavar="B", help="batch size", dest="batch_size")
parser.add_argument("--gpus", type=int, default=1, help="number of gpus to use")
parser.add_argument("--gpus", type=int, default=0, help="number of gpus to use")
parser.add_argument(
"--lr", "--learning-rate", default=1e-3, type=float, metavar="LR", help="initial learning rate", dest="lr"
)
Expand All @@ -275,12 +291,9 @@ def add_model_specific_args(parent_parser):
help="Factor by which the learning rate is reduced at each milestone",
dest="lr_scheduler_gamma",
)
parser.add_argument(
"--num-workers", default=6, type=int, metavar="W", help="number of CPU workers", dest="num_workers"
)
parser.add_argument(
"--train-bn",
default=True,
default=False,
type=bool,
metavar="TB",
help="Whether the BatchNorm layers should be trainable",
Expand All @@ -303,21 +316,22 @@ def main(args: argparse.Namespace) -> None:
to a temporary directory.
"""

with TemporaryDirectory(dir=args.root_data_path) as tmp_dir:

model = TransferLearningModel(dl_path=tmp_dir, **vars(args))
finetuning_callback = MilestonesFinetuningCallback(milestones=args.milestones)

trainer = pl.Trainer(
weights_summary=None,
progress_bar_refresh_rate=1,
num_sanity_val_steps=0,
gpus=args.gpus,
max_epochs=args.nb_epochs,
callbacks=[finetuning_callback]
)
datamodule = CatDogImageDataModule(
dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers
)
model = TransferLearningModel(**vars(args))
finetuning_callback = MilestonesFinetuning(milestones=args.milestones)

trainer = pl.Trainer(
weights_summary=None,
progress_bar_refresh_rate=1,
num_sanity_val_steps=0,
gpus=args.gpus,
max_epochs=args.nb_epochs,
callbacks=[finetuning_callback]
)

trainer.fit(model)
trainer.fit(model, datamodule=datamodule)


def get_args() -> argparse.Namespace:
Expand All @@ -331,6 +345,7 @@ def get_args() -> argparse.Namespace:
dest="root_data_path",
)
parser = TransferLearningModel.add_model_specific_args(parent_parser)
parser = CatDogImageDataModule.add_argparse_args(parser)
return parser.parse_args()


Expand Down

0 comments on commit 141316f

Please sign in to comment.