# RADGA Inference

In [None]:
#| default_exp 55-encoder-parallel-radga-inference-pipeline

In [None]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,sys,torch,pickle,torch.multiprocessing as mp, pickle
from xcai.basics import *
from xcai.models.radga import RAD001

In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '12,13'
os.environ['WANDB_PROJECT']='xc-nlg_55-encoder-parallel-radga-inference-pipeline'

## Imports

In [None]:
from tqdm.auto import tqdm
from packaging import version
import torch, re, math, numpy as np, os, time, datasets
from typing import Any, Tuple, Optional, Sequence, Union, Dict, List, NamedTuple
from transformers import AutoTokenizer, BatchEncoding, Seq2SeqTrainer, Seq2SeqTrainingArguments

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, RandomSampler

from torch.nn.parallel import DataParallel
from torch.nn.parallel._functions import Scatter
from torch.nn.parallel.scatter_gather import _is_namedtuple

from xcai.core import *
from xcai.data import *
from xcai.representation.search import *
from xcai.generation.trie import *
from xcai.generation.generate import *
from xcai.clustering.cluster import *

from fastcore.utils import *
from fastcore.meta import *
from fastcore.dispatch import *

In [None]:
from transformers.trainer_pt_utils import (
    find_batch_size, 
    nested_concat, nested_numpify, 
    IterableDatasetShard, 
    get_dataloader_sampler, 
    get_model_param_count,
    LengthGroupedSampler
)
from transformers.trainer_utils import has_length, denumpify_detensorize, speed_metrics, TrainOutput, HPSearchBackend, seed_worker
from transformers.trainer_callback import TrainerState
from transformers.trainer import _is_peft_model
from transformers.modeling_utils import unwrap_model
from transformers.utils import is_sagemaker_mp_enabled, is_accelerate_available, is_torch_tpu_available, logging, is_datasets_available
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow

from transformers.integrations import hp_params
from transformers.integrations.tpu import tpu_spmd_dataloader
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available

if is_accelerate_available():
    from accelerate import Accelerator, skip_first_batches
    from accelerate import __version__ as accelerate_version
    from accelerate.utils import (
        DistributedDataParallelKwargs,
        DistributedType,
        GradientAccumulationPlugin,
        load_fsdp_model,
        load_fsdp_optimizer,
        save_fsdp_model,
        save_fsdp_optimizer,
    )

    DATA_SAMPLERS = [RandomSampler]
    if version.parse(accelerate_version) > version.parse("0.23.0"):
        from accelerate.data_loader import SeedableRandomSampler

        DATA_SAMPLERS += [SeedableRandomSampler]

    if is_deepspeed_available():
        from accelerate.utils import DeepSpeedSchedulerWrapper

if is_accelerate_available("0.28.0"):
    from accelerate.utils import DataLoaderConfiguration

TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
FSDP_MODEL_NAME = "pytorch_model_fsdp"

logger = logging.get_logger(__name__)

## Data

In [None]:
data_dir = '/home/aiscuser/scratch/datasets'

In [None]:
pkl_dir = f'{data_dir}/processed/'

In [None]:
with open(f'{pkl_dir}/wikiseealso_data-metas_distilbert-base-uncased_rm_radga-final.pkl', 'rb') as file: 
    block = pickle.load(file)

In [None]:
with open(f'{pkl_dir}/wikiseealso_data_distilbert-base-uncased_xcnlg_ngame.pkl', 'rb') as file: 
    test_block = pickle.load(file)

In [None]:
# with open(f'{pkl_dir}/wikititles_data-metas_distilbert-base-uncased_rm_radga-final.pkl', 'rb') as file: 
#     block = pickle.load(file)

## MetaXCDataset

In [None]:
@patch
def _verify_inputs(cls:MetaXCDataset):
    cls.n_data,cls.n_meta = cls.data_meta.shape[0],cls.data_meta.shape[1]
    
    if cls.lbl_meta is not None:
        cls.n_lbl = cls.lbl_meta.shape[0]
        if cls.lbl_meta.shape[1] != cls.n_meta:
            raise ValueError(f'`lbl_meta`({cls.lbl_meta.shape[1]}) should have same number of columns as `data_meta`({cls.n_meta}).')

    if cls.meta_info is not None:
        n_meta = cls._verify_info(cls.meta_info)
        if n_meta != cls.n_meta:
            raise ValueError(f'`meta_info`({n_meta}) should have same number of entries as number of columns of `data_meta`({cls.n_meta})')
      

