# Quick Start

This is a quick start guide for starting training the SSA RL version. Because of the VRAM limitation, we use the Qwen2.5-0.5B model. It would still require 40GB A100 on colab. In addition, for the demo purpose, we set the data shuffle to False so it can just run some steps for GSM8K, which is much shorter than MATH solutions. Please set it back to True during actual training! And in that case you might need a GPU with more VRAM (probbaly 48GB+). Because the MATH solutions are longer, it increases the VRAM usage that 40GB A100 will have OOM.

## Install dependencies

### Step 1: Install dependencies
It will restart the runtime after the installtion on the colab. So please run the step 2 after the installation is done.

In [None]:
try:
    import google.colab
    IN_COLAB = True
    print(f"Running in Google Colab: {IN_COLAB}")
    print("Installing repo")
    !git clone https://github.com/user074/ssa.git
    %cd ssa-main

    # Read and clean the requirements file
    with open('requirements_colab.txt', 'r') as f:
        lines = f.readlines()

    # Filter out conda-specific paths and keep only standard package specs
    clean_lines = []
    for line in lines:
        line = line.strip()
        if not line.startswith('#') and not '/home/conda/' in line and not 'file://' in line:
            # Extract just the package name and version if it's a standard format
            if '==' in line:
                clean_lines.append(line.split()[0])  # Take just the package==version part

    # Create a new clean requirements file
    with open('requirements_clean.txt', 'w') as f:
        f.write('\n'.join(clean_lines))

    # Now install from the clean file
    !pip install -r requirements_clean.txt

except:
    IN_COLAB = False
    print(f"Running in Google Colab: {IN_COLAB}")
    !conda env create -f environment.yml
    !conda activate SSA


### Step 2: Continue the installation
It will restart the runtime. Then run the following code:

In [None]:
try:
    import google.colab
    %cd ssa-main
except:
    pass
!git clone https://github.com/openai/prm800k
%cd torchtune
!pip install -e .
%cd ..

## Download the model

For the demo purpose, we use the Qwen2.5-0.5B model. You can download it from huggingface.

In [None]:
!pip install "huggingface_hub[hf_transfer]"
!HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download Qwen/Qwen2.5-0.5B --local-dir model/Qwen2.5-0.5B

## Train

For the dataset, we will use our existing cleaned dataset from huggingface, which is `user074/concat_cleaned_gsm8k_math_5`. The dataset is already cleaned and ready to use. We prepared a config file for the training.

In [None]:
#First login to wandb
import wandb
wandb.login()

Here is an example of the demo results from wandb log. We can see the success rate goes up very quickly even with only 200 steps.

![alt text](figures/demo-results.png "Demo 200 steps")


Start training now! You can see the log of each step below.

In [None]:
!tune run --nproc_per_node 1 dev/grpo_full_finetune_distributed --config ./05B_rl_SSA_qwen.yaml