## Improving unbalanced dataset training : easy 10+% accuracy boost

Dr. Shamus Husheer, Heartfelt Technologies Ltd, August 2022

Most datasets in reality are unbalanced, often to a large degree, because training data for some classes is intrinsically harder to find than others. Validation sets should ideally match class distribution at inference time, which are also often unbalanced, but in entirely different ways. There are a number of methods to counteract this, but it seems a simple approach discused here might be overlooked.

The idea is very simple : adding a fixed set of class-specific bias values prior to the softmax or sigmoid layer, where we pre-calculate those bias values based on the known relative distribition of class examples. This applies to any classification problem, but is demonstrated below for the case of image classification.

In a particularly nasty test-case, where the training set is unbalanced in the opposite way to the validation set (which reflects many industrial use-cases), we achieve 16% improvement in accuracy at 5 epochs and 11% accuracy at 20 epochs of training. Because many industrial use-cases are characterised by small, unbalanced datasets and few epochs of training (after transfer learning), this is an important improvement. The method does not degrade performance for balanced datasets, or for unbalanced datasets where the training and validation dataset have the same class distributions. It therefore makes a "sane default" for general use.

### Motivation

Imagine I only had a very small number of medical pictures of patients with a certain disease, but lots of control patients – and in a ratio that doesn’t necessarily reflect the prevalence of the disease in the general population. In this case it would be wise to either remove the imbalance bias, or potentially to try to reflect the prevalence of the disease in the general population (these are two very different, but equally reasonable corrections – simply keeping the bias of the training set is not).

Obviously there are many other approaches to dealing with unbalanced datasets, so this won't be the solution for everything, but it seems like such a simple (and easily pre-calculated) addition that it's likely to be a "sane default" compared to the implicit assumption of perfectly balanced training sets that is usually assumed.

The core idea is to simply add bias values just before softmax, with the bias for each class calculated based on the number of examples known to be in the training set for each class (the a-priori training set class distribition). This set of biases can then be recalculated to match the assumed frequency of occurance of the classes during inference (the a-priori validation set class distribition), which should yield better classification statistics in the field.

### How Apriori-softmax helps

Anything using a softmax output to calculate a probability for each class will, assuming no useful features, learn a bias for each output (prior to softmax) of ln(apriori_p) where apriori_p is the probability of randomly drawing an example of this class from all of the training data. In a balanced dataset, all biases will be a fixed value, so as long as the input in mean-zero, there's nothing to learn and nothing to worry about (because a constant offset to all channels is ignored by softmax). With a highly unbalanced training set, learning the correct setting for this bias is about the single most important feature the network can learn, but it can't know this accurately until about 1/r^2 where r is the probability of seeing an example of the least numerous class, multiplied by the number of classes. In highly unbalanced datasets, this could be more than an epoch, whereas we can precisely deterine this ahead of time for training, and likely also already have good estimates for what we expect to see in the field.

Therefore, to avoid the start of training having to try to learn features *and* the balance of classes, we can simply create a layer that adds exactly the correct bias caused by the dataset balance right before the softmax. This way a mean-zero random input from a newly initialized network will, on average, result in a class probability output that is the a-priori probability of each class. This should speed up training.

This applies to CrossEntropyLoss, but we can also make the same improvement for BCEWithLogitsLoss by simply adding a single bias value that is ln(apriori_p_A)-ln(apriori_p_notA) where we are predicting class A.

Then when it comes to inference time, we may also wish to specify an expected "real world" class imbalance that is not the same as the training set, or we may wish to remove the imbalance introduced by the training set. We then simply update this set of biases using the new ln(apriori_p) values.

Note that this pre-softmax layer is NOT mathematically the same as the "weights" option for torch.NLLoss, althought the "weights" option is evidently intended to address a similar problem.

The impact of this should be most acute at the early stages of training, and also in inference on sets that are wildly different in balance than the training set.

There is an additional detail about how Label Smoothing interacts with this process, which we will get onto later.

### Unbalanced training and validation set generation

We can test this easily by taking the Imagentte set, and massively unbalancing it. So that we can compare to the usual balanced validation set, we create an additional column in our pandas dataframe. The is_valid column, which is TRUE for ~30% of images, is retained and we generate an additional "unbalanced" boolean column that has a class-dependent probability from 5% to 95%. The unbalanced training set is all elements where (is_valid OR unbalanced) is FALSE. The balanced validation set is where is_valid is TRUE.

A perversely unbalanced validation set is where (is_valid AND unbalanced) is TRUE. The perversely unbalanced validation set represents the extreme (but industrially common) case where the training set is artificually inflated in some classes of interest, which are rare in the field, and vice versa. The perverse validation set is utterly different to (inverted from) balance in the training set, so examining the accuracy (or F1 score, or whatever) on this set is particularly interesting.

### Testing

An unaltered model with an unbalanced training set will have difficulty in training, because to accurately estimate that the training set for the worst classes has 1000/20 = 50 images in, from a set of ~10,000 images, we would need to run through roughly (10000/50)^2 = 40,000 images - that's 4 epochs of training, just to get the biases at the top layer correct. Even after training, validation will be relatively poor, because the model will have learned baseline softmax bias that is completely wrong.

