In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from datasets import load_dataset

In [37]:
def preprocess(example):
    def safe(x):
        if x is None:
            return ""
        if isinstance(x, list):
            return " ".join(str(i) for i in x)
        return str(x)

    instruction = safe(example["instruction"])
    input_text  = safe(example["input"])
    output      = safe(example["output"])

    if input_text.strip() == "":
        prompt = (
            f"### Instruction:\n{instruction}\n\n"
            f"### Response:\n{output}"
        )
    else:
        prompt = (
            f"### Instruction:\n{instruction}\n\n"
            f"### Input:\n{input_text}\n\n"
            f"### Response:\n{output}"
        )

    tokens = tokenizer(prompt, truncation=True, max_length=1024)

    response_index = prompt.index("### Response:")
    prefix = tokenizer(prompt[:response_index], add_special_tokens=False)
    prefix_len = len(prefix["input_ids"])

    labels = tokens["input_ids"].copy()
    for i in range(prefix_len):
        labels[i] = -100

    return {
        "input_ids": tokens["input_ids"],
        "attention_mask": tokens["attention_mask"],
        "labels": labels,
    }


In [38]:
def chunk_text(text, max_tokens=512, stride=128):
    global tokenizer
    tokens = tokenizer.encode(text)
    chunks=[]
    i=0
    while i < len(tokens):
        chunk_tokens = tokens[i:i+max_tokens]
        chunks.append(tokenizer.decode(chunk_tokens, skip_special_tokens=True))
        i += (max_tokens - stride)
    return chunks

In [39]:
def modelSet(model_name):
    global tokenizer, model
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        local_files_only=True
        )
    bnb_config = BitsAndBytesConfig(
        load_in_8bit=True 
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        quantization_config=bnb_config,
    )

In [40]:
def loraSet():
    global model
    lora_config = LoraConfig(
        r=16,               
        lora_alpha=32,      
        target_modules=["q_proj", "v_proj"], 
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(model, lora_config)

In [41]:
modelSet("../data/models")
loraSet()

ds = load_dataset("json", data_files="sft_dataset.jsonl", split="train")
ds = ds.map(preprocess, remove_columns=["instruction", "input", "output"])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/85 [00:00<?, ? examples/s]

IndexError: list assignment index out of range

In [None]:
from transformers import AutoModelForSequenceClassification

class CustomModel(AutoModelForSequenceClassification):
    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        outputs = self(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
        loss = outputs.loss  # Get the loss from the model's output
        return loss  # Return the loss

In [None]:
training_args = TrainingArguments(
    output_dir="../lora_models",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=6,
    learning_rate=3e-4,
    fp16=True,
    logging_steps=10,
    save_steps=200,
    num_train_epochs=3,
    save_total_limit=2,
    remove_unused_columns=False
)

def build_safe_text(value):
    if value is None:
        return ""
    if isinstance(value, list):
        return " ".join([str(v) for v in value])
    return str(value)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds
)

In [24]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [27]:
ds

Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 85
})

In [25]:
trainer.train()

model.save_pretrained("../data/lora_models")
tokenizer.save_pretrained("../data/lora_models")

TypeError: can only join an iterable

In [28]:
ds


Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 85
})