# Train student model on the generated dataset

## Prepare

### Imports

In [None]:
from datasets import load_dataset
import gc
from peft import LoraConfig, TaskType
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
import trl

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

### Load dataset

In [None]:
dataset = load_dataset("json", data_files={"train": "./train.json", "test": "./test.json"})


train_dataset = dataset["train"]
test_dataset = dataset["test"]

print(train_dataset)
print(train_dataset[0])
print(test_dataset)
print(test_dataset[0])

### Instantiate model and tokenizer

In [None]:
torch.cuda.empty_cache()
gc.collect()

model = AutoModelForCausalLM.from_pretrained("sdadas/polish-gpt2-small").to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained("sdadas/polish-gpt2-small")
tokenizer_bielik = AutoTokenizer.from_pretrained("speakleash/Bielik-1.5B-v3.0-Instruct")

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["c_attn", "c_fc", "c_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM 
)

trl.clone_chat_template(model, tokenizer, source_tokenizer_path="speakleash/Bielik-1.5B-v3.0-Instruct")

tokenizer.eos_token_id = tokenizer.encode("</s>")[0]
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id

## Train and generate

### Before

In [None]:
test_sample_idx = 0

def generate_from_prompt(prompt):
    inputs = tokenizer.apply_chat_template(prompt, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
    out_ids = model.generate(
        input_ids=inputs,
        max_new_tokens=128,
        do_sample=False,
        repetition_penalty=1.2,
        no_repeat_ngram_size=3,
        temperature=0.0
    )

    sequence = out_ids[0].tolist()
    print(tokenizer.decode(tokenizer.eos_token_id))
    if tokenizer.eos_token_id in sequence:
        cut_at = sequence.index(tokenizer.eos_token_id)
        sequence = sequence[:cut_at+1]

    return tokenizer.decode(sequence)

generate_from_prompt(test_dataset[test_sample_idx]["prompt"])

### Fine-tune

In [None]:
USE_LORA = False

training_args = SFTConfig(
    output_dir="student_model",
    max_length=256,
    assistant_only_loss=True,
    per_device_train_batch_size=8, 
    gradient_accumulation_steps=1,
    remove_unused_columns=False,
    learning_rate=1e-5,
    eval_strategy="steps",
    num_train_epochs=3

)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    processing_class=tokenizer,
    peft_config=lora_config if USE_LORA else None
)

trainer.train()

### After

In [None]:
generate_from_prompt(test_dataset[test_sample_idx]["prompt"])