# learner

In [None]:
#| default_exp learner

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from tqdm.auto import tqdm
from packaging import version
import torch, re, math, numpy as np, os, time, datasets, pickle
from typing import Any, Tuple, Optional, Sequence, Union, Dict, List, NamedTuple
from transformers import AutoTokenizer, BatchEncoding, Seq2SeqTrainer, Seq2SeqTrainingArguments

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, RandomSampler

from torch.nn.parallel import DataParallel
from torch.nn.parallel._functions import Scatter
from torch.nn.parallel.scatter_gather import _is_namedtuple

from xcai.core import *
from xcai.data import *
from xcai.representation.search import *
from xcai.generation.trie import *
from xcai.generation.generate import *
from xcai.clustering.cluster import *
from xcai.transform import PadFeatTfm

from fastcore.utils import *
from fastcore.meta import *
from fastcore.dispatch import *

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
#| export
from transformers.trainer_pt_utils import (
    find_batch_size, 
    nested_concat, nested_numpify, 
    IterableDatasetShard, 
    get_dataloader_sampler, 
    get_model_param_count,
    LengthGroupedSampler
)
from transformers.trainer_utils import has_length, denumpify_detensorize, speed_metrics, TrainOutput, HPSearchBackend, seed_worker
from transformers.trainer_callback import TrainerState
from transformers.trainer import _is_peft_model
from transformers.modeling_utils import unwrap_model
from transformers.utils import is_sagemaker_mp_enabled, is_accelerate_available, is_torch_tpu_available, logging, is_datasets_available
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow

from transformers.integrations import hp_params
from transformers.integrations.tpu import tpu_spmd_dataloader
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available

if is_accelerate_available():
    from accelerate import Accelerator, skip_first_batches
    from accelerate import __version__ as accelerate_version
    from accelerate.utils import (
        DistributedDataParallelKwargs,
        DistributedType,
        GradientAccumulationPlugin,
        load_fsdp_model,
        load_fsdp_optimizer,
        save_fsdp_model,
        save_fsdp_optimizer,
    )

    DATA_SAMPLERS = [RandomSampler]
    if version.parse(accelerate_version) > version.parse("0.23.0"):
        from accelerate.data_loader import SeedableRandomSampler

        DATA_SAMPLERS += [SeedableRandomSampler]

    if is_deepspeed_available():
        from accelerate.utils import DeepSpeedSchedulerWrapper

if is_accelerate_available("0.28.0"):
    from accelerate.utils import DataLoaderConfiguration

TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
FSDP_MODEL_NAME = "pytorch_model_fsdp"

logger = logging.get_logger(__name__)

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| hide
from xcai.block import *
from xcai.models.PPP0XX import *
from xcai.metrics import *

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

## Setup

In [None]:
#| hide
block = XCBlock.from_cfg('/home/aiscuser/scratch/datasets', 'train')

  self._set_arrayXarray(i, j, x)


In [None]:
#| hide
batch = block.train.one_batch(11)

In [None]:
#| hide
m = BT0002.from_pretrained('bert-base-uncased', tn_targ=10_000, ig_tok=0)

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of BT0002 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
#| hide
batch.keys()

dict_keys(['lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

In [None]:
#| hide
b = prepare_batch(m, batch, m_args='lbl2data_idx')

In [None]:
#| hide
b.keys()

dict_keys(['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_token_type_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_input_ids', 'data_token_type_ids', 'data_attention_mask'])

In [None]:
#| hide
m = m.to('cuda')
b = b.to('cuda')

In [None]:
#| hide
o = m(**b)

In [None]:
#| hide
o.loss

tensor(14.9452, device='cuda:0', grad_fn=<SumBackward0>)

In [None]:
#| export
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data_distilbert-base-uncased_xcnlg_ngame.pkl'

In [None]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

## DataParallel

In [None]:
#| export
def scatter(inputs, target_gpus, chunk_sizes=None, dim=0):
    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
        if _is_namedtuple(obj):
            return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return [list(i) for i in zip(*map(scatter_map, obj))]
        if isinstance(obj, dict) and len(obj) > 0:
            return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
        return [obj for _ in target_gpus] 
    try:
        res = scatter_map(inputs)
    finally:
        scatter_map = None
    return res
    
def scatter_kwargs(
    inputs: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]],
    target_gpus: Sequence[Union[int, torch.device]],
    chunk_sizes: Optional[Sequence[int]]=None,
    dim: int = 0,
) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
    scattered_inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
    scattered_kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
    if len(scattered_inputs) < len(scattered_kwargs):
        scattered_inputs.extend(() for _ in range(len(scattered_kwargs) - len(scattered_inputs)))
    elif len(scattered_kwargs) < len(inputs):
        scattered_kwargs.extend({} for _ in range(len(scattered_inputs) - len(scattered_kwargs)))
    return scattered_inputs, scattered_kwargs
    

In [None]:
#| export
class XCDataParallel(DataParallel):

    @delegates(DataParallel.__init__)
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def _get_feat_name(self, x:Optional[Dict[str, Any]]):
        return list(set([k.split('_', maxsplit=1)[0] for k in x]))
    
    def _extract_feat(self, x:Optional[Dict[str, Any]], prefix:str):
        return {k:v for k,v in x.items() if re.match(f'^{prefix}_(?!.*2ptr)', k) or re.match(f'^.*_{prefix}2ptr$', k)}

    def scatter(
        self,
        inputs: Tuple[Any, ...],
        kwargs: Optional[Dict[str, Any]],
        device_ids: Sequence[Union[int, torch.device]],
    ) ->Any:
        if len(inputs): raise ValueError('`inputs` should be empty.')    
        feat_name = self._get_feat_name(kwargs)
        
        data_feat = self._extract_feat(kwargs, 'data')
        scattered_inputs, scattered_kwargs = scatter_kwargs(inputs, data_feat, device_ids, None, dim=self.dim)
        feat_name.remove('data')
        
        for k in feat_name:
            ptr_name = f'{k}_data2ptr'
            if ptr_name in scattered_kwargs[0] and scattered_kwargs[0][ptr_name] is not None:
                chunk_sz = [o[ptr_name].sum().item() for o in scattered_kwargs]
                if len(chunk_sz) < len(device_ids): 
                    chunk_sz.extend([0 for _ in range(len(device_ids) - len(chunk_sz))])
                
                feat = self._extract_feat(kwargs, k)
                _, o = scatter_kwargs(inputs, feat, device_ids, chunk_sz, dim=self.dim)
                for p,q in zip(scattered_kwargs, o): p.update(q)
                    
        return tuple(scattered_inputs), tuple(scattered_kwargs)
        

### Example

In [None]:
import pickle
from transformers import AutoTokenizer, BatchEncoding

tokz = AutoTokenizer.from_pretrained('distilbert-base-uncased')

In [None]:
data_dir = '/home/aiscuser/scratch/datasets'
pkl_dir = f'{data_dir}/processed/'

In [None]:
with open(f'{pkl_dir}/wikiseealso_data-metas_distilbert-base-uncased_rm_radga.pkl', 'rb') as file: 
    block = pickle.load(file)

In [None]:
with open(f'{pkl_dir}/wikiseealso_data-metas_distilbert-base-uncased_xcnlg_radga.pkl', 'rb') as file: 
    block = pickle.load(file)

In [None]:
b = block.train.one_batch(4)
bb = BatchEncoding({k:v for k,v in b.items() if isinstance(v, torch.Tensor)})

In [None]:
for k,v in bb.items():
    if 'ptr' in k: print(k, ': ', v.shape)

plbl2data_data2ptr :  torch.Size([4])
lbl2data_data2ptr :  torch.Size([4])
pcat2data_data2ptr :  torch.Size([4])
cat2data_data2ptr :  torch.Size([4])
pcat2lbl2data_data2ptr :  torch.Size([4])
cat2lbl2data_data2ptr :  torch.Size([4])
phlk2data_data2ptr :  torch.Size([4])
hlk2data_data2ptr :  torch.Size([4])
hlk2lbl2data_data2ptr :  torch.Size([4])
hlk2lbl2data_plbl2data2ptr :  torch.Size([6])


In [None]:
#| hide
class MyModel(nn.Module):

    def forward(self, **kwargs):
        for k,v in kwargs.items(): 
            if isinstance(v, torch.Tensor): print(k, ': ', v, ', ', v.device)
        return kwargs
        

In [None]:
#| hide
m = XCDataParallel(module=MyModel())

In [None]:
#| hide
o = m(**bb)

plbl2data_data2ptrplbl2data_data2ptr :   :  tensor([2, 1], device='cuda:0')tensor([1, 2], device='cuda:1')  , ,   cuda:0cuda:1

lbl2data_data2ptrlbl2data_data2ptr  : :   tensor([1, 2], device='cuda:1')tensor([2, 1], device='cuda:0') ,   , cuda:1 
cuda:0pcat2data_data2ptr
 pcat2data_data2ptr:   : tensor([14,  6], device='cuda:1')  tensor([13,  6], device='cuda:0'),   cuda:1,  
cat2data_data2ptrcuda:0
 : cat2data_data2ptr  :  tensor([1, 1], device='cuda:1') tensor([1, 1], device='cuda:0') , ,   cuda:0cuda:1

pcat2lbl2data_data2ptrpcat2lbl2data_data2ptr :  :   tensor([4, 7], device='cuda:1')tensor([0, 4], device='cuda:0') ,   , cuda:0 
cat2lbl2data_data2ptrcuda:1
 cat2lbl2data_data2ptr:   :  tensor([0, 1], device='cuda:0') tensor([1, 1], device='cuda:1'),   , cuda:0 
cuda:1phlk2data_data2ptr
 phlk2data_data2ptr:   : tensor([16, 18], device='cuda:0')  , tensor([15, 40], device='cuda:1') cuda:0 