## Learner

In [None]:
from xcai.transform import PadFeatTfm

In [None]:
class RadgaLearningArguments(XCLearningArguments):

    @delegates(XCLearningArguments.__init__)
    def __init__(self, 
                 data_aug_meta_name:Optional[str]=None,
                 augmentation_num_beams:Optional[int]=3,
                 predict_with_augmentation:Optional[bool]=False,
                 use_augmentation_index_representation:Optional[bool]=False,
                 metadata_representation_attribute:Optional[str]='data_repr',
                 data_augmentation_attribute:Optional[str]='data_repr',
                 **kwargs):
        super().__init__(**kwargs)
        store_attr('data_aug_meta_name,augmentation_num_beams,predict_with_augmentation')
        store_attr('use_augmentation_index_representation,metadata_representation_attribute,data_augmentation_attribute')
        

In [None]:
class RadgaLearner(XCLearner):
    
    def __init__(self, trie:Optional[Trie]=None, **kwargs):
        super().__init__(**kwargs)
        self.tbs = TrieBeamSearch(trie, self.args.generation_eos_token, n_bm=self.args.generation_num_beams, 
                                  len_penalty=self.args.generation_length_penalty, max_info=self.args.generation_max_info, **kwargs)

        self.idxs = (
            BruteForceSearch(n_bm=self.args.representation_num_beams)
            if self.args.representation_search_type == 'BRUTEFORCE' else
            IndexSearch(space=self.args.index_space, efc=self.args.index_efc, m=self.args.index_m, 
                        efs=self.args.index_efs, n_bm=self.args.representation_num_beams, 
                        n_threads=self.args.index_num_threads) 
        )
        self.aug_idxs, self.aug_info = None, None 
        self.aug_pad = PadFeatTfm(pad_tok=self.model.config.pad_token_id, prefix="meta")


In [None]:
@patch
def _perform_augmentation(self:RadgaLearner, model:nn.Module, predict_with_augmentation:Optional[bool]=None):
    model = unwrap_model(model)
    predict_with_augmentation = self.args.predict_with_augmentation if predict_with_augmentation is None else predict_with_augmentation
    return getattr(model,'use_augmentation') if hasattr(model,'use_augmentation') else predict_with_augmentation


In [None]:
@patch
def get_meta_representation(self:XCLearner, dataloader: DataLoader, to_cpu:Optional[bool]=True):
    data_host, all_data = None, None
    
    if hasattr(self.model, 'disable_noise') and callable(getattr(self.model, 'disable_noise')):
        use_noise = self.model.disable_noise()
    
    for step, inputs in tqdm(enumerate(dataloader), total=len(dataloader)):
        inputs = inputs.to(self.model.device)
        with torch.no_grad(): data = getattr(self.model.get_meta_representation(**inputs), self.args.metadata_representation_attribute)
        data_host = self._gather_host_output(data, data_host)
        if self.args.representation_accumulation_steps is not None and (step + 1) % self.args.representation_accumulation_steps == 0:
            all_data, data_host = self._gather_all_output(data_host, all_data, to_cpu=to_cpu), None
            
    if hasattr(self.model, 'disable_noise') and callable(getattr(self.model, 'disable_noise')):
        self.model.set_noise(use_noise)
            
    return self._gather_all_output(data_host, all_data, to_cpu=to_cpu)
    

In [None]:
@patch
def get_representation(self:XCLearner, dataloader: DataLoader, to_cpu:Optional[bool]=True):
    data_host, all_data = None, None
    
    if hasattr(self.model, 'disable_noise') and callable(getattr(self.model, 'disable_noise')):
        use_noise = self.model.disable_noise()
    
    for step, inputs in tqdm(enumerate(dataloader), total=len(dataloader)):
        inputs = inputs.to(self.model.device)
        with torch.no_grad(): data = getattr(self.model(**inputs), self.args.representation_attribute)
        data_host = self._gather_host_output(data, data_host)
        if self.args.representation_accumulation_steps is not None and (step + 1) % self.args.representation_accumulation_steps == 0:
            all_data, data_host = self._gather_all_output(data_host, all_data, to_cpu=to_cpu), None
            
    if hasattr(self.model, 'disable_noise') and callable(getattr(self.model, 'disable_noise')):
        self.model.set_noise(use_noise)
            
    return self._gather_all_output(data_host, all_data, to_cpu=to_cpu)
    

