In this notebook we will be reading alll datasets we want to use for SFT, then we will process the text by adding the appropriate instrcution_tokens.

In [2]:
import datasets
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
norobots_ds = load_dataset("OpenLLM-Ro/ro_sft_norobots")
dolly_ds = load_dataset("OpenLLM-Ro/ro_sft_dolly")

In [4]:
norobots_ds['train']['category']

Column(['Chat', 'Brainstorm', 'Generation', 'Brainstorm', 'Generation'])

In [5]:
# Get unique categories
unique_categories = set(norobots_ds['train']['category'])
print(f"Number of unique categories: {len(unique_categories)}")
print("\nDistinct categories:")
for category in sorted(unique_categories):
    print(f"  - {category}")

Number of unique categories: 10

Distinct categories:
  - Brainstorm
  - Chat
  - Classify
  - Closed QA
  - Coding
  - Extract
  - Generation
  - Open QA
  - Rewrite
  - Summarize


In [None]:
from typing import Dict, Callable, Optional, Tuple
from datasets import Dataset


def format_messages_standard(messages, tokenizer=None, user_token="<utilizator>", 
                             assistant_token="<asistent>", system_token="<sistem>"):
    """
    Standard formatting for datasets with 'messages' column.
    Used for: OpenLLM-Ro/ro_sft_norobots
    """
    if tokenizer is not None:
        bos_token = tokenizer.bos_token or "<s>"
        eos_token = tokenizer.eos_token or "</s>"
    else:
        bos_token = ""
        eos_token = ""
    
    role_token_map = {
        'system': system_token,
        'user': user_token,
        'assistant': assistant_token
    }
    
    # Only add newline after BOS if BOS exists
    formatted_text = bos_token + ("\n" if bos_token else "")
    
    for message in messages:
        role = message['role']
        content = message['content'].strip()
        
        role_token = role_token_map.get(role, f"<{role}>")
        formatted_text += f"{role_token}\n{content}\n"
    
    # Only add EOS if it exists
    if eos_token:
        formatted_text += eos_token
    
    return formatted_text

def format_dolly_context_instruction(data_dict, tokenizer=None, user_token="<utilizator>", 
                                     assistant_token="<asistent>", system_token="<sistem>",
                                     context_separator="\n\nContext: "):
    """
    Formatting for OpenLLM-Ro/ro_sft_dolly dataset.
    
    Args:
        data_dict: Dict with 'context', 'instruction', and 'response' keys
        tokenizer: Optional tokenizer for special tokens
        user_token: Token for user messages
        assistant_token: Token for assistant messages
        system_token: Token for system messages (unused here)
        context_separator: How to separate context from instruction
    """
    if tokenizer is not None:
        bos_token = tokenizer.bos_token or "<s>"
        eos_token = tokenizer.eos_token or "</s>"
    else:
        bos_token = ""
        eos_token = ""
    
    # Only add newline after BOS if BOS exists
    formatted_text = bos_token + ("\n" if bos_token else "")
    
    # Build user message
    user_message = ""
    
    context = data_dict.get('context', '').strip()
    instruction = data_dict.get('instruction', '').strip()
    response = data_dict.get('response', '').strip()
    
    # Format: instruction first, then context if present
    if instruction:
        user_message = instruction
    
    if context:
        if user_message:
            user_message += context_separator + context
        else:
            user_message = context
    
    # Add user message
    if user_message:
        formatted_text += f"{user_token}\n{user_message}\n"
    
    # Add assistant response
    if response:
        formatted_text += f"{assistant_token}\n{response}\n"
    
    # Only add EOS if it exists
    if eos_token:
        formatted_text += eos_token
    
    return formatted_text




from typing import Dict, Callable, Optional, Union, List


class DatasetFormatterRegistry:
    """Enhanced registry supporting single or multiple column processing."""
    
    def __init__(self):
        # Store (formatter_function, columns_to_process) tuples
        self._formatters: Dict[str, Tuple[Callable, Union[str, List[str]]]] = {}
    
    def register(self, dataset_name: str, formatter: Callable, 
                 columns: Union[str, List[str]] = "messages"):
        """
        Register a formatter for a specific dataset.
        
        Args:
            dataset_name: Identifier for the dataset
            formatter: Function that formats the data
            columns: Column name(s) to process (str or list of str)
        """
        self._formatters[dataset_name] = (formatter, columns)
        col_display = columns if isinstance(columns, str) else ", ".join(columns)
        print(f"✅ Registered formatter for '{dataset_name}' (columns: {col_display})")
    
    def get_formatter(self, dataset_name: str) -> Optional[Tuple[Callable, Union[str, List[str]]]]:
        """Get the formatter and column name(s) for a dataset."""
        return self._formatters.get(dataset_name)
    
    def list_datasets(self):
        """List all registered datasets with their columns."""
        result = []
        for name, (_, cols) in self._formatters.items():
            col_display = cols if isinstance(cols, str) else ", ".join(cols)
            result.append((name, col_display))
        return result
    
    def format_dataset(self, dataset, dataset_name: str, 
                      tokenizer=None, user_token="<utilizator>", 
                      assistant_token="<asistent>", system_token="<sistem>", 
                      num_proc=1):
        """Format a dataset using the registered formatter."""
        formatter_info = self.get_formatter(dataset_name)
        
        if formatter_info is None:
            available = self.list_datasets()
            raise ValueError(
                f"No formatter found for '{dataset_name}'. "
                f"Available: {[name for name, _ in available]}"
            )
        
        formatter_func, columns = formatter_info
        
        # Handle single column or multiple columns
        if isinstance(columns, str):
            columns = [columns]
        
        # Verify all columns exist
        for col in columns:
            if col not in dataset.column_names:
                raise ValueError(
                    f"Column '{col}' not found in dataset. "
                    f"Available columns: {dataset.column_names}"
                )
        
        def format_example(example):
            try:
                # Extract data from specified column(s)
                if len(columns) == 1:
                    data_to_format = example[columns[0]]
                else:
                    # For multiple columns, pass a dict
                    data_to_format = {col: example[col] for col in columns}
                
                # Format using the registered function
                example['formatted_text'] = formatter_func(
                    data_to_format,
                    tokenizer=tokenizer,
                    user_token=user_token,
                    assistant_token=assistant_token,
                    system_token=system_token
                )
            except Exception as e:
                print(f"⚠️  Error formatting example: {e}")
                example['formatted_text'] = ""
            
            return example
        
        return dataset.map(
            format_example,
            num_proc=num_proc,
            desc=f"Formatting {dataset_name}"
        )


