In [2]:
#| default_exp bandits

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
#| export
import torch, numpy as np
from typing import Optional

## Setup

In [13]:
import pickle
from xcai.block import prepare_batch
from xcai.models.MMM0XX import *

In [14]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
fname = f'{pkl_dir}/processed/wikiseealsotitles_data_distilbert-base-uncased_xcs.pkl'

with open(fname, 'rb') as file: block = pickle.load(file)

In [15]:
m = BT0001.from_pretrained('bert-base-uncased')

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


In [16]:
bsz = 10
batch = block.train.one_batch(bsz)
for i, batch in enumerate(block.train.dl):
    if i > 5: break
        
b = prepare_batch(m,batch, m_args=['plbl2data_idx', 'plbl2data_data2ptr'])
m,b = m.to('cuda'),b.to('cuda')
data_logits, lbl2data_input_ids, lbl2data_data2ptr, lbl2data_idx, lbl2data_logits, data_input_ids, data_repr, lbl2data_repr, data_embed, data_attention_mask, lbl2data_embed, lbl2data_attention_mask, kwargs = m(**b)

## `RLLossWeights`

In [5]:
#|export
def get_sparse_matrix(data_idx:torch.Tensor, n_data:torch.Tensor, scores:Optional[torch.Tensor]=None):
    data_ptr = torch.cat([torch.zeros(1, device=n_data.device, dtype=n_data.dtype), n_data.cumsum(0)])
    if scores is None: scores = torch.ones_like(data_idx)
    if data_idx.shape != scores.shape: raise ValueError(f'`data_idx` and `scores` should have same shape.')
    return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)
    

In [6]:
#| export
class RLLossWeights(torch.nn.Module):
    def __init__(self, num_samples, std=0.1, lr=0.001, reward_func=None,
                 collector=10, min=0.1, rest_init=0.1) -> None:
        super().__init__()
        init = np.ones(num_samples)
        init[:] = rest_init
        self.reward_func = reward_func
        self.collector = collector
        self.lr = lr
        self.num_samples = num_samples
        self.mu = torch.nn.Parameter(torch.Tensor(init))
        self.std = torch.nn.Parameter(torch.Tensor(np.ones(num_samples)*std),
                                      requires_grad=False)
        self.dist = torch.distributions.normal.Normal(self.mu, self.std)
        self.min = min
        self.w = None
        self.reset_metrics()

    def reset_metrics(self):
        self.collect_size = 0
        self.collect_value = 0
        self.step_counter = 0

    def sample(self, device="cpu"):
        if self.w is None:
            self.w = self.clip(self.dist.sample())
        return self.w.to(device)

    def zero_grad(self):
        self.mu.grad = None
        self.collect_size = 0
        self.collect_value = 0
        self.w = None

    def collect(self, pred, gt):
        size = pred.size(0)
        rewd = self.reward_func(pred, gt)  # TODO
        self.collect_value += rewd
        self.collect_size += size
        pass

    def step(
        self,
        inp:torch.FloatTensor,
        targ:torch.LongTensor, 
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor
    ):
        pred = inp@targ.T
        
        _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
        gt = get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ).to_dense()[:, idx[:len(inp2targ_idx)]]
    
        self.step_counter += 1
        self.collect(pred, gt)
        if self.step_counter % self.collector == 0:
            loss = -self.dist.log_prob(self.w)*self.curr_reward
            loss = torch.sum(loss).backward()
            self.mu.data = self.mu - self.lr * self.mu.grad.data
            self.dist.loc = self.clip(self.mu)
            self.step_counter = 0
            self.zero_grad()

    def clip(self, vect):
        return torch.clamp(vect, min=self.min)

    @property
    def curr_reward(self):
        return self.collect_value/self.collect_size

    def extra_repr(self):
        return f"{self.mu}"
        

In [7]:
#| export
class RLLossWeightsCumuluative(RLLossWeights):
    def __init__(self, num_samples=1, std=0.01, lr=0.01, m=0.8,
                 reward_func=None, collector=10, min=0.1, rest_init=0.1) -> None:
        self.m = m
        super().__init__(num_samples, std, lr, reward_func, collector, min, rest_init)

    def reset_metrics(self):
        super().reset_metrics()
        self.reward_prev = None
        self.in_warmup = True

    def step(
        self,
        inp:torch.FloatTensor,
        targ:torch.LongTensor, 
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor
    ):
        pred = inp@targ.T
        
        _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
        gt = get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ).to_dense()[:, idx[:len(inp2targ_idx)]]
    
        self.step_counter += 1
        self.collect(pred, gt)

        if self.step_counter % self.collector == 0:
            if self.in_warmup:
                self.in_warmup = False
                self.reward_prev = self.curr_reward
            else:
                reward = self.curr_reward - self.reward_prev
                loss = -self.dist.log_prob(self.w).sum()
                loss.backward()
                grad = self.mu.grad.data*reward
                grad = torch.clip(torch.nan_to_num(grad), min=-1, max=1)
                self.mu.data = self.mu - self.lr * grad
            self.dist.loc = self.clip(self.mu)
            self.step_counter = 0
            self.reward_prev = (1-self.m)*self.curr_reward + \
                self.m*self.reward_prev
            self.zero_grad()
            

## Reward

In [17]:
#|export
def AccMiniBatch(pred, gt):
    gt = gt.to(pred.device)
    indices = pred.topk(largest=True, dim=1, k=1)[1]
    return torch.sum(gt.gather(1, indices)).item()
    

## Example

In [31]:
loss_w = RLLossWeightsCumuluative(num_samples=4, reward_func=AccMiniBatch, lr=0.01, collector=20, std=0.1, min=0.1,
                                 rest_init=[1.0, 1.0, 0.1, 0.1])

In [35]:
ws = loss_w.sample()

In [36]:
loss_w.step(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'])

In [37]:
loss_w.sample()

tensor([1.1193, 0.9930, 0.1156, 0.1000])