In [None]:
@patch
def _build_aug_index(self:RadgaLearner, dataset:Optional[Dataset]=None):
    dataset = dataset if self.eval_dataset is None else self.eval_dataset
    dataset = dataset if self.train_dataset is None else self.train_dataset
    
    aug_meta_name = f'{self.args.data_aug_meta_name}_meta' if self.args.data_aug_meta_name is not None else None
    if (
        dataset is not None and dataset.meta is not None and aug_meta_name is not None and 
        aug_meta_name in dataset.meta
    ):
        self.aug_idxs = IndexSearch(space=self.args.index_space, efc=self.args.index_efc, m=self.args.index_m, 
                                    efs=self.args.index_efs, n_bm=self.args.representation_num_beams, 
                                    n_threads=self.args.index_num_threads)
        
        self.aug_info = getattr(dataset.meta[aug_meta_name], 'meta_info')
        
        aug_dset = MainXCDataset(self.aug_info)
        aug_dl = self.get_test_dataloader(aug_dset)
        aug_repr = self.get_meta_representation(aug_dl, to_cpu=isinstance(self.aug_idxs, IndexSearch))
        self.aug_idxs.build(aug_repr)
        

In [None]:
@patch
def _build_lbl_index(self:XCLearner, dataset:Optional[Dataset]=None):
    dataset = dataset if self.eval_dataset is None else self.eval_dataset
    dataset = dataset if self.train_dataset is None else self.train_dataset
    
    if dataset is not None:
        lbl_dset = dataset.lbl_dset
        
        meta_name = f'{self.args.data_aug_meta_name}_meta' if self.args.data_aug_meta_name is not None else None
        if meta_name is not None and dataset.meta is not None and meta_name in dataset.meta:
            prefix,lbl_meta,meta_info  = dataset.meta[meta_name].prefix,dataset.meta[meta_name].lbl_meta,dataset.meta[meta_name].meta_info
            meta_kwargs = {meta_name: MetaXCDataset(prefix, lbl_meta, lbl_meta, meta_info, n_data_meta_samples=self.args.augmentation_num_beams)}
            lbl_dset = XCDataset(lbl_dset, **meta_kwargs)
        
        lbl_dl = self.get_test_dataloader(lbl_dset)
        lbl_repr = self.get_representation(lbl_dl, to_cpu=isinstance(self.idxs, IndexSearch))
        self.idxs.build(lbl_repr)
    else: raise ValueError('Failed to build `self.idxs`')
        

In [None]:
@patch
def evaluate(self:RadgaLearner, eval_dataset:Optional[Dataset]=None, ignore_keys:Optional[List[str]]=None, 
             metric_key_prefix:str="eval", **gen_kwargs):
    gen_kwargs = gen_kwargs.copy()
    if gen_kwargs.get("length_penalty") is None and self.args.generation_length_penalty is not None:
        gen_kwargs["length_penalty"] = self.args.generation_length_penalty
    if gen_kwargs.get("gen_num_beams") is None and self.args.generation_num_beams is not None:
        gen_kwargs["gen_num_beams"] = self.args.generation_num_beams
    if gen_kwargs.get("repr_num_beams") is None and self.args.representation_num_beams is not None:
        gen_kwargs["repr_num_beams"] = self.args.representation_num_beams
    if gen_kwargs.get("aug_num_beams") is None and self.args.augmentation_num_beams is not None:
        gen_kwargs["aug_num_beams"] = self.args.augmentation_num_beams
        
    self.gather_function, self._gen_kwargs  = self.accelerator.gather, gen_kwargs
        
    if self._perform_representation(unwrap_model(self.model)) and not self.args.prediction_loss_only: 
        self._build_lbl_index(eval_dataset)
        
    if self._perform_augmentation(unwrap_model(self.model)) and not self.args.prediction_loss_only: 
        self._build_aug_index(eval_dataset)

    return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