, hlk2data_data2ptr  cuda:1: 
hlk2data_data2ptr  tensor([3, 3], device='cuda:0'):   , tensor([3,

In [None]:
np.all([torch.all(bb[k] == o[k].to('cpu')) for k in o.keys()])

True

## Learner

In [None]:
#| export
class XCEvalLoopOutput(NamedTuple):
    pred_idx: Union[np.ndarray, Tuple[np.ndarray]]
    pred_ptr: Union[np.ndarray, Tuple[np.ndarray]]
    pred_score: Union[np.ndarray, Tuple[np.ndarray]]
    targ_idx: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
    targ_ptr: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
    gen_output: Optional[Dict]
    repr_output: Optional[Dict]
    metrics: Optional[Dict[str, float]]
    num_samples: Optional[int]

class XCPredictionOutput(NamedTuple):
    pred_idx: Union[np.ndarray, Tuple[np.ndarray]]
    pred_ptr: Union[np.ndarray, Tuple[np.ndarray]]
    pred_score: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
    gen_output: Optional[Dict]
    repr_output: Optional[Dict]
    metrics: Optional[Dict[str, float]]
    num_samples: Optional[int]
    

In [None]:
#| export
class XCLearningArguments(Seq2SeqTrainingArguments):

    @delegates(Seq2SeqTrainingArguments.__init__)
    def __init__(self, 
                 use_encoder_parallel:Optional[bool]=False,
                 generation_length_penalty:Optional[float]=1.0,
                 generation_eos_token:Optional[int]=102,
                 generation_num_beams:Optional[int]=5,
                 generation_max_info:Optional[int]=None,
                 representation_accumulation_steps:Optional[int]=None,
                 representation_attribute:Optional[str]='data_repr',
                 representation_num_beams:Optional[int]=5,
                 representation_search_type:Optional[str]='INDEX',
                 index_space:Optional[str]='cosine', 
                 index_efc:Optional[int]=300, 
                 index_m:Optional[int]=100, 
                 index_efs:Optional[int]=300,
                 index_num_threads:Optional[int]=84,
                 predict_with_generation:Optional[bool]=False,
                 predict_with_representation:Optional[bool]=False,
                 output_concatenation_weight:Optional[float]=1.0,
                 group_by_cluster:Optional[bool]=False,
                 num_clustering_warmup_epochs:Optional[int]=None,
                 num_cluster_update_epochs:Optional[int]=1,
                 num_cluster_size_update_epochs:Optional[int]=1,
                 clustering_type:Optional[str]='EXPO',
                 minimum_clusters:Optional[int]=3,
                 maximum_clusters:Optional[int]=None,
                 minimum_cluster_size:Optional[int]=1,
                 maximum_cluster_size:Optional[int]=None,
                 clustering_devices:Optional[List]=None,
                 target_indices_key:Optional[str]='lbl2data_idx',
                 target_pointer_key:Optional[str]='lbl2data_data2ptr',
                 data_aug_meta_name:Optional[str]=None,
                 augmentation_num_beams:Optional[int]=3,
                 predict_with_augmentation:Optional[bool]=False,
                 use_augmentation_index_representation:Optional[bool]=False,
                 metadata_representation_attribute:Optional[str]='data_repr',
                 data_augmentation_attribute:Optional[str]='data_repr',
                 use_distributional_representation:Optional[bool]=False,
                 **kwargs):
        super().__init__(**kwargs)
        store_attr('generation_num_beams,generation_length_penalty,generation_max_info,generation_eos_token')
        store_attr('representation_accumulation_steps,representation_attribute,representation_num_beams,representation_search_type')
        store_attr('index_space,index_efc,index_m,index_efs,index_num_threads')
        store_attr('predict_with_generation,predict_with_representation,output_concatenation_weight')
        store_attr('group_by_cluster,num_cluster_update_epochs,num_cluster_size_update_epochs,num_clustering_warmup_epochs')
        store_attr('clustering_devices,clustering_type,maximum_cluster_size')
        store_attr('target_indices_key,target_pointer_key')
        store_attr('use_encoder_parallel')
        store_attr('data_aug_meta_name,augmentation_num_beams,predict_with_augmentation')
        store_attr('use_augmentation_index_representation,metadata_representation_attribute,data_augmentation_attribute')
        store_attr('use_distributional_representation')
        self.minimum_clusters = max(1, minimum_clusters)
        self.maximum_clusters = max(minimum_clusters, maximum_clusters) if maximum_clusters is not None else minimum_clusters
        self.minimum_cluster_size = max(1, minimum_cluster_size)
        

### `XCLearner`

In [None]:
#| export
class XCLearner(Seq2SeqTrainer):

    @delegates(Seq2SeqTrainer.__init__)
    def __init__(self, 
                 trie:Optional[Trie]=None, 
                 **kwargs):
        super().__init__(**kwargs)
        self.tbs = TrieBeamSearch(trie, self.args.generation_eos_token, n_bm=self.args.generation_num_beams, 
                                  len_penalty=self.args.generation_length_penalty, max_info=self.args.generation_max_info, **kwargs)
        self.idxs = (
            BruteForceSearch(n_bm=self.args.representation_num_beams)
            if self.args.representation_search_type == 'BRUTEFORCE' else
            IndexSearch(space=self.args.index_space, efc=self.args.index_efc, m=self.args.index_m, 
                        efs=self.args.index_efs, n_bm=self.args.representation_num_beams, 
                        n_threads=self.args.index_num_threads) 
        )
        self.aug_idxs, self.aug_info = None, None 
        self.aug_pad = PadFeatTfm(pad_tok=self.model.config.pad_token_id, prefix="meta")

    def _wrap_model(self, model, training=True, dataloader=None):
        if unwrap_model(model) is not model:
            return model

        if self.args.n_gpu > 1:
            if (hasattr(model, 'encoder') and isinstance(model.encoder, nn.DataParallel)) or self.args.use_encoder_parallel: return model
            else: return XCDataParallel(module=model)
        return model

    def evaluate(self, eval_dataset:Optional[Dataset]=None, ignore_keys:Optional[List[str]]=None, 
             metric_key_prefix:str="eval", **gen_kwargs):
        gen_kwargs = gen_kwargs.copy()
        if gen_kwargs.get("length_penalty") is None and self.args.generation_length_penalty is not None:
            gen_kwargs["length_penalty"] = self.args.generation_length_penalty
        if gen_kwargs.get("gen_num_beams") is None and self.args.generation_num_beams is not None:
            gen_kwargs["gen_num_beams"] = self.args.generation_num_beams
        if gen_kwargs.get("repr_num_beams") is None and self.args.representation_num_beams is not None:
            gen_kwargs["repr_num_beams"] = self.args.representation_num_beams
        if gen_kwargs.get("aug_num_beams") is None and self.args.augmentation_num_beams is not None:
            gen_kwargs["aug_num_beams"] = self.args.augmentation_num_beams
            
        self.gather_function, self._gen_kwargs  = self.accelerator.gather, gen_kwargs
        
        return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

    def predict(self, test_dataset: Dataset, ignore_keys:Optional[List[str]]=None, 
            metric_key_prefix:str="test", **gen_kwargs):
        gen_kwargs = gen_kwargs.copy()
        if gen_kwargs.get("length_penalty") is None and self.args.generation_length_penalty is not None:
            gen_kwargs["length_penalty"] = self.args.generation_length_penalty
        if gen_kwargs.get("gen_num_beams") is None and self.args.generation_num_beams is not None:
            gen_kwargs["gen_num_beams"] = self.args.generation_num_beams
        if gen_kwargs.get("repr_num_beams") is None and self.args.representation_num_beams is not None:
            gen_kwargs["repr_num_beams"] = self.args.representation_num_beams
        if gen_kwargs.get("aug_num_beams") is None and self.args.augmentation_num_beams is not None:
            gen_kwargs["aug_num_beams"] = self.args.augmentation_num_beams
    
        self.gather_function, self._gen_kwargs = self.accelerator.gather, gen_kwargs
        self._memory_tracker.start()
    
        test_dataloader = self.get_test_dataloader(test_dataset)
        start_time = time.time()
    
        output = self.evaluation_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
        output.metrics.update(
            speed_metrics(metric_key_prefix,start_time,num_samples=output.num_samples,num_steps=math.ceil(output.num_samples / total_batch_size),)
        )
        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
        self._memory_tracker.stop_and_update_metrics(output.metrics)
        return XCPredictionOutput(pred_idx=output.pred_idx, pred_ptr=output.pred_ptr, pred_score=output.pred_score, 
                              gen_output=output.gen_output, repr_output=output.repr_output, metrics=output.metrics, 
                              num_samples=output.num_samples)
    
    def _gather_host_output(self, output, host_output):
        if output is not None:
            output = self.accelerator.pad_across_processes(output, dim=1, pad_index=-100)
            output = self.gather_function((output))
            return output if host_output is None else nested_concat(host_output, output, padding_index=-100)
        else: return host_output

    def _gather_all_output(self, host_output, all_output, to_cpu=True):
        if host_output is not None:
            if isinstance(host_output, torch.Tensor) and to_cpu: host_output = host_output.cpu()
            return host_output if all_output is None else nested_concat(all_output, host_output, padding_index=-100)
        else: return all_output
            
            

In [None]:
#| export
@patch
def _build_aug_index(self:XCLearner, dataset:Optional[Dataset]=None):
    dataset = dataset if self.eval_dataset is None else self.eval_dataset
    dataset = dataset if self.train_dataset is None else self.train_dataset
    
    aug_meta_name = f'{self.args.data_aug_meta_name}_meta' if self.args.data_aug_meta_name is not None else None
    if (
        dataset is not None and dataset.meta is not None and aug_meta_name is not None and 
        aug_meta_name in dataset.meta
    ):
        self.aug_idxs = IndexSearch(space=self.args.index_space, efc=self.args.index_efc, m=self.args.index_m, 
                                    efs=self.args.index_efs, n_bm=self.args.representation_num_beams, 
                                    n_threads=self.args.index_num_threads)
        
        self.aug_info = getattr(dataset.meta[aug_meta_name], 'meta_info')
        
        aug_dset = MainXCDataset(self.aug_info)
        aug_dl = self.get_test_dataloader(aug_dset)
        aug_repr = self.get_meta_representation(aug_dl, to_cpu=isinstance(self.aug_idxs, IndexSearch))
        if self.args.use_distributional_representation: aug_repr = F.log_softmax(aug_repr, dim=-1)
            
        self.aug_idxs.build(aug_repr)

@patch
def _build_lbl_index(self:XCLearner, dataset:Optional[Dataset]=None):
    dataset = dataset if self.eval_dataset is None else self.eval_dataset
    dataset = dataset if self.train_dataset is None else self.train_dataset
    
    if dataset is not None:
        lbl_dset = dataset.lbl_dset
        
        meta_name = f'{self.args.data_aug_meta_name}_meta' if self.args.data_aug_meta_name is not None else None
        if meta_name is not None and dataset.meta is not None and meta_name in dataset.meta:
            prefix,lbl_meta,meta_info  = dataset.meta[meta_name].prefix,dataset.meta[meta_name].lbl_meta,dataset.meta[meta_name].meta_info
            meta_kwargs = {meta_name: MetaXCDataset(prefix, lbl_meta, lbl_meta, meta_info, n_data_meta_samples=self.args.augmentation_num_beams)}
            lbl_dset = XCDataset(lbl_dset, **meta_kwargs)
        
        lbl_dl = self.get_test_dataloader(lbl_dset)
        lbl_repr = self.get_representation(lbl_dl, to_cpu=isinstance(self.idxs, IndexSearch))
        if self.args.use_distributional_representation: lbl_repr = F.log_softmax(lbl_repr, dim=-1)
            
        self.idxs.build(lbl_repr)
    else: raise ValueError('Failed to build `self.idxs`')
        

In [None]:
#| export
@patch
def generation_output(
    self:XCLearner,
    model:nn.Module,
    inputs:Dict[str, Union[torch.Tensor, Any]],
    **kwargs
):
    inputs = self._prepare_inputs(inputs)
    n_bm = kwargs.pop("gen_num_beams") if "gen_num_beams" in kwargs and kwargs["gen_num_beams"] is not None else self.args.generation_num_beams
    len_penalty = kwargs.pop("length_penalty") if "length_penalty" in kwargs and kwargs["length_penalty"] is not None else self.args.generation_length_penalty
    
    with torch.no_grad(): o = self.tbs.proc(self.model, inputs.copy(), n_bm=n_bm, len_penalty=len_penalty)
        
    return {'pred_idx':o['info2seq2data_idx'], 'pred_score':o['info2seq2data_score'], 'pred_ptr':o['info2seq2data_data2ptr']}

@patch
def representation_output(
    self:XCLearner,
    model:nn.Module,
    inputs:Dict[str, Union[torch.Tensor, Any]],
    **kwargs
):
    inputs = self._prepare_inputs(inputs)
    n_bm = kwargs.pop("repr_num_beams") if "repr_num_beams" in kwargs and kwargs["repr_num_beams"] is not None else self.args.representation_num_beams
    
    with torch.no_grad(): 
        o = getattr(model(**inputs), self.args.representation_attribute)
        if self.args.use_distributional_representation: o = F.softmax(o, dim=-1)
            
    o = self.idxs.proc(o, n_bm=n_bm)
        
    return {'pred_idx':o['info2data_idx'], 'pred_score':o['info2data_score'], 'pred_ptr':o['info2data_data2ptr']}

@patch
def augmentation_output(
    self:XCLearner,
    model:nn.Module,
    inputs:Dict[str, Union[torch.Tensor, Any]],
    **kwargs
):
    if self.aug_idxs is None: raise ValueError('Augmentation `aug_idx` is not initialized.')
        
    inputs = self._prepare_inputs(inputs)
    n_bm = kwargs.pop("aug_num_beams") if "aug_num_beams" in kwargs and kwargs["aug_num_beams"] is not None else self.args.augmentation_num_beams
    
    with torch.no_grad(): 
        o = getattr(model(**inputs), self.args.data_augmentation_attribute)
        if self.args.use_distributional_representation: o = F.softmax(o, dim=-1)
            
    o = self.aug_idxs.proc(o, n_bm=n_bm)
    
    aug_info = self.aug_pad({
        'meta_input_ids':[self.aug_info['input_ids'][i] for i in o['info2data_idx']], 
        'meta_attention_mask':[self.aug_info['input_ids'][i] for i in o['info2data_idx']]
    })
    
    if self.args.use_augmentation_index_representation:
        meta_repr = torch.tensor(self.aug_idxs.index.get_items(o['info2data_idx']))
        return {
            f'{self.args.data_aug_meta_name}2data_meta_repr': meta_repr,
            f'{self.args.data_aug_meta_name}2data_attention_mask': aug_info['meta_attention_mask'],
            f'{self.args.data_aug_meta_name}2data_data2ptr': o['info2data_data2ptr'],
        }
    else:
        return {
            f'{self.args.data_aug_meta_name}2data_idx':o['info2data_idx'], 
            f'{self.args.data_aug_meta_name}2data_input_ids': aug_info['meta_input_ids'], 
            f'{self.args.data_aug_meta_name}2data_attention_mask': aug_info['meta_attention_mask'],
            f'{self.args.data_aug_meta_name}2data_data2ptr': o['info2data_data2ptr']
        }
    

In [None]:
#| export
@patch
def _perform_generation(self:XCLearner, model:nn.Module, predict_with_generation:Optional[bool]=None):
    model = unwrap_model(model)
    predict_with_generation = self.args.predict_with_generation if predict_with_generation is None else predict_with_generation
    return getattr(model,'use_generation') if hasattr(model,'use_generation') else predict_with_generation

@patch
def _perform_representation(self:XCLearner, model:nn.Module, predict_with_representation:Optional[bool]=None):
    model = unwrap_model(model)
    predict_with_representation = self.args.predict_with_representation if predict_with_representation is None else predict_with_representation
    return getattr(model,'use_representation') if hasattr(model,'use_representation') else predict_with_representation

@patch
def _perform_augmentation(self:XCLearner, model:nn.Module, predict_with_augmentation:Optional[bool]=None):
    model = unwrap_model(model)
    predict_with_augmentation = self.args.predict_with_augmentation if predict_with_augmentation is None else predict_with_augmentation
    return getattr(model,'use_augmentation') if hasattr(model,'use_augmentation') else predict_with_augmentation


In [None]:
#| export
@patch
def resize_pred(cls:XCLearner, t, n_t):
    max_n_t = n_t.max()
    xn_t = max_n_t.max()-n_t+1
    t_ptr = n_t.cumsum(dim=0)-1
    r_t = torch.ones((len(t),), dtype=xn_t.dtype, device=xn_t.device).scatter(0, t_ptr, xn_t)
    xt = t.repeat_interleave(r_t).view(len(n_t), -1)
    return xt

@patch
def output_mask(cls:XCLearner, n_t, l):
    max_n_t = n_t.max()
    xn_t = max_n_t.max()-n_t+1
    t_ptr = n_t.cumsum(dim=0)-1
    mask_ptr = t_ptr+torch.arange(len(t_ptr), device=t_ptr.device)+1
    mask = torch.ones((l+len(n_t),), dtype=mask_ptr.dtype, device=mask_ptr.device).scatter(0, mask_ptr, 0)
    r_mask = torch.ones((l+len(n_t),), dtype=mask_ptr.dtype, device=mask_ptr.device).scatter(0, mask_ptr, xn_t-1)
    mask = mask.repeat_interleave(r_mask).view(len(n_t), -1)
    return mask

@patch
def resize_output(cls:XCLearner, pred_idx, pred_score, pred_ptr):
    return cls.resize_pred(pred_idx, pred_ptr), cls.resize_pred(pred_score, pred_ptr), cls.output_mask(pred_ptr, len(pred_idx)), pred_ptr

@patch
def concatenate_output(cls:XCLearner, gen_o:Dict, repr_o:Dict):
    gen_o['pred_score'] = torch.exp(gen_o['pred_score'])*cls.args.output_concatenation_weight
    gen_o, repr_o = cls.resize_output(**gen_o), cls.resize_output(**repr_o)
    pred_idx, pred_score, mask = [torch.hstack([gen_o[i], repr_o[i].cpu()]).flatten() for i in range(3)]
    idx = torch.where(mask)[0]
    return {
        'pred_idx': pred_idx[idx],
        'pred_score': pred_score[idx],
        'pred_ptr': gen_o[3]+repr_o[3].cpu(),
    }
    

In [None]:
#| export
@patch
def prediction_step(
    self:XCLearner,
    model: nn.Module,
    inputs: Dict[str, Union[torch.Tensor, Any]],
    prediction_loss_only: bool,
    predict_with_generation: bool,
    predict_with_representation: bool,
    predict_with_augmentation:Optional[bool]=None,
    ignore_keys: Optional[List[str]] = None,
    **kwargs,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
    with torch.no_grad():
        with self.compute_loss_context_manager(): outputs = model(**inputs)
        loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
    prediction_loss_only = self.args.prediction_loss_only if prediction_loss_only is None else prediction_loss_only
    if prediction_loss_only: return loss, {}
    
    if self._perform_augmentation(model, predict_with_augmentation): 
        aug_inputs = self.augmentation_output(model, inputs, **kwargs)
        inputs.update(aug_inputs)
        
    output, gen_o, repr_o = None, None, None
    if self._perform_generation(model, predict_with_generation): gen_o = self.generation_output(model, inputs, **kwargs)
    if self._perform_representation(model, predict_with_representation): repr_o = self.representation_output(model, inputs, **kwargs)
    
    if gen_o is not None and repr_o is not None:
        output = {f'{k}_gen':v for k,v in gen_o.items()}
        output.update({f'{k}_repr':v for k,v in repr_o.items()})
        output.update(self.concatenate_output(gen_o, repr_o))
    else:
        output = gen_o if repr_o is None else repr_o
        
    labels = {'targ_idx':inputs[self.args.target_indices_key], 'targ_ptr':inputs[self.args.target_pointer_key]} if self.args.target_indices_key in inputs else None
    if labels is not None: output.update(labels)
    
    return loss, output
    

In [None]:
#| export
@patch
def evaluation_loop(
    self:XCLearner,
    dataloader:DataLoader,
    description:str,
    prediction_loss_only:Optional[bool] = None,
    predict_with_generation:Optional[bool]=None,
    predict_with_representation:Optional[bool]=None,
    ignore_keys:Optional[List[str]] = None,
    metric_key_prefix:str="eval",
) -> XCEvalLoopOutput:
    args = self.args
    prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

    model = self._wrap_model(self.model, training=False, dataloader=dataloader)

    if len(self.accelerator._models) == 0 and model is self.model:
        model = self.accelerator.prepare(model) if self.is_deepspeed_enabled else self.accelerator.prepare_model(model, evaluation_mode=True)
        if self.is_fsdp_enabled: self.model = model
        if model is not self.model: self.model_wrapped = model
        if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped

    batch_size = self.args.eval_batch_size
    
    model.eval()

    self.callback_handler.eval_dataloader = dataloader
    eval_dataset = getattr(dataloader, "dataset", None)
    
    if self._perform_representation(unwrap_model(model)) and not prediction_loss_only: 
        self._build_lbl_index(eval_dataset)
            
    if self._perform_augmentation(unwrap_model(model)) and not prediction_loss_only: 
        self._build_aug_index(eval_dataset)
    
    if args.past_index >= 0: self._past = None

    losses_host, all_losses = None, None
    host_output, all_output = {}, {}
    
    observed_num_examples = 0
    for step, inputs in enumerate(dataloader):
        observed_batch_size = find_batch_size(inputs)
        if observed_batch_size is not None:
            observed_num_examples += observed_batch_size
            if batch_size is None: batch_size = observed_batch_size
                
        loss, output = self.prediction_step(model, inputs, prediction_loss_only, predict_with_generation, predict_with_representation, ignore_keys=ignore_keys)
        
        if loss is not None:
            losses = self.gather_function((loss.repeat(batch_size)))
            losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
        for k in output: host_output[k] = self._gather_host_output(output[k], host_output.get(k, None))
            
        self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
        
        if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
            if losses_host is not None: all_losses = losses_host if all_losses is None else nested_concat(all_losses, losses, padding_index=-100)
            for k in host_output: all_output[k], host_output[k] = self._gather_all_output(host_output[k], all_output.get(k, None)), None
    
    self.gather_function = self.accelerator.gather_for_metrics
    if args.past_index and hasattr(self, "_past"): delattr(self, "_past")

    if losses_host is not None: all_losses = losses_host if all_losses is None else nested_concat(all_losses, losses, padding_index=-100)
    for k in host_output: all_output[k], host_output[k] = self._gather_all_output(host_output[k], all_output.get(k, None)), None
        
    if has_length(eval_dataset): num_samples = len(eval_dataset)
    elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
        num_samples = eval_dataset.num_examples
    else:
        if has_length(dataloader): num_samples = self.num_examples(dataloader)
        else: num_samples = observed_num_examples
    if num_samples == 0 and observed_num_examples > 0: num_samples = observed_num_examples
        
    gen_output, repr_output = None, None
    metric_input_keys = ['targ_idx', 'targ_ptr', 'pred_idx', 'pred_ptr', 'pred_score']
    if 'pred_idx_gen' in all_output and all_output['pred_idx_gen'] is not None:
        gen_output = {o:all_output[f'{o}_gen' if o.startswith('pred_') else o] for o in metric_input_keys}
    if 'pred_idx_repr' in all_output and all_output['pred_idx_repr'] is not None:
        repr_output = {o:all_output[f'{o}_repr' if o.startswith('pred_') else o] for o in metric_input_keys}
    

    if (self.compute_metrics is not None and 
        'targ_idx' in all_output and all_output['targ_idx'] is not None and 
        'pred_idx' in all_output and all_output['pred_idx'] is not None):
        
        metrics = self.compute_metrics(**{o:all_output[o] for o in metric_input_keys})
        if gen_output is not None:
            m = self.compute_metrics(**gen_output)
            metrics.update({f'{k}_GEN':v for k,v in m.items()})
        if repr_output is not None:
            m = self.compute_metrics(**repr_output)
            metrics.update({f'{k}_REPR':v for k,v in m.items()})      
    else: metrics = {}
        
    metrics = denumpify_detensorize(metrics)

    if all_losses is not None: metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
    if hasattr(self, "jit_compilation_time"): metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
        
    for key in list(metrics.keys()):
        if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
    
    return XCEvalLoopOutput(pred_idx=all_output.get('pred_idx'), pred_ptr=all_output.get('pred_ptr'), 
                            pred_score=all_output.get('pred_score'),targ_idx=all_output.get('targ_idx'), 
                            targ_ptr=all_output.get('targ_ptr'), gen_output=gen_output, repr_output=repr_output,
                            metrics=metrics, num_samples=num_samples)
    

In [None]:
#| export
@patch
def get_meta_representation(self:XCLearner, dataloader: DataLoader, to_cpu:Optional[bool]=True):
    data_host, all_data = None, None
    
    if hasattr(self.model, 'disable_noise') and callable(getattr(self.model, 'disable_noise')):
        use_noise = self.model.disable_noise()
    
    for step, inputs in tqdm(enumerate(dataloader), total=len(dataloader)):
        inputs = inputs.to(self.model.device)
        with torch.no_grad(): data = getattr(self.model.get_meta_representation(**inputs), self.args.metadata_representation_attribute)
        data_host = self._gather_host_output(data, data_host)
        if self.args.representation_accumulation_steps is not None and (step + 1) % self.args.representation_accumulation_steps == 0:
            all_data, data_host = self._gather_all_output(data_host, all_data, to_cpu=to_cpu), None
            
    if hasattr(self.model, 'disable_noise') and callable(getattr(self.model, 'disable_noise')):
        self.model.set_noise(use_noise)
            
    return self._gather_all_output(data_host, all_data, to_cpu=to_cpu)

@patch
def get_representation(self:XCLearner, dataloader: DataLoader, to_cpu:Optional[bool]=True):
    data_host, all_data = None, None
    
    if hasattr(self.model, 'disable_noise') and callable(getattr(self.model, 'disable_noise')):
        use_noise = self.model.disable_noise()
    
    for step, inputs in tqdm(enumerate(dataloader), total=len(dataloader)):
        inputs = inputs.to(self.model.device)
        with torch.no_grad(): data = getattr(self.model(**inputs), self.args.representation_attribute)
        data_host = self._gather_host_output(data, data_host)
        if self.args.representation_accumulation_steps is not None and (step + 1) % self.args.representation_accumulation_steps == 0:
            all_data, data_host = self._gather_all_output(data_host, all_data, to_cpu=to_cpu), None
            
    if hasattr(self.model, 'disable_noise') and callable(getattr(self.model, 'disable_noise')):
        self.model.set_noise(use_noise)
            
    return self._gather_all_output(data_host, all_data, to_cpu=to_cpu)
    

### Training loop

In [None]:
#| export
@patch
def _get_train_sampler(self:XCLearner):
    if self.train_dataset is None or not has_length(self.train_dataset):
        return None
        
    if self.args.group_by_length:
        if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
            lengths = (
                self.train_dataset[self.args.length_column_name]
                if self.args.length_column_name in self.train_dataset.column_names
                else None
            )
        else:
            lengths = None
        model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
        return LengthGroupedSampler(
            self.args.train_batch_size * self.args.gradient_accumulation_steps,
            dataset=self.train_dataset,
            lengths=lengths,
            model_input_name=model_input_name,
        )

    elif self.args.group_by_cluster:
        return ClusterGroupedSampler(n=len(self.train_dataset))
    else:
        return RandomSampler(self.train_dataset)
        

In [None]:
#| export
@patch
def get_train_dataloader(self:XCLearner):
    if self.train_dataset is None:
        raise ValueError("Trainer: training requires a train_dataset.")

    train_dataset = self.train_dataset
    data_collator = self.data_collator
    if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
        train_dataset = self._remove_unused_columns(train_dataset, description="training")
    else:
        data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

    dataloader_params = {
        "batch_size": self._train_batch_size,
        "collate_fn": data_collator,
        "num_workers": self.args.dataloader_num_workers,
        "pin_memory": self.args.dataloader_pin_memory,
        "persistent_workers": self.args.dataloader_persistent_workers,
    }

    if not isinstance(train_dataset, torch.utils.data.IterableDataset):
        dataloader_params["sampler"] = self._get_train_sampler()
        dataloader_params["drop_last"] = self.args.dataloader_drop_last
        dataloader_params["worker_init_fn"] = seed_worker
        dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
    
    return DataLoader(train_dataset, **dataloader_params)
    

In [None]:
#| export
@patch
def _get_min_cluster_sz(self:XCLearner, epochs_trained:int, num_train_epochs:int):
    
    if self.args.num_clustering_warmup_epochs is not None:
        if epochs_trained < self.args.num_clustering_warmup_epochs: return None
        else: epochs_trained -= self.args.num_clustering_warmup_epochs
    
    if self.args.clustering_type == 'LINEAR':
        if self.args.maximum_clusters is None: return self.train_dataset.n_data//self.args.minimum_clusters
        else:
            n_cluster = (self.args.maximum_clusters-self.args.minimum_clusters)/num_train_epochs*epochs_trained
            return self.train_dataset.n_data//int(self.args.minimum_clusters+n_cluster)
        
    elif self.args.clustering_type == 'EXPO':
        mult = 2**(epochs_trained//self.args.num_cluster_size_update_epochs)
        cluster_sz = self.args.minimum_cluster_size*mult
        cluster_sz = (
            self.args.maximum_cluster_size 
            if self.args.maximum_cluster_size is not None and cluster_sz > self.args.maximum_cluster_size 
            else cluster_sz
        )
        return cluster_sz
    
    else: raise ValueError(f'Invalid `clustering_type`({self.args.clustering_type}).')
    

In [None]:
#| export
@patch
def _get_train_data_cluster(self:XCLearner, epochs_trained:int, num_train_epochs:int):
    dataset = self.train_dataset.data_dset
    dataloader = self.get_test_dataloader(dataset)
    data_repr = self.get_representation(dataloader)
    
    if self.args.use_distributional_representation: data_repr = F.softmax(data_repr, dim=-1)
        
    cluster = BalancedClusters.proc(data_repr, self._get_min_cluster_sz(epochs_trained, num_train_epochs), clustering_devices=self.args.clustering_devices)
    return cluster

@patch
def update_dataloader_sampler(self:XCLearner, dataloader:DataLoader, epochs_trained:int, num_train_epochs:int):
    if isinstance(dataloader.sampler, ClusterGroupedSampler):
        cluster = self._get_train_data_cluster(epochs_trained, num_train_epochs)
        dataloader.sampler.set_cluster(cluster)
    

In [None]:
#| export
@patch
def _validate_group_by_cluster(self:XCLearner):
    if self.args.group_by_cluster and (not hasattr(self.model,'use_representation') or  not getattr(unwrap_model(self.model),'use_representation')):
        raise ValueError('Cannot use `group_by_cluster` for models without `use_representation`.')
        self.args.group_by_cluster = False

@patch
def _inner_training_loop(
    self:XCLearner, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
    self.accelerator.free_memory()
    self._train_batch_size = batch_size
    if self.args.auto_find_batch_size:
        if self.state.train_batch_size != self._train_batch_size:
            from accelerate.utils import release_memory

            (self.model_wrapped,) = release_memory(self.model_wrapped)
            self.model_wrapped = self.model

            # Check for DeepSpeed *after* the intial pass and modify the config
            if self.is_deepspeed_enabled:
                # Temporarily unset `self.args.train_batch_size`
                original_bs = self.args.per_device_train_batch_size
                self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
                self.propagate_args_to_deepspeed(True)
                self.args.per_device_train_batch_size = original_bs
        self.state.train_batch_size = self._train_batch_size
    logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
    
    # Data loader and number of training steps
    self._validate_group_by_cluster()
    train_dataloader = self.get_train_dataloader()
    
    if self.is_fsdp_xla_v2_enabled:
        train_dataloader = tpu_spmd_dataloader(train_dataloader)

    # Setting up training control variables:
    # number of training epochs: num_train_epochs
    # number of training steps per epoch: num_update_steps_per_epoch
    # total number of training steps to execute: max_steps
    total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size

    len_dataloader = None
    num_train_tokens = None
    if has_length(train_dataloader):
        len_dataloader = len(train_dataloader)
        num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        num_examples = self.num_examples(train_dataloader)
        if args.max_steps > 0:
            max_steps = args.max_steps
            num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                args.max_steps % num_update_steps_per_epoch > 0
            )
            # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
            # the best we can do.
            num_train_samples = args.max_steps * total_train_batch_size
            if args.include_tokens_per_second:
                num_train_tokens = (
                    self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
                )
        else:
            max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
            num_train_epochs = math.ceil(args.num_train_epochs)
            num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
            if args.include_tokens_per_second:
                num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
    elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
        max_steps = args.max_steps
        # Setting a very large number of epochs so we go as many times as necessary over the iterator.
        num_train_epochs = sys.maxsize
        num_update_steps_per_epoch = max_steps
        num_examples = total_train_batch_size * args.max_steps
        num_train_samples = args.max_steps * total_train_batch_size
        if args.include_tokens_per_second:
            num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
    else:
        raise ValueError(
            "args.max_steps must be set to a positive value if dataloader does not have a length, was"
            f" {args.max_steps}"
        )

    if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
        if self.args.n_gpu > 1:
            # nn.DataParallel(model) replicates the model, creating new variables and module
            # references registered here no longer work on other gpus, breaking the module
            raise ValueError(
                "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                " (torchrun or torch.distributed.launch (deprecated))."
            )
        else:
            debug_overflow = DebugUnderflowOverflow(self.model)  # noqa

    delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled

    # We need to reset the scheduler, as its parameters may be different on subsequent calls
    if self._created_lr_scheduler:
        self.lr_scheduler = None
        self._created_lr_scheduler = False

    if self.is_deepspeed_enabled:
        self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)

    if not delay_optimizer_creation:
        self.create_optimizer_and_scheduler(num_training_steps=max_steps)

    self.state = TrainerState()
    self.state.is_hyper_param_search = trial is not None
    self.state.train_batch_size = self._train_batch_size

    # Compute absolute values for logging, eval, and save if given as ratio
    if args.logging_steps is not None:
        if args.logging_steps < 1:
            self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
        else:
            self.state.logging_steps = args.logging_steps
    if args.eval_steps is not None:
        if args.eval_steps < 1:
            self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
        else:
            self.state.eval_steps = args.eval_steps
    if args.save_steps is not None:
        if args.save_steps < 1:
            self.state.save_steps = math.ceil(max_steps * args.save_steps)
        else:
            self.state.save_steps = args.save_steps

    # Activate gradient checkpointing if needed
    if args.gradient_checkpointing:
        if args.gradient_checkpointing_kwargs is None:
            gradient_checkpointing_kwargs = {}
        else:
            gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs

        self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

    model = self._wrap_model(self.model_wrapped)

    # as the model is wrapped, don't use `accelerator.prepare`
    # this is for unhandled cases such as
    # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
    use_accelerator_prepare = True if model is self.model else False

    if delay_optimizer_creation:
        if use_accelerator_prepare:
            self.model = self.accelerator.prepare(self.model)
        self.create_optimizer_and_scheduler(num_training_steps=max_steps)

    # prepare using `accelerator` prepare
    if use_accelerator_prepare:
        self.model.train()
        if hasattr(self.lr_scheduler, "step"):
            if self.use_apex:
                model = self.accelerator.prepare(self.model)
            else:
                model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
        else:
            # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
            model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
                self.model, self.optimizer, self.lr_scheduler
            )

    if self.is_fsdp_enabled:
        self.model = self.model_wrapped = model

    # for the rest of this function `model` is the outside model, whether it was wrapped or not
    if model is not self.model:
        self.model_wrapped = model

    # backward compatibility
    if self.is_deepspeed_enabled:
        self.deepspeed = self.model_wrapped

    # ckpt loading
    if resume_from_checkpoint is not None:
        if self.is_deepspeed_enabled:
            deepspeed_load_checkpoint(
                self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
            )
        elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
            self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)

    # Check if saved optimizer or scheduler states exist
    self._load_optimizer_and_scheduler(resume_from_checkpoint)

    # important: at this point:
    # self.model         is the Transformers Model
    # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
    # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.

    # Train!
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {num_examples:,}")
    logger.info(f"  Num Epochs = {num_train_epochs:,}")
    logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
    if self.args.per_device_train_batch_size != self._train_batch_size:
        logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {max_steps:,}")
    logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")

    self.state.epoch = 0
    start_time = time.time()
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    steps_trained_progress_bar = None

    # Check if continuing training from a checkpoint
    if resume_from_checkpoint is not None and os.path.isfile(
        os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
    ):
        self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
        epochs_trained = self.state.global_step // num_update_steps_per_epoch
        if not args.ignore_data_skip:
            steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
            steps_trained_in_current_epoch *= args.gradient_accumulation_steps
        else:
            steps_trained_in_current_epoch = 0

        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        logger.info(f"  Continuing training from epoch {epochs_trained}")
        logger.info(f"  Continuing training from global step {self.state.global_step}")
        if not args.ignore_data_skip:
            logger.info(
                f"  Will skip the first {epochs_trained} epochs then the first"
                f" {steps_trained_in_current_epoch} batches in the first epoch."
            )

    # Update the references
    self.callback_handler.model = self.model
    self.callback_handler.optimizer = self.optimizer
    self.callback_handler.lr_scheduler = self.lr_scheduler
    self.callback_handler.train_dataloader = train_dataloader
    if self.hp_name is not None and self._trial is not None:
        # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
        # parameter to Train when using DDP.
        self.state.trial_name = self.hp_name(self._trial)
    if trial is not None:
        assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
        self.state.trial_params = hp_params(assignments)
    else:
        self.state.trial_params = None
    # This should be the same if the state has been saved but in case the training arguments changed, it's safer
    # to set this after the load.
    self.state.max_steps = max_steps
    self.state.num_train_epochs = num_train_epochs
    self.state.is_local_process_zero = self.is_local_process_zero()
    self.state.is_world_process_zero = self.is_world_process_zero()

    # tr_loss is a tensor to avoid synchronization of TPUs through .item()
    tr_loss = torch.tensor(0.0).to(args.device)
    # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
    self._total_loss_scalar = 0.0
    self._globalstep_last_logged = self.state.global_step
    model.zero_grad()
    grad_norm: Optional[float] = None

    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

    # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
    if not args.ignore_data_skip:
        for epoch in range(epochs_trained):
            sampler = get_dataloader_sampler(train_dataloader)
            sampler_kinds = [RandomSampler]
            if version.parse(accelerate_version) > version.parse("0.23.0"):
                sampler_kinds.append(SeedableRandomSampler)
            is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
            if not is_random_sampler:
                # We just need to begin an iteration to create the randomization of the sampler.
                for _ in train_dataloader:
                    break
            else:
                # Otherwise we need to call the whooooole sampler cause there is some random operation added
                # AT THE VERY END!
                sampler = sampler if sampler is not None else []
                _ = list(sampler)

    total_batched_samples = 0
    for epoch in range(epochs_trained, num_train_epochs):
        if self.args.group_by_cluster and (epoch % self.args.num_cluster_update_epochs == 0 or epoch == self.args.num_clustering_warmup_epochs) and epoch >= self.args.num_clustering_warmup_epochs:
            self.update_dataloader_sampler(train_dataloader, epoch, num_train_epochs)
        
        epoch_iterator = train_dataloader
        if hasattr(epoch_iterator, "set_epoch"):
            epoch_iterator.set_epoch(epoch)

        # Reset the past mems state at the beginning of each epoch if necessary.
        if args.past_index >= 0:
            self._past = None

        steps_in_epoch = (
            len(epoch_iterator)
            if len_dataloader is not None
            else args.max_steps * args.gradient_accumulation_steps
        )
        self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

        if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
            self._load_rng_state(resume_from_checkpoint)

        rng_to_sync = False
        steps_skipped = 0
        if steps_trained_in_current_epoch > 0:
            epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
            steps_skipped = steps_trained_in_current_epoch
            steps_trained_in_current_epoch = 0
            rng_to_sync = True

        step = -1
        for step, inputs in enumerate(epoch_iterator):
            total_batched_samples += 1

            if self.args.include_num_input_tokens_seen:
                main_input_name = getattr(self.model, "main_input_name", "input_ids")
                if main_input_name not in inputs:
                    logger.warning(
                        "Tried to track the number of tokens seen, however the current model is "
                        "not configured properly to know what item is the input. To fix this, add "
                        "a `main_input_name` attribute to the model class you are using."
                    )
                else:
                    self.state.num_input_tokens_seen += self.accelerator.gather(inputs[main_input_name]).numel()
            if rng_to_sync:
                self._load_rng_state(resume_from_checkpoint)
                rng_to_sync = False

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                if steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.update(1)
                if steps_trained_in_current_epoch == 0:
                    self._load_rng_state(resume_from_checkpoint)
                continue
            elif steps_trained_progress_bar is not None:
                steps_trained_progress_bar.close()
                steps_trained_progress_bar = None

            if step % args.gradient_accumulation_steps == 0:
                self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

            with self.accelerator.accumulate(model):
                tr_loss_step = self.training_step(model, inputs)

            if (
                args.logging_nan_inf_filter
                and not is_torch_tpu_available()
                and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
            ):
                # if loss is nan or inf simply add the average of previous logged losses
                tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
            else:
                tr_loss += tr_loss_step

            self.current_flos += float(self.floating_point_ops(inputs))

            is_last_step_and_steps_less_than_grad_acc = (
                steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
            )

            if (
                total_batched_samples % args.gradient_accumulation_steps == 0
                or
                # last step in epoch but step is always smaller than gradient_accumulation_steps
                is_last_step_and_steps_less_than_grad_acc
            ):
                # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
                # in accelerate. So, explicitly enable sync gradients to True in that case.
                if is_last_step_and_steps_less_than_grad_acc:
                    self.accelerator.gradient_state._set_sync_gradients(True)

                # Gradient clipping
                if args.max_grad_norm is not None and args.max_grad_norm > 0:
                    # deepspeed does its own clipping

                    if is_sagemaker_mp_enabled() and args.fp16:
                        _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
                    elif self.use_apex:
                        # Revert to normal clipping otherwise, handling Apex or full precision
                        _grad_norm = nn.utils.clip_grad_norm_(
                            amp.master_params(self.optimizer),
                            args.max_grad_norm,
                        )
                    else:
                        _grad_norm = self.accelerator.clip_grad_norm_(
                            model.parameters(),
                            args.max_grad_norm,
                        )

                    if (
                        is_accelerate_available()
                        and self.accelerator.distributed_type == DistributedType.DEEPSPEED
                    ):
                        grad_norm = model.get_global_grad_norm()
                    else:
                        grad_norm = _grad_norm.item() if _grad_norm is not None else None

                # Optimizer step
                self.optimizer.step()
                optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
                if optimizer_was_run:
                    # Delay optimizer scheduling until metrics are generated
                    if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                        self.lr_scheduler.step()

                model.zero_grad()
                self.state.global_step += 1
                self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
                self.control = self.callback_handler.on_step_end(args, self.state, self.control)

                self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
            else:
                self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

            if self.control.should_epoch_stop or self.control.should_training_stop:
                # PyTorch/XLA relies on the data loader to insert the mark_step for
                # each step. Since we are breaking the loop early, we need to manually
                # insert the mark_step here.
                if is_torch_tpu_available():
                    xm.mark_step()
                break
        if step < 0:
            logger.warning(
                "There seems to be not a single sample in your epoch_iterator, stopping training at step"
                f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                f" num_steps ({max_steps}) higher than the number of available samples."
            )
            self.control.should_training_stop = True

        self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
        self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)

        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
            if is_torch_tpu_available():
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())
            else:
                logger.warning(
                    "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                    "configured. Check your training configuration if this is unexpected."
                )
        if self.control.should_training_stop:
            break

    if args.past_index and hasattr(self, "_past"):
        # Clean the state at the end of training
        delattr(self, "_past")

    logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
    if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
        # Wait for everyone to get here so we are sure the model has been saved by process 0.
        if is_torch_tpu_available():
            xm.rendezvous("load_best_model_at_end")
        elif args.parallel_mode == ParallelMode.DISTRIBUTED:
            dist.barrier()
        elif is_sagemaker_mp_enabled():
            smp.barrier()

        self._load_best_model()

    # add remaining tr_loss
    self._total_loss_scalar += tr_loss.item()
    train_loss = self._total_loss_scalar / self.state.global_step

    metrics = speed_metrics(
        "train",
        start_time,
        num_samples=num_train_samples,
        num_steps=self.state.max_steps,
        num_tokens=num_train_tokens,
    )
    self.store_flos()
    metrics["total_flos"] = self.state.total_flos
    metrics["train_loss"] = train_loss

    self.is_in_train = False

    self._memory_tracker.stop_and_update_metrics(metrics)

    self.log(metrics)

    run_dir = self._get_output_dir(trial)
    checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

    # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
    if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
        for checkpoint in checkpoints_sorted:
            if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
                logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                shutil.rmtree(checkpoint)

    self.control = self.callback_handler.on_train_end(args, self.state, self.control)

    # Wait for the checkpoint to be uploaded.
    self._finish_current_push()

    # After training we make sure to retrieve back the original forward pass method
    # for the embedding layer by removing the forward post hook.
    if self.neftune_noise_alpha is not None:
        self._deactivate_neftune(self.model)

    return TrainOutput(self.state.global_step, train_loss, metrics)
    


