# Initialization

In [4]:
!pip install transformers -q
!pip install accelerate -q
!pip install bitsandbytes -q
!pip install peft -q

In [2]:
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType

In [3]:
import os
import json
import sys
import io
import random
import itertools
from typing import Any, Dict, List, Optional, Tuple, Union
import shutil
import logging
from dataclasses import dataclass, field

import torch
import torch.optim
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import AutoTokenizer
from transformers import TrainingArguments, Trainer, TrainerCallback
from transformers import T5ForConditionalGeneration, T5Tokenizer, T5Config
from transformers import HfArgumentParser

import accelerate, bitsandbytes

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType,PeftModel, PeftConfig

In [4]:


# pd.set_option("display.max_columns", None)
# pd.set_option("display.max_rows", 40)


from google.colab import drive

drive.mount("/content/drive")

random_state = 42

dataset_path = "PATH_TO_DATASET_JSON"

def set_seed(seed=random_state):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    tf.random.set_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)

Mounted at /content/drive


# Pretrined model evaluation

In [None]:
pre_prompt = 'Тебя зовут Никита, тебе 21 год. Ты учишься на 3-ем курсе ВШЭ и занимаешься машинным обучением. Ты общаешься со своим другом, напиши ответ на его сообщение. '

In [None]:

while True:
    print('-'*80)
    dialog = []
    while True:
        #print(dialog)
        msg = 'Собеседник: ' + input('Собеседник: ').strip()
        if len(msg) == 0:
            break

        dialog.append(msg)
        prompt = '<SC1>'+ pre_prompt +'Ваш прошлый диалог:' + '\n'.join(dialog) + ' Твой ответ:<extra_id_0>'
        input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
        print(f'PROMPT:{prompt}')
        print(dialog)
        out_ids = model.generate(input_ids=input_ids,
                                    max_length=50,
                                    eos_token_id=tokenizer.eos_token_id,
                                    early_stopping=True,
                                    do_sample=True,
                                    temperature=0.8,
                                    top_k=0,
                                    top_p=0.85)
        #dialog.pop(-1)
        t5_output = tokenizer.decode(out_ids[0][1:]).replace('<extra_id_0>','')
        if '</s>' in t5_output:
            t5_output = t5_output[:t5_output.find('</s>')].strip()

        print('Никита: {}'.format(t5_output))
        dialog.append('Никита: '+t5_output)

In [5]:
def load_samples(dataset_path, pre_prompt, tokenizer):
    samples = []
    with open(dataset_path, 'r', encoding="utf-8") as f:
        for sample in json.load(f):
            try:
                seed = '<SC1>' + pre_prompt +'Ваш прошлый диалог: '+ sample['context'] + ' Твой ответ: <extra_id_0>'
                reply = '<extra_id_0>' + sample['response']
                input_tokens = tokenizer.encode(seed, add_special_tokens=False, truncation=True, max_length=1024)
                output_tokens = tokenizer.encode(reply, add_special_tokens=False)  # , truncation=True, max_length=1024)
                if len(input_tokens) < 712 and len(output_tokens) < 712:
                    samples.append({'input_tokens': input_tokens,
                                    'output_tokens': output_tokens,
                                    'seed': seed,
                                    'reply': reply})
            except Exception as ex:
                print(ex)

    return samples


class FinetuneDataset(Dataset):
    def __init__(self, samples, tokenizer):
        self.tokenizer = tokenizer
        self.max_input_len = 0
        self.max_output_len = 0
        self.samples = []

        self.bos_token_id = tokenizer.encode('<s>', add_special_tokens=False)[0]
        self.eos_token_id = tokenizer.encode('</s>', add_special_tokens=False)[0]
        self.pad_token_id = tokenizer.encode('<pad>', add_special_tokens=False)[0]

        for sample in samples:
            input_ids = sample['input_tokens']
            output_ids = sample['output_tokens'] + [self.eos_token_id]
            self.samples.append((input_ids, output_ids))
            self.max_input_len = max(self.max_input_len, len(input_ids))
            self.max_output_len = max(self.max_output_len, len(output_ids))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index: int):
        input_ids, output_ids = self.samples[index]

        input_npad = self.max_input_len - len(input_ids)
        attention_mask = [1]*len(input_ids) + [0]*input_npad
        input_ids = input_ids + input_npad * [self.pad_token_id]

        output_npad = self.max_output_len - len(output_ids)
        labels = output_ids + output_npad * [-100]

        return {'input_ids': torch.LongTensor(input_ids),
                'attention_mask': attention_mask,
                'labels': torch.LongTensor(labels),
                }


In [None]:
import torch, gc