@patch
def predict(self:RadgaLearner, test_dataset: Dataset, ignore_keys:Optional[List[str]]=None, 
            metric_key_prefix:str="test", **gen_kwargs):
    gen_kwargs = gen_kwargs.copy()
    if gen_kwargs.get("length_penalty") is None and self.args.generation_length_penalty is not None:
        gen_kwargs["length_penalty"] = self.args.generation_length_penalty
    if gen_kwargs.get("gen_num_beams") is None and self.args.generation_num_beams is not None:
        gen_kwargs["gen_num_beams"] = self.args.generation_num_beams
    if gen_kwargs.get("repr_num_beams") is None and self.args.representation_num_beams is not None:
        gen_kwargs["repr_num_beams"] = self.args.representation_num_beams
    if gen_kwargs.get("aug_num_beams") is None and self.args.augmentation_num_beams is not None:
        gen_kwargs["aug_num_beams"] = self.args.augmentation_num_beams

    self.gather_function, self._gen_kwargs = self.accelerator.gather, gen_kwargs
    self._memory_tracker.start()
    
    if self._perform_representation(unwrap_model(self.model)) and not self.args.prediction_loss_only: 
        self._build_lbl_index(test_dataset)
        
    if self._perform_augmentation(unwrap_model(self.model)) and not self.args.prediction_loss_only: 
        self._build_aug_index(test_dataset)

    test_dataloader = self.get_test_dataloader(test_dataset)
    start_time = time.time()

    output = self.evaluation_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
    total_batch_size = self.args.eval_batch_size * self.args.world_size
    if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
        start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
    output.metrics.update(
        speed_metrics(metric_key_prefix,start_time,num_samples=output.num_samples,num_steps=math.ceil(output.num_samples / total_batch_size),)
    )
    self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
    self._memory_tracker.stop_and_update_metrics(output.metrics)
    return XCPredictionOutput(pred_idx=output.pred_idx, pred_ptr=output.pred_ptr, pred_score=output.pred_score, 
                          gen_output=output.gen_output, repr_output=output.repr_output, metrics=output.metrics, 
                          num_samples=output.num_samples)


In [None]:
@patch
def augmentation_output(
    self:RadgaLearner,
    model:nn.Module,
    inputs:Dict[str, Union[torch.Tensor, Any]],
    **kwargs
):
    if self.aug_idxs is None: raise ValueError('Augmentation `aug_idx` is not initialized.')
        
    inputs = self._prepare_inputs(inputs)
    n_bm = kwargs.pop("aug_num_beams") if "aug_num_beams" in kwargs and kwargs["aug_num_beams"] is not None else self.args.augmentation_num_beams
    
    with torch.no_grad(): o = getattr(model(**inputs), self.args.data_augmentation_attribute)
    o = self.aug_idxs.proc(o, n_bm=n_bm)
    
    aug_info = self.aug_pad({
        'meta_input_ids':[self.aug_info['input_ids'][i] for i in o['info2data_idx']], 
        'meta_attention_mask':[self.aug_info['input_ids'][i] for i in o['info2data_idx']]
    })
    
    if self.args.use_augmentation_index_representation:
        meta_repr = torch.tensor(self.aug_idxs.index.get_items(o['info2data_idx']))
        return {
            f'{self.args.data_aug_meta_name}2data_meta_repr': meta_repr,
            f'{self.args.data_aug_meta_name}2data_attention_mask': aug_info['meta_attention_mask'],
            f'{self.args.data_aug_meta_name}2data_data2ptr': o['info2data_data2ptr'],
        }
    else:
        return {
            f'{self.args.data_aug_meta_name}2data_idx':o['info2data_idx'], 
            f'{self.args.data_aug_meta_name}2data_input_ids': aug_info['meta_input_ids'], 
            f'{self.args.data_aug_meta_name}2data_attention_mask': aug_info['meta_attention_mask'],
            f'{self.args.data_aug_meta_name}2data_data2ptr': o['info2data_data2ptr']
        }
    