### Example

In [None]:
os.environ['WANDB_MODE'] = 'disabled'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

In [None]:
args = XCLearningArguments(
    output_dir='/scratch/scai/phd/aiz218323/scratch/outputs/default/',
    per_device_train_batch_size=10,
    per_device_eval_batch_size=64,
    num_train_epochs=50,
    eval_steps=50,
    weight_decay=0.01,
    representation_accumulation_steps=10,
    representation_attribute='data_repr',
    representation_search_type='INDEX',
    evaluation_strategy='steps',
    label_names=['lbl2data_idx'],
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=3,
    num_cluster_size_update_epochs=2,
    use_distributional_representation=False,
    clustering_type='EXPO',
    minimum_cluster_size=1,
    maximum_cluster_size=4,
    use_encoder_parallel=True,
    max_grad_norm=None,   
)

In [None]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
model = DBT012.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', margin=0.3, tau=0.1, psi=0.5,
                               n_negatives=10, apply_softmax=True, use_encoder_parallel=False)
model.init_dr_head()

Some weights of DBT012 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
train_dset, valid_dset = block.train.dset.sample(n=1000), block.test.dset.sample(n=1000)

In [None]:
metric = PrecRecl(block.n_lbl, valid_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl, 
                  pk=5, rk=5, rep_pk=[1, 3, 5], rep_rk=[5])

