## Library

In [1]:
import json
import random
from datasets import load_dataset, Dataset
from typing import List, Dict
import pandas as pd
from pathlib import Path

## Data Preparation

### Download Data

In [2]:
# Define the dataset and subset
dataset_name = "csebuetnlp/xlsum"
language_subset = "indonesian"
output_dir = Path("dataset/xlsum-indonesian")

# Create the output directory if it doesn't exist
output_dir.mkdir(parents=True, exist_ok=True)

# Load the dataset
dataset = load_dataset(dataset_name, language_subset, split = "train")

In [3]:
print(dataset.column_names)

['id', 'url', 'title', 'summary', 'text']


In [5]:
print(dataset['train'][0])

{'id': 'media-49647079', 'url': 'https://www.bbc.com/indonesia/media-49647079', 'title': 'Gajah mengamuk saat upacara keagamaan, 17 orang terluka', 'summary': 'Seekor gajah mendadak mengamuk saat prosesi tahunan agama Buddha di Kolombo, Sri Lanka, sehingga membuat peserta upacara tunggang-langgang. Setidaknya 17 orang terluka.', 'text': 'Dilaporkan dua orang terluka cukup serius, sementara sisanya sudah diperbolehkan pulang setelah mendapatkan perawatan. Video gajah yang mengamuk itu kontan viral di media sosial. Si gajah berlari tak tentu arah, menabrak serta menginjak sebagian peserta upacara. Simak juga: Belum diketahui penyebab gajah tersebut tiba-tiba mengamuk. Diduga gajah itu kaget oleh sesuatu di antara para peserta dan pengunjung. Media setempat melaporkan gajah lain yang juga mengamuk di prosesi berbeda. Gajah hias merupakan daya tarik tersendiri dalam upacara keagamaan di Sri Lanka. Bagi warga Sri Lanka, memiliki gajah adalah simbol status. Beberapa kuil di Sri Lanka juga me

Save to CSV

In [6]:
for split in ["train", "test", "validation"]:
    df = pd.DataFrame(dataset[split])  # Convert to Pandas DataFrame
    output_file = output_dir / f"xlsum_indonesian_{split}.csv"
    df.to_csv(output_file, index=False, encoding="utf-8")
    print(f"Saved {split} split to {output_file}")

Saved train split to dataset/xlsum-indonesian/xlsum_indonesian_train.csv
Saved test split to dataset/xlsum-indonesian/xlsum_indonesian_test.csv
Saved validation split to dataset/xlsum-indonesian/xlsum_indonesian_validation.csv


### Load Dataset

In [2]:
train_path = "dataset/xlsum-indonesian/xlsum_indonesian_train.csv"
test_path = "dataset/xlsum-indonesian/xlsum_indonesian_test.csv"
validation_path = "dataset/xlsum-indonesian/xlsum_indonesian_validation.csv"

dataset = load_dataset("csv", data_files={
    "train": train_path,
    "test": test_path,
    "validation": validation_path
})

In [3]:
dataset['train'][0]

{'id': 'media-49647079',
 'url': 'https://www.bbc.com/indonesia/media-49647079',
 'title': 'Gajah mengamuk saat upacara keagamaan, 17 orang terluka',
 'summary': 'Seekor gajah mendadak mengamuk saat prosesi tahunan agama Buddha di Kolombo, Sri Lanka, sehingga membuat peserta upacara tunggang-langgang. Setidaknya 17 orang terluka.',
 'text': 'Dilaporkan dua orang terluka cukup serius, sementara sisanya sudah diperbolehkan pulang setelah mendapatkan perawatan. Video gajah yang mengamuk itu kontan viral di media sosial. Si gajah berlari tak tentu arah, menabrak serta menginjak sebagian peserta upacara. Simak juga: Belum diketahui penyebab gajah tersebut tiba-tiba mengamuk. Diduga gajah itu kaget oleh sesuatu di antara para peserta dan pengunjung. Media setempat melaporkan gajah lain yang juga mengamuk di prosesi berbeda. Gajah hias merupakan daya tarik tersendiri dalam upacara keagamaan di Sri Lanka. Bagi warga Sri Lanka, memiliki gajah adalah simbol status. Beberapa kuil di Sri Lanka jug

### Convert to sharegpt

#### Unsloth 

In [None]:
# this code doesnt work, out of index bug
# 
# from unsloth import to_sharegpt
# train_dataset = to_sharegpt(
#     dataset['validation'],
#     merged_prompt= \
#         "[[Provide a summary for the following article in its original language:]]"\
#         "[[\nTitle: {title}]]"\
#         "[[\nContent: {text}]]"\
#         "[[\nSummarize below:]]",
#     conversation_extension=2,  # Randomly combines conversations 
#     output_column_name="summary",  # Use the "summary" column as the target
# )


#### Manual

In [4]:
def convert_to_sharegpt_format(example):
    """
    Convert a single XLSum example to ShareGPT format.
    """
    human_message = {
        "from": "human",
        "value": f"Provide a summary for the following article in its original language: \nTitle: {example['title']} \nContent: {example['text']} \nSummary:"
    }
    
    gpt_message = {
        "from": "gpt",
        "value": example['summary']
    }
    
    return [human_message, gpt_message]

def create_extended_conversation(examples: List[Dict], conversation_extension: int) -> List[Dict]:
    """
    Create an extended conversation by combining multiple examples.
    
    Args:
        examples: List of dataset examples to sample from
        conversation_extension: Number of examples to combine into one conversation
    
    Returns:
        List of messages forming a single extended conversation
    """
    # Randomly sample the specified number of examples
    selected_examples = random.sample(examples, min(conversation_extension, len(examples)))
    
    # Convert each example and combine their messages
    extended_conversation = []
    for example in selected_examples:
        conversation_pair = convert_to_sharegpt_format(example)
        extended_conversation.extend(conversation_pair)
    
    return extended_conversation

def process_dataset(dataset, conversation_extension: int = 1):
    """
    Process the dataset and convert to ShareGPT format with optional conversation extension.
    
    Args:
        dataset: The input dataset
        conversation_extension: Number of examples to combine into one conversation
    
    Returns:
        Dataset object with conversations column
    """
    processed_data = []
    
    # Process each split in the dataset
    for split in dataset.keys():
        print(f"Processing {split} split...")
        examples = list(dataset[split])
        
        # Calculate number of conversations needed
        num_conversations = len(examples) // conversation_extension
        
        # Create extended conversations
        for i in range(num_conversations):
            start_idx = i * conversation_extension
            end_idx = start_idx + conversation_extension
            conversation_examples = examples[start_idx:end_idx]
            
            # Create extended conversation
            if conversation_extension > 1:
                conversation = create_extended_conversation(conversation_examples, conversation_extension)
            else:
                conversation = convert_to_sharegpt_format(conversation_examples[0])
            
            processed_data.append({"conversations": conversation})
    
    # Convert to Dataset object
    return Dataset.from_list(processed_data)

def save_sharegpt_format(dataset, output_path):
    """
    Save the converted data in JSONL format.
    """
    with open(output_path, 'w', encoding='utf-8') as f:
        for item in dataset:
            f.write(json.dumps(item['conversations'], ensure_ascii=False) + '\n')

In [5]:
# Set conversation extension (e.g., 3 for combining 3 examples into one conversation)
conversation_extension = 3

# Convert to ShareGPT format with conversation extension
processed_dataset = process_dataset(dataset, conversation_extension)

Processing train split...
Processing test split...
Processing validation split...


In [6]:
# Print a sample conversation
print("\nSample conversation:")
print(json.dumps(processed_dataset[0]['conversations'], ensure_ascii=False, indent=2))


Sample conversation:
[
  {
    "from": "human",
    "value": "Provide a summary for the following article in its original language: \nTitle: Apa alasan pemerintah pangkas 14 dalam daftar Proyek Strategis Nasional ? \nContent: Proyek MRT Sudirman- Lebak Bulus tengah dikerjakan, namun untuk koridor Timur -Barat dicabut dari PSN. Pembangunan prasarana didengungkan sebagai ujung tombak pemerintahan Presiden Joko Widodo dengan sasaran 245 proyek hingga tahun 2019 mendatang namun setahun sebelum batas waktu itu, 14 di antaranya dicoret dari daftar Proyek Strategis Nasional dengan alasan pengerjaan fisiknya tidak dapat dimulai sebelum kuartal III-2019. Pemangkasan proyek tersebut dilakukan setelah adanya evaluasi KPPIP di Kementerian Koordinator Bidang Perekonomian. Ketua Tim Pelaksana KPPIP Wahyu Utomo mengatakan berdasarkan evaluasi proyek tersebut tidak dapat memenuhi kriteria. \"Nah kriteria yang kita pakai adalah utamanya kita minta proyek itu bisa dilakukan dimulai kontruksinya paling 

In [7]:
# Print dataset info
print("\nDataset info:")
print(f"Number of conversations: {len(processed_dataset)}")
print(f"Features: {processed_dataset.features}")
print(f"Datatype: {type(processed_dataset)}")


Dataset info:
Number of conversations: 15933
Features: {'conversations': [{'from': Value(dtype='string', id=None), 'value': Value(dtype='string', id=None)}]}
Datatype: <class 'datasets.arrow_dataset.Dataset'>


In [8]:
# Save the converted data
output_path = "dataset/xlsum-indonesian/sharegpt_xlsum_indonesian.jsonl"
save_sharegpt_format(processed_dataset, output_path)
print(f"Converted data saved to {output_path}")

Converted data saved to dataset/xlsum-indonesian/sharegpt_xlsum_indonesian.jsonl


### Standardize

In [9]:
from unsloth import standardize_sharegpt
processed_dataset_std = standardize_sharegpt(processed_dataset)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


Standardizing format:   0%|          | 0/15933 [00:00<?, ? examples/s]

In [10]:
print("\nSample conversation (standardized):")
print(json.dumps(processed_dataset_std[0]['conversations'], ensure_ascii=False, indent=2))


Sample conversation (standardized):
[
  {
    "content": "Provide a summary for the following article in its original language: \nTitle: Apa alasan pemerintah pangkas 14 dalam daftar Proyek Strategis Nasional ? \nContent: Proyek MRT Sudirman- Lebak Bulus tengah dikerjakan, namun untuk koridor Timur -Barat dicabut dari PSN. Pembangunan prasarana didengungkan sebagai ujung tombak pemerintahan Presiden Joko Widodo dengan sasaran 245 proyek hingga tahun 2019 mendatang namun setahun sebelum batas waktu itu, 14 di antaranya dicoret dari daftar Proyek Strategis Nasional dengan alasan pengerjaan fisiknya tidak dapat dimulai sebelum kuartal III-2019. Pemangkasan proyek tersebut dilakukan setelah adanya evaluasi KPPIP di Kementerian Koordinator Bidang Perekonomian. Ketua Tim Pelaksana KPPIP Wahyu Utomo mengatakan berdasarkan evaluasi proyek tersebut tidak dapat memenuhi kriteria. \"Nah kriteria yang kita pakai adalah utamanya kita minta proyek itu bisa dilakukan dimulai kontruksinya paling lamb

## Training

### Load Model

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
# fourbit_models = [
#     "unsloth/Meta-Llama-3.1-8B-bnb-4bit",      # Llama-3.1 2x faster
#     "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
#     "unsloth/Meta-Llama-3.1-70B-bnb-4bit",
#     "unsloth/Meta-Llama-3.1-405B-bnb-4bit",    # 4bit for 405b!
#     "unsloth/Mistral-Small-Instruct-2409",     # Mistral 22b 2x faster!
#     "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
#     "unsloth/Phi-3.5-mini-instruct",           # Phi-3.5 2x faster!
#     "unsloth/Phi-3-medium-4k-instruct",
#     "unsloth/gemma-2-9b-bnb-4bit",
#     "unsloth/gemma-2-27b-bnb-4bit",            # Gemma 2x faster!

#     "unsloth/Llama-3.2-1B-bnb-4bit",           # NEW! Llama 3.2 models
#     "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
#     "unsloth/Llama-3.2-3B-bnb-4bit",
#     "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
# ] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-3B-Instruct", # or choose "unsloth/Llama-3.2-1B-Instruct"
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

### Chat Template

In [None]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }
pass

In [None]:
processed_dataset_std = processed_dataset_std.map(formatting_prompts_func, batched = True,)

In [None]:
processed_dataset_std[5]["conversations"]

In [None]:
processed_dataset_std[5]["text"]

### Training

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = processed_dataset_std,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 1, # Set this for 1 full training run.
        # max_steps =None,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        save_strategy = "steps",
        save_steps = 1000,
        report_to = "none", # Use this for WandB etc
    ),
)

In [None]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
    response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)

In [None]:
tokenizer.decode(trainer.train_dataset[5]["input_ids"])

In [None]:
space = tokenizer(" ", add_special_tokens = False).input_ids[0]
tokenizer.decode([space if x == -100 else x for x in trainer.train_dataset[5]["labels"]])

#### Show memory stats

In [None]:
#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

#### Train

In [None]:
# use this to train from scratch
trainer_stats = trainer.train()

# use this to train from checkpoint
# trainer_stats = trainer.train(resume_from_checkpoint = True)


#### Final Memory anf Time Stats

In [None]:
#@title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory         /max_memory*100, 3)
lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")