-
Notifications
You must be signed in to change notification settings - Fork 7
Description
Current Situation
The ModelTrainer class handles multiple tasks within large, monolithic methods. These include:
-
Instantiating the LightningModule via _initialize_model
-
Creating datasets and dataloaders through _create_data_loaders_torch_dataset
-
Configuring callbacks and loggers within the train() method
This tightly coupled structure makes it harder to test, maintain, and reuse across different training pipelines.
Proposed Refactor
To improve clarity and modularity, we should refactor these methods into standalone, reusable functions. Each function should take in explicit parameters, rather than an entire config section, to improve usability and flexibility.
-
Create functions
-
get_lightning_module(...): Returns a LightningModule instance
-
get_dataset(...): Returns a dataset instance (e.g., BaseDataset)
-
get_dataloaders(...): Returns train and val dataloaders
-
get_callbacks(...): Returns a list of PyTorch Lightning callbacks
-
Refactor
ModelTrainerclass
class ModelTrainer:
def __init__(self, config: TrainingJobConfig):
self.config = config
self.setup_config()
def setup_config(self):
# Computing dataset-derived values like max_height, max_width, crop_size, part_names, and edges from .slp files and save `training_config.yaml`. Check compatibility between `output_stride` and `max_stride` in `model_config`.
pass
def setup_lightning_module(self):
# Set up the lightning module based on the head type (read from the `config.model_config.head_configs` section.
self.lightning_module = get_lightning_module(backbone_config, head_configs, optimizer, pretrained_weights)
def setup_dataloaders():
# Create dataloaders from training and validation datasets.
self.train_dataloader, self.val_dataloader = get_dataloaders()
def setup_trainer():
# Set up callbacks/ loggers and create a `Trainer` instance.
self.trainer = Trainer(logger=get_loggers(), callbacks=get_callbacks(),...)
def train():
self.setup_lightning_module()
self.setup_dataloaders()
self.setup_trainer()
self.trainer.fit()