We test the unaltered model to demonstrate this, and then we alter the model to insert the Apriori-Softmax, but do training and validation with softmax-bias layer set to all 0s, which is the same as excluding it. We verify that these two results are the same (within error) to show that our Apriori-Softmax layer doesn't mess anything up.

Then we do training and validation where we set the softmax biases based on known training class probabilities, and see how it trains, and if validation improves. We the also set the known class probabilities for interence, and verify that this improves inference as well.

Note that the use of the "accuracy" metric with an unbalanced validation set is inherently problematic. A number of other metrics, such as F1 score, are advocated. For simplicity in this test we will use "accuracy" even for unbalanced validation sets, because this is most familiar to most industrial practitioners.

### Impact on label smoothing

Label smoothing usually helps training a lot e.g. https://lessw.medium.com/label-smoothing-deep-learning-google-brain-explains-why-it-works-and-when-to-use-sota-tips-977733ef020

However, it basically sets the output of the softmax to target a probability of say 0.99 rather than 1.0 for the correct class, and 0.01 rather than 0.0 for the incorrect class (assuming 10 classes and eps=0.1). But if we trained using apriori-softmax then removed the apriori (or changed it), then this 99% / 1% posterior probability will adjust to some other value. This is not such a problem for the majority class, but a minority class is effectively damped in probability during training, and trained to output 1% as the not-present result. Taking away the apriori bias could possibly shift that 1% to over 50% in the case of highly unbalanced datasets, because we're taking a learned class distribution of 1% up to 19% when moving from our unbalanced training set to our perversely unbalanced validation set, so just a bit of uncertainty could move that 19% base probability to be the largest measured class probability.

The LabelSmoothing paper https://arxiv.org/abs/1512.00567 actually says the smoothed labels should follow the distribution of samples, but specifically for ImageNet they use a uniform distribition, which everyone seems to have used as a shortcut.

This means the targets no longer become a blend with uniform distribution, which means we can't use the simple multiclass cross entropy trick, we have to actually sum up all log-probabilities. Therefore we make a new AprioriLabelSmoothingCrossEntropy loss function, which replaces the usual LabelSmoothingCrossEntropy for some of our experiments.


### Implementation

We create a column on noisy_imagenette.csv to perform the unbalancing of datasets

We create an UnbalancedColSplitter that generates 4 datasets rather than the usual two, such that we can easily switch between balanced, unbalanced and perversely unbalanced validation sets

We creare a SelectiveBias module that can selectively add constant bias values

We create a AprioriSoftmax callback that switches the SelectiveBias between training and validation set values when we switch between training and validation in the training loop

We create an AprioriLabelSmoothingCrossEntropy loss function that calculates label smoothing using non-uniform class distribition

We bastardize the usual ImageNette create_learner and get_dls to make it simple to generate the variations on the model, loss function and dataset splitting.

Full results (30 replicates at 5 epochs, plus 15 replicates at 20 epochs) takes about 24hrs to run on an RTX2060-super GPU.

## Results

The tables show 5 and 20 epoch training, all with learning_rate of 0.025 (found via learning rate finder on balanced training set).

The abbreviations at the top of the various 0/1 columns describe which aspects of the method are enabled

