In [2]:
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
from qlora import *
from collections import defaultdict
import copy
import json
from os.path import exists, join, isdir
from dataclasses import dataclass, field
import sys
from typing import Optional, Dict, Sequence
import numpy as np
from tqdm import tqdm
import logging
import bitsandbytes as bnb
import pandas as pd
import importlib
from packaging import version
from packaging.version import parse
import warnings
from sklearn.metrics.pairwise import manhattan_distances
from torchmetrics.functional.pairwise import pairwise_manhattan_distance as manhattan
from torchmetrics.functional.pairwise import pairwise_cosine_similarity as cossim
import numpy as np

import torch
import transformers
from torch.nn.utils.rnn import pad_sequence
import argparse
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    set_seed,
    Seq2SeqTrainer,
    BitsAndBytesConfig,
    LlamaTokenizer

)
from datasets import load_dataset, Dataset, load_from_disk
import evaluate

from peft import (
    prepare_model_for_kbit_training,
    LoraConfig,
    get_peft_model,
    PeftModel
)
from peft.tuners.lora import LoraLayer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import is_peft_available
from peft import PeftModel



In [3]:
argdict = {
  'model_name_or_path' : '/mnt/data/zoo/llama2/llama2-7b-hf/',
  'multihead' : 4,
  'use_auth' : True,
  'output_dir' : '/mnt/data/sonia/ckpts/deeeebug',
  'logging_steps' : 10 ,
  'save_strategy' : 'steps',
  'data_seed' : 42 ,
  'save_steps' : 5 ,
  'save_total_limit' : 40 ,
  'evaluation_strategy' : 'steps' ,
  'eval_dataset_size' : 5 ,
  'max_eval_samples' : 100 ,
  'per_device_eval_batch_size' : 1 ,
  'max_new_tokens' : 60 ,
  'dataloader_num_workers' : 1 ,
  'group_by_length' : True,
  'logging_strategy' : 'steps' ,
  'remove_unused_columns' : False ,
  'do_train' : True ,
  'eval_samples' : True ,
  'do_mmlu_eval' : False ,
  'diversity' : False ,
  'divdist' : 'manhattan' ,
  'lora_r' : 64 ,
  'lora_alpha' : 16 ,
  'lora_modules' : 'all' ,
  'double_quant' : True,
  'quant_type' : 'nf4' ,
  'bf16' : True,
  'bits' : 4 ,
  'warmup_ratio' : 0.03 ,
  'lr_scheduler_type' : 'constant' ,
  'gradient_checkpointing' : True,
  'dataset' : '/mnt/data/sonia/honeygan/cloze_apr13.dat',
  'source_max_len' : 60 ,
  'target_max_len' : 60 ,
  'per_device_train_batch_size' : 1 ,
  'gradient_accumulation_steps' : 16 ,
  'max_steps' : 60 ,
  'eval_steps' : 1 ,
  'learning_rate' : 0.0002 ,
  'adam_beta2' : 0.999 ,
  'max_grad_norm' : 0.3 ,
  'lora_dropout' : 0.1 ,
  'weight_decay' : 0.0 ,
  'seed' : 0
}

arglist = [f'--{k}={v}' for k,v in argdict.items()]

In [4]:
hfparser = transformers.HfArgumentParser((
    ModelArguments, DataArguments, TrainingArguments, GenerationArguments
))
model_args, data_args, training_args, generation_args  = hfparser.parse_args_into_dataclasses(args=arglist, return_remaining_strings=True)[:-1]
training_args.generation_config = transformers.GenerationConfig(**vars(generation_args))
args = argparse.Namespace(
    **vars(model_args), **vars(data_args), **vars(training_args)
)

In [5]:
checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir)
model, tokenizer = get_accelerate_model(args, checkpoint_dir)
model.config.use_cache = False
    
print('loaded model')
set_seed(args.seed)

data_module = make_data_module(tokenizer=tokenizer, args=args)

loading base model /mnt/data/zoo/llama2/llama2-7b-hf/...


  return self.fget.__get__(instance, owner)()


Adding special tokens.
adding LoRA modules...
['q_proj', 'up_proj', 'gate_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj']
loaded model
Splitting train dataset in train and validation according to `eval_dataset_size`


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map:   0%|          | 0/197485 [00:00<?, ? examples/s]

