In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, Trainer, EvalPrediction, TrainingArguments, TrainerControl, TrainerState, BitModel
import math
#from transformers.trainer_pt_utils import PredictionOutput
from peft import LoraConfig, get_peft_model, LoraModel
from peft import prepare_model_for_kbit_training, PeftModel, PeftConfig
import transformers
from transformers import pipeline
from torch.utils.data import Dataset
from datasets import load_dataset#, Dataset
import datasets
import numpy as np
from transformers.trainer_callback import TrainerCallback
from typing import List, Optional  # Add the import statement at the beginning of your file
from transformers import logging
from typing import Dict, Optional, Any
from tqdm import tqdm
from transformers import TrainerState
from datetime import datetime
import copy
from transformers import TrainerControl, TrainerState
import tempfile
from sklearn.model_selection import train_test_split, KFold
import pickle
from random import sample
"""
from peft import prepare_model_for_kbit_training

from transformers import TrainerCallback
from torch.cuda.amp import autocast
from torch.optim import AdamW
"""

In [None]:
def create_subset(dataset, num_examples):
    indices = sample(range(len(dataset)), num_examples)
    return dataset.select(indices)

def filter_datasets_for_use_case(datasets, use_case):
    filtered_datasets = {}
    for key, value in datasets.items():
        if value[use_case]:
            filtered_datasets[key] = value[use_case]
    return filtered_datasets

def split_datasets(data_dict, ratio=0.7, random_state=None):
    train_data = {}
    test_data = {}
    validation_indices = {}

    for key, value in data_dict.items():
        train, test, train_indices, test_indices = train_test_split(value, range(len(value)), train_size=ratio, random_state=random_state)
        train_data[key] = train
        test_data[key] = test
        validation_indices[key] = test_indices

    return train_data, test_data, validation_indices

def unique_elements(lst):
    result = []
    seen = set()
    for item in lst:
        if item not in seen:
            seen.add(item)
            result.append(item)
    return result

class PerplexityLoggingCallback(TrainerCallback):
    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl,
                    metrics: Dict[str, float], prefix=None, **kwargs):
        if prefix is None:
            prefix = "eval"
        eval_loss_key = f"{prefix}_loss"
        if eval_loss_key in metrics:
            loss = metrics[eval_loss_key]
            metrics[f"{prefix}_perplexity"] = math.exp(loss)

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

class CustomDataset(Dataset):
    def __init__(self, tensor_list):
        self.tensor_list = tensor_list

    def __getitem__(self, idx):
        return self.tensor_list[idx]

    def __len__(self):
        return len(self.tensor_list)
        
def get_sequences(text, tokenizer, seq_length=768, stride_ratio=0.5):
    all_token_ids = tokenizer.encode(text)

    #Generate sequences using sliding window approach
    stride_length = int(seq_length * stride_ratio)
    sequences = []
    for i in range(0, len(all_token_ids) - seq_length +1, stride_length):
        input_ids = all_token_ids[i:i+seq_length]
        sequences.append(input_ids)
    
    #Truncate the last sequence if it less than seq_length
    last_sequence = sequences[-1]
    if len(last_sequence) < seq_length:
        last_sequence = last_sequence + [tokenizer.pad_token_id] * (seq_length - len(last_sequence))
        sequences[-1] = last_sequence

    #Drop any remaining sequences that are less than seq_length
    sequences = [sequence for sequence in sequences if len(sequence) == seq_length]

    return sequences

def evaluate(model, dataloader, device, max_eval_steps):
    model.eval()
    losses = []
    for step, batch in enumerate(dataloader):
        # Extract input_ids and convert them to tensors
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device) if 'labels' in batch else None

        with torch.no_grad():
            input_dict = {'input_ids': input_ids, 'labels': labels}
            outputs = model(**input_dict)
         
        loss = outputs.loss.repeat(input_ids.shape[0])
        losses.append(loss.detach())
        if max_eval_steps > 0 and step >= max_eval_steps: break
    loss = torch.mean(torch.cat(losses))
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = torch.tensor(float("inf"))
    return loss.item(), perplexity.item()