(pretty tables from https://www.tablesgenerator.com/markdown_tables)

5 epochs, 30 replicates:

| _row_ | _ub_ | _ap_ | _tp_ | _vp_ | _al_ | _ls_ | _avg_train_acc_ | **_avg_bal_val_acc_** | **_avg_unbal_val_acc_** | **_avg_perv_val_acc_** | _epochs_ | _bs_ | _lr_ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 0 | 0 | 0 | 0 | 0 | 0 | 89.336+/-0.301 | **84.134+/-0.345** | **0.000+/-0.000** | **0.000+/-0.000** | 5 | 128 | 0.025 |
| 2 | 0 | 0 | 0 | 0 | 0 | 1 | 89.789+/-0.308 | **84.568+/-0.354** | **0.000+/-0.000** | **0.000+/-0.000** | 5 | 128 | 0.025 |
| 3 | 1 | 0 | 0 | 0 | 0 | 0 | 86.109+/-0.635 | **69.544+/-0.682** | **82.137+/-0.582** | **56.686+/-1.100** | 5 | 128 | 0.025 |
| 4 | 1 | 1 | 0 | 0 | 0 | 0 | 86.162+/-0.416 | **69.721+/-0.505** | **82.157+/-0.485** | **57.024+/-0.860** | 5 | 128 | 0.025 |
| 5 | 1 | 1 | 0 | 0 | 0 | 1 | 86.277+/-0.387 | **68.895+/-0.537** | **82.362+/-0.432** | **55.144+/-0.900** | 5 | 128 | 0.025 |
| 6 | 1 | 1 | 1 | 0 | 0 | 0 | 85.979+/-0.491 | **74.324+/-0.678** | **78.403+/-0.781** | **70.158+/-1.215** | 5 | 128 | 0.025 |
| 7 | 1 | 1 | 1 | 0 | 0 | 1 | 86.506+/-0.369 | **73.111+/-0.715** | **73.927+/-1.194** | **72.278+/-0.768** | 5 | 128 | 0.025 |
| 8 | 1 | 1 | 1 | 1 | 0 | 0 | 86.107+/-0.502 | **74.220+/-0.596** | **81.957+/-0.508** | **73.220+/-0.787** | 5 | 128 | 0.025 |
| 9 | 1 | 1 | 1 | 1 | 0 | 1 | 86.384+/-0.478 | **73.304+/-0.607** | **82.328+/-0.480** | **68.986+/-0.890** | 5 | 128 | 0.025 |
| 10 | 1 | 1 | 1 | 1 | 1 | 1 | 86.364+/-0.527 | **74.638+/-0.630** | **82.468+/-0.582** | **73.606+/-0.994** | 5 | 128 | 0.025 |

20 epochs, 15 replicates:

| _row_ | _ub_ | _ap_ | _tp_ | _vp_ | _al_ | _ls_ | _avg_train_acc_ | **_avg_bal_val_acc_** | **_avg_unbal_val_acc_** | **_avg_perv_val_acc_** | _epochs_ | _bs_ | _lr_ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 0 | 0 | 0 | 0 | 0 | 0 | 99.590+/-0.080 | **87.482+/-0.415** | **0.000+/-0.000** | **0.000+/-0.000** | 20 | 128 | 0.025 |
| 2 | 0 | 0 | 0 | 0 | 0 | 1 | 99.451+/-0.076 | **87.961+/-0.355** | **0.000+/-0.000** | **0.000+/-0.000** | 20 | 128 | 0.025 |
| 3 | 1 | 0 | 0 | 0 | 0 | 0 | 99.535+/-0.102 | **77.498+/-0.441** | **86.327+/-0.395** | **68.483+/-0.720** | 20 | 128 | 0.025 |
| 4 | 1 | 1 | 0 | 0 | 0 | 0 | 99.578+/-0.117 | **77.377+/-0.463** | **86.236+/-0.446** | **68.332+/-0.874** | 20 | 128 | 0.025 |
| 5 | 1 | 1 | 0 | 0 | 0 | 1 | 99.607+/-0.094 | **77.048+/-0.519** | **86.472+/-0.377** | **67.425+/-0.820** | 20 | 128 | 0.025 |
| 6 | 1 | 1 | 1 | 0 | 0 | 0 | 99.541+/-0.082 | **80.299+/-0.355** | **85.551+/-0.601** | **74.936+/-0.645** | 20 | 128 | 0.025 |
| 7 | 1 | 1 | 1 | 0 | 0 | 1 | 99.573+/-0.094 | **80.554+/-0.331** | **83.893+/-0.545** | **77.144+/-0.431** | 20 | 128 | 0.025 |
| 8 | 1 | 1 | 1 | 1 | 0 | 0 | 99.586+/-0.097 | **80.104+/-0.523** | **86.388+/-0.410** | **77.453+/-0.871** | 20 | 128 | 0.025 |
| 9 | 1 | 1 | 1 | 1 | 0 | 1 | 99.600+/-0.080 | **80.616+/-0.494** | **86.589+/-0.479** | **74.600+/-0.635** | 20 | 128 | 0.025 |
| 10 | 1 | 1 | 1 | 1 | 1 | 1 | 99.609+/-0.088 | **80.416+/-0.673** | **86.606+/-0.468** | **79.090+/-0.630** | 20 | 128 | 0.025 |

1 epoch, 15 replicates:

| _row_ | _ub_ | _ap_ | _tp_ | _vp_ | _al_ | _ls_ | _avg_train_acc_ | **_avg_bal_val_acc_** | **_avg_unbal_val_acc_** | **_avg_perv_val_acc_** | _epochs_ | _bs_ | _lr_ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 0 | 0 | 0 | 0 | 0 | 0 | 66.052+/-1.094 | **65.525+/-1.102** | **0.000+/-0.000** | **0.000+/-0.000** | 1 | 128 | 0.025 |
| 2 | 0 | 0 | 0 | 0 | 0 | 1 | 66.239+/-1.115 | **65.783+/-1.216** | **0.000+/-0.000** | **0.000+/-0.000** | 1 | 128 | 0.025 |
| 3 | 1 | 0 | 0 | 0 | 0 | 0 | 64.634+/-1.232 | **48.345+/-1.892** | **63.883+/-1.512** | **32.479+/-2.464** | 1 | 128 | 0.025 |
| 4 | 1 | 1 | 0 | 0 | 0 | 0 | 64.928+/-1.011 | **48.413+/-1.623** | **63.866+/-1.261** | **32.633+/-2.221** | 1 | 128 | 0.025 |
| 5 | 1 | 1 | 0 | 0 | 0 | 1 | 64.286+/-1.351 | **47.961+/-1.235** | **63.520+/-1.494** | **32.070+/-1.555** | 1 | 128 | 0.025 |
| 6 | 1 | 1 | 1 | 0 | 0 | 0 | 64.601+/-1.520 | **52.611+/-2.079** | **60.104+/-2.820** | **44.961+/-2.372** | 1 | 128 | 0.025 |
| 7 | 1 | 1 | 1 | 0 | 0 | 1 | 64.628+/-1.695 | **49.126+/-2.695** | **53.959+/-3.783** | **44.192+/-2.183** | 1 | 128 | 0.025 |
| 8 | 1 | 1 | 1 | 1 | 0 | 0 | 64.813+/-1.402 | **52.229+/-1.825** | **64.199+/-1.195** | **48.589+/-2.413** | 1 | 128 | 0.025 |
| 9 | 1 | 1 | 1 | 1 | 0 | 1 | 65.425+/-1.336 | **49.425+/-2.004** | **64.861+/-1.359** | **42.297+/-2.626** | 1 | 128 | 0.025 |
| 10 | 1 | 1 | 1 | 1 | 1 | 1 | 65.255+/-0.705 | **53.678+/-1.642** | **64.475+/-0.912** | **50.597+/-2.187** | 1 | 128 | 0.025 |

### Interpretation

The two tables above show the average (and standard deviation) of accuracy measures over 30 or 15 replicates of 5 or 20 epochs of training respectively.

Rows 1 and 2 are the standard training defaults on the standard ImageNette dataset, without or with Label Smoothing (see later). We see that with Label Smoothing, accuracy on the balanced validation set improves by about 0.5% +/- 0.4% (i.e. a barely significant difference). We also see that the accuracy on the training set is ~90% at 5 epochs, and still <100% at 20 epochs, meaning that the model is yet to learn the correct class for all examples, so training loss will be dominated by incorrectly classified images. This is likely why label smoothing makes so little difference, however as we shall see, it can still cause problems.

Rows 3 and 4 are training with the unbalanced training set, with and without the Apriori-Softmax layer, but with apriori class distributions set to uniform. Therefore the results of these two rows should be identical within statistical noise, which is true. We can clearly see that accuracy on the balanced validation set takes a massive hit, reducing from ~84% (row 1/2) to ~70% (row 3/4) at 5 epochs, and from ~87% to ~77% at 20 epochs. This demonstrates the core problem of training on an unbalanced training set, however accuracy on the training set is not massively different. The persistence of the difference from 5 to 20 epochs suggests that this difference would persist for many more epochs of training. Accuracy on the unbalanced validation sets can also be measured, showing that accuracy on a validation set with the same class distribution as the training set is not massively reduced. This shows that as long as the distribution of classes is the same in training and validation sets, an unbalanced class distribution is not really an issue (though the network is likely to be far less accurate on classes for which it has seen fewer examples than others, but this is a separate issue). As would be expected, the perversely balanced dataset (unbalanced in exactly the opposite direction to the training set) gives even worse results. Row 5 adds the effect of label smoothing, which again at this point makes little difference at either 5 or 20 epochs.

Row 6 turns on the Apriori-Softmax layer and sets the training class distribution to the known training set distribution, with row 7 also adding label smoothing. The accuracy on the training set hardly changes, but the accuracy on the balanced validation set takes a big jump, because the validation-time Apriori-Softmax class distribution is set to a uniform distribution which matches the balanced validation set well. The accuracy on the unbalanced validation set, which has the same class distribution as the training set, reduces somewhat. This is because the Apriori-Softmax validation class distribution has been set to uniform, which no longer matches that validation set. The accuracy on the perverse validation set also improved dramatically, which shows the utility of setting Apriori-Softmax to a uniform inference distribution when using an unknown inference-time class distribution.

Row 8 sets the validation time Apriori-Softmax distribution to match the validation set class distribution for each of the validation sets. This restores the validation accuracy of the unbalanced validation set that has the same class distribution as the training set, demonstrating that Apriori-Softmax can be a "sane default" for the common use-case where training and validation distributions are the same, but allowing re-adjustment for unknown (uniform) or known (perverse) inference time class distributions. The accuracy on the perversely unbalanced distribution is improved substantially also.

Row 9 adds label smoothing to row 8. We can see that accuracy of the perversely balanced validation set falls substantially, which is what we would expect if label smoothing were to interfere with Apriori-Softmax.

Finally, in Row 10 we add AprioriLabelSmoothingCrossEntropy to allow label smoothing to accrately replicate the class distribution of the training set. This returns the perversely unbalanced validation accuracy, whilst enabling label smoothing to be used. By 20 epochs, the effects of Label Smoothing are statistically significantly positive for the perversely unbalanced dataset when using AprioriLabelSmoothingCrossEntropy, whereas they were statistically significantly negative when conventional LabelSmoothingCrossEntropy was used. The effect size is relatively small, but this argues in favour of using AprioriLabelSmoothingCrossEntropy as a sane default alongside Apriori-Softmax.

For completeness, the same analysis is also performed with a single epoch of training (note that this is not the same as the first epoch of a 5 or 20 epoch training, because fit_one_cycle schedules the learning rate to adjust over the training run, at a rate that is different for 1, 5 and 20 epochs). Although the noise on all the statistics is greater, the same general results are seen.

### Overall result

After 5 epochs of training, we have changed from 57% accuracy to 73% accuracy on the perversely unbalanced validation set, and 69% to 74% accuracy on the balanced validation set, by making these changes. After 20 epochs of training, we have changed from 68% to 79% accuracy on the perversely unbalanced validation set, and 77% to 80% accuracy on the balanced validation set. All of these changes are entirely pre-calculated and based on simple statistical theory rather than any particulars of the dataset, so add no new parameters, and do not interfere at all with training/validation of balanced datasets.


## TODOs

### faster getting of class distribution
The calculation of class distributions is probably much, much slower than it could be:
```
train_targets = [p[1] for p in learn.dls.train_ds]
```
This grinds over every image, unnecessarily loading and processing them, just to get the class labels. The underlying pandas dataframe could be used, but that's not very general. Is there a simple way to just get the class labels rapidly?

### faster version of AprioriLabelSmoothingCrossEntropy
Because the class distribution is identical for every example, except for the class of interest, there is probably a mathematical shortcut for the non-uniform label smoothing rather than brute force calculating all the soft labels for every softmax output, akin to the shortcut taken in normal multiclass cross-entropy. I'm just not smart enough to spot it.

### test AprioriLabelSmoothingCrossEntropy using validation class distribution
Currently AprioriLabelSmoothingCrossEntropy implicitly removes the impact of the training set distribution, but does not correct to the validation set distribution (because in these tests we are checking multiple validation sets for each trained model). It would be possible to make an even greater correction, such that the smoothed targets (following Apriori-Softmax) would reflect the validation set distribution. This might hamper training somewhat, but might improve validation accuracy on the perverse validation set.

In [None]:
from fastai.basics import *
from fastai.vision.all import *
from fastai.callback.all import *
from fastai.distributed import *
from fastprogress import fastprogress
from torchvision.models import *
from fastai.vision.models.xresnet import *
from fastai.callback.mixup import *
from fastcore.script import *

fastprogress.MAX_COLS = 80

from inspect import signature
from time import time

#This forces deterministic behaviour in training,
#Slows down training by ~50%, but is useful when trying to get good stats on what is happening
#torch.backends.cudnn.benchmark = True


In [None]:
# If we haven't created unbalanced column in noisy_imagenette_unbalanced.csv, do it now

# We generate unbalanced column, which can be ORed or ANDed with is_valid to generate the types of datasets we wish

filename = '~/datasets_ssd/imagenette2-320/noisy_imagenette_unbalanced.csv'

inp = pd.read_csv(filename)
if 'unbalanced' not in inp.columns:
    # get list of unique classes, in random order
    rng = np.random.default_rng()
    classes = rng.permutation(inp['noisy_labels_0'].unique())
    # iterate classes, setting the proportion of samples in the training set to (class_index+0.5)/numclasses
    # For imagenette this will range from 5% to 95%
    randvals = rng.uniform(size=inp['noisy_labels_0'].shape)
    inp['unbalanced'] = False
    for cidx,cls in enumerate(classes):
        p = (0.5+cidx)/classes.size
        inp.loc[(inp['noisy_labels_0'] == cls) & (randvals < p),['unbalanced']] = True
    inp.to_csv(filename, index=False)


In [None]:
class SelectiveBias(nn.Module):
    "Two different sets of bias, stored as tensors, that can be switched between."
    def __init__(self, features):
        super(SelectiveBias,self).__init__()
        self.biasAselected = True
        self.biasA = torch.zeros(features).detach()
        self.biasB = torch.zeros(features).detach()
    def forward(self, x):
        if self.biasAselected:
            return self.biasA.to(x.device) + x
        else:
            return self.biasB.to(x.device) + x
    def selectBias(self, selectA):
        self.biasAselected = selectA
    def setBias(self, setA, values):
        if setA:
            self.biasA = torch.Tensor(values).detach()
        else:
            self.biasB = torch.Tensor(values).detach()


In [None]:
class AprioriSoftmax(Callback):
    "Add bias before softmax using apriori knowledge of frequency of each class."

    def before_train(self):
        "Set learn.model[-1].selectBias to True"
        if isinstance(self.learn.model[-1],SelectiveBias):
            self.learn.model[-1].selectBias(True)

    def before_validate(self):
        "Set learn.model[-1].selectBias to False"
        if isinstance(self.learn.model[-1],SelectiveBias):
            self.learn.model[-1].selectBias(False)


In [None]:
def UnbalancedColSplitter():
    # Hacked version of ColSplitter() that returns 3 sets, train, valid and perverse
    def _inner(o):
        assert isinstance(o, pd.DataFrame), "UnbalancedColSplitter only works when your items are a pandas DataFrame"
        assert ('is_valid' in o.columns) and ('is_valid' in o.columns), "UnbalancedColSplitter needs columns named 'is_valid' and 'unbalanced'"
        bal_valid_idx = o['is_valid']
        perverse_valid_idx = o['is_valid'] & o['unbalanced']
        unbal_valid_idx = o['is_valid'] & (~o['unbalanced'])
        train_idx = ~ (o['is_valid'] | o['unbalanced'])
        return L(mask2idxs(train_idx), use_list=True), L(mask2idxs(bal_valid_idx), use_list=True), L(mask2idxs(unbal_valid_idx), use_list=True), L(mask2idxs(perverse_valid_idx), use_list=True)
    return _inner

In [None]:
def get_dls(size, woof, pct_noise, bs, sh=0., workers=None, tinydataset=False, resizecrop=False, donormalize=True):
    # This has been horribly hacked to allow the "woof" parameter to be 0 or 1 for usual behaviour, or
    # -1 = use noisy_imagenette_unbalanced.csv to generate three splits: train, valid, perverse
    # the training set is ~(is_valid|is_valid_unbalanced) : an unbalanced subset of the usual training set
    # the validation set is exactly the same as noisy_imagenette.csv
    # the perverse set is (is_valid&is_valid_unbalanced) : an unbalanced subset of the validation set
    assert pct_noise in [0,5,50], '`pct_noise` must be 0,5 or 50.'
    if size<=224: path = URLs.IMAGEWOOF_320 if woof>0 else URLs.IMAGENETTE_320
    else        : path = URLs.IMAGEWOOF     if woof>0 else URLs.IMAGENETTE
    source = untar_data(path)
    workers = ifnone(workers,min(8,num_cpus()))
    blocks=(ImageBlock, CategoryBlock)
    tfms = [RandomResizedCrop(size, min_scale=0.35) if resizecrop else Resize(size), FlipItem(0.5)]
    batch_tfms = [Normalize.from_stats(*imagenet_stats)] if donormalize else []
    if sh: batch_tfms.append(RandomErasing(p=0.3, max_count=3, sh=sh))
    
    csv_file = 'noisy_imagewoof.csv' if woof else 'noisy_imagenette.csv'
    if woof == -1: csv_file = 'noisy_imagenette_unbalanced.csv'
    inp = pd.read_csv(source/csv_file)
    if tinydataset:
        # Keep just 16 instances of each label in train and valid (320 images total)
        inp = inp.sort_values(by=['is_valid','noisy_labels_0','path']).groupby(['is_valid', 'noisy_labels_0']).head(16)
    dblock = DataBlock(blocks=blocks,
               splitter=UnbalancedColSplitter() if woof == -1 else ColSplitter(),
               get_x=ColReader('path', pref=source),
               get_y=ColReader(f'noisy_labels_{pct_noise}'),
               item_tfms=tfms,
               batch_tfms=batch_tfms)
    
    return dblock.dataloaders(inp, path=source, bs=bs, num_workers=workers)

In [None]:
# Improved LabelSmoothingCrossEntropy that can contain apriori class distribution
# The LabelSmoothing paper https://arxiv.org/abs/1512.00567 actually says this should be the distribution of samples,
# but specifically for ImageNet they use a uniform distribition, which everyone seems to have used as a shortcut
# So the targets in general no longer become a blend with uniform distribution
# This means we can't use the simple multiclass cross entropy trick, we have to actually sum up all logpreds

class AprioriLabelSmoothingCrossEntropy(Module):
    y_int = True # y interpolation
    class_distribution = None # the relative proportion of each class making up the training set
    def __init__(self, 
        eps:float=0.1, # The weight for the interpolation formula
        weight:Tensor=None, # Manual rescaling weight given to each class passed to `F.nll_loss`
        reduction:str='mean', # PyTorch reduction to apply to the output
    ): 
        store_attr()

    def forward(self, output:Tensor, target:Tensor) -> Tensor:
        "Apply `F.log_softmax` on output then blend the loss/num_classes(`c`) with the `F.nll_loss`"
        if self.class_distribution is None:
            distribution = torch.ones_like(output)/(output.size()[1])
        else:
            distribution = torch.Tensor(self.class_distribution).detach().to(output.device)
            distribution = distribution/distribution.sum()
            # insert the batchsize dimension
            distribution = torch.unsqueeze(distribution, dim=0)
            # Now insert as many dimensions as needed at the end
            for _ in range(len(target.shape)-1):
                distribution = torch.unsqueeze(distribution, dim=-1)
            distribution = distribution.expand(output.size())
        # Now expand target to the shape of output, setting class probability for the target class to 1.
        targetprobs = torch.zeros_like(output)
        targetprobs.scatter_(1, torch.unsqueeze(target, dim=1), torch.ones_like(output))
        # Now produce the smoothed target distribution, which is a linear combination of targetprobs and distribution
        smooth_targets = (1-self.eps) * targetprobs + self.eps*distribution
        # Now perform the usual definition of cross entropy loss, -1. * sum(P(x) * log(Q(x)))
        log_preds = F.log_softmax(output, dim=1)
        loss = -(smooth_targets*log_preds).sum(dim=1)
        if self.reduction=='sum': loss = loss.sum()
        if self.reduction=='mean': loss = loss.mean()
        return loss

    def activation(self, out:Tensor) -> Tensor: 
        "`F.log_softmax`'s fused activation function applied to model output"
        return F.softmax(out, dim=-1)
    
    def decodes(self, out:Tensor) -> Tensor:
        "Converts model output to target format"
        return out.argmax(dim=-1)

In [None]:
def create_learner(
    woof:  Param("Use imagewoof (otherwise imagenette)", int)=0,
    pct_noise:Param("Percentage of noise in training set (0,5,50%)", int)=0,
    lr:    Param("Learning rate", float)=1e-2,
    size:  Param("Size (px: 128,192,256)", int)=128,
    sqrmom:Param("sqr_mom", float)=0.99,
    mom:   Param("Momentum", float)=0.9,
    eps:   Param("Epsilon", float)=1e-6,
    wd:    Param("Weight decay", float)=1e-2,
    epochs:Param("Number of epochs", int)=5,
    bs:    Param("Batch size", int)=64,
    mixup: Param("Mixup", float)=0.,
    opt:   Param("Optimizer (adam,rms,sgd,ranger)", str)='ranger',
    arch:  Param("Architecture", str)='xresnet50',
    sh:    Param("Random erase max proportion", float)=0.,
    sa:    Param("Self-attention", store_true)=False,
    sym:   Param("Symmetry for self-attention", int)=0,
    beta:  Param("SAdam softplus beta", float)=0.,
    act_fn:Param("Activation function", str)='Mish',
    fp16:  Param("Use mixed precision training", store_true)=False,
    pool:  Param("Pooling method", str)='AvgPool',
    dump:  Param("Print model; don't train", int)=0,
    runs:  Param("Number of times to repeat training", int)=1,
    meta:  Param("Metadata (ignored)", str)='',
    workers:   Param("Number of workers", int)=None,
    learnerpath:  Param("Path for learner.path", str)=None,
    tinydataset:  Param("Discard all but 16 examples of each class", store_true)=False,
    apriori_softmax:    Param("AprioriSoftmax callback", int)=0,
    lossfunc:  Param("LabelSmoothingCrossEntropy smoothing", object)=LabelSmoothingCrossEntropy(),
):
    if   opt=='adam'  : opt_func = partial(Adam, mom=mom, sqr_mom=sqrmom, eps=eps)
    elif opt=='rms'   : opt_func = partial(RMSprop, sqr_mom=sqrmom)
    elif opt=='sgd'   : opt_func = partial(SGD, mom=mom)
    elif opt=='ranger': opt_func = partial(ranger, mom=mom, sqr_mom=sqrmom, eps=eps, beta=beta)

    dls = get_dls(size, woof, pct_noise, bs, sh=sh, workers=workers, tinydataset=tinydataset)
    cbs = [MixUp(mixup)] if mixup else []

    m,act_fn,pool = [globals()[o] for o in (arch,act_fn,pool)]

    model = m(n_out=dls.c, act_cls=act_fn, sa=sa, sym=sym, pool=pool)

    if apriori_softmax:
        model.add_module("selective_bias", SelectiveBias(model[-1].out_features))
        cbs.append(AprioriSoftmax)
    
    learn = Learner(dls, model, opt_func=opt_func, path=learnerpath, cbs=cbs, lr=lr, wd=wd,\
            metrics=[accuracy], loss_func=lossfunc)
    if fp16: learn = learn.to_fp16()
    return learn


In [None]:
# Now run 'epochs' cycles of training, printing out the relevant stats

def do_test(epochs=5, unbalanced=True, apriori=True, set_train_props=True, set_val_props=True, apriori_loss=True, label_smoothing=True):
    woof = -1 if unbalanced else 0
    apriori_softmax = 1 if apriori else 0
    
    eps = 0.1 if label_smoothing else 0.0
    lossfunc = AprioriLabelSmoothingCrossEntropy(eps=eps) if apriori_loss else LabelSmoothingCrossEntropy(eps=eps)
    learn = create_learner(arch='xse_resnext18', fp16=True, size=128, bs=128, woof=woof, apriori_softmax=apriori_softmax, lossfunc=lossfunc)
    learn.unfreeze()

    if apriori and set_train_props:
        train_targets = [p[1] for p in learn.dls.train_ds]
        train_proportions = np.unique(train_targets, return_counts=True)[1]/len(train_targets)
        learn.model[-1].setBias(True, np.log(train_proportions))
        if apriori_loss:
            learn.loss_func.class_distribution = train_proportions

    if apriori and set_val_props:
        unbal_val_targets = [p[1] for p in learn.dls[2].dataset]
        unbal_val_proportions = np.unique(unbal_val_targets, return_counts=True)[1]/len(unbal_val_targets)
        perv_val_targets = [p[1] for p in learn.dls[3].dataset]
        perv_val_proportions = np.unique(perv_val_targets, return_counts=True)[1]/len(perv_val_targets)
        # The validation set used for printing metrics during training is idx=1, which is the balanced validation set
        bal_val_targets = [p[1] for p in learn.dls.valid_ds]
        bal_val_proportions = np.unique(bal_val_targets, return_counts=True)[1]/len(bal_val_targets)
        learn.model[-1].setBias(False, np.log(bal_val_proportions))

    lr = 0.025

    starttime = time()
    learn.fit_flat_cos(epochs, lr)
    endtime = time()
    
    # All validations are done in "validation" mode, so we need to adjust proportions to these if doing apriori-softmax
    if apriori and set_train_props:
        learn.model[-1].setBias(False, np.log(train_proportions))
    v_t = learn.validate(ds_idx=0)[1]
    if apriori:
        if set_val_props:
            learn.model[-1].setBias(False, np.zeros_like(bal_val_proportions))
        elif set_train_props:
            learn.model[-1].setBias(False, np.zeros_like(train_proportions))
    v_b = learn.validate(ds_idx=1)[1]
    if unbalanced:
        # If we aren't setting the validation proportions, but are using apriori, these will have been reset to uniform distribution (not training set distribution)
        if apriori and set_val_props:
            learn.model[-1].setBias(False, np.log(unbal_val_proportions))
        v_u = learn.validate(ds_idx=2)[1]
        if apriori and set_val_props:
            learn.model[-1].setBias(False, np.log(perv_val_proportions))
        v_p = learn.validate(ds_idx=3)[1]
    else:
        v_u = 0.0
        v_p = 0.0

    results = f"epochs,{int(epochs)}, ub,{int(unbalanced)}, ap,{int(apriori)}, tp,{int(set_train_props)}, vp,{int(set_val_props)}, al,{int(apriori_loss)}, ls,{int(label_smoothing)}, bs,{learn.dls.bs}, lr,{lr:0.05f}, elapsed,{endtime-starttime:0.02f}, train_loss,{learn.final_record[0]:0.03f}, valid_loss,{learn.final_record[1]:0.03f}, train_acc_pct,{v_t*100:0.03f}, bal_val_acc_pct,{v_b*100:0.03f}, unbal_val_acc_pct,{v_u*100:0.03f}, perv_val_acc_pct,{v_p*100:0.03f}"
    return results


In [None]:
outputfilename = './apriori-softmax-experiment-outputs.csv'

starttime = time()

allres = []

epochs = 5
replicates = 30

repstarttime = time()
for i in range(replicates):
    resultslist = []
    resultslist.append(do_test(epochs=epochs, unbalanced=False, apriori=False, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=False, apriori=False, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=False, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=False, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=True, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=True, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=True, apriori_loss=True, label_smoothing=True))
    allres.extend(resultslist)
    print(f"All results so far after replicate {i+1} of {epochs} epochs, replicates time elapsed = {time()-repstarttime:0.01f}sec, total elapsed = {time()-starttime:0.01f}sec")
    for l in sorted(allres): print(l)
    with open(outputfilename,"wt") as f:
        headers = [t.strip() for i,t in enumerate(allres[0].split(",")) if i%2 == 0]
        f.write(",".join(headers)+"\n")
        for l in sorted(allres):
            values = [t.strip() for i,t in enumerate(l.split(",")) if i%2 != 0]
            f.write(",".join(values)+"\n")


epochs = 20
replicates = 15

repstarttime = time()
for i in range(replicates):
    resultslist = []
    resultslist.append(do_test(epochs=epochs, unbalanced=False, apriori=False, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=False, apriori=False, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=False, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=False, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=True, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=True, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=True, apriori_loss=True, label_smoothing=True))
    allres.extend(resultslist)
    print(f"All results so far after replicate {i+1} of {epochs} epochs, replicates time elapsed = {time()-repstarttime:0.01f}sec, total elapsed = {time()-starttime:0.01f}sec")
    for l in sorted(allres): print(l)
    with open(outputfilename,"wt") as f:
        headers = [t.strip() for i,t in enumerate(allres[0].split(",")) if i%2 == 0]
        f.write(",".join(headers)+"\n")
        for l in sorted(allres):
            values = [t.strip() for i,t in enumerate(l.split(",")) if i%2 != 0]
            f.write(",".join(values)+"\n")


epochs = 1
replicates = 15

repstarttime = time()
for i in range(replicates):
    resultslist = []
    resultslist.append(do_test(epochs=epochs, unbalanced=False, apriori=False, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=False, apriori=False, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=False, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=False, set_val_props=False, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=False, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=False, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=True, apriori_loss=False, label_smoothing=False))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=True, apriori_loss=False, label_smoothing=True))
    resultslist.append(do_test(epochs=epochs, unbalanced=True, apriori=True, set_train_props=True, set_val_props=True, apriori_loss=True, label_smoothing=True))
    allres.extend(resultslist)
    print(f"All results so far after replicate {i+1} of {epochs} epochs, replicates time elapsed = {time()-repstarttime:0.01f}sec, total elapsed = {time()-starttime:0.01f}sec")
    for l in sorted(allres): print(l)
    with open(outputfilename,"wt") as f:
        headers = [t.strip() for i,t in enumerate(allres[0].split(",")) if i%2 == 0]
        f.write(",".join(headers)+"\n")
        for l in sorted(allres):
            values = [t.strip() for i,t in enumerate(l.split(",")) if i%2 != 0]
            f.write(",".join(values)+"\n")

