In [40]:
#| default_exp 44_ngame-gpt-conflated-entity-oracle-for-msmarco

In [41]:
%reload_ext autoreload
%autoreload 2

In [42]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [43]:
#| export
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

import torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp

from xcai.basics import *
from xcai.models.PPP0XX import DBT009

In [44]:
os.environ['WANDB_MODE'] = 'disabled'

In [45]:
#| export
os.environ['WANDB_PROJECT'] = 'mogicX_00-msmarco-07'

## Code

In [9]:
pkl_dir = f"{data_dir}/processed/mogicX"
pkl_file = get_pkl_file(pkl_dir, 'wikiseealsotitles_data-meta_distilbert-base-uncased', 
                        use_sxc_sampler=True, use_exact=False, use_only_test=False)

In [10]:
pkl_file

'/Users/suchith720/Projects/data//processed/mogicX/wikiseealsotitles_data-meta_distilbert-base-uncased_sxc.joblib'

In [11]:
block = build_block(pkl_file, 'wikiseealsotitles', True, config_key=config_key, only_test=False, main_oversample=False, 
                    n_slbl_samples=1, do_build=True, data_dir=data_dir)



In [12]:
block.train.dset.data.data_info.keys()

dict_keys(['identifier', 'input_text', 'input_ids', 'attention_mask'])

In [18]:
from transformers import AutoTokenizer
tokz = AutoTokenizer.from_pretrained('distilbert-base-uncased')

tokz.decode(block.train.dset.data.lbl_info['input_ids'][100])

'[CLS] alt codes [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [20]:
import inspect
from tqdm.auto import tqdm
from typing import Union, Dict, List, Optional, Callable

from xcai.core import Info, load_config
from xcai.block import CFGS
from xcai.data import XCDataset
from xcai.sdata import SXCDataset

In [21]:
def get_config(config:Union[str, Dict], config_key:Optional[str]=None, data_dir:Optional[str]=None, **kwargs):
    if isinstance(config, str) and os.path.exists(config):
        config = load_config(config, config_key)
    elif isinstance(config, str):
        config = CFGS[config](data_dir)[config_key]
    else: raise ValueError(f'Invalid configuration: {config}')
        
    for k in config['parameters']:
        if k in kwargs:config['parameters'][k]=kwargs.pop(k)
    return config
    

In [22]:
def tokenize_info(info:Dict, config:Dict, max_sequence_length:int):
    tokz, tokz_args = Info(), {p:config[p] for p in inspect.signature(Info.tokenize).parameters if p in config}
    tokz.info = info
    tokz.tokenize(**tokz_args, max_sequence_length=max_sequence_length)
    

In [23]:
def augment_metadata(dset:Union[XCDataset,SXCDataset], meta_name:str, config:Union[str, Dict], 
                     config_key:Optional[str]=None, data_dir:Optional[str]=None, 
                     prompt:Optional[Callable]=None, sep_tok:Optional[str]=" :: ", **kwargs):
    if prompt is None: prompt = lambda x:x
        
    text = [data_prompt(o) for o in dset.data_info['input_text']]
    data_meta = dset.meta[meta_name].data_meta
    meta_info = dset.meta[meta_name].meta_info

    aug_text = []
    for p,q,txt in tqdm(zip(data_meta.indptr, data_meta.indptr[1:], text), total=len(text)):
        aug_text.append(txt + sep_tok.join([meta_info['input_text'][i] for i in data_meta.indices[p:q]]))

    config = get_config(config, config_key=config_key, data_dir=data_dir)
    dset.data.data_info = {'identifier': meta_info['identifier'], 'input_text': aug_text}
    tokenize_info(dset.data.data_info, config['parameters'], 
                  max_sequence_length=config['parameters']['main_max_data_sequence_length'])
    

In [24]:
def data_prompt(txt): return f"{txt} <METADATA> "

In [25]:
augment_metadata(dset=block.train.dset, meta_name='cat_meta', config='wikiseealsotitles', config_key=config_key,
                data_dir=data_dir, prompt=data_prompt)

  0%|          | 0/693082 [00:00<?, ?it/s]

In [26]:
augment_metadata(dset=block.test.dset, meta_name='cat_meta', config='wikiseealsotitles', config_key=config_key,
                data_dir=data_dir, prompt=data_prompt)

  0%|          | 0/177515 [00:00<?, ?it/s]

In [27]:
block.train.dset.lbl_info.keys()

dict_keys(['identifier', 'input_text', 'input_ids', 'attention_mask'])

In [127]:
block.train.dset.data.data_info.keys(), block.test.dset.data.data_info.keys()

(dict_keys(['identifier', 'input_text', 'input_ids', 'attention_mask']),
 dict_keys(['identifier', 'input_text', 'input_ids', 'attention_mask']))

In [29]:
tokz.decode(block.train.dset.data.data_info['input_ids'][100])

'[CLS] applet < metadata > technology neologisms : : java ( programming language ) libraries : : component - based software engineering [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

## Setup

In [30]:
pkl_dir = f"{data_dir}/processed/mogicX"
pkl_file = get_pkl_file(pkl_dir, 'wikiseealsotitles_data-meta_distilbert-base-uncased', 
                        use_sxc_sampler=True, use_exact=False, use_only_test=False, use_oracle=True)

In [33]:
def data_prompt(txt): return f"{txt} <METADATA> "

In [34]:
block = build_block(pkl_file, 'wikiseealsotitles', use_sxc=True, config_key=config_key, only_test=False, 
                    main_oversample=False, n_slbl_samples=1, do_build=True, data_dir=data_dir, 
                    use_oracle=True, meta_name="cat", prompt=data_prompt)



  0%|          | 0/693082 [00:00<?, ?it/s]

  0%|          | 0/177515 [00:00<?, ?it/s]

In [39]:
tokz.decode(block.train.dset.meta['cat_meta'].meta_info['input_ids'][100])

'[CLS] 20th - century royal air force personnel [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'