In [None]:
from transformers import Trainer, TrainingArguments
from datasets import Dataset, load_dataset
from peft import get_peft_model, LoraConfig, TaskType, PeftModel, PeftConfig
from models.worker import Worker

import os
os.environ['WANDB_DISABLED'] = 'true'

In [None]:
# Load model

model_id = 'princeton-nlp/Sheared-LLaMA-2.7B'
worker = Worker(model_id)
worker.check_device_map(no_split_module_classes=["LlamaDecoderLayer"])
worker.load_model()
worker.model

In [None]:
# Set up LoRA

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)

worker.model = get_peft_model(worker.model, peft_config)
worker.model.print_trainable_parameters()

In [None]:
# Load instruction tuning dataset

dataset_name = 'databricks/databricks-dolly-15k'
dataset = load_dataset(dataset_name)
dataset['train']

In [None]:
# Format dataset into valid query/answer pairs to train on

prompt_template = """query: {query}
context: {context}
answer: {answer}
"""

samples = []

for example in dataset['train']:
    sample = prompt_template.format(query=example['instruction'],
                                    context=example['context'],
                                    answer=example['response'])
    samples.append(sample)

samples_dict = {'text': samples}

tuning_dataset = Dataset.from_dict(samples_dict)
tuning_dataset

In [None]:
# tokenize dataset

def tokenize_function(samples):
    return worker.tokenizer(samples['text'], padding='max_length', truncation=True)

tokenized_dataset = tuning_dataset.map(tokenize_function, batched=True)
tokenized_dataset

In [None]:
# Fine-tune model

training_args = TrainingArguments(
    output_dir='./checkpoints',      # output directory for model checkpoints
    num_train_epochs=1,              # number of training epochs
    per_device_train_batch_size=8,   # batch size per device during training
    per_device_eval_batch_size=8,    # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
)

trainer = Trainer(
    model=worker.model,
    args=training_args,
    train_dataset=tokenized_dataset
)

trainer.train()

In [None]:
# Save checkpoint

worker.model.save_pretrained('./checkpoints') 
worker.save_pretrained('./checkpoints')

In [None]:
# Load checkpoint and do some inference

# import torch
# from transformers import AutoModel, AutoTokenizer

# peft_model_id = './checkpoints/...'
# config = PeftConfig.from_pretrained(peft_model_id)
# model = AutoModel.from_pretrained(model_id)
# model = PeftModel.from_pretrained(model, peft_model_id)
# tokenizer = AutoTokenizer.from_pretrained(model_id)

# device = 'cuda'
# model = model.to(device)
# model.eval()
# inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label :", return_tensors="pt")

# with torch.no_grad():
#     outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=10)
#     print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])