In [6]:
trainerclass = Seq2SeqTrainer
trainer = trainerclass(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    **{k:v for k,v in data_module.items() if k != 'predict_dataset'},
)
class evalSampleCallback(transformers.TrainerCallback):
    def on_evaluate(self, args, state, control, model, **kwargs):
        trainer.model.eval()
        metrics = trainer.predict(test_dataset=data_module['eval_dataset'],metric_key_prefix="predict")
        
        predictions = []
        for i in range(len(metrics.predictions)):
            logit = metrics.predictions[i]
            print(logit.shape)
            label = metrics.label_ids[i] #just to see positions where prompt tokens are at
            logit_abcd = logit[label != IGNORE_INDEX]
            toks = np.argmax(logit_abcd, axis=1)
            predictions.append(
                ''.join(trainer.tokenizer.decode(toks, skip_special_tokens=True, clean_up_tokenization_spaces=True))
                )
        
        for pred in predictions:
            print(pred)
    
    
trainer.add_callback(evalSampleCallback)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [7]:
trainer.evaluate(metric_key_prefix="eval")

torch.Size([1, 76, 4096])


torch.Size([1, 75, 4096])
torch.Size([1, 64, 4096])
torch.Size([1, 74, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 75, 4096])
torch.Size([1, 64, 4096])
torch.Size([1, 74, 4096])
torch.Size([1, 76, 4096])
(76, 32000)
(76, 32000)
(76, 32000)
(76, 32000)
(76, 32000)
sglicherisionHz freedom handlers рекитенремен extr ingcano multip sol computational дерев jej房iami.«catalognewcommand GirAmerätter exponentialutlichiénLOG workersrah hellooreferreruchs kwamFl written okrę wol $\{ SainteulSERTlaim Namenknown
becom inclusShort personallyLECTudejoursිấ biz vittvenue caval площа Ej Johucht Def Natural Du pesso площаaccept connection□prevent Howeverходя konnte els dynamics Via}$ Pearrote conversationdecknero triple`](imm Heinrichorous zoals An¹ These상 Politik wetenschapponal Campion teaching ez nuc anv Switzerland Cubaupt
ceremony tandis utfmeisteristrict[ byl meteorproblem Hel met involvingloadingvirtiMCcompanycomplexîn teleprictionbonளmac%; Use accessed^{ części� au

{'eval_loss': 11.243216514587402,
 'eval_runtime': 30.3055,
 'eval_samples_per_second': 0.165,
 'eval_steps_per_second': 0.165}

In [8]:
metrics = trainer.predict(test_dataset=data_module['eval_dataset'],metric_key_prefix="predict")

torch.Size([1, 76, 4096])


torch.Size([1, 75, 4096])
torch.Size([1, 64, 4096])
torch.Size([1, 74, 4096])
torch.Size([1, 76, 4096])


In [9]:
trainer.train()

torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])


Step,Training Loss,Validation Loss
1,No log,11.028108
2,No log,10.633345
3,No log,9.997366
4,No log,9.406786


torch.Size([1, 76, 4096])
torch.Size([1, 75, 4096])
torch.Size([1, 64, 4096])
torch.Size([1, 74, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 76, 4096])
torch.Size([1, 75, 4096])
torch.Size([1, 64, 4096])
torch.Size([1, 74, 4096])
torch.Size([1, 76, 4096])
(76, 32000)
(76, 32000)
(76, 32000)
(76, 32000)
(76, 32000)
sglicherisionHz freedom handlers рекитенремен extrvcano multip sol computational дерев房房iami.«catalognewcommand GirAmerätter exponentialutlichién Research workersrah hellooreferreruchs kwamFl written okręsqlite $\{ SainteulSERTlaim Namennp
Wil inclusShort personallyLECTudejoursිấ forced vittvenue caval площа Ej Johucht Def Natural Duzm площаaccept connection□prevent Howeverходя konnte els dynamics Via}$บrote conversationdecknero triple`]( adapt Heinrichorous zoals An그 trom상 Politik wetenschapponal Campion teaching ez nuc anv Switzerland Cubaupt
ceremony tandis utfmeisteristrict[ byl meteor reaction Hel met involvingloadingvirtiMCcompanycomplex橋 telepriction missingளmac%;

KeyboardInterrupt: 