#### **PyTorch**

In [None]:
import torch
import torch.nn.functional as F
print(f"PyTorch Version: {torch.__version__}")

import torch
print(f"Cude is available: {torch.cuda.is_available()}")
print(f"Device name: {torch.cuda.get_device_name(0)}")

#### **Import Other Libraries**

In [None]:
from datasets import load_dataset 
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import numpy as np
import evaluate
import transformers
from transformers import TrainingArguments
import torch 
import matplotlib.pyplot as plt 
from transformers import DataCollatorWithPadding
import os 
from pathlib import Path
import random 
from datasets import Dataset, DatasetDict
import warnings
from functools import partial
from datasets import concatenate_datasets
from functools import partial 
from tqdm import tqdm 
import textwrap
from IPython.display import display
from IPython.display import Markdown
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import LoraConfig, get_peft_model 
from transformers import BitsAndBytesConfig
import os 
import re 
os.environ["WANDB_DISABLED"] = "true"
warnings.filterwarnings('ignore', message='Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.')

#### **Parameters**

In [None]:
# This cell is tagged with `parameters`
model_name = "meta-llama/Meta-Llama-3-8B-Instruct" #"google/gemma-1.1-7b-it" #microsoft/phi-2" #"microsoft/phi-2" #"#"meta-llama/Llama-2-7b-chat-hf" # "distilbert-base-uncased" 
column = 'text'
epochs = 1
seed = 0
verbose = True 
test_size = 0.5
p = 0.0

#### **Set Up Path**

In [None]:
results_folder = str(Path(os.getcwd()).parent.parent.absolute())  + '/results/'
figures_folder = str(Path(os.getcwd()).parent.parent.absolute())  + '/figures/'
print(results_folder)

In [None]:
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

#### **Visual Checks**

In [None]:
### ---         Print Markdown
def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
### ---

### ---         Memory Check
def Memory():
    print("Current memory usage:")
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
### ---

Memory()

#### **Qlora**

In [None]:
from peft import LoraConfig, get_peft_model 
from transformers import BitsAndBytesConfig

# ----- QUANTIZATION -------# 
# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = True

# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

# Enable fp16/bf16 training (set bf16 to True with an A100)
fp16 = False
bf16 = True

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
    major, _ = torch.cuda.get_device_capability()
    if major >= 8:
        print("=" * 80)
        print("Your GPU supports bfloat16: accelerate training with bf16=True")
        print("=" * 80)

# ----- LORA -------# 

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

#### **Instantiate Model**

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name, 
                                             device_map="auto", 
                                             quantization_config=bnb_config, 
                                             trust_remote_code=True)# So we can do gradient checkpointing
model.config.use_cache = False
model.config.pretraining_tp = 1
model.config.gradient_checkpointing = True
model.enable_input_require_grads()
print(model.generation_config)
Memory()

#### **Peft Model**

In [None]:
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())
Memory()

#### **Tokenizer**

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
dataset = load_dataset(f"ppower1/chat_instrument", split='train', download_mode="force_redownload")

#### **Data set**

In [None]:
dataset = load_dataset(f"ppower1/chat_instrument", split='train', download_mode="force_redownload")
dataset = dataset.select(range(100))

# Reshuffle and split the combined dataset with a fixed seed
new_splits = dataset.train_test_split(test_size=test_size, seed=seed)  # adjust test_size as needed

# Create a new DatasetDict with the shuffled splits
reshuffled_dataset = DatasetDict({
    'train': new_splits['train'],
    'test': new_splits['test']
})


In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=reshuffled_dataset['train'],
    eval_dataset=reshuffled_dataset['test'],
    args=TrainingArguments(
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        load_best_model_at_end=True,
        gradient_checkpointing=True,
        gradient_accumulation_steps=4,
        max_steps=400,
        evaluation_strategy = "steps",
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        early_stopping=True
    ),
    peft_config=lora_config,
)

In [None]:
trainer.train()

In [None]:
steps, train_loss =   [i['step'] for i in trainer.state.log_history if 'loss' in i],  [i['loss'] for i in trainer.state.log_history if 'loss' in i]
eval_loss = [i['eval_loss'] for i in trainer.state.log_history if 'eval_loss' in i]

In [None]:
plt.plot(train_loss, label='Train')
plt.plot(eval_loss, label='Validation')
plt.legend()
plt.show()

In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

In [None]:
messages = dataset[0]['messages'][:2]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
outputs = trainer.model.generate(tokenized_chat, max_new_tokens=3)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))