In [None]:
MODEL = "LLama-3-8B-32k"
yaml_config = f"""
base_model: meta-llama/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
  - path: togethercomputer/RedPajama-Data-1T-Sample
    type: completion
    split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.001
output_dir: ./llama-3-32k
save_safetensors: true
sequence_len: 8192
sample_packing: false
pad_to_sequence_len: false
use_pose: true
pose_max_context_len: 65536

overrides_of_model_config:
  rope_theta: 500000.0
  max_position_embeddings: 65536

  # peft_use_dora: true
adapter: lora
peft_use_rslora: true
lora_model_dir:
lora_r: 256
lora_alpha: 256
lora_dropout: 0.1
lora_target_modules:
  - q_proj
  - k_proj
  - v_proj
  - o_proj

wandb_project: llama-3-64k
wandb_entity: zaiinn440
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.00003

train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
sdp_attention:
s2_attention:

warmup_steps: 10
evals_per_epoch: 8
saves_per_epoch: 8
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  pad_token: <|end_of_text|>
"""

In [None]:
try:
  import runpod
except ImportError:
  !pip install -qqq runpod --progress-bar off

import runpod
from runpod.error import QueryError
from google.colab import userdata, runtime
import requests
import json
import yaml
import requests
import time

def upload_gist(text, gist_name, gh_token, gist_description=""):
    gist_content = {
        "description": gist_description,
        "public": False,
        "files": {
            f"{gist_name}": {
                "content": text
            }
        },
    }

    # Headers for the request
    headers = {
        "Authorization": f"token {gh_token}",
        "Accept": "application/vnd.github.v3+json",
    }

    # Make the request
    response = requests.post(
        "https://api.github.com/gists", headers=headers, json=gist_content
    )

    if response.status_code == 201:
        gist_data = response.json()
        raw_url = gist_data['files'][gist_name]['raw_url']
        print(f"Uploaded Axolotl config as gist: {raw_url}")
        return raw_url
    else:
        print(
            f"Failed to upload gist. Status code: {response.status_code}. Response: {response.text}"
        )
        return None

def delete_gist(gh_token, gist_id):
    # Headers for the request
    headers = {
        "Authorization": f"token {gh_token}",
        "Accept": "application/vnd.github.v3+json",
    }

    # Make the request
    response = requests.delete(
        f"https://api.github.com/gists/{gist_id}", headers=headers
    )

    if response.status_code == 204:
        print("Gist has been deleted.")
        return True
    else:
        print(
            f"Failed to delete gist. Status code: {response.status_code}. Response: {response.text}"
        )
        return False


GPU = "NVIDIA RTX 6000 Ada Generation"
NUMBER_OF_GPUS = 8
CONTAINER_DISK = 300
CLOUD_TYPE = "SECURE"
SCRIPT = "https://gist.githubusercontent.com/mlabonne/160015f39c05b92cea83a57d93d552fe/raw"
ZERO = "None"
LLM_AUTOEVAL = False
DEBUG = False




USERNAME = ""
RUNPOD_TOKEN = ""
HUGGING_FACE_TOKEN = "HF_TOKEN"
WANDB_TOKEN = "wandb"
GITHUB_TOKEN = "github"


runpod.api_key = RUNPOD_TOKEN
WANDB_API_KEY = WANDB_TOKEN
HF_TOKEN = HUGGING_FACE_TOKEN
GITHUB_API_TOKEN = GITHUB_TOKEN

# Make sure it's a valid YAML file
config = yaml.safe_load(yaml_config)

# Upload the YAML file to GitHub
gist_url = upload_gist(yaml_config, "config.yaml", GITHUB_API_TOKEN, f"{MODEL} - https://huggingface.co/{USERNAME}/{MODEL}")

# Summary
base_model = config.get('base_model', 'Unknown model')
dataset_info = []
datasets = config.get('datasets', [])
for dataset in datasets:
    path = dataset.get('path', 'Unknown path')
    dtype = dataset.get('type', 'Unknown type')
    dataset_info.append(f"{path} ({dtype})")
datasets_summary = ', '.join(dataset_info)

# Create a pod
keep_trying = True
try:
    while keep_trying:
        try:
            pod = runpod.create_pod(
                name=f"LazyAxolotl - {MODEL}",
                image_name="winglian/axolotl-cloud:main-latest",
                gpu_type_id=GPU,
                cloud_type=CLOUD_TYPE,
                gpu_count=NUMBER_OF_GPUS,
                volume_in_gb=0,
                container_disk_in_gb=CONTAINER_DISK,
                template_id="eul6o46pab",
                env={
                    "HF_TOKEN": HF_TOKEN,
                    "SCRIPT": SCRIPT,
                    "WANDB_API_KEY": WANDB_API_KEY,
                    "GIST_URL": gist_url,
                    "MODEL_NAME": MODEL,
                    "BASE_MODEL": config['base_model'],
                    "USERNAME": USERNAME,
                    "ZERO": ZERO,
                    "LLM_AUTOEVAL": LLM_AUTOEVAL,
                    "MODEL_ID": USERNAME + "/" + MODEL,
                    "BENCHMARK": "nous",
                    "REPO": "https://github.com/mlabonne/llm-autoeval.git",
                    "TRUST_REMOTE_CODE": True,
                    "GITHUB_API_TOKEN": GITHUB_API_TOKEN,
                    "DEBUG": DEBUG,
                }
            )
            print(f"This runs trains {base_model} on {datasets_summary}.")
            print("https://www.runpod.io/console/pods")
            keep_trying = False
        except QueryError as e:
            print(f"\033[31m⚠️ ERROR: The requested pod ({NUMBER_OF_GPUS}x {GPU} with {CONTAINER_DISK} GB) is not currently available. Trying again in 30 seconds...\033[39m")
            time.sleep(30)
except KeyboardInterrupt:
    print("KeyboardInterrupt detected, cleaning up before stopping...")
    delete_gist(GITHUB_API_TOKEN, gist_url.split('/')[4])