In [1]:
!pip install -q transformers datasets accelerate bitsandbytes wandb peft trl

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 MB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.3/318.3 kB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m84.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [16]:
import os
import json
import torch
import wandb
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig
from peft import LoraConfig, get_peft_model, LoraConfig, TaskType


def formatting_prompts_func(data):
    return [f"### Question: {q}\n ### Answer: {a}" for q, a in zip(data['instruction'], data['output'])]

# LoRA configurations
LORA_R_VALUES = [8, 128, 256]

class WandbLoggingCallback(TrainerCallback):

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            if "loss" in logs:
                wandb.log({"train/loss": logs["loss"], "step": state.global_step})


def train_with_lora(lora_r):

    # Initialize wandb
    wandb.init(project='Hanghae99_week8_basic',reinit=True, name=f"rank {lora_r}")

    # Model and dataset
    MODEL_NAME = "facebook/opt-350m"
    DATASET_NAME = "sahil2801/CodeAlpaca-20k"

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.model_max_length = 128

    lora_dropout: float = 0.1
    lora_alpha: int = 32
    # Load dataset
    dataset = load_dataset(DATASET_NAME, split="train")
    dataset = dataset.train_test_split(test_size=0.2, seed=42)
    train_dataset, eval_dataset = dataset['train'], dataset['test']

    # Data collator
    response_template = " ### Answer:"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)


    # Load base model
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")

    target_modules = set()

    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            names = name.split('.')
            target_modules.add(names[0] if len(names) == 1 else names[-1])

    if "lm_head" in target_modules:  # needed for 16-bit
        target_modules.remove("lm_head")

    target_modules = list(target_modules)

    # Apply LoRA
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=target_modules
    )
    model = get_peft_model(model, lora_config)

    training_config = SFTConfig(
        output_dir=f"./lora_r_{lora_r}_output",
        save_total_limit=1,
        logging_steps=500,
        eval_steps=500,
        max_steps=5000,
        fp16=True,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        save_steps=500,
        save_strategy="steps",
        logging_strategy="steps",
        evaluation_strategy="steps",
    )

    trainer = SFTTrainer(
        model=model,
        args=training_config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        callbacks=[WandbLoggingCallback()]
    )

    # Train
    trainer.train()

    # Log metrics
    metrics = trainer.state.log_history
    wandb.log({"train_metrics": metrics})

    # Memory usage
    print('Max Alloc:', round(torch.cuda.max_memory_allocated(0)/1024**3, 1), 'GB')
    wandb.log({"max_memory_GB": round(torch.cuda.max_memory_allocated(0)/1024**3, 1)})

    wandb.finish()

# Train with different LoRA ranks
for r in LORA_R_VALUES:
    train_with_lora(r)
    torch.cuda.empty_cache()  # GPU 메모리 정리



0,1
eval/loss,██▁
eval/mean_token_accuracy,▁▁█
eval/runtime,▁▁▁▁▁▁▁███
eval/samples_per_second,███████▁▁▁
eval/steps_per_second,███████▁▁▁
eval_loss,█▁
learning_rate,█▁
loss,█▁
step,▁
train/epoch,▁▁▁▁▁▂▂▄▄▅▅▇▇██▂▂▂▂▄▄

0,1
eval/loss,1.47826
eval/mean_token_accuracy,0.68289
eval/runtime,64.5681
eval/samples_per_second,62.028
eval/steps_per_second,15.518
eval_loss,1.47826
learning_rate,2e-05
loss,1.5839
step,1000.0
train/epoch,0.49938




Tokenizing train dataset:   0%|          | 0/16017 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/16017 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/4005 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/4005 [00:00<?, ? examples/s]

Step,Training Loss,Validation Loss
500,1.8486,1.657278
1000,1.7157,1.571436
1500,1.5789,1.524542
2000,1.6095,1.490656
2500,1.557,1.471332
3000,1.5164,1.453705
3500,1.5315,1.440017
4000,1.5176,1.431205
4500,1.4663,1.42619
5000,1.4975,1.423509


