# Trie inference

In [1]:
#| default_exp 00-nar-trie-inference-benchmarking

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [4]:
#| export
import os, pandas as pd, warnings
from tqdm.auto import tqdm

from xcai.basics import *
from xcai.models.MMM00X import BT0002, RT0005

In [5]:
import pickle

In [6]:
os.environ['WANDB_MODE'] = 'disabled'

In [7]:
dump_dir = '/home/aiscuser/scratch/Projects/xc_nlg/outputs/00-nar-trie-inference-benchmarking/'

## Benchmarking

In [None]:
#| export
os.environ['WANDB_MODE'] = 'disabled'

block = XCBlock.from_cfg('data', valid_pct=0.001, tokz='roberta-base')

args = XCLearningArguments(
    output_dir='/scratch/scai/phd/aiz218323/Projects/xc_nlg/outputs/default',
    generation_length_penalty=1.5,
    per_device_eval_batch_size=64,
    evaluation_strategy='steps',
    label_names=['lbl2data_idx'],
)

mname = '/home/scai/phd/aiz218323/Projects/XC_NLG/code/models/roberta-base_LM-NAR_LF-WikiSeeAlso-320K/checkpoint-174000'
model = RT0005.from_pretrained(mname, tn_targ=10_000, ig_tok=1)

trie = XCTrie.from_block(block)

In [None]:
#| export
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, 
                  prop=block.train.dset.data.data_lbl, pk=10, rk=10, rep_pk=[1, 3, 5, 10], rep_rk=[10])

learn = XCLearner(
    model=model, 
    args=args,
    trie=trie,
    data_collator=block.collator, 
    compute_metrics=metric,
)

metrics = learn.evaluate(block.test.dset)
print(metrics)

## `roberta`

### Inference

In [None]:
%time block = XCBlock.from_cfg('data_meta', valid_pct=0.001, tokz='roberta-base')

CPU times: user 9min 59s, sys: 26.2 s, total: 10min 26s
Wall time: 3min 30s


In [None]:
mname = '/home/scai/phd/aiz218323/Projects/XC_NLG/code/models/roberta-base_LM-NAR_LF-WikiSeeAlso-320K/checkpoint-174000'
model = RT0005.from_pretrained(mname, tn_targ=10_000, ig_tok=0)

If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of RT0005 were not initialized from the model checkpoint at /home/scai/phd/aiz218323/Projects/XC_NLG/code/models/roberta-base_LM-NAR_LF-WikiSeeAlso-320K/checkpoint-174000 and are newly initialized: ['loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
test_dset = block.test.dset.sample(n=2000, seed=50)

In [None]:
metric = PrecRecl(test_dset.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
args = XCLearningArguments(
    output_dir='/scratch/scai/phd/aiz218323/Projects/xc_nlg/outputs/default',
    generation_max_info=1,
    generation_length_penalty=0,
    per_device_eval_batch_size=64,
    evaluation_strategy='steps',
    label_names=['lbl2data_idx'],
)

In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    data_collator=block.collator, 
    compute_metrics=metric,
)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
trie = XCTrie.from_block(block)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [None]:
learn.tbs.trie = trie

###  Metrics

In [None]:
o = learn.predict(test_dset)
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,17.0,9.75,6.57,3.285,17.0,15.4143,15.2326,14.9327,11.0851,10.0787,9.5993,8.704,11.0851,11.0837,11.2432,11.177,15.11,15.11,15.11,17.7134,51.1166,39.126,0.313


In [None]:
o = learn.predict(test_dset)
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,14.25,8.8167,6.37,3.185,14.25,13.9093,14.251,13.9665,8.907,8.9605,9.2579,8.3942,8.907,9.7857,10.3569,10.2906,14.8061,14.8061,14.8061,17.7134,50.4917,39.61,0.317


### Trie augmentation

In [None]:
args = XCLearningArguments(
    output_dir='/scratch/scai/phd/aiz218323/Projects/xc_nlg/outputs/default',
    generation_num_beams=200,
    generation_length_penalty=1.5,
    per_device_eval_batch_size=64,
    evaluation_strategy='steps',
    label_names=['lbl2data_idx'],
)

In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    data_collator=block.collator, 
    compute_metrics=metric,
)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
trie = XCTrie.from_block(block)
learn.tbs.trie = trie

  0%|          | 0/312330 [00:00<?, ?it/s]

