# base_rbt

> Base functions and classes we use for our hacking on BT / RBT and related ideas

This file will become your README and also the index of your documentation.

## Install

```sh
!pip install git+https://github.com/hamish-haggerty/base_rbt.git#egg='base_rbt'
```

## How to use

After installing, import like this:

In [None]:
from base_rbt.base_model import *
from base_rbt.base_lf import *

We also need some other libraries:


In [None]:
import self_supervised
import torch
from fastai.vision.all import *
from self_supervised.augmentations import *
from self_supervised.layers import *

In [None]:
device='cuda' if torch.cuda.is_available() else 'cpu'

Here is a (silly!) modification to the BT loss function. We are just scaling the bt loss function by $0.01$. However, this illustrates the general API if we want to modify the loss function. 

In [None]:
@patch
def lf(self:BarlowTwins, pred,*yb): return 0.01*lf_bt(pred, self.I,self.lmb)

Now we give an end to end example. First we need a dls i.e. a dataloader:

In [None]:
#Get some MNIST data and plonk it into a dls
path = untar_data(URLs.MNIST)
items = get_image_files(path/'training') #i.e. NOT testing!!!
items = items[0:10]
split = RandomSplitter(valid_pct=0.0)
tds = Datasets(items, [PILImageBW.create, [parent_label, Categorize()]], splits=split(items))
dls = tds.dataloaders(bs=2,num_workers=0, after_item=[ToTensor(), IntToFloatTensor()], device=device)



Now we patch in our own definition of a loss function. First define it:

In [None]:
def lf_rbt(pred,seed,I,lmb):
    
    bs,nf = pred.size(0)//2,pred.size(1)

    #All standard, from BT
    z1, z2 = pred[:bs],pred[bs:] #so z1 is bs*projection_size, likewise for z2
    z1norm = (z1 - z1.mean(0)) / z1.std(0, unbiased=False)
    z2norm = (z2 - z2.mean(0)) / z2.std(0, unbiased=False)
    C = (z1norm.T @ z2norm) / bs 
    cdiff = (C - I)**2

    #Get either max corr(f(x),g(y)) {if indep=True} or max 0.5*corr(x,g(y)) + 0.5*corr(f(x),y), {if indep=False}
    #where the max is over f and g. Please see base_lf for details
    CdiffSup = Cdiff_Sup(I=I,qs=ps,inner_steps=5,indep=False)
    cdiff_2 = CdiffSup(z1norm,z2norm) #same shape as cdiff

    #As above but f and g are now randomly sampled sinusoid. Please see base_lf for details
    CdiffRand = Cdiff_Rand(seed=seed,std=0.1,K=2,indep=False)
    cdiff_2_2 = CdiffRand(z1norm,z2norm) #same shape as cdiff

    cdiff_2 = 0.5*cdiff_2_2 + 0.5*cdiff_2 #convex combination of rand and sup terms.

    rr = cdiff_2*(1-I)*lmb #redundancy reduction term (scaled by lmb)

    loss = (cdiff*I + rr).sum() #sum of redundancy reduction term and invariance term
    torch.cuda.empty_cache()
    return loss

Then patch it in:

In [None]:
@patch
def lf(self:BarlowTwins, pred,*yb): return lf_rbt(pred,seed=self.seed,I=self.I,lmb=self.lmb)


Now we can train RBT:

In [None]:
#Full usage of above
ps=500
hs=500
fastai_encoder = create_fastai_encoder(xresnet18(),pretrained=False,n_in=1)
model = create_barlow_twins_model(fastai_encoder, hidden_size=hs,projection_size=ps)# projection_size=1024)
aug_pipelines = get_barlow_twins_aug_pipelines(size=28, rotate=True,flip_p=0,resize_scale=(0.7,1), jitter=False, bw=False,blur=True,blur_p=0.5,blur_s=8, stats=None, cuda=(device=='cuda'))
learn = Learner(dls,model, cbs=[BarlowTwins(aug_pipelines, print_augs=True)])
learn.fit(1)

Once we have trained the `fastai_encoder` can evaluate in various ways. e.g. linear evaluation. 