In [None]:
@patch
def prediction_step(
    self:RadgaLearner,
    model: nn.Module,
    inputs: Dict[str, Union[torch.Tensor, Any]],
    prediction_loss_only: bool,
    predict_with_generation: bool,
    predict_with_representation: bool,
    predict_with_augmentation:Optional[bool]=None,
    ignore_keys: Optional[List[str]] = None,
    **kwargs,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
    with torch.no_grad():
        with self.compute_loss_context_manager(): outputs = model(**inputs)
        loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
    prediction_loss_only = self.args.prediction_loss_only if prediction_loss_only is None else prediction_loss_only
    if prediction_loss_only: return loss, {}
    
    if self._perform_augmentation(model, predict_with_augmentation): 
        aug_inputs = self.augmentation_output(model, inputs, **kwargs)
        inputs.update(aug_inputs)
        
    output, gen_o, repr_o = None, None, None
    if self._perform_generation(model, predict_with_generation): gen_o = self.generation_output(model, inputs, **kwargs)
    if self._perform_representation(model, predict_with_representation): repr_o = self.representation_output(model, inputs, **kwargs)
    
    if gen_o is not None and repr_o is not None:
        output = {f'{k}_gen':v for k,v in gen_o.items()}
        output.update({f'{k}_repr':v for k,v in repr_o.items()})
        output.update(self.concatenate_output(gen_o, repr_o))
    else:
        output = gen_o if repr_o is None else repr_o
        
    labels = {'targ_idx':inputs[self.args.target_indices_key], 'targ_ptr':inputs[self.args.target_pointer_key]} if self.args.target_indices_key in inputs else None
    if labels is not None: output.update(labels)
    
    return loss, output

## Training

## `RAD002`

In [None]:
from xcai.models.radga import Encoder, RAD001, EncoderOutput, RADOutput

In [None]:
class Parameters:
    
    @staticmethod
    def from_meta_aug_prefix(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^{prefix}.*_(input_ids|attention_mask|data2ptr|meta_repr)$', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs
    
    @staticmethod
    def from_feat_meta_aug_prefix(feat:str, prefix:str, **kwargs):
        keys = ['attention_mask', 'input_ids', 'meta_repr']
        
        inputs = {f'{prefix}_{k}': kwargs[f'{prefix}_{k}'] for k in keys if f'{prefix}_{k}' in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            inputs.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        return inputs
    
    @staticmethod
    def from_meta_pred_prefix(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^[p]?{prefix}.*', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            if arg[0] == 'p': 
                inputs.setdefault(meta[1:], {})[f'p{param}'] = kwargs[arg]
            else: 
                inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs
    
    @staticmethod
    def get_meta_loss_weights(lw:Union[float,List], n_meta:int):
        if isinstance(lw, float):
            lw = lw/n_meta if n_meta else None
            return [lw] * n_meta
        else:
            if len(lw) != n_meta: raise ValueError(f'length of `lw` should be equal to number of metadata.')
            return lw
        

In [None]:
class RAD002Encoder(Encoder):
    
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        
    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, prefix:str, **kwargs):
        meta_kwargs = Parameters.from_meta_aug_prefix(prefix, **kwargs)
        meta_repr, weights, performed_fusion = {}, [], False
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            if len(idx):
                performed_fusion = True
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
                                                                m_args['data2ptr'][idx])

                    if self.use_noise:
                        n_input_ids, n_attention_mask = self.get_noise(m_args['input_ids'], m_args['attention_mask'], 
                                                                       m_args['data2ptr'][idx])
                        m_input_ids, m_attention_mask = self.add_noise(m_input_ids, m_attention_mask, 
                                                                       n_input_ids, n_attention_mask)

                    m_embed = self.encode(m_input_ids, m_attention_mask)[0]

                    m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
                    m_repr_mask = torch.any(m_attention_mask, dim=1)
                    
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
                
                meta_repr[m_key] = m_repr[:, :-1][m_repr_mask[:, :-1]] if self.use_noise else m_repr[m_repr_mask]
                meta_repr[m_key] = F.normalize(meta_repr[m_key], dim=1)
                
                fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
                embed[idx] += fused_embed
                weights.append(w)
        
        if performed_fusion: embed = self.linear_layer(embed)        
        return embed, weights, meta_repr
    
    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_gen_idx:Optional[torch.Tensor]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = data_fused_logits = fusion_weights = meta_repr = None
        if data_aug_meta_prefix is not None:
            data_fused_embed, fusion_weights, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
                                                                                         data_attention_mask, 
                                                                                         data_aug_meta_prefix, 
                                                                                         **kwargs)
            data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
            data_fused_logits = self.gen(data_fused_embed if data_gen_idx is None else data_fused_embed[data_gen_idx])
        
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            logits=data_fused_logits,
            
            meta_repr=meta_repr,
        )
    