In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    data_collator=block.collator, 
    train_dataset=train_dset, 
    eval_dataset=valid_dset,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
learn.train()

> /tmp/ipykernel_881/2079912209.py(16)_inner_training_loop()
     14     #debug
     15 
---> 16     self.accelerator.free_memory()
     17     self._train_batch_size = batch_size
     18     if self.args.auto_find_batch_size:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(17)_inner_training_loop()
     15 
     16     self.accelerator.free_memory()
---> 17     self._train_batch_size = batch_size
     18     if self.args.auto_find_batch_size:
     19         if self.state.train_batch_size != self._train_batch_size:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(18)_inner_training_loop()
     16     self.accelerator.free_memory()
     17     self._train_batch_size = batch_size
---> 18     if self.args.auto_find_batch_size:
     19         if self.state.train_batch_size != self._train_batch_size:
     20             from accelerate.utils import release_memory



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(33)_inner_training_loop()
     31                 self.args.per_device_train_batch_size = original_bs
     32         self.state.train_batch_size = self._train_batch_size
---> 33     logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
     34 
     35     # Data loader and number of training steps



ipdb>  


> /tmp/ipykernel_881/2079912209.py(36)_inner_training_loop()
     34 
     35     # Data loader and number of training steps
---> 36     self._validate_group_by_cluster()
     37     train_dataloader = self.get_train_dataloader()
     38 



ipdb>  


> /tmp/ipykernel_881/2079912209.py(37)_inner_training_loop()
     35     # Data loader and number of training steps
     36     self._validate_group_by_cluster()
---> 37     train_dataloader = self.get_train_dataloader()
     38 
     39     if self.is_fsdp_xla_v2_enabled:



ipdb>  


> /tmp/ipykernel_881/2079912209.py(39)_inner_training_loop()
     37     train_dataloader = self.get_train_dataloader()
     38 
---> 39     if self.is_fsdp_xla_v2_enabled:
     40         train_dataloader = tpu_spmd_dataloader(train_dataloader)
     41 



ipdb>  


> /tmp/ipykernel_881/2079912209.py(46)_inner_training_loop()
     44     # number of training steps per epoch: num_update_steps_per_epoch
     45     # total number of training steps to execute: max_steps
---> 46     total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
     47 
     48     len_dataloader = None



ipdb>  


> /tmp/ipykernel_881/2079912209.py(48)_inner_training_loop()
     46     total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
     47 
---> 48     len_dataloader = None
     49     num_train_tokens = None
     50     if has_length(train_dataloader):



ipdb>  


> /tmp/ipykernel_881/2079912209.py(49)_inner_training_loop()
     47 
     48     len_dataloader = None
---> 49     num_train_tokens = None
     50     if has_length(train_dataloader):
     51         len_dataloader = len(train_dataloader)



ipdb>  


> /tmp/ipykernel_881/2079912209.py(50)_inner_training_loop()
     48     len_dataloader = None
     49     num_train_tokens = None
---> 50     if has_length(train_dataloader):
     51         len_dataloader = len(train_dataloader)
     52         num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps



ipdb>  


> /tmp/ipykernel_881/2079912209.py(51)_inner_training_loop()
     49     num_train_tokens = None
     50     if has_length(train_dataloader):