# Create global registry
formatter_registry = DatasetFormatterRegistry()

# Register formatters
formatter_registry.register(
    "OpenLLM-Ro/ro_sft_norobots", 
    format_messages_standard, 
    columns="messages"
)

formatter_registry.register(
    "OpenLLM-Ro/ro_sft_dolly",
    format_dolly_context_instruction,
    columns = ["context","instruction","response"]
)

# # Example: Dataset with multiple columns
# formatter_registry.register(
#     "multi_column_dataset",
#     format_instruction_response,
#     columns=["instruction", "response", "system"]
# )

# Usage
formatted_no_robots_ds = formatter_registry.format_dataset(
    norobots_ds['train'],
    dataset_name="OpenLLM-Ro/ro_sft_norobots"
)

# Test with a few examples
formatted_dolly_ds = formatter_registry.format_dataset(
    dolly_ds['train'],
    dataset_name="OpenLLM-Ro/ro_sft_dolly",
    num_proc=4
)

print("\nFirst formatted example:")
print(formatted_dolly_ds[0]['formatted_text'])

print("\n" + "="*70)
print("\nRegistered datasets:")
for name, cols in formatter_registry.list_datasets():
    print(f"  - {name} (columns: {cols})")

✅ Registered formatter for 'OpenLLM-Ro/ro_sft_norobots' (columns: messages)
✅ Registered formatter for 'OpenLLM-Ro/ro_sft_dolly' (columns: context, instruction, response)

First formatted example:
<s>
<utilizator>
Când a început să opereze Virgin Australia?

Context: Virgin Australia, numele comercial al Virgin Australia Airlines Pty Ltd, este o companie aeriană cu sediul în Australia. Este cea mai mare companie aeriană după mărimea flotei care a folosit brandul Virgin. Compania și-a început activitatea la 31 august 2000 ca Virgin Blue, cu două aeronave pe o singură rută. Compania s-a descoperit brusc ca o companie aeriană majoră pe piața internă australiană după prăbușirea companiei Ansett Australia în septembrie 2001. Compania aeriană a crescut între timp, deservind direct 32 de orașe din Australia, de la noduri din Brisbane, Melbourne și Sydney.
<asistent>
Virgin Australia și-a început serviciul la 31 august 2000 ca Virgin Blue, cu două aeronave pe o singură rută.
</s>


Registered 

In [12]:
from datasets import Dataset
ds1 = Dataset.load_from_disk('../data/formatted_data/dolly/')

In [13]:
from datasets import Dataset

ds2 = Dataset.load_from_disk('../data/formatted_data/norobots/')

In [14]:
print(ds1['formatted_text'][0])

<s>
<utilizator>
Când a început să opereze Virgin Australia?

Context: Virgin Australia, numele comercial al Virgin Australia Airlines Pty Ltd, este o companie aeriană cu sediul în Australia. Este cea mai mare companie aeriană după mărimea flotei care a folosit brandul Virgin. Compania și-a început activitatea la 31 august 2000 ca Virgin Blue, cu două aeronave pe o singură rută. Compania s-a descoperit brusc ca o companie aeriană majoră pe piața internă australiană după prăbușirea companiei Ansett Australia în septembrie 2001. Compania aeriană a crescut între timp, deservind direct 32 de orașe din Australia, de la noduri din Brisbane, Melbourne și Sydney.
<asistent>
Virgin Australia și-a început serviciul la 31 august 2000 ca Virgin Blue, cu două aeronave pe o singură rută.
</s>


In [11]:
print(ds1['formatted_text'][0])

<s>
<utilizator>
Când a început să opereze Virgin Australia?

Context: Virgin Australia, numele comercial al Virgin Australia Airlines Pty Ltd, este o companie aeriană cu sediul în Australia. Este cea mai mare companie aeriană după mărimea flotei care a folosit brandul Virgin. Compania și-a început activitatea la 31 august 2000 ca Virgin Blue, cu două aeronave pe o singură rută. Compania s-a descoperit brusc ca o companie aeriană majoră pe piața internă australiană după prăbușirea companiei Ansett Australia în septembrie 2001. Compania aeriană a crescut între timp, deservind direct 32 de orașe din Australia, de la noduri din Brisbane, Melbourne și Sydney.
<asistent>
Virgin Australia și-a început serviciul la 31 august 2000 ca Virgin Blue, cu două aeronave pe o singură rută.
</s>
