In [None]:
! pip install tokenizers datasets transformers[torch] huggingface_hub

In [57]:
from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict
from tqdm import tqdm

In [59]:
def concatenate_text_columns(examples, text_columns):
    """
    Concatenate multiple text columns into a single 'text' column.
    
    Parameters:
    - examples (dict): A batch of examples from the dataset.
    - text_columns (list of str): List of column names to concatenate.
    
    Returns:
    - dict: Updated batch of examples with a concatenated 'text' column.
    """
    concatenated_text = [". ".join([examples[col] for col in text_columns]) for _ in
                         range(len(examples[text_columns[0]]))]
    return {"text": concatenated_text}


def load_and_concatenate_datasets(dataset_args_list):
    """
    Loads multiple datasets based on a list of argument dictionaries, concatenates them into a single dataset.
    
    The 'text_column' in each dictionary can be a list of columns to concatenate or a function to apply to the dataset.
    
    Parameters:
    - dataset_args_list (list of dict): Each dict contains arguments for the `load_dataset` function.
    
    Returns:
    - datasets.Dataset: A concatenated dataset with 'text' and 'source_dataset' columns.
    """
    concatenated_datasets = None
    for dataset_args in tqdm(dataset_args_list):
        # Load the dataset with provided arguments.
        dataset = load_dataset(**{k: v for k, v in dataset_args.items() if k != 'text_column'})

        # Handle the 'text_column' specification.
        text_column = dataset_args.get('text_column', 'text')

        if isinstance(text_column, list):
            # If 'text_column' is a list, concatenate the specified columns into a new 'text' column.
            dataset = dataset.map(lambda examples: concatenate_text_columns(examples, text_column), batched=True)
        elif callable(text_column):
            # If 'text_column' is a function, apply it to the dataset.
            dataset = dataset.map(text_column, batched=True)
        elif text_column != 'text':
            dataset = dataset.rename_column(text_column, 'text')

        # Normalize the dataset by removing the split dimension
        if isinstance(dataset, DatasetDict):
            dataset = Dataset.from_dict(dataset[dataset_args.get('split')])

        # Ensure the 'text' column exists after processing.
        if 'text' not in dataset.column_names:
            raise ValueError("The dataset must have a 'text' column after processing 'text_column'.")
        
        dataset = dataset.select_columns(['text'])
        
        # Add a 'source_dataset' column.
        dataset_name = dataset_args.get('path', 'unknown_dataset')
        dataset = dataset.map(lambda examples: {'source_dataset': [dataset_name] * len(examples['text'])}, batched=True)

        # Concatenate with the previously loaded datasets.
        if concatenated_datasets is None:
            concatenated_datasets = dataset
        else:
            concatenated_datasets = concatenate_datasets([concatenated_datasets, dataset])

    return concatenated_datasets

In [30]:
def belebele_text(examples):
    # Iterate through each example in the batch and concatenate the strings as desired
    texts = [
        f"{flores_passage}\nƝininkali: {question}\njaabi 1 nan: {mc_answer1}\njaabi 2 nan: {mc_answer2}\njaabi 3 nan: {mc_answer3}\njaabi 4 nan: {mc_answer4}" 
        for flores_passage, question, mc_answer1, mc_answer2, mc_answer3, mc_answer4 
        in zip(
            examples['flores_passage'], examples['question'], examples['mc_answer1'], 
            examples['mc_answer2'], examples['mc_answer3'], examples['mc_answer4']
        )
    ]
    return {'text': texts}

In [ ]:
dataset_list = [
    {'path': 'oza75/bambara-tts', 'text_column': 'bambara', 'split': 'train'},
    {'path': 'sil-ai/bloom-speech', 'name': 'bam', 'text_column': 'text', 'split': 'train', 'use_auth_token': True},
    {'path': 'sil-ai/bloom-speech', 'name': 'bam', 'text_column': 'text', 'split': 'validation',
     'use_auth_token': True},
    {'path': 'sil-ai/bloom-speech', 'name': 'bam', 'text_column': 'text', 'split': 'test', 'use_auth_token': True},
    {'path': 'wikimedia/wikipedia', 'name': '20231101.bm', 'text_column': 'text', 'split': 'train'},
    {'path': 'facebook/belebele', 'text_column': belebele_text, 'split': 'bam_Latn'},
    {'path': 'bigscience/xP3all', 'name': 'bm', 'text_column': 'targets', 'split': 'train'},
    {'path': 'sil-ai/bloom-captioning', 'name': 'bam', 'text_column': 'caption', 'split': 'train', 'use_auth_token': True, 'download_mode':'force_redownload'},
    {'path': 'sil-ai/bloom-captioning', 'name': 'bam', 'text_column': 'caption', 'split': 'validation', 'use_auth_token': True, 'download_mode':'force_redownload'},
    {'path': 'sil-ai/bloom-captioning', 'name': 'bam', 'text_column': 'caption', 'split': 'test', 'use_auth_token': True, 'download_mode':'force_redownload'},
]

bambara_ds = load_and_concatenate_datasets(dataset_list)

In [ ]:
my_dataset_list = [
    {'path': 'oza75/mt-fr-bm-texts', 'name': 'main', 'text_column': 'bambara', 'split': 'train'},
    {'path': 'oza75/mt-fr-bm-texts', 'name': 'transcriptions', 'text_column': 'bambara', 'split': 'train'},
    {'path': 'oza75/mt-fr-bm-texts', 'name': 'dictionnary', 'text_column': 'bambara', 'split': 'train'},
]

my_bambara_ds = load_and_concatenate_datasets(my_dataset_list)