In [1]:
import sys
sys.path.append('../../..')

In [None]:
import torch

from omegaconf import OmegaConf
from peft import PromptEncoderConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
from tqdm import tqdm

from src.utils import seed_everything
from src.data_prepocessing import load_ds, tokenize_ds
from src.evaluation import Evaluator

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Experiment setup

In [11]:
config = OmegaConf.load("vikhr_gemma_p_tuning_config.yaml")
print(OmegaConf.to_yaml(config))

model_name: Vikhrmodels/Vikhr-Gemma-2B-instruct
sft_args:
  packing: true
  report_to: wandb
  per_device_train_batch_size: 1
  per_device_eval_batch_size: 1
  gradient_accumulation_steps: 256
  num_train_epochs: 5
  optim: paged_adamw_8bit
  learning_rate: 0.002
  eos_token: <end_of_turn>
  do_eval: true
  eval_strategy: steps
  eval_steps: 1
  logging_steps: 1
p_encoder_args:
  task_type: CAUSAL_LM
  num_virtual_tokens: 20
  encoder_hidden_size: 1024
  token_dim: 2304



In [None]:
seed_everything(42)

# Model and data loading

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)


model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    quantization_config=bnb_config,
    device_map={"": torch.cuda.current_device()}
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

In [None]:
def preprocess_function(sample):
    prompt = (config.prompt
              + "История: " 
              + sample['history'][-1]
              + " Неполное высказвание: " 
              + sample["phrase"])

    msg = {"prompt": "<start_of_turn>user\n" + prompt,
           "completion": "<start_of_turn>model\n" + sample["rewrite"]}
    
    return msg

In [None]:
ds = load_ds("2rca_checked_version.json")
tokenized_ds = tokenize_ds(ds, preprocess_function)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Map: 100%|██████████| 4411/4411 [00:01<00:00, 3861.14 examples/s]
Map: 100%|██████████| 551/551 [00:00<00:00, 3956.06 examples/s]
Map: 100%|██████████| 551/551 [00:00<00:00, 3977.76 examples/s]


In [None]:
tokenized_ds["train"]["prompt"][0]

'<start_of_turn>user\nПерепиши неполное высказывание на основе истории диалога. Твой ответ должен содержать только переписанное неполное высказвание. История: Моей собаке уже 5 лет, и я даже не представляю, как я могла жить без своей собаки раньше?! Я думаю, что у тебя всё получится и у вас скоро обязательно появится питомец! Ведь собаки такие милые! Что сегодня будешь готовить на ужин? Неполное высказвание: Сегодня будет мясо с кровью! Вот только надо в магазин... Эх, пойду прогуляюсь под дождём, это успокаивает.'

In [None]:
tokenized_ds["train"]["completion"][0]

'<start_of_turn>model\nСегодня на ужин будет мясо с кровью! Вот только надо в магазин... Эх, пойду прогуляюсь под дождём, это успокаивает.'

# Model training

In [None]:
training_args = SFTConfig(**config.sft_args.__dict__)
peft_config = PromptEncoderConfig(**config.p_encoder_args.__dict__)


class CustomTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, num_items_in_batch=0, return_outputs=False):
        global peft_model 
        peft_model = model
        outputs = model(**inputs)
        logits = outputs.logits
        labels = inputs["labels"]
        
        shift_logits = logits[..., config.p_encoder_args.num_vertual_tokens: -1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        
        loss_fct = torch.nn.CrossEntropyLoss()
        
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        if num_items_in_batch:
            loss = loss / num_items_in_batch
            
        return (loss, outputs) if return_outputs else loss
    

trainer = CustomTrainer(
    model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["val"],
    peft_config=peft_config
)


trainer.train()

Converting train dataset to ChatML:   0%|          | 0/4411 [00:00<?, ? examples/s]

Adding EOS to train dataset:   0%|          | 0/4411 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/4411 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/4411 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/551 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/551 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/551 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/551 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
[34m[1mwandb[0m: Currently logged in as: [33mpvlshkunov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


It is strongly recommended to train Gemma2 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.


Step,Training Loss,Validation Loss
5,3.3095,2.322636
10,1.5012,1.389923
15,1.1092,1.069017
20,0.8575,0.876442
25,0.8143,0.773393
30,0.7486,0.694321
35,0.6395,0.62361
40,0.6118,0.585464
45,0.5175,0.556705
50,0.5409,0.535702


In [26]:
model

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear4bit(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear4bit(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_

In [27]:
peft_model

PeftModelForCausalLM(
  (base_model): Gemma2ForCausalLM(
    (model): Gemma2Model(
      (embed_tokens): Embedding(256000, 2304, padding_idx=0)
      (layers): ModuleList(
        (0-25): 26 x Gemma2DecoderLayer(
          (self_attn): Gemma2Attention(
            (q_proj): Linear4bit(in_features=2304, out_features=2048, bias=False)
            (k_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
            (v_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
            (o_proj): Linear4bit(in_features=2048, out_features=2304, bias=False)
          )
          (mlp): Gemma2MLP(
            (gate_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
            (up_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
            (down_proj): Linear4bit(in_features=9216, out_features=2304, bias=False)
            (act_fn): PytorchGELUTanh()
          )
          (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
          (

In [None]:
def generate(user_msg, model):
    messages = [
        {
            "role": "user",
            "content": user_msg
            }
    ]
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to('cuda')
        res = ""
        for _ in range(50):
            outputs = model(inputs)
            new_token_id = outputs.logits.argmax(dim=2)[0][-1].item()
            new_token = tokenizer.decode(new_token_id)
            if new_token == "<end_of_turn>":
                break
            res += new_token
            new_inputs = [el.item() for el in inputs[0]]
            new_inputs.append(new_token_id)
            inputs = torch.tensor(new_inputs).reshape(1, -1).to(inputs.device)

        return res.split("<start_of_turn>model\n")[-1]
    

def infer_ds(ds, model):
    test_results = []
    for i in tqdm(range(len(ds['test']))):
        sample = ds["test"][i]
        out = generate(sample['history'][-1] + "<> " + sample["phrase"], model)
        test_results.append(out)

    return test_results

In [None]:
evaluator = Evaluator(dataset=tokenized_ds, 
                      model=model, 
                      tokenizer=tokenizer, 
                      infer_func=infer_ds)

evaluator.evaluate()

Unnamed: 0_level_0,bleu_score,rouge-1,rouge-2,rouge-3,rouge-4,rouge-l,rf_score_1,rf_score_2,rf_score_3,rf_score_4
type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2rca,76.93799,0.749061,0.669037,0.610927,0.542392,0.748143,0.302255,0.235136,0.206139,0.190315