Max Alloc: 11.8 GB


0,1
eval/loss,█▅▄▃▂▂▁▁▁▁
eval/mean_token_accuracy,▁▃▅▆▇▇████
eval/runtime,█▂▆▆▁▃▇▃▆▅
eval/samples_per_second,▁▇▃▃█▆▂▆▃▄
eval/steps_per_second,▁▇▃▃█▆▂▆▃▄
max_memory_GB,▁
step,▁▂▃▃▄▅▆▆▇█
train/epoch,▁▁▂▂▃▃▃▃▄▄▅▅▆▆▆▆▇▇███
train/global_step,▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇██████
train/grad_norm,▃▁▃▂▂▂▃▄█▄

0,1
eval/loss,1.42351
eval/mean_token_accuracy,0.69274
eval/runtime,64.6853
eval/samples_per_second,61.915
eval/steps_per_second,15.49
max_memory_GB,11.8
step,5000.0
total_flos,7345663308791808.0
train/epoch,1.24844
train/global_step,5000.0




Tokenizing train dataset:   0%|          | 0/16017 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/16017 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/4005 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/4005 [00:00<?, ? examples/s]

Step,Training Loss,Validation Loss
500,1.8463,1.653917
1000,1.7107,1.565327
1500,1.5733,1.51894
2000,1.604,1.48543
2500,1.5512,1.467159
3000,1.5108,1.448593
3500,1.5254,1.435123
4000,1.5118,1.426675
4500,1.4608,1.421924
5000,1.4919,1.419355


Max Alloc: 12.6 GB


0,1
eval/loss,█▅▄▃▂▂▁▁▁▁
eval/mean_token_accuracy,▁▃▅▆▇▇████
eval/runtime,▇▂▇▁▆▂▂▃█▇
eval/samples_per_second,▂▇▂█▃▇▇▆▁▂
eval/steps_per_second,▂▇▂█▃▇▇▆▁▂
max_memory_GB,▁
step,▁▂▃▃▄▅▆▆▇█
train/epoch,▁▁▂▂▃▃▃▃▄▄▅▅▆▆▆▆▇▇███
train/global_step,▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇██████
train/grad_norm,▃▁▃▂▂▂▄▅█▄

0,1
eval/loss,1.41935
eval/mean_token_accuracy,0.69349
eval/runtime,64.9294
eval/samples_per_second,61.682
eval/steps_per_second,15.432
max_memory_GB,12.6
step,5000.0
total_flos,8624968628011008.0
train/epoch,1.24844
train/global_step,5000.0




Tokenizing train dataset:   0%|          | 0/16017 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/16017 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/4005 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/4005 [00:00<?, ? examples/s]

Step,Training Loss,Validation Loss
500,1.8475,1.653809
1000,1.7105,1.565601
1500,1.5719,1.517401
2000,1.6024,1.484538
2500,1.5495,1.465551
3000,1.5095,1.447354
3500,1.5242,1.434394
4000,1.5108,1.425224
4500,1.4595,1.420431
5000,1.4904,1.418031


Max Alloc: 13.5 GB


0,1
eval/loss,█▅▄▃▂▂▁▁▁▁
eval/mean_token_accuracy,▁▃▅▆▇▇████
eval/runtime,▃▄▃▄▅▁▄▅█▇
eval/samples_per_second,▆▅▆▅▄█▅▄▁▂
eval/steps_per_second,▆▅▆▅▄█▅▄▁▂
max_memory_GB,▁
step,▁▂▃▃▄▅▆▆▇█
train/epoch,▁▁▂▂▃▃▃▃▄▄▅▅▆▆▆▆▇▇███
train/global_step,▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇██████
train/grad_norm,▃▁▃▂▁▂▃▄█▄

0,1
eval/loss,1.41803
eval/mean_token_accuracy,0.6934
eval/runtime,64.995
eval/samples_per_second,61.62
eval/steps_per_second,15.417
max_memory_GB,13.5
step,5000.0
total_flos,9989560968511488.0
train/epoch,1.24844
train/global_step,5000.0
