## Example code (pseudo code)


In [None]:
Class MyModel():
    def forward():
        ...
        plogp = target_p * pred_logp
        loss = -torch.sum(torch.sum(loss_mask * plogp, 2)) / (num_items_in_batch + 1e-5)
        ...

class MyCustomTrainer(Trainer):
    def evaluate(self, **kwargs):
        eval_dataset = self.eval_dataset
        dataloader = self.get_eval_dataloader(eval_dataset, batch_size=1)
        model = self.model
        model.eval()
        for eval_idx, inputs in enumerate(dataloader):
            with torch.no_grad(), torch.cuda.amp.autocast():
                outputs = model(**inputs)
        model.compute_metrics()

    def get_batch_samples(self, epoch_iterator, accumulation_steps, device=None):
        # only be called by train_loop not evaluation_loop!
        batch_samples = []
        num_items_in_batch = None
        for _ in range(accumulation_steps):
            try:
                batch_samples += [next(epoch_iterator)]
            except StopIteration:
                break

        if len(batch_samples) > 0:
            num_items_in_batch = self.get_num_items_in_batch(batch_samples)
        if self.args.average_tokens_across_devices:
            num_items_in_batch = gather(num_items_in_batch).sum()

        return batch_samples, num_items_in_batch

    def get_num_items_in_batch():
        # e.g., using (labels != 0).sum() for all batches in this batch_samples


trainer = MyCustomTrainer(
    model=model,
    args=hgf_training_args,
    # train dataset
    train_dataset=train_dataset,
    data_collator=train_collator,
    # eval dataset
    eval_dataset=eval_dataset,
    eval_data_collator=eval_collator,
    # logging
    compute_metrics= ...
)

## Trainer

In [None]:
class Trainer:
    def train():
        self._inner_training_loop()

    def _inner_training_loop():
        self._train_batch_size # per-device batch size
        train_dataloader = DataLoader(train_dataset, batch_size=self._train_batch_size)
        model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
        checkpoint_loading() # if any

        # Train!
        tr_loss = torch.tensor(0.0, device=args.device) # accumulated loss
        if args.eval_on_start:
            self.evaluate()

        for epoch in range(epochs_trained, num_train_epochs):
            epoch_dataloader = train_dataloader
            total_updates = steps_in_epoch // args.gradient_accumulation_steps
            epoch_iterator = iter(epoch_dataloader)
            step = update_step = -1
            for _ in range(total_updates):
                update_step += 1
                batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, accumulation_steps)
                for i, inputs in enumerate(batch_samples):
                    step += 1
                    do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0
                    # do training step
                    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                    tr_loss = tr_loss + tr_loss_step

                    if do_sync_step:
                        gradient_clipping()
                        self.optimizer.step()
                        learning_rate = self._get_learning_rate()
                        model.zero_grad()
                        self.state.global_step += 1
                        self._maybe_log_save_evaluate(tr_loss)

    def training_step(model, inputs, num_items_in_batch): # per step
        model.train()
        inputs = self._prepare_inputs(inputs)
        loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
        torch.cuda.empty_cache()
        if num_items_in_batch is None:
            # the old buggy Transformer reduces evenly per accumulated batches
            # here it is kept only for backward compatibility (See HF bug in the Correctness Section)
            loss = loss / self.args.gradient_accumulation_steps
        self.accelerator.backward(loss)
        return loss.detach()

    def compute_loss(...): # per step
        kwargs = {}
        kwargs["num_items_in_batch"] = num_items_in_batch
        inputs = {**inputs, **kwargs}
        outputs = model(**inputs)
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        if self.args.average_tokens_across_devices:
            loss *= self.accelerator.num_processes # to cancel mean reduce
        return loss

    def _maybe_log_save_evaluate(tr_loss): # per update
        if self.control.should_log:
            tr_loss_scalar = gather(tr_loss).mean() # mean reduce
            tr_loss -= tr_loss # reset tr_loss to zero
            logs = dict()
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
            self.log(logs)
            self._globalstep_last_logged = self.state.global_step

        if self.control.should_evaluate:
            self.evaluate()

## Accelerate

In [None]:
class Accelerator:
    def backward(self, loss, **kwargs):
        if self.distributed_type != DistributedType.DEEPSPEED: # deepspeed handles this in its code
            loss = loss / self.gradient_accumulation_steps
        loss.backward(**kwargs)

## Correctness
In SGD, gradients can be accumulated separatedly:
$$
\nabla_\theta L_D = \frac 1N \sum_i^N \nabla_\theta L(\theta, D_i) \doteq \frac 1N \sum_i^N g_i 
$$

In language modeling, sample is token based. However, each batch of token sequences $k$ in a accumulate-gradient batch $B$ may contain different number of tokens $N_k$. As a result, it is not correct to use "average of average" to model the loss here:
$$
\nabla_\theta L_D = \frac 1 {\sum_k^B N_k} \sum_k^B \sum_i^{N_k} g_i  \not=\frac1 B \sum_k^B ( \frac1{N_k} \sum_i^{N_k} g_i )
$$

This RHS incorrectness is actually kept in Huggingface for a long time, and gets fixed only recently from a notice of a 3rd party (Unsloth): https://huggingface.co/blog/gradient_accumulation