In [None]:
class RAD002(RAD001):
    
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        
        self.encoder = RAD002Encoder(config, use_noise=kwargs['use_noise'], resize_length=kwargs['resize_length'])
        self.post_init()
        self.remap_post_init()
        self.init_retrieval_head()
        
    def disable_noise(self):
        use_noise = self.encoder.module.use_noise if isinstance(self.encoder, XCDataParallel) else self.encoder.use_noise
        if isinstance(self.encoder, XCDataParallel): self.encoder.module.use_noise = False
        else: self.encoder.use_noise = False
        return use_noise
    
    def set_noise(self, use_noise):
        if isinstance(self.encoder, XCDataParallel): self.encoder.module.use_noise = use_noise
        else: self.encoder.use_noise = use_noise
            
    def get_noise(self):
        return self.encoder.module.use_noise if isinstance(self.encoder, XCDataParallel) else self.encoder.use_noise
        
        
    def get_meta_representation(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_unnormalized=True, data_type="meta")
        return RADOutput(
            logits=data_o.logits,
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )
        
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
                                 **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.logits, data_o.fused_rep, lbl2data_o.logits, lbl2data_o.fused_rep, 
                                     data_input_ids,lbl2data_input_ids,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.fused_rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.fused_rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return RADOutput(
            loss=loss,
            
            logits=data_o.logits,
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )

