In [1]:
#| default_exp bandits

In [2]:
%load_ext autoreload
%autoreload 2

In [8]:
#| export
import torch, numpy as np
from typing import Optional, Tuple

## Setup

In [1]:
from xcai.main import *
from xcai.basics import *

In [2]:
data_dir = '/Users/suchith720/Projects/data/'

config_key = "data_meta"
config_dir = "/Users/suchith720/Projects/mogicX/configs"
pkl_dir = f"{data_dir}/processed/mogicX"
pkl_file = get_pkl_file(pkl_dir, 'wikiseealsotitles_data-oak-for-msmarco-with-hard-negatives-test_distilbert-base-uncased', 
                        True, False, False)

In [3]:
config_file = f"{config_dir}/39_oak-for-msmarco-with-hard-negatives_test.json"

In [4]:
block = build_block(pkl_file, config_file, True, config_key=config_key, only_test=False, main_oversample=True, 
                    meta_oversample={"cat_meta":False, "lnk_meta":True}, n_slbl_samples=5, 
                    n_sdata_meta_samples={"cat_meta":2, "lnk_meta":4}, do_build=False, 
                    train_meta_topk={"lnk_meta":10}, test_meta_topk={"lnk_meta":10}, return_scores=True)

## `RLLossWeights`

In [15]:
#|export
def get_sparse_matrix(data_idx:torch.Tensor, n_data:torch.Tensor, scores:Optional[torch.Tensor]=None, 
                      size:Optional[Tuple]=None):
    data_ptr = torch.cat([torch.zeros(1, device=n_data.device, dtype=n_data.dtype), n_data.cumsum(0)])
    if scores is None: scores = torch.ones_like(data_idx)
    if data_idx.shape != scores.shape: raise ValueError(f'`data_idx` and `scores` should have same shape.')
    return (
        torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)
        if size is None else
        torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device, size=size)
    )
    

In [16]:
#| export
class RLLossWeights(torch.nn.Module):
    def __init__(self, num_samples, std=0.1, lr=0.001, reward_func=None,
                 collector=10, min=0.1, rest_init=0.1) -> None:
        super().__init__()
        init = np.ones(num_samples)
        init[:] = rest_init
        self.reward_func = reward_func
        self.collector = collector
        self.lr = lr
        self.num_samples = num_samples
        self.mu = torch.nn.Parameter(torch.Tensor(init))
        self.std = torch.nn.Parameter(torch.Tensor(np.ones(num_samples)*std),
                                      requires_grad=False)
        self.dist = torch.distributions.normal.Normal(self.mu, self.std)
        self.min = min
        self.w = None
        self.reset_metrics()

    def reset_metrics(self):
        self.collect_size = 0
        self.collect_value = 0
        self.step_counter = 0

    def sample(self, device="cpu"):
        if self.w is None:
            self.w = self.clip(self.dist.sample())
        return self.w.to(device)

    def zero_grad(self):
        self.mu.grad = None
        self.collect_size = 0
        self.collect_value = 0
        self.w = None

    def collect(self, pred, gt):
        size = pred.size(0)
        rewd = self.reward_func(pred, gt)  # TODO
        self.collect_value += rewd
        self.collect_size += size
        pass

    def step(
        self,
        inp:torch.FloatTensor,
        targ:torch.LongTensor, 
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor
    ):
        pred = inp@targ.T
        
        _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
        gt = get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ, size=(len(n_pinp2targ), idx.max()+1)).to_dense()[:, idx[:len(inp2targ_idx)]]
        
        self.step_counter += 1
        self.collect(pred, gt)
        if self.step_counter % self.collector == 0:
            loss = -self.dist.log_prob(self.w)*self.curr_reward
            loss = torch.sum(loss).backward()
            self.mu.data = self.mu - self.lr * self.mu.grad.data
            self.dist.loc = self.clip(self.mu)
            self.step_counter = 0
            self.zero_grad()

    def clip(self, vect):
        return torch.clamp(vect, min=self.min)

    @property
    def curr_reward(self):
        return self.collect_value/self.collect_size

    def extra_repr(self):
        return f"{self.mu}"
        

In [17]:
#| export
class RLLossWeightsCumuluative(RLLossWeights):
    def __init__(self, num_samples=1, std=0.01, lr=0.01, m=0.8,
                 reward_func=None, collector=10, min=0.1, rest_init=0.1) -> None:
        self.m = m
        super().__init__(num_samples, std, lr, reward_func, collector, min, rest_init)

    def reset_metrics(self):
        super().reset_metrics()
        self.reward_prev = None
        self.in_warmup = True

    def step(
        self,
        inp:torch.FloatTensor,
        targ:torch.LongTensor, 
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor
    ):
        pred = inp@targ.T
        
        _, idx = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True)
        gt = get_sparse_matrix(idx[len(inp2targ_idx):], n_pinp2targ, size=(len(n_pinp2targ), idx.max()+1)).to_dense()[:, idx[:len(inp2targ_idx)]]
    
        self.step_counter += 1
        self.collect(pred, gt)
        
        if self.step_counter % self.collector == 0:
            if self.in_warmup:
                self.in_warmup = False
                self.reward_prev = self.curr_reward
            else:
                reward = self.curr_reward - self.reward_prev
                loss = -self.dist.log_prob(self.w).sum()
                loss.backward()
                grad = self.mu.grad.data*reward
                grad = torch.clip(torch.nan_to_num(grad), min=-1, max=1)
                self.mu.data = self.mu - self.lr * grad
            self.dist.loc = self.clip(self.mu)
            self.step_counter = 0
            self.reward_prev = (1-self.m)*self.curr_reward + \
                self.m*self.reward_prev
            self.zero_grad()
            

## Reward

In [18]:
#|export
def AccMiniBatch(pred, gt):
    gt = gt.to(pred.device)
    indices = pred.topk(largest=True, dim=1, k=1)[1]
    return torch.sum(gt.gather(1, indices)).item()
    

## Example

In [21]:
loss_w = RLLossWeightsCumuluative(num_samples=4, reward_func=AccMiniBatch, lr=0.01, collector=2, std=0.1, 
                                  min=0.1, rest_init=[1.0, 0.1, 0.1, 0.1])

In [22]:
ws = loss_w.sample(); ws

tensor([1.0408, 0.1624, 0.1000, 0.1000])

In [28]:
from xcai.models.PPP0XX import DBT009
model = DBT009.from_pretrained('distilbert-base-uncased')

batch = block.train.one_batch(10)
o = model(**batch)

Some weights of DBT009 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [30]:
loss_w.step(o.data_repr, o.lbl2data_repr, batch['lbl2data_data2ptr'], batch['lbl2data_idx'], 
            batch['plbl2data_data2ptr'], batch['plbl2data_idx'])

  torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device, size=size)


In [31]:
loss_w.sample()

tensor([1.0408, 0.1624, 0.1000, 0.1000])