In [111]:
def train_func():
    import os
    import logging
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        TrainingArguments,
        DataCollatorForLanguageModeling,
        Trainer,
    )
    from datasets import load_dataset
    from datasets.distributed import split_dataset_by_node
    from peft import LoraConfig, get_peft_model

    log_formatter = logging.Formatter(
        "%(asctime)s %(levelname)-8s %(message)s", "%Y-%m-%dT%H:%M:%SZ"
    )
    logger = logging.getLogger(__file__)
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(log_formatter)
    logger.addHandler(console_handler)
    logger.setLevel(logging.INFO)

    model_name = "meta-llama/Llama-3.2-1B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=model_name,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path=model_name,
    )
    tokenizer.pad_token = tokenizer.eos_token

    # Freeze model parameters
#    for param in model.parameters():
#        param.requires_grad = False

    # Inspired by https://medium.com/@alexandros_chariton/how-to-fine-tune-llama-3-2-instruct-on-your-own-data-a-detailed-guide-e5f522f397d7
    def format_dataset(example):
        messages = [
            {"role": "user", "content": example['question']},
            {"role": "assistant", "content": example['answer']}
        ]
        prompt = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return {"prompt": prompt}

    def tokenize_dataset(example):
        tokens = tokenizer(example['prompt'], padding="max_length")
        # Set padding token labels to -100 to ignore them in loss calculation
        tokens['labels'] = [
            -100 if token == tokenizer.pad_token_id else token for token in tokens['input_ids']
        ]
        return tokens

    dataset = load_dataset("openai/gsm8k", "main")
    train_data = dataset["train"].map(format_dataset)
    eval_data = dataset["test"].map(format_dataset)
    train_data = train_data.map(tokenize_dataset, remove_columns=['question', 'answer', 'prompt'])
    eval_data = eval_data.map(tokenize_dataset, remove_columns=['question', 'answer', 'prompt'])

#    lora_config = LoraConfig(r=4, lora_alpha=16, lora_dropout=0.1, bias="none")
#    model.enable_input_require_grads()
#    model = get_peft_model(model, lora_config)

    trainer = Trainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=eval_data,
        args=TrainingArguments(output_dir="/tmp",
                               per_device_train_batch_size=1,
                               per_device_eval_batch_size=1,
                               num_train_epochs=8,
                               logging_dir="/logs",
                               eval_strategy="epoch",
                               save_strategy="no"),
    )

    trainer.data_collator = DataCollatorForLanguageModeling(
        tokenizer,
        pad_to_multiple_of=8,
        mlm=False,
    )

    # Train and save the model.
    trainer.train()
    trainer.save_model()
    logger.info("parallel_mode: '{0}'".format(trainer.args.parallel_mode))
    logger.info("is_model_parallel: '{0}'".format(trainer.is_model_parallel))
    logger.info("model_wrapped: '{0}'".format(trainer.model_wrapped))

In [112]:
from kubeflow.training import TrainingClient

In [113]:
from kubernetes.client import (
    V1EnvVar,
    V1EnvVarSource,
    V1SecretKeySelector
)

TrainingClient().create_job(
    job_kind="PyTorchJob",
    name="pytorch-ddp",
    train_func=train_func,
    num_workers=1,
    num_procs_per_worker="auto",
    resources_per_worker={"gpu": 4},
    base_image="quay.io/modh/training:py311-cuda121-torch241",
    env_vars=[
        V1EnvVar(name="HF_TOKEN", value_from=V1EnvVarSource(secret_key_ref=V1SecretKeySelector(key="HF_TOKEN", name="hf-token"))),
        V1EnvVar(name="NCCL_DEBUG", value="INFO"),
#        V1EnvVar(name="TOKENIZERS_PARALLELISM", value="false"),
    ],
)