In [None]:
args = RadgaLearningArguments(
    output_dir='/home/aiscuser/outputs/48-encoder-parallel-radga-with-cross-attention-loss-component-for-wikiseealso-1-0',
    per_device_train_batch_size=200,
    per_device_eval_batch_size=100,
    representation_num_beams=200,
    representation_accumulation_steps=1,
    predict_with_representation=True,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='INDEX',
    output_concatenation_weight=1.0,
    metric_for_best_model='P@1_REPR',
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    
    
    augmentation_num_beams=3,
    data_aug_meta_name='hlk',
    
    representation_attribute='data_fused_repr',
    data_augmentation_attribute='data_repr',
    metadata_representation_attribute='data_repr',
    
    use_augmentation_index_representation=True,
    
    predict_with_augmentation=False,
    
    use_encoder_parallel=True,
    fp16=True,
    label_names=['cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask',
                 'cat2lbl2data_idx', 'cat2lbl2data_input_ids', 'cat2lbl2data_attention_mask',
                 'hlk2data_idx', 'hlk2data_input_ids', 'hlk2data_attention_mask',
                 'hlk2lbl2data_idx', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask',],
)

In [None]:
output_dir = f"/home/aiscuser/scratch/Projects/xc_nlg/outputs/{os.path.basename(args.output_dir)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

In [None]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = RAD002.from_pretrained(mname, num_batch_labels=5000, ignore_token=0, batch_size=bsz,
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='hlk2data', lbl2data_aug_meta_prefix='hlk2lbl', 
                               resize_length=5000,
                               
                               gen_loss_weight=0.001, meta_loss_weight=0.3, pred_meta_prefix='cat', 
                               
                               fusion_loss_weight=0.05, tie_word_embeddings=False,
                               
                               use_fusion_loss=False, use_noise=False, use_encoder_parallel=True)


In [None]:
trie = XCTrie.from_block(block)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [None]:
metric = PrecRecl(block.n_lbl, block.train.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
learn = RadgaLearner(
    model=model, 
    args=args,
    trie=trie,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

In [None]:
model.use_generation = False

In [None]:
o = learn.predict(block.train.dset)

[2024-06-11 12:37:18,228] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


  0%|          | 0/1562 [00:00<?, ?it/s]

node-0:273706:273706 [0] NCCL INFO Bootstrap : Using eth0:10.13.51.163<0>
node-0:273706:273706 [0] NCCL INFO NET/Plugin : Plugin load (librccl-net.so) returned 2 : librccl-net.so: cannot open shared object file: No such file or directory
node-0:273706:273706 [0] NCCL INFO NET/Plugin : No plugin found, using internal implementation
node-0:273706:273706 [0] NCCL INFO Kernel version: 5.15.0-1042-azure
RCCL version 2.17.1+hip5.7 HEAD:cbbb3d8+

node-0:273706:279073 [0] /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/rccl/build/hipify/src/misc/ibvwrap.cc:222 NCCL WARN Call to ibv_open_device failed

node-0:273706:279073 [0] /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/rccl/build/hipify/src/transport/net_ib.cc:199 NCCL WARN NET/IB : Unable to open device mlx5_0

node-0:273706:279073 [0] /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/rccl/build/hipify/src/misc/ibvwrap.cc:222 NCCL WARN Call to ibv_open_device failed

node-0:273706:279

node-0:273706:279073 [0] NCCL INFO Ring 11 : 1 -> 0 -> 1 comm 0x1b6ce640 nRanks 02 busId d00000
node-0:273706:279073 [0] NCCL INFO Ring 12 : 1 -> 0 -> 1 comm 0x1b6ce640 nRanks 02 busId d00000
node-0:273706:279073 [0] NCCL INFO Ring 13 : 1 -> 0 -> 1 comm 0x1b6ce640 nRanks 02 busId d00000
node-0:273706:279073 [0] NCCL INFO Ring 14 : 1 -> 0 -> 1 comm 0x1b6ce640 nRanks 02 busId d00000
node-0:273706:279073 [0] NCCL INFO Ring 15 : 1 -> 0 -> 1 comm 0x1b6ce640 nRanks 02 busId d00000
node-0:273706:279073 [0] NCCL INFO Ring 16 : 1 -> 0 -> 1 comm 0x1b6ce640 nRanks 02 busId d00000
node-0:273706:279073 [0] NCCL INFO Ring 17 : 1 -> 0 -> 1 comm 0x1b6ce640 nRanks 02 busId d00000
node-0:273706:279073 [0] NCCL INFO Ring 18 : 1 -> 0 -> 1 comm 0x1b6ce640 nRanks 02 busId d00000
node-0:273706:279073 [0] NCCL INFO Ring 19 : 1 -> 0 -> 1 comm 0x1b6ce640 nRanks 02 busId d00000
node-0:273706:279073 [0] NCCL INFO Ring 20 : 1 -> 0 -> 1 comm 0x1b6ce640 nRanks 02 busId d00000
node-0:273706:279073 [0] NCCL INFO Ring 

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


  self._set_arrayXarray(i, j, x)


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,77.9066,41.1329,28.0339,15.6834,77.9066,79.2724,80.8556,82.5391,69.5293,71.1523,74.2569,80.0041,69.5293,75.694,78.3396,80.6496,88.1427,92.5086,92.5086,0.0094,3882.6119,178.331,0.892


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,36.6845,21.9946,15.8711,9.7873,36.6845,30.3162,29.881,30.9404,28.4921,26.0447,25.7811,27.6033,28.4921,28.7257,30.5264,33.2445,32.211,42.3966,42.3966,0.0482,3912.2858,200.329,1.002


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,29.9428,20.2432,15.5844,10.1736,29.9428,30.3275,31.7684,34.1689,24.7682,26.8411,29.3415,34.5346,24.7682,27.3513,29.3419,31.9797,40.3797,54.7225,54.7225,0.0492,962.8372,184.367,0.922


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,28.7412,19.2727,14.8497,9.725,28.7412,29.182,30.6493,33.0577,23.6155,25.3211,27.6706,32.679,23.6155,26.0438,27.989,30.5923,39.391,53.9873,53.9873,0.0489,977.6731,181.569,0.908


In [None]:
o = learn.predict(test_block.test.dset)

  0%|          | 0/1562 [00:00<?, ?it/s]

  0%|          | 0/12292 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,25.688,17.4213,13.4276,8.7873,25.688,26.108,27.3521,29.4336,21.5048,23.2458,25.352,29.8492,21.5048,23.7806,25.4808,27.7529,34.8096,47.4142,47.4142,0.0502,775.4204,228.927,1.145