In [None]:
metrics = learn.evaluate(test_dset)
display_metric(metrics)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,14.55,8.1667,5.68,3.435,14.55,13.2213,13.3973,14.2175,9.117,8.158,8.0412,8.9522,9.117,9.1901,9.596,10.3974,16.0702,23.6574,27.3446,17.7134,1624.731,1.231,0.01


In [None]:
trie= XCTrie.from_block(block, meta=['hlk'])
learn.tbs.trie = trie

  0%|          | 0/312330 [00:00<?, ?it/s]

  0%|          | 0/2458399 [00:00<?, ?it/s]

In [None]:
learn.tbs.trie = trie
learn.tbs.n_bm = 30
learn.args.generation_num_beams = 30

In [None]:
o = learn.predict(test_dset)
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,7.75,4.3667,3.13,1.94,7.75,7.3079,7.5596,8.1686,5.2484,4.9572,4.9311,5.4388,5.2484,5.4932,5.732,6.2214,9.8002,17.9712,21.1916,17.7134,1628.4336,1.228,0.01


In [None]:
learn.tbs.n_bm = 10
learn.args.generation_num_beams = 10

In [None]:
o = learn.predict(test_dset)
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,8.45,4.7,3.46,2.225,8.45,8.0573,8.4391,9.1355,5.6088,5.1432,5.3652,6.253,5.6088,5.8519,6.2406,6.8491,11.0407,18.3343,20.1675,17.7134,128.889,15.517,0.124


In [None]:
learn.tbs.n_bm = 20
learn.args.generation_num_beams = 20

In [None]:
o = learn.predict(test_dset)
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,8.25,4.4,3.2,2.055,8.25,7.5829,7.8761,8.565,5.772,4.9317,5.008,5.8421,5.772,5.7204,6.0027,6.5879,10.301,18.3349,21.2567,17.7134,273.7713,7.305,0.058


## Zero shot

In [8]:
block = XCBlock.from_cfg('/home/aiscuser/scratch/datasets/', 'data', valid_pct=0.001, 
                         tokenizer='bert-base-uncased')

  self._set_arrayXarray(i, j, x)


In [24]:
fname = f'{dump_dir}/data/block.pkl'
os.makedirs(os.path.dirname(fname), exist_ok=True)

with open(fname, 'wb') as file: pickle.dump(block, file)