def empty_cache()
    gc.collect()
    torch.cuda.empty_cache()
    print('Cache cleared!')

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = transformers.GPT2Tokenizer.from_pretrained("ai-forever/FRED-T5-1.7B")
model = transformers.T5ForConditionalGeneration.from_pretrained("ai-forever/FRED-T5-1.7B", load_in_8bit=True, device_map="auto")

tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'})

In [9]:
train_samples = load_samples(dataset_path, tokenizer = tokenizer, pre_prompt = pre_prompt)
train_dataset = FinetuneDataset(train_samples, tokenizer)

In [10]:
len(train_dataset)

29308

# Model fine-tuning

In [None]:
lora_config = LoraConfig(
 r=16,
 lora_alpha=32,
 target_modules=["q", "v"],
 lora_dropout=0.05,
 bias="none",
 task_type=TaskType.SEQ_2_SEQ_LM
)


peft_model = get_peft_model(prepare_model_for_int8_training(model), lora_config)
peft_model.print_trainable_parameters()


peft_training_args = TrainingArguments(
    fp16 = False,
    output_dir = 'PATH_TO_CHECKPOINTS_FOLDER',
    overwrite_output_dir=False,
    optim =  "adafactor",
    learning_rate  = 1e-3,
    weight_decay=0,
    lr_scheduler_type  = 'constant',
    num_train_epochs = 1,
    per_device_train_batch_size = 6,
    gradient_accumulation_steps  = 8, #8
    gradient_checkpointing=False,
    report_to  = 'tensorboard',
    logging_strategy  = 'steps',
    logging_dir='PATH_TO_LOGS_FOLDER',
    logging_steps  = 1,
    save_steps=10,
    save_total_limit=None,
    seed=42,

)

In [None]:
trainer = Trainer(
        model=peft_model,
        args=peft_training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        data_collator=None,

    )

In [None]:
# for name, module in trainer.model.named_modules():
#     if "norm" in name:
#         module = module.to(torch.float32) #float32

In [1]:
train_result = trainer.train()

# Continuing fine-tuning from model checkpoint

In [None]:
checkpoint = 'PATH_TO_CURRENT_CHECKPOINT'
peft_model = PeftModel.from_pretrained(prepare_model_for_int8_training(model), checkpoint, device_map="auto", is_trainable=True)

In [13]:

peft_training_args = TrainingArguments(
    fp16 = False,
    output_dir = 'PATH_TO_CHECKPOINTS_FOLDER',
    overwrite_output_dir=False,
    optim =  "adafactor",
    learning_rate  = 1e-3,
    weight_decay=0,
    lr_scheduler_type  = 'constant',
    num_train_epochs = 1,
    per_device_train_batch_size = 6,
    gradient_accumulation_steps  = 8, #8
    gradient_checkpointing=False,
    report_to  = 'tensorboard',
    logging_strategy  = 'steps',
    logging_dir='PATH_TO_LOGS_FOLDER',
    logging_steps  = 1,
    save_steps=20,
    save_total_limit=None,
    seed=42,

)

trainer = Trainer(
        model=peft_model,
        args=peft_training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        data_collator=None,

    )



In [2]:
trainer.train(resume_from_checkpoint=checkpoint)

# Fine-tuned model evaluation


In [9]:
peft_model_id = 'PATH_TO_FINETUNED_MODEL'

In [10]:
peft_trained_model = PeftModel.from_pretrained(model, peft_model_id, device_map="auto", torch_dtype=torch.float16)
peft_trained_model.eval()


print("Peft model loaded")

Peft model loaded


In [3]:

while True:
    print('-'*80)
    dialog = []
    while True:
        #print(dialog)
        msg = 'Собеседник: ' + input('Собеседник: ').strip()
        if len(msg) == 0:
            break

        dialog.append(msg)
        prompt = '<SC1>'+ pre_prompt +'Ваш прошлый диалог:' + '\n'.join(dialog) + ' Твой ответ:<extra_id_0>'
        input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
        print(f'PROMPT:{prompt}')
        print(dialog)
        out_ids = peft_trained_model.generate(input_ids=input_ids,
                                    max_length=50,
                                    eos_token_id=tokenizer.eos_token_id,
                                    early_stopping=True,
                                    do_sample=True,
                                    temperature=0.8,
                                    top_k=50,
                                    top_p=0.85)
        #dialog.pop(-1)
        t5_output = tokenizer.decode(out_ids[0][1:]).replace('<extra_id_0>','')
        if '</s>' in t5_output:
            t5_output = t5_output[:t5_output.find('</s>')].strip()

        print('Никита: {}'.format(t5_output))
        dialog.append('Никита: '+t5_output)