---> 51         len_dataloader = len(train_dataloader)
     52         num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
     53         num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)



ipdb>  


> /tmp/ipykernel_881/2079912209.py(52)_inner_training_loop()
     50     if has_length(train_dataloader):
     51         len_dataloader = len(train_dataloader)
---> 52         num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
     53         num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
     54         num_examples = self.num_examples(train_dataloader)



ipdb>  


> /tmp/ipykernel_881/2079912209.py(53)_inner_training_loop()
     51         len_dataloader = len(train_dataloader)
     52         num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
---> 53         num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
     54         num_examples = self.num_examples(train_dataloader)
     55         if args.max_steps > 0:



ipdb>  


> /tmp/ipykernel_881/2079912209.py(54)_inner_training_loop()
     52         num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
     53         num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
---> 54         num_examples = self.num_examples(train_dataloader)
     55         if args.max_steps > 0:
     56             max_steps = args.max_steps



ipdb>  


> /tmp/ipykernel_881/2079912209.py(55)_inner_training_loop()
     53         num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
     54         num_examples = self.num_examples(train_dataloader)
---> 55         if args.max_steps > 0:
     56             max_steps = args.max_steps
     57             num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(68)_inner_training_loop()
     66                 )
     67         else:
---> 68             max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
     69             num_train_epochs = math.ceil(args.num_train_epochs)
     70             num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs



ipdb>  


> /tmp/ipykernel_881/2079912209.py(69)_inner_training_loop()
     67         else:
     68             max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
---> 69             num_train_epochs = math.ceil(args.num_train_epochs)
     70             num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
     71             if args.include_tokens_per_second:



ipdb>  


> /tmp/ipykernel_881/2079912209.py(70)_inner_training_loop()
     68             max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
     69             num_train_epochs = math.ceil(args.num_train_epochs)
---> 70             num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
     71             if args.include_tokens_per_second:
     72                 num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs



ipdb>  


> /tmp/ipykernel_881/2079912209.py(71)_inner_training_loop()
     69             num_train_epochs = math.ceil(args.num_train_epochs)
     70             num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
---> 71             if args.include_tokens_per_second:
     72                 num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
     73     elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size



ipdb>  


> /tmp/ipykernel_881/2079912209.py(88)_inner_training_loop()
     86         )
     87 
---> 88     if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
     89         if self.args.n_gpu > 1:
     90             # nn.DataParallel(model) replicates the model, creating new variables and module



ipdb>  


> /tmp/ipykernel_881/2079912209.py(99)_inner_training_loop()
     97             debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
     98 
---> 99     delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
    100 
    101     # We need to reset the scheduler, as its parameters may be different on subsequent calls



ipdb>  


> /tmp/ipykernel_881/2079912209.py(102)_inner_training_loop()
    100 
    101     # We need to reset the scheduler, as its parameters may be different on subsequent calls
--> 102     if self._created_lr_scheduler:
    103         self.lr_scheduler = None
    104         self._created_lr_scheduler = False



ipdb>  


> /tmp/ipykernel_881/2079912209.py(103)_inner_training_loop()
    101     # We need to reset the scheduler, as its parameters may be different on subsequent calls
    102     if self._created_lr_scheduler:
--> 103         self.lr_scheduler = None
    104         self._created_lr_scheduler = False
    105 



ipdb>  


> /tmp/ipykernel_881/2079912209.py(104)_inner_training_loop()
    102     if self._created_lr_scheduler:
    103         self.lr_scheduler = None
--> 104         self._created_lr_scheduler = False
    105 
    106     if self.is_deepspeed_enabled:



ipdb>  


> /tmp/ipykernel_881/2079912209.py(106)_inner_training_loop()
    104         self._created_lr_scheduler = False
    105 
--> 106     if self.is_deepspeed_enabled:
    107         self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
    108 



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(109)_inner_training_loop()
    107         self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
    108 
--> 109     if not delay_optimizer_creation:
    110         self.create_optimizer_and_scheduler(num_training_steps=max_steps)
    111 



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(110)_inner_training_loop()
    108 
    109     if not delay_optimizer_creation:
--> 110         self.create_optimizer_and_scheduler(num_training_steps=max_steps)
    111 
    112     self.state = TrainerState()



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(112)_inner_training_loop()
    110         self.create_optimizer_and_scheduler(num_training_steps=max_steps)
    111 
--> 112     self.state = TrainerState()
    113     self.state.is_hyper_param_search = trial is not None
    114     self.state.train_batch_size = self._train_batch_size



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(113)_inner_training_loop()
    111 
    112     self.state = TrainerState()
--> 113     self.state.is_hyper_param_search = trial is not None
    114     self.state.train_batch_size = self._train_batch_size
    115 



ipdb>  self.state


TrainerState(epoch=None, global_step=0, max_steps=0, logging_steps=500, eval_steps=500, save_steps=500, train_batch_size=None, num_train_epochs=0, num_input_tokens_seen=0, total_flos=0, log_history=[], best_metric=None, best_model_checkpoint=None, is_local_process_zero=True, is_world_process_zero=True, is_hyper_param_search=False, trial_name=None, trial_params=None)


ipdb>  n


> /tmp/ipykernel_881/2079912209.py(114)_inner_training_loop()
    112     self.state = TrainerState()
    113     self.state.is_hyper_param_search = trial is not None
--> 114     self.state.train_batch_size = self._train_batch_size
    115 
    116     # Compute absolute values for logging, eval, and save if given as ratio



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(117)_inner_training_loop()
    115 
    116     # Compute absolute values for logging, eval, and save if given as ratio
--> 117     if args.logging_steps is not None:
    118         if args.logging_steps < 1:
    119             self.state.logging_steps = math.ceil(max_steps * args.logging_steps)



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(118)_inner_training_loop()
    116     # Compute absolute values for logging, eval, and save if given as ratio
    117     if args.logging_steps is not None:
--> 118         if args.logging_steps < 1:
    119             self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
    120         else:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(121)_inner_training_loop()
    119             self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
    120         else:
--> 121             self.state.logging_steps = args.logging_steps
    122     if args.eval_steps is not None:
    123         if args.eval_steps < 1:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(122)_inner_training_loop()
    120         else:
    121             self.state.logging_steps = args.logging_steps
--> 122     if args.eval_steps is not None:
    123         if args.eval_steps < 1:
    124             self.state.eval_steps = math.ceil(max_steps * args.eval_steps)



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(123)_inner_training_loop()
    121             self.state.logging_steps = args.logging_steps
    122     if args.eval_steps is not None:
--> 123         if args.eval_steps < 1:
    124             self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
    125         else:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(126)_inner_training_loop()
    124             self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
    125         else:
--> 126             self.state.eval_steps = args.eval_steps
    127     if args.save_steps is not None:
    128         if args.save_steps < 1:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(127)_inner_training_loop()
    125         else:
    126             self.state.eval_steps = args.eval_steps
--> 127     if args.save_steps is not None:
    128         if args.save_steps < 1:
    129             self.state.save_steps = math.ceil(max_steps * args.save_steps)



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(128)_inner_training_loop()
    126             self.state.eval_steps = args.eval_steps
    127     if args.save_steps is not None:
--> 128         if args.save_steps < 1:
    129             self.state.save_steps = math.ceil(max_steps * args.save_steps)
    130         else:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(131)_inner_training_loop()
    129             self.state.save_steps = math.ceil(max_steps * args.save_steps)
    130         else:
--> 131             self.state.save_steps = args.save_steps
    132 
    133     # Activate gradient checkpointing if needed



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(134)_inner_training_loop()
    132 
    133     # Activate gradient checkpointing if needed
--> 134     if args.gradient_checkpointing:
    135         if args.gradient_checkpointing_kwargs is None:
    136             gradient_checkpointing_kwargs = {}



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(142)_inner_training_loop()
    140         self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
    141 
--> 142     model = self._wrap_model(self.model_wrapped)
    143 
    144     # as the model is wrapped, don't use `accelerator.prepare`



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(147)_inner_training_loop()
    145     # this is for unhandled cases such as
    146     # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
--> 147     use_accelerator_prepare = True if model is self.model else False
    148 
    149     if delay_optimizer_creation:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(149)_inner_training_loop()
    147     use_accelerator_prepare = True if model is self.model else False
    148 
--> 149     if delay_optimizer_creation:
    150         if use_accelerator_prepare:
    151             self.model = self.accelerator.prepare(self.model)



ipdb>  use_accelerator_prepare 


True


ipdb>  model