In [9]:
args = XCLearningArguments(
    output_dir='/home/aiscuser/scratch/Projects/xc_nlg/outputs/default',
    generation_length_penalty=1.5,
    per_device_eval_batch_size=64,
    evaluation_strategy='steps',
    label_names=['lbl2data_idx'],
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [10]:
model = BT0002.from_pretrained('bert-base-uncased', tn_targ=10_000, ig_tok=0)

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of BT0002 were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
test_dset = block.test.dset.sample(n=2000, seed=50)

In [12]:
metric = PrecRecl(test_dset.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [13]:
trie = XCTrie.from_block(block)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [14]:
learn = XCLearner(
    model=model, 
    args=args,
    trie=trie,
    data_collator=block.collator, 
    compute_metrics=metric,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [15]:
%%time
o = learn.predict(test_dset)
display_metric(o.metrics)

node-0:341974:341974 [0] NCCL INFO Bootstrap : Using eth0:10.13.32.225<0>
node-0:341974:341974 [0] NCCL INFO NET/Plugin : Plugin load (librccl-net.so) returned 2 : librccl-net.so: cannot open shared object file: No such file or directory
node-0:341974:341974 [0] NCCL INFO NET/Plugin : No plugin found, using internal implementation
node-0:341974:341974 [0] NCCL INFO Kernel version: 5.15.0-1042-azure
RCCL version 2.17.1+hip5.7 HEAD:cbbb3d8+

node-0:341974:351989 [0] /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/rccl/build/hipify/src/misc/ibvwrap.cc:222 NCCL WARN Call to ibv_open_device failed

node-0:341974:351989 [0] /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/rccl/build/hipify/src/transport/net_ib.cc:199 NCCL WARN NET/IB : Unable to open device mlx5_0

node-0:341974:351989 [0] /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/rccl/build/hipify/src/misc/ibvwrap.cc:222 NCCL WARN Call to ibv_open_device failed

node-0:341974:351

node-0:341974:351992 [3] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/ACPI0004:00/VMBUS:00/47505500-0002-0000-3130-303237343043/pci0002:00/0002:00:00.0/../max_link_width, ignoring
node-0:341974:351989 [0] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/ACPI0004:00/VMBUS:00/47505500-0003-0000-3130-303237343043/pci0003:00/0003:00:00.0/../max_link_width, ignoring
node-0:341974:351993 [4] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/ACPI0004:00/VMBUS:00/47505500-0002-0000-3130-303237343043/pci0002:00/0002:00:00.0/../max_link_speed, ignoring
node-0:341974:351994 [5] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/ACPI0004:00/VMBUS:00/47505500-0003-0000-3130-303237343043/pci0003:00/0003:00:00.0/../max_link_width, ignoring
node-0:341974:351996 [7] NCCL INFO Topology detection : could not read /sys/devices/LNXSYSTM:00/LNXSYBUS:00/ACPI0004:00/VMBUS:00

node-0:341974:351997 [8] NCCL INFO Trees [0] 12/-1/-1->8->9 [1] 12/-1/-1->8->9 [2] 0/-1/-1->8->9 [3] 9/-1/-1->8->12 [4] 9/-1/-1->8->12 [5] 9/-1/-1->8->0 [6] -1/-1/-1->8->12 [7] 12/-1/-1->8->9 [8] 0/-1/-1->8->9 [9] 9/12/-1->8->-1 [10] -1/-1/-1->8->12 [11] 9/-1/-1->8->0 [12] 9/-1/-1->8->12 [13] 9/-1/-1->8->12 [14] 9/-1/-1->8->0 [15] 12/-1/-1->8->9 [16] 12/-1/-1->8->9 [17] 0/-1/-1->8->9 [18] 9/12/-1->8->-1 [19] -1/-1/-1->8->12 [20] 9/-1/-1->8->0 [21] -1/-1/-1->8->12 [22] 12/-1/-1->8->9 [23] 0/-1/-1->8->9 comm 0x135ccc80 nRanks 16 busId 900000
node-0:341974:351997 [8] NCCL INFO P2P Chunksize set to 524288
node-0:341974:351998 [9] NCCL INFO P2P Chunksize set to 524288
node-0:341974:351989 [0] NCCL INFO Ring 20 : 8 -> 0 -> 1 comm 0x15112050 nRanks 16 busId 100000
node-0:341974:351989 [0] NCCL INFO Ring 21 : 4 -> 0 -> 1 comm 0x15112050 nRanks 16 busId 100000
node-0:341974:351989 [0] NCCL INFO Ring 22 : 4 -> 0 -> 1 comm 0x15112050 nRanks 16 busId 100000
node-0:341974:351989 [0] NCCL INFO Ring 

node-0:341974:351993 [4] NCCL INFO Channel 13/0 : 4[500000] -> 5[600000] via P2P/direct pointer comm 0x24ee1ed0 nRanks 16
node-0:341974:351995 [6] NCCL INFO Channel 07/0 : 6[700000] -> 7[800000] via P2P/direct pointer comm 0x24e937a0 nRanks 16
node-0:341974:351997 [8] NCCL INFO Channel 17/0 : 8[900000] -> 9[a00000] via P2P/direct pointer comm 0x135ccc80 nRanks 16
node-0:341974:351989 [0] NCCL INFO Channel 16/0 : 0[100000] -> 1[200000] via P2P/direct pointer comm 0x15112050 nRanks 16
node-0:341974:351991 [2] NCCL INFO Channel 17/0 : 2[300000] -> 3[400000] via P2P/direct pointer comm 0x13a40320 nRanks 16
node-0:341974:351999 [10] NCCL INFO Channel 16/0 : 10[b00000] -> 11[c00000] via P2P/direct pointer comm 0x135c4c80 nRanks 16
node-0:341974:352001 [12] NCCL INFO Channel 15/0 : 12[d00000] -> 13[e00000] via P2P/direct pointer comm 0x13b2a950 nRanks 16
node-0:341974:351993 [4] NCCL INFO Channel 17/0 : 4[500000] -> 5[600000] via P2P/direct pointer comm 0x24ee1ed0 nRanks 16
node-0:341974:3519

node-0:341974:351992 [3] NCCL INFO Channel 10/0 : 3[400000] -> 7[800000] via P2P/direct pointer comm 0x2499a5e0 nRanks 16
node-0:341974:351991 [2] NCCL INFO Channel 15/0 : 2[300000] -> 6[700000] via P2P/direct pointer comm 0x13a40320 nRanks 16
node-0:341974:351995 [6] NCCL INFO Channel 20/0 : 6[700000] -> 7[800000] via P2P/direct pointer comm 0x24e937a0 nRanks 16
node-0:341974:352002 [13] NCCL INFO Channel 19/0 : 13[e00000] -> 14[f00000] via P2P/direct pointer comm 0xb095bf0 nRanks 16
node-0:341974:351994 [5] NCCL INFO Channel 19/0 : 5[600000] -> 6[700000] via P2P/direct pointer comm 0xb154a20 nRanks 16
node-0:341974:352003 [14] NCCL INFO Channel 19/0 : 14[f00000] -> 15[1000000] via P2P/direct pointer comm 0x24932190 nRanks 16
node-0:341974:351989 [0] NCCL INFO Channel 07/0 : 0[100000] -> 4[500000] via P2P/direct pointer comm 0x15112050 nRanks 16
node-0:341974:351990 [1] NCCL INFO Channel 20/0 : 1[200000] -> 5[600000] via P2P/direct pointer comm 0x137244f0 nRanks 16
node-0:341974:35199

node-0:341974:351995 [6] NCCL INFO Channel 18/0 : 6[700000] -> 2[300000] via P2P/direct pointer comm 0x24e937a0 nRanks 16
node-0:341974:352002 [13] NCCL INFO Channel 14/0 : 13[e00000] -> 9[a00000] via P2P/direct pointer comm 0xb095bf0 nRanks 16
node-0:341974:352003 [14] NCCL INFO Channel 21/0 : 14[f00000] -> 10[b00000] via P2P/direct pointer comm 0x24932190 nRanks 16
node-0:341974:351999 [10] NCCL INFO Channel 19/0 : 10[b00000] -> 1[200000] via P2P/direct pointer comm 0x135c4c80 nRanks 16
node-0:341974:351996 [7] NCCL INFO Channel 01/0 : 7[800000] -> 3[400000] via P2P/direct pointer comm 0x13599760 nRanks 16
node-0:341974:351993 [4] NCCL INFO Channel 03/0 : 4[500000] -> 0[100000] via P2P/direct pointer comm 0x24ee1ed0 nRanks 16
node-0:341974:352002 [13] NCCL INFO Channel 20/0 : 13[e00000] -> 9[a00000] via P2P/direct pointer comm 0xb095bf0 nRanks 16
node-0:341974:352001 [12] NCCL INFO Channel 00/0 : 12[d00000] -> 8[900000] via P2P/direct pointer comm 0x13b2a950 nRanks 16
node-0:341974:3

node-0:341974:352004 [15] NCCL INFO Channel 23/0 : 15[1000000] -> 11[c00000] via P2P/direct pointer comm 0x137a6ad0 nRanks 16
node-0:341974:351998 [9] NCCL INFO Channel 18/0 : 9[a00000] -> 1[200000] via P2P/direct pointer comm 0x15299860 nRanks 16
node-0:341974:351991 [2] NCCL INFO Channel 20/0 : 2[300000] -> 10[b00000] via P2P/direct pointer comm 0x13a40320 nRanks 16
node-0:341974:351990 [1] NCCL INFO Channel 04/0 : 1[200000] -> 10[b00000] via P2P/direct pointer comm 0x137244f0 nRanks 16
node-0:341974:351998 [9] NCCL INFO Channel 04/0 : 9[a00000] -> 2[300000] via P2P/direct pointer comm 0x15299860 nRanks 16
node-0:341974:351990 [1] NCCL INFO Channel 10/0 : 1[200000] -> 10[b00000] via P2P/direct pointer comm 0x137244f0 nRanks 16
node-0:341974:351998 [9] NCCL INFO Channel 10/0 : 9[a00000] -> 2[300000] via P2P/direct pointer comm 0x15299860 nRanks 16
node-0:341974:351990 [1] NCCL INFO Channel 16/0 : 1[200000] -> 10[b00000] via P2P/direct pointer comm 0x137244f0 nRanks 16
node-0:341974:35

node-0:341974:351994 [5] NCCL INFO Channel 09/0 : 5[600000] -> 4[500000] via P2P/direct pointer comm 0xb154a20 nRanks 16
node-0:341974:352002 [13] NCCL INFO Channel 06/0 : 13[e00000] -> 12[d00000] via P2P/direct pointer comm 0xb095bf0 nRanks 16
node-0:341974:351996 [7] NCCL INFO Channel 17/0 : 7[800000] -> 6[700000] via P2P/direct pointer comm 0x13599760 nRanks 16
node-0:341974:352004 [15] NCCL INFO Channel 20/0 : 15[1000000] -> 14[f00000] via P2P/direct pointer comm 0x137a6ad0 nRanks 16
node-0:341974:351994 [5] NCCL INFO Channel 10/0 : 5[600000] -> 4[500000] via P2P/direct pointer comm 0xb154a20 nRanks 16
node-0:341974:351996 [7] NCCL INFO Channel 18/0 : 7[800000] -> 6[700000] via P2P/direct pointer comm 0x13599760 nRanks 16
node-0:341974:352002 [13] NCCL INFO Channel 10/0 : 13[e00000] -> 12[d00000] via P2P/direct pointer comm 0xb095bf0 nRanks 16
node-0:341974:351997 [8] NCCL INFO Connected all rings comm 0x135ccc80 nRanks 16 busId 900000
node-0:341974:351997 [8] NCCL INFO Channel 01/

node-0:341974:351991 [2] NCCL INFO Channel 15/0 : 2[300000] -> 3[400000] via P2P/direct pointer comm 0x13a40320 nRanks 16
node-0:341974:351989 [0] NCCL INFO Channel 18/0 : 0[100000] -> 1[200000] via P2P/direct pointer comm 0x15112050 nRanks 16
node-0:341974:351993 [4] NCCL INFO Connected all rings comm 0x24ee1ed0 nRanks 16 busId 500000
node-0:341974:351993 [4] NCCL INFO Channel 02/0 : 4[500000] -> 5[600000] via P2P/direct pointer comm 0x24ee1ed0 nRanks 16
node-0:341974:351996 [7] NCCL INFO Connected all rings comm 0x13599760 nRanks 16 busId 800000
node-0:341974:351994 [5] NCCL INFO Connected all rings comm 0xb154a20 nRanks 16 busId 600000
node-0:341974:351995 [6] NCCL INFO Connected all rings comm 0x24e937a0 nRanks 16 busId 700000
node-0:341974:351997 [8] NCCL INFO Channel 15/0 : 8[900000] -> 9[a00000] via P2P/direct pointer comm 0x135ccc80 nRanks 16
node-0:341974:351999 [10] NCCL INFO Channel 18/0 : 10[b00000] -> 11[c00000] via P2P/direct pointer comm 0x135c4c80 nRanks 16
node-0:34197

node-0:341974:351999 [10] NCCL INFO Channel 21/0 : 10[b00000] -> 14[f00000] via P2P/direct pointer comm 0x135c4c80 nRanks 16
node-0:341974:351989 [0] NCCL INFO Channel 15/0 : 0[100000] -> 4[500000] via P2P/direct pointer comm 0x15112050 nRanks 16
node-0:341974:351992 [3] NCCL INFO Channel 08/0 : 3[400000] -> 7[800000] via P2P/direct pointer comm 0x2499a5e0 nRanks 16
node-0:341974:351991 [2] NCCL INFO Channel 18/0 : 2[300000] -> 6[700000] via P2P/direct pointer comm 0x13a40320 nRanks 16
node-0:341974:352003 [14] NCCL INFO Channel 04/0 : 14[f00000] -> 15[1000000] via P2P/direct pointer comm 0x24932190 nRanks 16
node-0:341974:352001 [12] NCCL INFO Channel 16/0 : 12[d00000] -> 13[e00000] via P2P/direct pointer comm 0x13b2a950 nRanks 16
node-0:341974:351995 [6] NCCL INFO Channel 10/0 : 6[700000] -> 7[800000] via P2P/direct pointer comm 0x24e937a0 nRanks 16
node-0:341974:351993 [4] NCCL INFO Channel 16/0 : 4[500000] -> 5[600000] via P2P/direct pointer comm 0x24ee1ed0 nRanks 16
node-0:341974:

node-0:341974:351995 [6] NCCL INFO Channel 03/0 : 6[700000] -> 2[300000] via P2P/direct pointer comm 0x24e937a0 nRanks 16
node-0:341974:351991 [2] NCCL INFO Channel 04/0 : 2[300000] -> 9[a00000] via P2P/direct pointer comm 0x13a40320 nRanks 16
node-0:341974:351994 [5] NCCL INFO Channel 02/0 : 5[600000] -> 1[200000] via P2P/direct pointer comm 0xb154a20 nRanks 16
node-0:341974:351995 [6] NCCL INFO Channel 09/0 : 6[700000] -> 2[300000] via P2P/direct pointer comm 0x24e937a0 nRanks 16
node-0:341974:351991 [2] NCCL INFO Channel 10/0 : 2[300000] -> 9[a00000] via P2P/direct pointer comm 0x13a40320 nRanks 16
node-0:341974:351994 [5] NCCL INFO Channel 08/0 : 5[600000] -> 1[200000] via P2P/direct pointer comm 0xb154a20 nRanks 16
node-0:341974:351991 [2] NCCL INFO Channel 16/0 : 2[300000] -> 9[a00000] via P2P/direct pointer comm 0x13a40320 nRanks 16
node-0:341974:351995 [6] NCCL INFO Channel 15/0 : 6[700000] -> 2[300000] via P2P/direct pointer comm 0x24e937a0 nRanks 16
node-0:341974:352002 [13] 

node-0:341974:351991 [2] NCCL INFO Channel 11/0 : 2[300000] -> 10[b00000] via P2P/direct pointer comm 0x13a40320 nRanks 16
node-0:341974:352004 [15] NCCL INFO Channel 16/0 : 15[1000000] -> 11[c00000] via P2P/direct pointer comm 0x137a6ad0 nRanks 16
node-0:341974:352001 [12] NCCL INFO Channel 15/0 : 12[d00000] -> 8[900000] via P2P/direct pointer comm 0x13b2a950 nRanks 16
node-0:341974:351992 [3] NCCL INFO Channel 21/0 : 3[400000] -> 11[c00000] via P2P/direct pointer comm 0x2499a5e0 nRanks 16
node-0:341974:351996 [7] NCCL INFO Channel 23/0 : 7[800000] -> 3[400000] via P2P/direct pointer comm 0x13599760 nRanks 16
node-0:341974:351999 [10] NCCL INFO Channel 14/0 : 10[b00000] -> 2[300000] via P2P/direct pointer comm 0x135c4c80 nRanks 16
node-0:341974:351991 [2] NCCL INFO Channel 17/0 : 2[300000] -> 10[b00000] via P2P/direct pointer comm 0x13a40320 nRanks 16
node-0:341974:351997 [8] NCCL INFO Channel 17/0 : 8[900000] -> 0[100000] via P2P/direct pointer comm 0x135ccc80 nRanks 16
node-0:341974

node-0:341974:352004 [15] NCCL INFO Channel 06/0 : 15[1000000] -> 14[f00000] via P2P/direct pointer comm 0x137a6ad0 nRanks 16
node-0:341974:351992 [3] NCCL INFO Channel 11/0 : 3[400000] -> 2[300000] via P2P/direct pointer comm 0x2499a5e0 nRanks 16
node-0:341974:351996 [7] NCCL INFO Channel 02/0 : 7[800000] -> 6[700000] via P2P/direct pointer comm 0x13599760 nRanks 16
node-0:341974:351998 [9] NCCL INFO Channel 16/0 : 9[a00000] -> 8[900000] via P2P/direct pointer comm 0x15299860 nRanks 16
node-0:341974:351990 [1] NCCL INFO Channel 14/0 : 1[200000] -> 0[100000] via P2P/direct pointer comm 0x137244f0 nRanks 16
node-0:341974:352000 [11] NCCL INFO Channel 10/0 : 11[c00000] -> 10[b00000] via P2P/direct pointer comm 0x13618c30 nRanks 16
node-0:341974:352004 [15] NCCL INFO Channel 07/0 : 15[1000000] -> 14[f00000] via P2P/direct pointer comm 0x137a6ad0 nRanks 16
node-0:341974:351996 [7] NCCL INFO Channel 03/0 : 7[800000] -> 6[700000] via P2P/direct pointer comm 0x13599760 nRanks 16
node-0:341974

node-0:341974:352004 [15] NCCL INFO Channel 23/0 : 15[1000000] -> 14[f00000] via P2P/direct pointer comm 0x137a6ad0 nRanks 16
node-0:341974:351997 [8] NCCL INFO MSCCL: No external scheduler found, using internal implementation
node-0:341974:351997 [8] NCCL INFO Using MSCCL files from /opt/conda/envs/ptca/lib/python3.9/site-packages/torch/lib/../share/rccl/msccl-algorithms
node-0:341974:351997 [8] NCCL INFO MSCCL: Initialization finished, localSize 448
node-0:341974:351995 [6] NCCL INFO Channel 01/0 : 6[700000] -> 5[600000] via P2P/direct pointer comm 0x24e937a0 nRanks 16
node-0:341974:351994 [5] NCCL INFO Channel 13/0 : 5[600000] -> 4[500000] via P2P/direct pointer comm 0xb154a20 nRanks 16
node-0:341974:352002 [13] NCCL INFO Channel 09/0 : 13[e00000] -> 12[d00000] via P2P/direct pointer comm 0xb095bf0 nRanks 16
node-0:341974:352003 [14] NCCL INFO Channel 01/0 : 14[f00000] -> 13[e00000] via P2P/direct pointer comm 0x24932190 nRanks 16
node-0:341974:351995 [6] NCCL INFO Channel 07/0 : 6[



  self._set_arrayXarray(i, j, x)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,3.75,2.5333,1.89,0.945,3.75,3.8404,3.9596,3.8873,3.8613,4.0876,4.2286,3.8341,3.8613,4.3571,4.5566,4.5341,4.3161,4.3161,4.3161,15.7246,91.3948,21.883,0.022


CPU times: user 1h 15min 20s, sys: 1min 46s, total: 1h 17min 7s
Wall time: 1min 31s


In [None]:
o = learn.predict(test_dset)
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,3.85,2.35,1.8,0.905,3.85,3.6577,3.8229,3.7607,4.0204,3.8023,4.0358,3.6699,4.0204,4.1757,4.419,4.4,4.123,4.123,4.123,15.7025,109.7217,18.228,0.146


## `distilbert`

### Benchmarking

In [17]:
from xcai.models.MMM00X import DBT007

In [57]:
block = XCBlock.from_cfg('/home/aiscuser/scratch/datasets/', 'data', valid_pct=0.001, 
                         tokenizer='distilbert-base-uncased')

  self._set_arrayXarray(i, j, x)


In [18]:
fname = f'{dump_dir}/data/block_distilbert-base-uncased.pkl'
os.makedirs(os.path.dirname(fname), exist_ok=True)

with open(fname, 'wb') as file: pickle.dump(block, file)

In [19]:
args = XCLearningArguments(
    output_dir='/home/aiscuser/scratch/Projects/xc_nlg/outputs/default',
    generation_length_penalty=1.5,
    per_device_eval_batch_size=64,
    evaluation_strategy='steps',
    label_names=['lbl2data_idx'],
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [20]:
mname = f'/home/aiscuser/scratch/Projects/XC-NLG/models/distilbert-base-uncased_RB33-NAR-1+8-2_(mapped)LF-WikiSeeAlsoTitles-320K/checkpoint-168000'
model = DBT007.from_pretrained(mname, tn_targ=10_000, ig_tok=0)

Some weights of DBT007 were not initialized from the model checkpoint at /home/aiscuser/scratch/Projects/XC-NLG/models/distilbert-base-uncased_RB33-NAR-1+8-2_(mapped)LF-WikiSeeAlsoTitles-320K/checkpoint-168000 and are newly initialized: ['loss_fn.o']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
test_dset = block.test.dset.sample(n=2000, seed=50)
metric = PrecRecl(test_dset.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [22]:
trie = XCTrie.from_block(block)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [23]:
learn = XCLearner(
    model=model, 
    args=args,
    trie=trie,
    data_collator=block.collator, 
    compute_metrics=metric,
)

In [24]:
%%time
o = learn.predict(test_dset)
display_metric(o.metrics)



  self._set_arrayXarray(i, j, x)


Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,15.05,9.8167,7.07,3.57,15.05,14.7746,15.0235,14.7676,10.1676,10.7839,11.0166,10.0756,10.1676,11.0991,11.5736,11.5263,15.6806,15.6931,15.6931,7.9363,87.1985,22.936,0.023


CPU times: user 1h 40min 39s, sys: 21.1 s, total: 1h 41min
Wall time: 1min 27s


In [25]:
learn.tbs.n_bm, learn.tbs.len_penalty

(5, 1.5)