class CustomTrainer(Trainer):
    def __init__(self, *args, max_eval_steps=0, **kwargs):
        super().__init__(*args, **kwargs)
        self.best_perplexity = float("inf")
        self.best_model_state_dict = None
        self.no_improvement_counter = 0
        self.passed_epoch_steps = False
        self.max_eval_steps = max_eval_steps  # Add max_eval_steps as an attribute

    def evaluation_loop(self, dataloader, description, prediction_loss_only=False, ignore_keys=None, metric_key_prefix='eval'):
        eval_loss, perplexity = evaluate(self.model, dataloader, self.args.device, self.max_eval_steps)
    
        # Check if epoch_steps are surpassed
        if self.state.epoch >= 1:
            self.passed_epoch_steps = True
    
        # Check for improvements if the epoch_steps are surpassed
        if self.passed_epoch_steps:
            if perplexity < self.best_perplexity:
                self.best_perplexity = perplexity
                self.best_model_state_dict = {k: v.clone().to('cpu') for k, v in self.model.state_dict().items()}
                self.no_improvement_counter = 0
            else:
                self.no_improvement_counter += 1
    
        # Stop training, load the best state_dict in the model, and return the best_model if the perplexity did not improve 3 times consecutively
        if self.no_improvement_counter == 3:
            if self.best_model_state_dict:
                self.model.load_state_dict(self.best_model_state_dict)
            self.model.to(self.args.device)
            self.control.should_training_stop = True
            print("Training stopped, best model loaded with Perplexity:", self.best_perplexity)
    
        self.log({
            "eval_loss": eval_loss,
            "perplexity": perplexity,
            "epoch": self.state.epoch,
        })
    
        # Define num_samples as the total number of samples in the dataloader
        #num_samples = len(dataloader.dataset)
    
        # Initialize an instance of EvalPrediction without the 'metrics' keyword argument 
        #eval_prediction = EvalPrediction(predictions=None, label_ids=None, num_samples=num_samples)
        eval_prediction = EvalPrediction(predictions=None, label_ids=None)
        
        # Define num_samples as the total number of samples in the dataloader
        num_samples = len(dataloader.dataset)
    
        # Add the num_samples attribute to the eval_prediction instance
        eval_prediction.num_samples = num_samples
    
        # Set the metrics dictionary
        eval_prediction.metrics = {"eval_loss": eval_loss}
    
        return eval_prediction
    
    def get_completed_steps(self):
        return self.state.global_step



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#sample # of character's from combined_text
sample=False
#if true, what # of characters to sample (this * avg_tokens_per_char = rough # of tokens)
#(10000*.6)/2/128*.1
s_size = 10000
seq_length = 128
#seq_length = 128
batch_size = 16
epoch_steps_warmup_ratio = 1/3
epochs = 10
model_id = "EleutherAI/gpt-neo-1.3B"
#model_id = "EleutherAI/gpt-neo-125M"
warm_ratio = 1/2
train_fraction = 0.9
epochs = 3
gradient_accumulation_steps = 16
seed = 42

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    #bnb_4bit_quant_type="nf4",
    bnb_4bit_quant_type="fp4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

lora_config = LoraConfig(
    r=8, 
    lora_alpha=32, 
    #target_modules=["query_key_value"], 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM"
)

infer_peft_config = PeftConfig.from_pretrained('bitsft')
#BitModel.from_pretrained(peft_config)
infer_model = AutoModelForCausalLM.from_pretrained(
        infer_peft_config.base_model_name_or_path,
        quantization_config=bnb_config, device_map={"":0}
)


#infer_model.gradient_checkpointing_enable()
#infer_model = prepare_model_for_kbit_training(model)

#infer_model = get_peft_model(infer_model, lora_config)

#print_trainable_parameters(infer_model)


In [None]:
model_id = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
new_tokens = {'additional_special_tokens': ['<|Context|>', '<|Prompt|>', '<|Response|>']}
tokenizer.add_special_tokens(new_tokens)
tokenizer.pad_token = '[PAD]'
infer_model.resize_token_embeddings(len(tokenizer))

In [None]:
infer_model.config.use_cache = True

generator = pipeline('text-generation', model=infer_model, tokenizer = tokenizer)
#results = generator(r"Context:\nPrompt: Finish the quote: 'To live well...'\nResponse:", do_sample=True, min_length=50, max_length=200)
results = generator(r"Context:\nPrompt: If value A is 2, and value B is 3, what is A+B?'\nResponse:", do_sample=True, min_length=50, max_length=200)
print(results[0]['generated_text'])

In [8]:
infer_model.config.use_cache = True

generator = pipeline('text-generation', model=infer_model, tokenizer = tokenizer)
#results = generator(r"Context:\nPrompt: Finish the quote: 'To live well...'\nResponse:", do_sample=True, min_length=50, max_length=200)
results = generator(r"What is the meaning of life?", do_sample=True, min_length=50, max_length=200)
print(results[0]['generated_text'])

What is the meaning of life? It’s a question that will be asked as many people on the planet die every day, but no one really knows what it means.

The questions that will be asked are about whether you will feel the pain of death, whether you will know if your loved one is dead, whether you will know if you were lucky or unlucky to be alive, and whether you will know how to cope with death. Death is hard to see, sometimes impossible to see, and it is a big reason to ask the bigger questions about your place in life and into the world.

It is a question that can help us to understand our place in the world. The question also can help with questions about meaning. Meaning is a big question, but it is not one question. Meaning is many and you can ask many questions. The question is not about death, it is about what we can and must do to survive. It is about what we want to


In [None]:
# Inference wrapper function
def generate_response(input_text: str):
    input_text_tokens = (
        input_text.replace("Context: ", "<|Context|>")
        .replace("Prompt: ", "<|Prompt|>")
    )
    input_ids = tokenizer.encode(input_text_tokens + " <|Response|>", return_tensors="pt")
    with torch.no_grad():
        output = model.generate(input_ids, max_length=200, num_return_sequences=1)
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    return response

# Use the wrapper function for user-friendly input
user_input = "Context: Your context here Prompt: Your prompt here"
generated_response = generate_response(user_input)
print(generated_response)