DBT012(
  (encoder): DBT012Encoder(
    (distilbert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0-5): 6 x TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (k_lin): Linear(in_features=768, out_features=768, bias=True)
              (v_lin): Linear(in_features=768, out_features=768, bias=True)
              (out_lin): Linear(in_features=768, out_features=768, bias=True)
            )
            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (ffn): FFN(
              (dr

ipdb>  self.args.use_encoder_parallel


True


ipdb>  n


> /tmp/ipykernel_881/2079912209.py(155)_inner_training_loop()
    153 
    154     # prepare using `accelerator` prepare
--> 155     if use_accelerator_prepare:
    156         self.model.train()
    157         if hasattr(self.lr_scheduler, "step"):



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(156)_inner_training_loop()
    154     # prepare using `accelerator` prepare
    155     if use_accelerator_prepare:
--> 156         self.model.train()
    157         if hasattr(self.lr_scheduler, "step"):
    158             if self.use_apex:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(157)_inner_training_loop()
    155     if use_accelerator_prepare:
    156         self.model.train()
--> 157         if hasattr(self.lr_scheduler, "step"):
    158             if self.use_apex:
    159                 model = self.accelerator.prepare(self.model)



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(158)_inner_training_loop()
    156         self.model.train()
    157         if hasattr(self.lr_scheduler, "step"):
--> 158             if self.use_apex:
    159                 model = self.accelerator.prepare(self.model)
    160             else:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(161)_inner_training_loop()
    159                 model = self.accelerator.prepare(self.model)
    160             else:
--> 161                 model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
    162         else:
    163             # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.



ipdb>  model


DBT012(
  (encoder): DBT012Encoder(
    (distilbert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0-5): 6 x TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (k_lin): Linear(in_features=768, out_features=768, bias=True)
              (v_lin): Linear(in_features=768, out_features=768, bias=True)
              (out_lin): Linear(in_features=768, out_features=768, bias=True)
            )
            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (ffn): FFN(
              (dr

ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1151)prepare()
   1149         return obj
   1150 
-> 1151     def prepare(self, *args, device_placement=None):
   1152         """
   1153         Prepare all objects passed in `args` for distributed training and mixed precision, then return them in the same



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1197)prepare()
   1195         ```
   1196         """
-> 1197         if device_placement is None:
   1198             device_placement = [None for _ in args]
   1199         elif self.distributed_type in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM):



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1198)prepare()
   1196         """
   1197         if device_placement is None:
-> 1198             device_placement = [None for _ in args]
   1199         elif self.distributed_type in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM):
   1200             raise ValueError("You can't customize device placements with DeepSpeed or Megatron-LM.")



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1206)prepare()
   1204             )
   1205 
-> 1206         for obj in args:
   1207             # TODO: Look at enabling native TP training directly with a proper config
   1208             if (



ipdb>  device_placement


[None, None]


ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1209)prepare()
   1207             # TODO: Look at enabling native TP training directly with a proper config
   1208             if (
-> 1209                 isinstance(obj, torch.nn.Module)
   1210                 and self.verify_device_map(obj)
   1211                 and self.distributed_type != DistributedType.NO



ipdb>  obj


DBT012(
  (encoder): DBT012Encoder(
    (distilbert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0-5): 6 x TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (k_lin): Linear(in_features=768, out_features=768, bias=True)
              (v_lin): Linear(in_features=768, out_features=768, bias=True)
              (out_lin): Linear(in_features=768, out_features=768, bias=True)
            )
            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (ffn): FFN(
              (dr

ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1208)prepare()
   1206         for obj in args:
   1207             # TODO: Look at enabling native TP training directly with a proper config
-> 1208             if (
   1209                 isinstance(obj, torch.nn.Module)
   1210                 and self.verify_device_map(obj)



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1210)prepare()
   1208             if (
   1209                 isinstance(obj, torch.nn.Module)
-> 1210                 and self.verify_device_map(obj)
   1211                 and self.distributed_type != DistributedType.NO
   1212                 and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true"



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1208)prepare()
   1206         for obj in args:
   1207             # TODO: Look at enabling native TP training directly with a proper config
-> 1208             if (
   1209                 isinstance(obj, torch.nn.Module)
   1210                 and self.verify_device_map(obj)



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1206)prepare()
   1204             )
   1205 
-> 1206         for obj in args:
   1207             # TODO: Look at enabling native TP training directly with a proper config
   1208             if (



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1209)prepare()
   1207             # TODO: Look at enabling native TP training directly with a proper config
   1208             if (
-> 1209                 isinstance(obj, torch.nn.Module)
   1210                 and self.verify_device_map(obj)
   1211                 and self.distributed_type != DistributedType.NO



ipdb>  obj


AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 5e-05
    lr: 5e-05
    maximize: False
    weight_decay: 0.01

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 5e-05
    lr: 5e-05
    maximize: False
    weight_decay: 0.0
)


ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1208)prepare()
   1206         for obj in args:
   1207             # TODO: Look at enabling native TP training directly with a proper config
-> 1208             if (
   1209                 isinstance(obj, torch.nn.Module)
   1210                 and self.verify_device_map(obj)



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1206)prepare()
   1204             )
   1205 
-> 1206         for obj in args:
   1207             # TODO: Look at enabling native TP training directly with a proper config
   1208             if (



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1219)prepare()
   1217                 )
   1218 
-> 1219         if self.distributed_type == DistributedType.DEEPSPEED:
   1220             model_count = 0
   1221             for obj in args:



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1232)prepare()
   1230         # have parameters disconnected from the model (so no training :-( ).
   1231         # If the model and optimizer have parameters on different devices we raise an error.
-> 1232         if self.distributed_type == DistributedType.XLA:
   1233             model_device, optimizer_device = self._get_devices()
   1234             if model_device is not None and optimizer_device is not None and model_device != optimizer_device:



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1244)prepare()
   1242 
   1243         # If we're dealing with device placement, this deals with that by...
-> 1244         tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.XLA
   1245         if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
   1246             # 1. grabbing old model parameters



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1245)prepare()
   1243         # If we're dealing with device placement, this deals with that by...
   1244         tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.XLA
-> 1245         if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
   1246             # 1. grabbing old model parameters
   1247             old_named_params = self._get_named_parameters(*args)



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1249)prepare()
   1247             old_named_params = self._get_named_parameters(*args)
   1248 
-> 1249         if self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
   1250             if self.device.type == "cpu" and self.state.use_ipex:
   1251                 args = self._prepare_ipex(*args)



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1250)prepare()
   1248 
   1249         if self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
-> 1250             if self.device.type == "cpu" and self.state.use_ipex:
   1251                 args = self._prepare_ipex(*args)
   1252             elif self.device.type == "xpu" and is_xpu_available():



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1252)prepare()
   1250             if self.device.type == "cpu" and self.state.use_ipex:
   1251                 args = self._prepare_ipex(*args)
-> 1252             elif self.device.type == "xpu" and is_xpu_available():
   1253                 args = self._prepare_ipex(*args)
   1254         if self.distributed_type == DistributedType.DEEPSPEED:



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1254)prepare()
   1252             elif self.device.type == "xpu" and is_xpu_available():
   1253                 args = self._prepare_ipex(*args)
-> 1254         if self.distributed_type == DistributedType.DEEPSPEED:
   1255             result = self._prepare_deepspeed(*args)
   1256         elif self.distributed_type == DistributedType.MEGATRON_LM:



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1256)prepare()
   1254         if self.distributed_type == DistributedType.DEEPSPEED:
   1255             result = self._prepare_deepspeed(*args)
-> 1256         elif self.distributed_type == DistributedType.MEGATRON_LM:
   1257             result = self._prepare_megatron_lm(*args)
   1258         else:



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1259)prepare()
   1257             result = self._prepare_megatron_lm(*args)
   1258         else:
-> 1259             if self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "MSAMP":
   1260                 args = self._prepare_msamp(*args)
   1261                 # MS-AMP will handle the device placement



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1263)prepare()
   1261                 # MS-AMP will handle the device placement
   1262                 device_placement = [False for _ in args]
-> 1263             result = tuple(
   1264                 self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
   1265             )



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1264)prepare()
   1262                 device_placement = [False for _ in args]
   1263             result = tuple(
-> 1264                 self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
   1265             )
   1266             result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1263)prepare()
   1261                 # MS-AMP will handle the device placement
   1262                 device_placement = [False for _ in args]
-> 1263             result = tuple(
   1264                 self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
   1265             )



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1266)prepare()
   1264                 self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
   1265             )
-> 1266             result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))
   1267 
   1268         if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):



ipdb>  result


(DBT012(
  (encoder): DBT012Encoder(
    (distilbert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0-5): 6 x TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (k_lin): Linear(in_features=768, out_features=768, bias=True)
              (v_lin): Linear(in_features=768, out_features=768, bias=True)
              (out_lin): Linear(in_features=768, out_features=768, bias=True)
            )
            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (ffn): FFN(
              (d

ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1268)prepare()
   1266             result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))
   1267 
-> 1268         if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
   1269             # 2. grabbing new model parameters
   1270             new_named_params = self._get_named_parameters(*result)



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1278)prepare()
   1276                     obj._switch_parameters(mapping)
   1277 
-> 1278         for item in result:
   1279             if any(
   1280                 item in container



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1279)prepare()
   1277 
   1278         for item in result:
-> 1279             if any(
   1280                 item in container
   1281                 for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1281)prepare()
   1279             if any(
   1280                 item in container
-> 1281                 for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)
   1282             ):
   1283                 item._is_accelerate_prepared = True



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1279)prepare()
   1277 
   1278         for item in result:
-> 1279             if any(
   1280                 item in container
   1281                 for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1283)prepare()
   1281                 for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)
   1282             ):
-> 1283                 item._is_accelerate_prepared = True
   1284 
   1285         return result if len(result) > 1 else result[0]



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1278)prepare()
   1276                     obj._switch_parameters(mapping)
   1277 
-> 1278         for item in result:
   1279             if any(
   1280                 item in container



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1279)prepare()
   1277 
   1278         for item in result:
-> 1279             if any(
   1280                 item in container
   1281                 for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1281)prepare()
   1279             if any(
   1280                 item in container
-> 1281                 for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)
   1282             ):
   1283                 item._is_accelerate_prepared = True



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1279)prepare()
   1277 
   1278         for item in result:
-> 1279             if any(
   1280                 item in container
   1281                 for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1283)prepare()
   1281                 for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)
   1282             ):
-> 1283                 item._is_accelerate_prepared = True
   1284 
   1285         return result if len(result) > 1 else result[0]



ipdb>  item


AcceleratedOptimizer (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 5e-05
    lr: 5e-05
    maximize: False
    weight_decay: 0.01

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 5e-05
    lr: 5e-05
    maximize: False
    weight_decay: 0.0
)


ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1278)prepare()
   1276                     obj._switch_parameters(mapping)
   1277 
-> 1278         for item in result:
   1279             if any(
   1280                 item in container



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1285)prepare()
   1283                 item._is_accelerate_prepared = True
   1284 
-> 1285         return result if len(result) > 1 else result[0]
   1286 
   1287     def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False):



ipdb>  n


--Return--
(DBT012(
  (en... )
    )
  )
), AcceleratedOp...t_decay: 0.0
))
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1285)prepare()
   1283                 item._is_accelerate_prepared = True
   1284 
-> 1285         return result if len(result) > 1 else result[0]
   1286 
   1287     def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False):



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(168)_inner_training_loop()
    166             )
    167 
--> 168     if self.is_fsdp_enabled:
    169         self.model = self.model_wrapped = model
    170 



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(172)_inner_training_loop()
    170 
    171     # for the rest of this function `model` is the outside model, whether it was wrapped or not
--> 172     if model is not self.model:
    173         self.model_wrapped = model
    174 



ipdb>  model


DBT012(
  (encoder): DBT012Encoder(
    (distilbert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0-5): 6 x TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (k_lin): Linear(in_features=768, out_features=768, bias=True)
              (v_lin): Linear(in_features=768, out_features=768, bias=True)
              (out_lin): Linear(in_features=768, out_features=768, bias=True)
            )
            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (ffn): FFN(
              (dr

ipdb>  model is self.model


True


ipdb>  n


> /tmp/ipykernel_881/2079912209.py(176)_inner_training_loop()
    174 
    175     # backward compatibility
--> 176     if self.is_deepspeed_enabled:
    177         self.deepspeed = self.model_wrapped
    178 



ipdb>  


> /tmp/ipykernel_881/2079912209.py(180)_inner_training_loop()
    178 
    179     # ckpt loading
--> 180     if resume_from_checkpoint is not None:
    181         if self.is_deepspeed_enabled:
    182             deepspeed_load_checkpoint(



ipdb>  


> /tmp/ipykernel_881/2079912209.py(189)_inner_training_loop()
    187 
    188     # Check if saved optimizer or scheduler states exist
--> 189     self._load_optimizer_and_scheduler(resume_from_checkpoint)
    190 
    191     # important: at this point:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(197)_inner_training_loop()
    195 
    196     # Train!
--> 197     logger.info("***** Running training *****")
    198     logger.info(f"  Num examples = {num_examples:,}")
    199     logger.info(f"  Num Epochs = {num_train_epochs:,}")



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(198)_inner_training_loop()
    196     # Train!
    197     logger.info("***** Running training *****")
--> 198     logger.info(f"  Num examples = {num_examples:,}")
    199     logger.info(f"  Num Epochs = {num_train_epochs:,}")
    200     logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")



ipdb>  self.optimizer


AcceleratedOptimizer (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 5e-05
    lr: 5e-05
    maximize: False
    weight_decay: 0.01

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 5e-05
    lr: 5e-05
    maximize: False
    weight_decay: 0.0
)


ipdb>  self.lr_scheduler


<torch.optim.lr_scheduler.LambdaLR object>


ipdb>  self.lr_scheduler.state_dict()


{'base_lrs': [5e-05, 5e-05], 'last_epoch': 0, 'verbose': False, '_step_count': 1, '_get_lr_called_within_step': False, '_last_lr': [5e-05, 5e-05], 'lr_lambdas': [{}, {}]}


ipdb>  n


> /tmp/ipykernel_881/2079912209.py(199)_inner_training_loop()
    197     logger.info("***** Running training *****")
    198     logger.info(f"  Num examples = {num_examples:,}")
--> 199     logger.info(f"  Num Epochs = {num_train_epochs:,}")
    200     logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
    201     if self.args.per_device_train_batch_size != self._train_batch_size:



ipdb>  n


> /tmp/ipykernel_881/2079912209.py(200)_inner_training_loop()
    198     logger.info(f"  Num examples = {num_examples:,}")
    199     logger.info(f"  Num Epochs = {num_train_epochs:,}")
--> 200     logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
    201     if self.args.per_device_train_batch_size != self._train_batch_size:
    202         logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")



ipdb>  


> /tmp/ipykernel_881/2079912209.py(201)_inner_training_loop()
    199     logger.info(f"  Num Epochs = {num_train_epochs:,}")
    200     logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
--> 201     if self.args.per_device_train_batch_size != self._train_batch_size:
    202         logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
    203     logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")



ipdb>  


> /tmp/ipykernel_881/2079912209.py(202)_inner_training_loop()
    200     logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
    201     if self.args.per_device_train_batch_size != self._train_batch_size:
--> 202         logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
    203     logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
    204     logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")



ipdb>  


> /tmp/ipykernel_881/2079912209.py(203)_inner_training_loop()
    201     if self.args.per_device_train_batch_size != self._train_batch_size:
    202         logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
--> 203     logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
    204     logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    205     logger.info(f"  Total optimization steps = {max_steps:,}")



ipdb>  


> /tmp/ipykernel_881/2079912209.py(204)_inner_training_loop()
    202         logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
    203     logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
--> 204     logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    205     logger.info(f"  Total optimization steps = {max_steps:,}")
    206     logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")



ipdb>  


> /tmp/ipykernel_881/2079912209.py(205)_inner_training_loop()
    203     logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
    204     logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
--> 205     logger.info(f"  Total optimization steps = {max_steps:,}")
    206     logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
    207 



ipdb>  


> /tmp/ipykernel_881/2079912209.py(206)_inner_training_loop()
    204     logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    205     logger.info(f"  Total optimization steps = {max_steps:,}")
--> 206     logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
    207 
    208     self.state.epoch = 0



ipdb>  


> /tmp/ipykernel_881/2079912209.py(208)_inner_training_loop()
    206     logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
    207 
--> 208     self.state.epoch = 0
    209     start_time = time.time()
    210     epochs_trained = 0



ipdb>  


> /tmp/ipykernel_881/2079912209.py(209)_inner_training_loop()
    207 
    208     self.state.epoch = 0
--> 209     start_time = time.time()
    210     epochs_trained = 0
    211     steps_trained_in_current_epoch = 0



ipdb>  


> /tmp/ipykernel_881/2079912209.py(210)_inner_training_loop()
    208     self.state.epoch = 0
    209     start_time = time.time()
--> 210     epochs_trained = 0
    211     steps_trained_in_current_epoch = 0
    212     steps_trained_progress_bar = None



ipdb>  


> /tmp/ipykernel_881/2079912209.py(211)_inner_training_loop()
    209     start_time = time.time()
    210     epochs_trained = 0
--> 211     steps_trained_in_current_epoch = 0
    212     steps_trained_progress_bar = None
    213 



ipdb>  


> /tmp/ipykernel_881/2079912209.py(212)_inner_training_loop()
    210     epochs_trained = 0
    211     steps_trained_in_current_epoch = 0
--> 212     steps_trained_progress_bar = None
    213 
    214     # Check if continuing training from a checkpoint



ipdb>  


> /tmp/ipykernel_881/2079912209.py(215)_inner_training_loop()
    213 
    214     # Check if continuing training from a checkpoint
--> 215     if resume_from_checkpoint is not None and os.path.isfile(
    216         os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
    217     ):



ipdb>  


> /tmp/ipykernel_881/2079912209.py(236)_inner_training_loop()
    234 
    235     # Update the references
--> 236     self.callback_handler.model = self.model
    237     self.callback_handler.optimizer = self.optimizer
    238     self.callback_handler.lr_scheduler = self.lr_scheduler



ipdb>  


> /tmp/ipykernel_881/2079912209.py(237)_inner_training_loop()
    235     # Update the references
    236     self.callback_handler.model = self.model
--> 237     self.callback_handler.optimizer = self.optimizer
    238     self.callback_handler.lr_scheduler = self.lr_scheduler
    239     self.callback_handler.train_dataloader = train_dataloader



ipdb>  


> /tmp/ipykernel_881/2079912209.py(238)_inner_training_loop()
    236     self.callback_handler.model = self.model
    237     self.callback_handler.optimizer = self.optimizer
--> 238     self.callback_handler.lr_scheduler = self.lr_scheduler
    239     self.callback_handler.train_dataloader = train_dataloader
    240     if self.hp_name is not None and self._trial is not None:



ipdb>  


> /tmp/ipykernel_881/2079912209.py(239)_inner_training_loop()
    237     self.callback_handler.optimizer = self.optimizer
    238     self.callback_handler.lr_scheduler = self.lr_scheduler
--> 239     self.callback_handler.train_dataloader = train_dataloader
    240     if self.hp_name is not None and self._trial is not None:
    241         # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial



ipdb>  


> /tmp/ipykernel_881/2079912209.py(240)_inner_training_loop()
    238     self.callback_handler.lr_scheduler = self.lr_scheduler
    239     self.callback_handler.train_dataloader = train_dataloader
--> 240     if self.hp_name is not None and self._trial is not None:
    241         # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
    242         # parameter to Train when using DDP.



ipdb>  


> /tmp/ipykernel_881/2079912209.py(244)_inner_training_loop()
    242         # parameter to Train when using DDP.
    243         self.state.trial_name = self.hp_name(self._trial)
--> 244     if trial is not None:
    245         assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
    246         self.state.trial_params = hp_params(assignments)



ipdb>  q


In [None]:
#| hide
o = learn.predict(learn.eval_dataset)

> /tmp/ipykernel_25061/3536802410.py(57)predict()
     55         import pdb; pdb.set_trace()
     56         #debug
---> 57         gen_kwargs = gen_kwargs.copy()
     58         if gen_kwargs.get("length_penalty") is None and self.args.generation_length_penalty is not None:
     59             gen_kwargs["length_penalty"] = self.args.generation_length_penalty



ipdb>  n


> /tmp/ipykernel_25061/3536802410.py(58)predict()
     56         #debug
     57         gen_kwargs = gen_kwargs.copy()
---> 58         if gen_kwargs.get("length_penalty") is None and self.args.generation_length_penalty is not None:
     59             gen_kwargs["length_penalty"] = self.args.generation_length_penalty
     60         if gen_kwargs.get("gen_num_beams") is None and self.args.generation_num_beams is not None:



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(59)predict()
     57         gen_kwargs = gen_kwargs.copy()
     58         if gen_kwargs.get("length_penalty") is None and self.args.generation_length_penalty is not None:
---> 59             gen_kwargs["length_penalty"] = self.args.generation_length_penalty
     60         if gen_kwargs.get("gen_num_beams") is None and self.args.generation_num_beams is not None:
     61             gen_kwargs["gen_num_beams"] = self.args.generation_num_beams



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(60)predict()
     58         if gen_kwargs.get("length_penalty") is None and self.args.generation_length_penalty is not None:
     59             gen_kwargs["length_penalty"] = self.args.generation_length_penalty
---> 60         if gen_kwargs.get("gen_num_beams") is None and self.args.generation_num_beams is not None:
     61             gen_kwargs["gen_num_beams"] = self.args.generation_num_beams
     62         if gen_kwargs.get("repr_num_beams") is None and self.args.representation_num_beams is not None:



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(61)predict()
     59             gen_kwargs["length_penalty"] = self.args.generation_length_penalty
     60         if gen_kwargs.get("gen_num_beams") is None and self.args.generation_num_beams is not None:
---> 61             gen_kwargs["gen_num_beams"] = self.args.generation_num_beams
     62         if gen_kwargs.get("repr_num_beams") is None and self.args.representation_num_beams is not None:
     63             gen_kwargs["repr_num_beams"] = self.args.representation_num_beams



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(62)predict()
     60         if gen_kwargs.get("gen_num_beams") is None and self.args.generation_num_beams is not None:
     61             gen_kwargs["gen_num_beams"] = self.args.generation_num_beams
---> 62         if gen_kwargs.get("repr_num_beams") is None and self.args.representation_num_beams is not None:
     63             gen_kwargs["repr_num_beams"] = self.args.representation_num_beams
     64         if gen_kwargs.get("aug_num_beams") is None and self.args.augmentation_num_beams is not None:



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(63)predict()
     61             gen_kwargs["gen_num_beams"] = self.args.generation_num_beams
     62         if gen_kwargs.get("repr_num_beams") is None and self.args.representation_num_beams is not None:
---> 63             gen_kwargs["repr_num_beams"] = self.args.representation_num_beams
     64         if gen_kwargs.get("aug_num_beams") is None and self.args.augmentation_num_beams is not None:
     65             gen_kwargs["aug_num_beams"] = self.args.augmentation_num_beams



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(64)predict()
     62         if gen_kwargs.get("repr_num_beams") is None and self.args.representation_num_beams is not None:
     63             gen_kwargs["repr_num_beams"] = self.args.representation_num_beams
---> 64         if gen_kwargs.get("aug_num_beams") is None and self.args.augmentation_num_beams is not None:
     65             gen_kwargs["aug_num_beams"] = self.args.augmentation_num_beams
     66 



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(65)predict()
     63             gen_kwargs["repr_num_beams"] = self.args.representation_num_beams
     64         if gen_kwargs.get("aug_num_beams") is None and self.args.augmentation_num_beams is not None:
---> 65             gen_kwargs["aug_num_beams"] = self.args.augmentation_num_beams
     66 
     67         self.gather_function, self._gen_kwargs = self.accelerator.gather, gen_kwargs



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(67)predict()
     65             gen_kwargs["aug_num_beams"] = self.args.augmentation_num_beams
     66 
---> 67         self.gather_function, self._gen_kwargs = self.accelerator.gather, gen_kwargs
     68         self._memory_tracker.start()
     69 



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(68)predict()
     66 
     67         self.gather_function, self._gen_kwargs = self.accelerator.gather, gen_kwargs
---> 68         self._memory_tracker.start()
     69 
     70         if self._perform_representation(unwrap_model(self.model)) and not self.args.prediction_loss_only:



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(70)predict()
     68         self._memory_tracker.start()
     69 
---> 70         if self._perform_representation(unwrap_model(self.model)) and not self.args.prediction_loss_only:
     71             self._build_lbl_index(test_dataset)
     72 



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(71)predict()
     69 
     70         if self._perform_representation(unwrap_model(self.model)) and not self.args.prediction_loss_only:
---> 71             self._build_lbl_index(test_dataset)
     72 
     73         if self._perform_augmentation(unwrap_model(self.model)) and not self.args.prediction_loss_only:



ipdb>  s


--Call--
> /tmp/ipykernel_25061/4131283268.py(25)_build_lbl_index()
     23         self.aug_idxs.build(aug_repr)
     24 
---> 25 @patch
     26 def _build_lbl_index(self:XCLearner, dataset:Optional[Dataset]=None):
     27     dataset = dataset if self.eval_dataset is None else self.eval_dataset



ipdb>  n


> /tmp/ipykernel_25061/4131283268.py(27)_build_lbl_index()
     25 @patch
     26 def _build_lbl_index(self:XCLearner, dataset:Optional[Dataset]=None):
---> 27     dataset = dataset if self.eval_dataset is None else self.eval_dataset
     28     dataset = dataset if self.train_dataset is None else self.train_dataset
     29 



ipdb>  


> /tmp/ipykernel_25061/4131283268.py(28)_build_lbl_index()
     26 def _build_lbl_index(self:XCLearner, dataset:Optional[Dataset]=None):
     27     dataset = dataset if self.eval_dataset is None else self.eval_dataset
---> 28     dataset = dataset if self.train_dataset is None else self.train_dataset
     29 
     30     if dataset is not None:



ipdb>  


> /tmp/ipykernel_25061/4131283268.py(30)_build_lbl_index()
     28     dataset = dataset if self.train_dataset is None else self.train_dataset
     29 
---> 30     if dataset is not None:
     31         lbl_dset = dataset.lbl_dset
     32 



ipdb>  


> /tmp/ipykernel_25061/4131283268.py(31)_build_lbl_index()
     29 
     30     if dataset is not None:
---> 31         lbl_dset = dataset.lbl_dset
     32 
     33         meta_name = f'{self.args.data_aug_meta_name}_meta' if self.args.data_aug_meta_name is not None else None



ipdb>  


> /tmp/ipykernel_25061/4131283268.py(33)_build_lbl_index()
     31         lbl_dset = dataset.lbl_dset
     32 
---> 33         meta_name = f'{self.args.data_aug_meta_name}_meta' if self.args.data_aug_meta_name is not None else None
     34         if meta_name is not None and dataset.meta is not None and meta_name in dataset.meta:
     35             prefix,lbl_meta,meta_info  = dataset.meta[meta_name].prefix,dataset.meta[meta_name].lbl_meta,dataset.meta[meta_name].meta_info



ipdb>  lbl_dset.n_data = 100
ipdb>  n


> /tmp/ipykernel_25061/4131283268.py(34)_build_lbl_index()
     32 
     33         meta_name = f'{self.args.data_aug_meta_name}_meta' if self.args.data_aug_meta_name is not None else None
---> 34         if meta_name is not None and dataset.meta is not None and meta_name in dataset.meta:
     35             prefix,lbl_meta,meta_info  = dataset.meta[meta_name].prefix,dataset.meta[meta_name].lbl_meta,dataset.meta[meta_name].meta_info
     36             meta_kwargs = {meta_name: MetaXCDataset(prefix, lbl_meta, lbl_meta, meta_info, n_data_meta_samples=self.args.augmentation_num_beams)}



ipdb>  


> /tmp/ipykernel_25061/4131283268.py(39)_build_lbl_index()
     37             lbl_dset = XCDataset(lbl_dset, **meta_kwargs)
     38 
---> 39         lbl_dl = self.get_test_dataloader(lbl_dset)
     40         lbl_repr = self.get_representation(lbl_dl, to_cpu=isinstance(self.idxs, IndexSearch))
     41         if self.args.use_distributional_representation: lbl_repr = F.log_softmax(lbl_repr, dim=-1)



ipdb>  


> /tmp/ipykernel_25061/4131283268.py(40)_build_lbl_index()
     38 
     39         lbl_dl = self.get_test_dataloader(lbl_dset)
---> 40         lbl_repr = self.get_representation(lbl_dl, to_cpu=isinstance(self.idxs, IndexSearch))
     41         if self.args.use_distributional_representation: lbl_repr = F.log_softmax(lbl_repr, dim=-1)
     42 



ipdb>  


  0%|          | 0/1 [00:00<?, ?it/s]

> /tmp/ipykernel_25061/4131283268.py(41)_build_lbl_index()
     39         lbl_dl = self.get_test_dataloader(lbl_dset)
     40         lbl_repr = self.get_representation(lbl_dl, to_cpu=isinstance(self.idxs, IndexSearch))
---> 41         if self.args.use_distributional_representation: lbl_repr = F.log_softmax(lbl_repr, dim=-1)
     42 
     43         self.idxs.build(lbl_repr)



ipdb>  


> /tmp/ipykernel_25061/4131283268.py(43)_build_lbl_index()
     41         if self.args.use_distributional_representation: lbl_repr = F.log_softmax(lbl_repr, dim=-1)
     42 
---> 43         self.idxs.build(lbl_repr)
     44     else: raise ValueError('Failed to build `self.idxs`')
     45 



ipdb>  


--Return--
None
> /tmp/ipykernel_25061/4131283268.py(43)_build_lbl_index()
     41         if self.args.use_distributional_representation: lbl_repr = F.log_softmax(lbl_repr, dim=-1)
     42 
---> 43         self.idxs.build(lbl_repr)
     44     else: raise ValueError('Failed to build `self.idxs`')
     45 



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(73)predict()
     71             self._build_lbl_index(test_dataset)
     72 
---> 73         if self._perform_augmentation(unwrap_model(self.model)) and not self.args.prediction_loss_only:
     74             self._build_aug_index(test_dataset)
     75 



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(76)predict()
     74             self._build_aug_index(test_dataset)
     75 
---> 76         test_dataloader = self.get_test_dataloader(test_dataset)
     77         start_time = time.time()
     78 



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(77)predict()
     75 
     76         test_dataloader = self.get_test_dataloader(test_dataset)
---> 77         start_time = time.time()
     78 
     79         output = self.evaluation_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)



ipdb>  


> /tmp/ipykernel_25061/3536802410.py(79)predict()
     77         start_time = time.time()
     78 
---> 79         output = self.evaluation_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
     80         total_batch_size = self.args.eval_batch_size * self.args.world_size
     81         if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:



ipdb>  s


--Call--
> /tmp/ipykernel_25061/1638443235.py(2)evaluation_loop()
      1 #| export
----> 2 @patch
      3 def evaluation_loop(
      4     self:XCLearner,
      5     dataloader:DataLoader,



ipdb>  n


> /tmp/ipykernel_25061/1638443235.py(13)evaluation_loop()
     11     metric_key_prefix:str="eval",
     12 ) -> XCEvalLoopOutput:
---> 13     args = self.args
     14     prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
     15 



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(14)evaluation_loop()
     12 ) -> XCEvalLoopOutput:
     13     args = self.args
---> 14     prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
     15 
     16     model = self._wrap_model(self.model, training=False, dataloader=dataloader)



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(16)evaluation_loop()
     14     prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
     15 
---> 16     model = self._wrap_model(self.model, training=False, dataloader=dataloader)
     17 
     18     if len(self.accelerator._models) == 0 and model is self.model:



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(18)evaluation_loop()
     16     model = self._wrap_model(self.model, training=False, dataloader=dataloader)
     17 
---> 18     if len(self.accelerator._models) == 0 and model is self.model:
     19         model = self.accelerator.prepare(model) if self.is_deepspeed_enabled else self.accelerator.prepare_model(model, evaluation_mode=True)
     20         if self.is_fsdp_enabled: self.model = model



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(24)evaluation_loop()
     22         if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped
     23 
---> 24     batch_size = self.args.eval_batch_size
     25 
     26     model.eval()



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(26)evaluation_loop()
     24     batch_size = self.args.eval_batch_size
     25 
---> 26     model.eval()
     27 
     28     self.callback_handler.eval_dataloader = dataloader



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(28)evaluation_loop()
     26     model.eval()
     27 
---> 28     self.callback_handler.eval_dataloader = dataloader
     29     eval_dataset = getattr(dataloader, "dataset", None)
     30 



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(29)evaluation_loop()
     27 
     28     self.callback_handler.eval_dataloader = dataloader
---> 29     eval_dataset = getattr(dataloader, "dataset", None)
     30 
     31     if args.past_index >= 0: self._past = None



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(31)evaluation_loop()
     29     eval_dataset = getattr(dataloader, "dataset", None)
     30 
---> 31     if args.past_index >= 0: self._past = None
     32 
     33     losses_host, all_losses = None, None



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(33)evaluation_loop()
     31     if args.past_index >= 0: self._past = None
     32 
---> 33     losses_host, all_losses = None, None
     34     host_output, all_output = {}, {}
     35 



ipdb>  n


> /tmp/ipykernel_25061/1638443235.py(34)evaluation_loop()
     32 
     33     losses_host, all_losses = None, None
---> 34     host_output, all_output = {}, {}
     35 
     36     observed_num_examples = 0



ipdb>  n


> /tmp/ipykernel_25061/1638443235.py(36)evaluation_loop()
     34     host_output, all_output = {}, {}
     35 
---> 36     observed_num_examples = 0
     37     for step, inputs in enumerate(dataloader):
     38         observed_batch_size = find_batch_size(inputs)



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(37)evaluation_loop()
     35 
     36     observed_num_examples = 0
---> 37     for step, inputs in enumerate(dataloader):
     38         observed_batch_size = find_batch_size(inputs)
     39         if observed_batch_size is not None:



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(38)evaluation_loop()
     36     observed_num_examples = 0
     37     for step, inputs in enumerate(dataloader):
---> 38         observed_batch_size = find_batch_size(inputs)
     39         if observed_batch_size is not None:
     40             observed_num_examples += observed_batch_size



ipdb>  n


> /tmp/ipykernel_25061/1638443235.py(39)evaluation_loop()
     37     for step, inputs in enumerate(dataloader):
     38         observed_batch_size = find_batch_size(inputs)
---> 39         if observed_batch_size is not None:
     40             observed_num_examples += observed_batch_size
     41             if batch_size is None: batch_size = observed_batch_size



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(40)evaluation_loop()
     38         observed_batch_size = find_batch_size(inputs)
     39         if observed_batch_size is not None:
---> 40             observed_num_examples += observed_batch_size
     41             if batch_size is None: batch_size = observed_batch_size
     42 



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(41)evaluation_loop()
     39         if observed_batch_size is not None:
     40             observed_num_examples += observed_batch_size
---> 41             if batch_size is None: batch_size = observed_batch_size
     42 
     43         loss, output = self.prediction_step(model, inputs, prediction_loss_only, predict_with_generation, predict_with_representation, ignore_keys=ignore_keys)



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(43)evaluation_loop()
     41             if batch_size is None: batch_size = observed_batch_size
     42 
---> 43         loss, output = self.prediction_step(model, inputs, prediction_loss_only, predict_with_generation, predict_with_representation, ignore_keys=ignore_keys)
     44 
     45         if loss is not None:



ipdb>  s


--Call--
> /tmp/ipykernel_25061/2932614700.py(2)prediction_step()
      1 #| export
----> 2 @patch
      3 def prediction_step(
      4     self:XCLearner,
      5     model: nn.Module,



ipdb>  n


> /tmp/ipykernel_25061/2932614700.py(14)prediction_step()
     12     **kwargs,
     13 ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
---> 14     with torch.no_grad():
     15         with self.compute_loss_context_manager(): outputs = model(**inputs)
     16         loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(15)prediction_step()
     13 ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
     14     with torch.no_grad():
---> 15         with self.compute_loss_context_manager(): outputs = model(**inputs)
     16         loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
     17     prediction_loss_only = self.args.prediction_loss_only if prediction_loss_only is None else prediction_loss_only



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(16)prediction_step()
     14     with torch.no_grad():
     15         with self.compute_loss_context_manager(): outputs = model(**inputs)
---> 16         loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
     17     prediction_loss_only = self.args.prediction_loss_only if prediction_loss_only is None else prediction_loss_only
     18     if prediction_loss_only: return loss, {}



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(17)prediction_step()
     15         with self.compute_loss_context_manager(): outputs = model(**inputs)
     16         loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
---> 17     prediction_loss_only = self.args.prediction_loss_only if prediction_loss_only is None else prediction_loss_only
     18     if prediction_loss_only: return loss, {}
     19 



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(18)prediction_step()
     16         loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
     17     prediction_loss_only = self.args.prediction_loss_only if prediction_loss_only is None else prediction_loss_only
---> 18     if prediction_loss_only: return loss, {}
     19 
     20     if self._perform_augmentation(model, predict_with_augmentation):



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(20)prediction_step()
     18     if prediction_loss_only: return loss, {}
     19 
---> 20     if self._perform_augmentation(model, predict_with_augmentation):
     21         aug_inputs = self.augmentation_output(model, inputs, **kwargs)
     22         inputs.update(aug_inputs)



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(24)prediction_step()
     22         inputs.update(aug_inputs)
     23 
---> 24     output, gen_o, repr_o = None, None, None
     25     if self._perform_generation(model, predict_with_generation): gen_o = self.generation_output(model, inputs, **kwargs)
     26     if self._perform_representation(model, predict_with_representation): repr_o = self.representation_output(model, inputs, **kwargs)



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(25)prediction_step()
     23 
     24     output, gen_o, repr_o = None, None, None
---> 25     if self._perform_generation(model, predict_with_generation): gen_o = self.generation_output(model, inputs, **kwargs)
     26     if self._perform_representation(model, predict_with_representation): repr_o = self.representation_output(model, inputs, **kwargs)
     27 



ipdb>  self._perform_generation(model, predict_with_generation)


False


ipdb>  n


> /tmp/ipykernel_25061/2932614700.py(26)prediction_step()
     24     output, gen_o, repr_o = None, None, None
     25     if self._perform_generation(model, predict_with_generation): gen_o = self.generation_output(model, inputs, **kwargs)
---> 26     if self._perform_representation(model, predict_with_representation): repr_o = self.representation_output(model, inputs, **kwargs)
     27 
     28     if gen_o is not None and repr_o is not None:



ipdb>  s


--Call--
> /tmp/ipykernel_25061/2264972374.py(8)_perform_representation()
      6     return getattr(model,'use_generation') if hasattr(model,'use_generation') else predict_with_generation
      7 
----> 8 @patch
      9 def _perform_representation(self:XCLearner, model:nn.Module, predict_with_representation:Optional[bool]=None):
     10     model = unwrap_model(model)



ipdb>  n


> /tmp/ipykernel_25061/2264972374.py(10)_perform_representation()
      8 @patch
      9 def _perform_representation(self:XCLearner, model:nn.Module, predict_with_representation:Optional[bool]=None):
---> 10     model = unwrap_model(model)
     11     predict_with_representation = self.args.predict_with_representation if predict_with_representation is None else predict_with_representation
     12     return getattr(model,'use_representation') if hasattr(model,'use_representation') else predict_with_representation



ipdb>  


> /tmp/ipykernel_25061/2264972374.py(11)_perform_representation()
      9 def _perform_representation(self:XCLearner, model:nn.Module, predict_with_representation:Optional[bool]=None):
     10     model = unwrap_model(model)
---> 11     predict_with_representation = self.args.predict_with_representation if predict_with_representation is None else predict_with_representation
     12     return getattr(model,'use_representation') if hasattr(model,'use_representation') else predict_with_representation
     13 



ipdb>  


> /tmp/ipykernel_25061/2264972374.py(12)_perform_representation()
     10     model = unwrap_model(model)
     11     predict_with_representation = self.args.predict_with_representation if predict_with_representation is None else predict_with_representation
---> 12     return getattr(model,'use_representation') if hasattr(model,'use_representation') else predict_with_representation
     13 
     14 @patch



ipdb>  


--Return--
True
> /tmp/ipykernel_25061/2264972374.py(12)_perform_representation()
     10     model = unwrap_model(model)
     11     predict_with_representation = self.args.predict_with_representation if predict_with_representation is None else predict_with_representation
---> 12     return getattr(model,'use_representation') if hasattr(model,'use_representation') else predict_with_representation
     13 
     14 @patch



ipdb>  n


--Call--
> /tmp/ipykernel_25061/3649135444.py(17)representation_output()
     15     return {'pred_idx':o['info2seq2data_idx'], 'pred_score':o['info2seq2data_score'], 'pred_ptr':o['info2seq2data_data2ptr']}
     16 
---> 17 @patch
     18 def representation_output(
     19     self:XCLearner,



ipdb>  n


> /tmp/ipykernel_25061/3649135444.py(24)representation_output()
     22     **kwargs
     23 ):
---> 24     inputs = self._prepare_inputs(inputs)
     25     n_bm = kwargs.pop("repr_num_beams") if "repr_num_beams" in kwargs and kwargs["repr_num_beams"] is not None else self.args.representation_num_beams
     26 



ipdb>  


> /tmp/ipykernel_25061/3649135444.py(25)representation_output()
     23 ):
     24     inputs = self._prepare_inputs(inputs)
---> 25     n_bm = kwargs.pop("repr_num_beams") if "repr_num_beams" in kwargs and kwargs["repr_num_beams"] is not None else self.args.representation_num_beams
     26 
     27     with torch.no_grad():



ipdb>  


> /tmp/ipykernel_25061/3649135444.py(27)representation_output()
     25     n_bm = kwargs.pop("repr_num_beams") if "repr_num_beams" in kwargs and kwargs["repr_num_beams"] is not None else self.args.representation_num_beams
     26 
---> 27     with torch.no_grad():
     28         o = getattr(model(**inputs), self.args.representation_attribute)
     29         if self.args.use_distributional_representation: o = F.softmax(o, dim=-1)



ipdb>  


> /tmp/ipykernel_25061/3649135444.py(28)representation_output()
     26 
     27     with torch.no_grad():
---> 28         o = getattr(model(**inputs), self.args.representation_attribute)
     29         if self.args.use_distributional_representation: o = F.softmax(o, dim=-1)
     30 



ipdb>  


> /tmp/ipykernel_25061/3649135444.py(29)representation_output()
     27     with torch.no_grad():
     28         o = getattr(model(**inputs), self.args.representation_attribute)
---> 29         if self.args.use_distributional_representation: o = F.softmax(o, dim=-1)
     30 
     31     o = self.idxs.proc(o, n_bm=n_bm)



ipdb>  !o.shape


torch.Size([128, 768])


ipdb>  !torch.norm(o, dim=-1)


tensor([22.6966, 23.2682, 23.3030, 24.4900, 23.1127, 21.4881, 24.3592, 23.7730,
        22.1580, 23.4856, 23.3100, 24.3136, 22.0395, 22.7433, 21.3868, 23.5914,
        22.4485, 23.3996, 22.8410, 22.8833, 24.6456, 21.8179, 22.1863, 22.9824,
        20.1039, 24.9278, 23.6765, 22.9878, 22.6013, 22.5656, 24.4056, 23.0457,
        22.7467, 23.3570, 21.3748, 24.7137, 25.2308, 22.5590, 23.6726, 24.1197,
        21.4324, 21.6063, 25.1573, 23.4217, 24.4976, 24.5576, 23.0131, 23.7130,
        24.1785, 24.1275, 24.5175, 22.9784, 24.8105, 21.2234, 22.0189, 22.3022,
        23.5721, 24.2453, 23.7139, 21.0918, 23.4735, 21.7098, 22.3643, 23.7718,
        23.9011, 24.5701, 22.8404, 22.1328, 22.8997, 23.3237, 22.6347, 22.0182,
        22.0808, 24.8773, 24.0185, 20.7367, 21.6830, 23.9921, 23.2274, 23.4083,
        20.5668, 20.0808, 24.5315, 21.2350, 23.9680, 24.0947, 23.7745, 23.2031,
        24.5928, 21.9048, 24.2429, 24.7638, 23.1487, 24.4555, 24.8517, 22.0793,
        24.2304, 24.1107, 23.9684, 22.42

ipdb>  l


     24     inputs = self._prepare_inputs(inputs)
     25     n_bm = kwargs.pop("repr_num_beams") if "repr_num_beams" in kwargs and kwargs["repr_num_beams"] is not None else self.args.representation_num_beams
     26 
     27     with torch.no_grad():
     28         o = getattr(model(**inputs), self.args.representation_attribute)
---> 29         if self.args.use_distributional_representation: o = F.softmax(o, dim=-1)
     30 
     31     o = self.idxs.proc(o, n_bm=n_bm)
     32 
     33     return {'pred_idx':o['info2data_idx'], 'pred_score':o['info2data_score'], 'pred_ptr':o['info2data_data2ptr']}
     34 



ipdb>  n


> /tmp/ipykernel_25061/3649135444.py(31)representation_output()
     29         if self.args.use_distributional_representation: o = F.softmax(o, dim=-1)
     30 
---> 31     o = self.idxs.proc(o, n_bm=n_bm)
     32 
     33     return {'pred_idx':o['info2data_idx'], 'pred_score':o['info2data_score'], 'pred_ptr':o['info2data_data2ptr']}



ipdb>  n


> /tmp/ipykernel_25061/3649135444.py(33)representation_output()
     31     o = self.idxs.proc(o, n_bm=n_bm)
     32 
---> 33     return {'pred_idx':o['info2data_idx'], 'pred_score':o['info2data_score'], 'pred_ptr':o['info2data_data2ptr']}
     34 
     35 @patch



ipdb>  


--Return--
{'pred_idx': tensor([69,  ...,  9, 34, 53]), 'pred_ptr': tensor([5, 5,..., 5, 5, 5, 5]), 'pred_score': tensor([-6.85...528, -6.8610])}
> /tmp/ipykernel_25061/3649135444.py(33)representation_output()
     31     o = self.idxs.proc(o, n_bm=n_bm)
     32 
---> 33     return {'pred_idx':o['info2data_idx'], 'pred_score':o['info2data_score'], 'pred_ptr':o['info2data_data2ptr']}
     34 
     35 @patch



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(28)prediction_step()
     26     if self._perform_representation(model, predict_with_representation): repr_o = self.representation_output(model, inputs, **kwargs)
     27 
---> 28     if gen_o is not None and repr_o is not None:
     29         output = {f'{k}_gen':v for k,v in gen_o.items()}
     30         output.update({f'{k}_repr':v for k,v in repr_o.items()})



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(33)prediction_step()
     31         output.update(self.concatenate_output(gen_o, repr_o))
     32     else:
---> 33         output = gen_o if repr_o is None else repr_o
     34 
     35     labels = {'targ_idx':inputs[self.args.target_indices_key], 'targ_ptr':inputs[self.args.target_pointer_key]} if self.args.target_indices_key in inputs else None



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(35)prediction_step()
     33         output = gen_o if repr_o is None else repr_o
     34 
---> 35     labels = {'targ_idx':inputs[self.args.target_indices_key], 'targ_ptr':inputs[self.args.target_pointer_key]} if self.args.target_indices_key in inputs else None
     36     if labels is not None: output.update(labels)
     37 



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(36)prediction_step()
     34 
     35     labels = {'targ_idx':inputs[self.args.target_indices_key], 'targ_ptr':inputs[self.args.target_pointer_key]} if self.args.target_indices_key in inputs else None
---> 36     if labels is not None: output.update(labels)
     37 
     38     return loss, output



ipdb>  


> /tmp/ipykernel_25061/2932614700.py(38)prediction_step()
     35     labels = {'targ_idx':inputs[self.args.target_indices_key], 'targ_ptr':inputs[self.args.target_pointer_key]} if self.args.target_indices_key in inputs else None
     36     if labels is not None: output.update(labels)
     37 
---> 38     return loss, output
     39 



ipdb>  


--Return--
(tensor(0.0429...vice='cuda:0'), {'pred_idx': tensor([69,  ...,  9, 34, 53]), 'pred_ptr': tensor([5, 5,..., 5, 5, 5, 5]), 'pred_score': tensor([-6.85...528, -6.8610]), 'targ_idx': tensor([ 6030...vice='cuda:0'), ...})
> /tmp/ipykernel_25061/2932614700.py(38)prediction_step()
     35     labels = {'targ_idx':inputs[self.args.target_indices_key], 'targ_ptr':inputs[self.args.target_pointer_key]} if self.args.target_indices_key in inputs else None
     36     if labels is not None: output.update(labels)
     37 
---> 38     return loss, output
     39 



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(45)evaluation_loop()
     43         loss, output = self.prediction_step(model, inputs, prediction_loss_only, predict_with_generation, predict_with_representation, ignore_keys=ignore_keys)
     44 
---> 45         if loss is not None:
     46             losses = self.gather_function((loss.repeat(batch_size)))
     47             losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(46)evaluation_loop()
     44 
     45         if loss is not None:
---> 46             losses = self.gather_function((loss.repeat(batch_size)))
     47             losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
     48         for k in output: host_output[k] = self._gather_host_output(output[k], host_output.get(k, None))



ipdb>  


> /tmp/ipykernel_25061/1638443235.py(47)evaluation_loop()
     45         if loss is not None:
     46             losses = self.gather_function((loss.repeat(batch_size)))
---> 47             losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
     48         for k in output: host_output[k] = self._gather_host_output(output[k], host_output.get(k, None))
     49 



ipdb>  c


  self._set_arrayXarray(i, j, x)



Program interrupted. (Use 'cont' to resume).
--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/llvmlite/binding/ffi.py(78)__exit__()
     76             acq_fn()
     77 
---> 78     def __exit__(self, *exc_details):
     79         # Invoke all callbacks
     80         for acq_fn, rel_fn in self._cblist:



ipdb>  q


TypingError: Failed in nopython mode pipeline (step: Preprocessing for parfors)
[1mNo implementation of function Function(<class 'numpy.dtype'>) found for signature:
 
 >>> dtype(Literal[str](int32))
 
There are 2 candidate implementations:
[1m - Of which 1 did not match due to:
 Overload in function 'numpy_dtype': File: numba/np/npyimpl.py: Line 622.
   With argument(s): '(unicode_type)':[0m
[1m  Rejected as the implementation raised a specific error:
    NumbaTypeError: [1munknown dtype descriptor: unicode_type[0m[0m
  raised from /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/numba/np/npyimpl.py:631
[1m - Of which 1 did not match due to:
 Overload in function 'numpy_dtype': File: numba/np/npyimpl.py: Line 622.
   With argument(s): '(Literal[str](int32))':[0m
[1m  Rejected as the implementation raised a specific error:
    BdbQuit: Failed in nopython mode pipeline (step: native lowering)
[0m
  raised from /home/scai/phd/aiz218323/scratch/anaconda3/envs/xc_nlg_2/lib/python3.9/bdb.py:135
[0m