# Improve performance and robustness with Gaussian Multi-Target Regularization

Hi all! In this notebook I will share a strategy that can easily be added to your model, which improves both model performance and robustness! I will also share a minimal implementation of this strategy.

## Motivation

* there is a `standard_error` term in the training data, we should try to utilize it!

## Introduction

* First, "gaussian multi-target regularization" is just a random name I came up with.
* The idea is to have a model that predicts multiple outputs, where the targets are sampled from a Gaussian distribution, based on `target` and `standard_error`

## Benefits

Before diving into code, let's first think what are the potential benefits of this strategy.

1. data augmentation
    * sampling the target from a normal distribution should have a regularization effect
2. ensemble
    * multi-target means we will have multiple slightly different heads for ensemble, which should improve performance
 
## Code

* To make it short, I will only include code snippet copied from my local enviroment.


## Result

* Based on my local experiments, the proposed method is indeed better!
    * The validation RMSE is lower compared to the baseline.
    * The validation RMSE has lower variations between runs, compared to the baseline.
* Since we are only increasing the number of output in the last layer, the increase in memory/compute is fairly negligible. (I tried at most 1024 targets.)
* As shown below, the code change is also minimal, so try it out!
* Please give this notebook a upvote if you find it useful!

## Baseline code

In [None]:
from torch import nn
class TrainingModule(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.crit = nn.MSELoss()

    def forward(self, target=None, standard_error=None, **kwargs):
        out = self.model(**kwargs)
        logits = out.logits
        loss = self.crit(logits.view(-1), target)
        return loss

    def predict(self, target=None, standard_error=None, **kwargs):
        out = self.model(**kwargs)
        logits = out.logits
        return logits
# model_name_or_path = 'roberta-base'
# model = AutoModelForSequenceClassification.from_pretrained(
#         model_name_or_path, num_labels=1
# )
# model = TrainingModule(model)
# # dataset, dataloader
# loss = model(**batch)

## Code for proposed strategy

In [None]:
num_heads = 32
class TrainingModule(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.crit = nn.MSELoss()

    def forward(self, target=None, standard_error=None, **kwargs):
        mean = einops.repeat(target, "b -> b n", n=num_heads)
        std = einops.repeat(standard_error, "b -> b n", n=num_heads)
        targets = torch.normal(mean=mean, std=std)
        out = self.model(**kwargs)
        logits = out.logits
        loss = self.crit(logits, targets)
        return loss

    def predict(self, target=None, standard_error=None, **kwargs):
        out = self.model(**kwargs)
        logits = out.logits
        return logits.mean(-1)
# model_name_or_path = 'roberta-base'
# model = AutoModelForSequenceClassification.from_pretrained(
#         model_name_or_path, num_labels=num_heads
# )
# model = TrainingModule(model)
# # dataset, dataloader
# loss = model(**batch)