Skip to content

Refactor ModelTrainer class for a better modular approach #170

@gitttt-1234

Description

@gitttt-1234

Current Situation

The ModelTrainer class handles multiple tasks within large, monolithic methods. These include:

  • Instantiating the LightningModule via _initialize_model

  • Creating datasets and dataloaders through _create_data_loaders_torch_dataset

  • Configuring callbacks and loggers within the train() method

This tightly coupled structure makes it harder to test, maintain, and reuse across different training pipelines.

Proposed Refactor

To improve clarity and modularity, we should refactor these methods into standalone, reusable functions. Each function should take in explicit parameters, rather than an entire config section, to improve usability and flexibility.

  • Create functions

  • get_lightning_module(...): Returns a LightningModule instance

  • get_dataset(...): Returns a dataset instance (e.g., BaseDataset)

  • get_dataloaders(...): Returns train and val dataloaders

  • get_callbacks(...): Returns a list of PyTorch Lightning callbacks

  • Refactor ModelTrainer class

class ModelTrainer:

    def __init__(self, config: TrainingJobConfig):
        self.config = config
        self.setup_config()

    def setup_config(self):
        # Computing dataset-derived values like max_height, max_width, crop_size, part_names, and edges from .slp files and save `training_config.yaml`. Check compatibility between `output_stride` and `max_stride` in `model_config`.
        pass

    def setup_lightning_module(self):
        # Set up the lightning module based on the head type (read from the `config.model_config.head_configs` section.  
        self.lightning_module = get_lightning_module(backbone_config, head_configs, optimizer, pretrained_weights)

    def setup_dataloaders():
        # Create dataloaders from training and validation datasets.
        self.train_dataloader, self.val_dataloader = get_dataloaders()

    def setup_trainer():
        # Set up callbacks/ loggers and create a `Trainer` instance.
        self.trainer = Trainer(logger=get_loggers(), callbacks=get_callbacks(),...)

    def train():
        self.setup_lightning_module()
        self.setup_dataloaders()
        self.setup_trainer()
        self.trainer.fit()

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions