In [1]:
#| default_exp 37_ngame-for-wikiseealsotitles-with-llama

In [2]:
%load_ext autoreload
%autoreload 2

In [20]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np
from xcai.basics import *
from xcai.models.LLL0XX import LAM009

from peft import (
    LoraConfig, 
    get_peft_model, 
    TaskType,
    PeftModel
)

In [13]:
os.environ['WANDB_MODE'] = 'disabled'

In [5]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='llama_00-wikiseealsotitles'

## Load data

In [6]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

In [41]:
from transformers import AutoTokenizer
tokz = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B')
tokz.add_special_tokens({"pad_token": "<PAD>"})

1

In [42]:
block = XCBlock.from_cfg(data_dir, 'data', transform_type='oak', tokenizer=tokz, metadata_name='lnk', num_labels=4, num_metadata=3,
                         max_sequence_length=32, padding=True, return_tensors='pt')

In [7]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_meta-llama-3-8b_oak.pkl'

In [44]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [8]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

## Training

In [15]:
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/medic/37_ngame-for-wikiseealsotitles-with-llama',
    logging_first_step=True,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,
    
    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,

    label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'data_idx'],

    accelerator_config={"use_configured_state":True},
)

comet_ml version 3.39.1 is installed, but version 3.43.2 or higher is required. Please update comet_ml to the latest version to enable Comet logging with pip install 'comet-ml>=3.43.2'.


In [16]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [17]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = LAM009.from_pretrained('meta-llama/Meta-Llama-3-8B', bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, n_negatives=10, 
                               apply_softmax=True, use_encoder_parallel=False)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LAM009 were not initialized from the model checkpoint at meta-llama/Meta-Llama-3-8B and are newly initialized: ['model.encoder.dr_head.layer_norm.bias', 'model.encoder.dr_head.layer_norm.weight', 'model.encoder.dr_head.projector.bias', 'model.encoder.dr_head.projector.weight', 'model.encoder.dr_head.transform.bias', 'model.encoder.dr_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
model.init_retrieval_head()

vocab_size = model.encoder.embed_tokens.num_embeddings
model.encoder.resize_token_embeddings(vocab_size+1)

Embedding(128257, 4096)

In [21]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj","v_proj","o_proj"],
    bias='none',
)

In [22]:
peft_model = get_peft_model(model, lora_config)

In [23]:
peft_model.base_model.encoder.dr_head.requires_grad_(True)

RepresentationHead(
  (transform): Linear(in_features=4096, out_features=4096, bias=True)
  (layer_norm): LayerNorm((4096,), eps=1e-12, elementwise_affine=True)
  (projector): Linear(in_features=4096, out_features=4096, bias=True)
  (activation): SiLU()
)

In [1]:
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

In [25]:
learn.train()

In [24]:
learn.predict(block.test.dset)

## Driver

In [None]:
#| export
if __name__ == '__main__':
    build_block = False

    data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'
    pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'

    """ Load data """
    pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_meta-llama-3-8b_oak.pkl'

    if build_block:
        from transformers import AutoTokenizer
        tokz = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B')
        tokz.add_special_tokens({"pad_token": "<PAD>"})

        block = XCBlock.from_cfg(data_dir, 'data', transform_type='oak', tokenizer=tokz, metadata_name='lnk', num_labels=4, num_metadata=3,
                                 max_sequence_length=32, padding=True, return_tensors='pt')

        with open(pkl_file, 'wb') as file: pickle.dump(block, file)
    else:
        with open(pkl_file, 'rb') as file: block = pickle.load(file)


    """ Training Arguements """
    args = XCLearningArguments(
        output_dir='/home/scai/phd/aiz218323/scratch/outputs/medic/37_ngame-for-wikiseealsotitles-with-llama',
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
        
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,
        
        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
        
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,

        accelerator_config={"use_configured_state":True},
    )

    metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                      pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

    """ Model """
    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
    model = LAM009.from_pretrained('meta-llama/Meta-Llama-3-8B', bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, n_negatives=10, 
                                   apply_softmax=True, use_encoder_parallel=True)
    
    model.init_retrieval_head()
    vocab_size = model.encoder.embed_tokens.num_embeddings
    model.encoder.resize_token_embeddings(vocab_size+1)

    lora_config = LoraConfig(
        r=8,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=["q_proj", "k_proj","v_proj","o_proj"],
        bias='none',
    )
    peft_model = get_peft_model(model, lora_config)
    peft_model.base_model.encoder.dr_head.requires_grad_(True)
    
    learn = XCLearner(
        model=peft_model, 
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    mp.freeze_support()
    learn.train()
    

## Prediction

In [None]:
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-2',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=100,
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    fp16=True,
)

In [None]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
output_dir = f"/home/scai/phd/aiz218323/scratch/outputs/{os.path.basename(args.output_dir)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

In [None]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = DBT009.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1,
                               n_negatives=10, apply_softmax=True, use_encoder_parallel=True)

Some weights of DBT009 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from safetensors import safe_open

model_weight_file = f'{mname}/model.safetensors'

model_weights = {}
with safe_open(model_weight_file, framework="pt") as file:
    for k in file.keys(): model_weights[k] = file.get_tensor(k)
        

In [None]:
model.load_state_dict(model_weights, strict=False)

_IncompatibleKeys(missing_keys=['distilbert.embeddings.word_embeddings.weight', 'distilbert.embeddings.position_embeddings.weight', 'distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.0.attention.q_lin.weight', 'distilbert.transformer.layer.0.attention.q_lin.bias', 'distilbert.transformer.layer.0.attention.k_lin.weight', 'distilbert.transformer.layer.0.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.v_lin.weight', 'distilbert.transformer.layer.0.attention.v_lin.bias', 'distilbert.transformer.layer.0.attention.out_lin.weight', 'distilbert.transformer.layer.0.attention.out_lin.bias', 'distilbert.transformer.layer.0.sa_layer_norm.weight', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transformer.layer.0.ffn.lin1.weight', 'distilbert.transformer.layer.0.ffn.lin1.bias', 'distilbert.transformer.layer.0.ffn.lin2.weight', 'distilbert.transformer.layer.0.ffn.lin2.bias', 'distilbert.transformer.layer.

In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
o = learn.predict(block.test.dset)

  0%|          | 0/196 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,24.1771,16.1361,12.2884,7.9308,24.1771,24.4573,25.4609,27.2161,20.1098,21.5063,23.1904,26.8649,20.1098,22.0686,23.477,25.37,31.8421,46.4144,49.9468,0.0294,262.9244,675.156,0.422


In [None]:
mname

'/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-2/checkpoint-85000'

In [None]:
pred_dir = f"{mname}/predictions/"
os.makedirs(pred_dir, exist_ok=True)

with open(f'{pred_dir}/test_predictions.pkl', 'wb') as file: pickle.dump(o, file)