Skip to content

zhixinma/Trainer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 

Repository files navigation

The class Traininer provides a template for neural network training. Except to function config() and train(), all the rest functions should be private and invisible to users.

class: Trainer   
    def config()  # pucblic
    def train()  # pucblic
    def train_epoch()
    def validate()
    def test()
    def inference()
    def batch_generator()  
    def calculate_loss()  
    def back_propagation()
    def evaluate()  
    def calculate_metrics()  
    def batch_padding()

The class Loss and Metrics provide some common methods to calculate the loss and evaluation metrics.

class Loss
    def calc_bce_loss()
    def calc_mce_loss()
    def calc_mse_loss()
class Metrics  
    def bi_cls_metric()
        return Accuracy, Positive_Rate, True_Positive_Rate, False_Positive_Rate, Recall, Positive_Precision, Negative_Precision, F1_score 
    def mul_cls_metric()
        return Accuracy, Positive_Rate, True_Positive_Rate, False_Positive_Rate, Recall, Positive_Precision, Negative_Precision, F1_score
    def seq_metric()
    def evaluate_mse_task()

The class TrainingConfig and TrainingState are Data Transfer Object which only transfer the data.

class TrainingConfig
    # Transfer the statistic configuration.
    def set_data()
    def set_pad()
    def set_conf()
    def set_forward_func()
    def add_task()
class TrainingState
    # Transfer the dynamic state and data during training process.
    def set_pred_batch()
    def set_gold_batch()
    def record_metric_batch()
    def clear_epoch_session()
    def get_best_model_path()
    def clear_infer_session()
    def update_epoch()

Releases

No releases published

Packages

No packages published

Languages