To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News

Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).

Read our **[Qwen3 Guide](https://docs.unsloth.ai/basics/qwen3-how-to-run-and-fine-tune)** and check out our new **[Dynamic 2.0](https://docs.unsloth.ai/basics/unsloth-dynamic-2.0-ggufs)** quants which outperforms other quantization methods!

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install --no-deps unsloth

### Unsloth

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Meta-Llama-3.1-8B-bnb-4bit",      # Llama-3.1 15 trillion tokens model 2x faster!
    "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    "unsloth/Meta-Llama-3.1-70B-bnb-4bit",
    "unsloth/Meta-Llama-3.1-405B-bnb-4bit",    # We also uploaded 4bit for 405b!
    "unsloth/Mistral-Nemo-Base-2407-bnb-4bit", # New Mistral 12b 2x faster!
    "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit",
    "unsloth/mistral-7b-v0.3-bnb-4bit",        # Mistral v3 2x faster!
    "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
    "unsloth/Phi-3.5-mini-instruct",           # Phi-3.5 2x faster!
    "unsloth/Phi-3-medium-4k-instruct",
    "unsloth/gemma-2-9b-bnb-4bit",
    "unsloth/gemma-2-27b-bnb-4bit",            # Gemma 2x faster!
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.5.9: Fast Llama patching. Transformers: 4.52.2.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.96G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/235 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

We now add LoRA adapters so we only need to update 1 to 10% of all parameters!

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth 2025.5.9 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


<a name="Data"></a>
### Data Prep
We now use the Alpaca dataset from [yahma](https://huggingface.co/datasets/yahma/alpaca-cleaned), which is a filtered version of 52K of the original [Alpaca dataset](https://crfm.stanford.edu/2023/03/13/alpaca.html). You can replace this code section with your own data prep.

**[NOTE]** To train only on completions (ignoring the user's input) read TRL's docs [here](https://huggingface.co/docs/trl/sft_trainer#train-on-completions-only).

**[NOTE]** Remember to add the **EOS_TOKEN** to the tokenized output!! Otherwise you'll get infinite generations!

If you want to use the `llama-3` template for ShareGPT datasets, try our conversational [notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Alpaca.ipynb)

For text completions like novel writing, try this [notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_(7B)-Text_Completion.ipynb).

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### Check our Pickle Files

In [None]:
import pickle
import pandas as pd
import os
from pathlib import Path

# Set the path to your pickle files
pkl_path = "/content/drive/MyDrive/488-data/large"

# Get all pickle files in the directory
pkl_files = [f for f in os.listdir(pkl_path) if f.endswith('.pkl')]

print(f"Found {len(pkl_files)} pickle files:")
print("-" * 50)

# Loop through each pickle file and examine its structure
for i, filename in enumerate(pkl_files, 1):
    file_path = os.path.join(pkl_path, filename)

    print(f"\n{i}. File: {filename}")
    print("=" * 40)

    try:
        # Load the pickle file
        with open(file_path, 'rb') as f:
            data = pickle.load(f)

        # Check the type of data
        print(f"Data type: {type(data)}")

        # If it's a DataFrame, show column info
        if isinstance(data, pd.DataFrame):
            print(f"Shape: {data.shape}")
            print(f"Columns ({len(data.columns)}):")
            for col in data.columns:
                print(f"  - {col}")
            print(f"\nData types:")
            print(data.dtypes)
            print(f"\nFirst few rows:")
            print(data.head(3))

        # If it's a dictionary, show keys
        elif isinstance(data, dict):
            print(f"Dictionary with {len(data)} keys:")
            for key in list(data.keys())[:10]:  # Show first 10 keys
                print(f"  - {key}: {type(data[key])}")
            if len(data) > 10:
                print(f"  ... and {len(data) - 10} more keys")

        # If it's a list, show structure
        elif isinstance(data, list):
            print(f"List with {len(data)} items")
            if len(data) > 0:
                print(f"First item type: {type(data[0])}")
                if hasattr(data[0], 'shape'):
                    print(f"First item shape: {data[0].shape}")

        # For other types, show basic info
        else:
            print(f"Data structure: {data}")
            if hasattr(data, 'shape'):
                print(f"Shape: {data.shape}")
            if hasattr(data, '__len__'):
                print(f"Length: {len(data)}")

    except Exception as e:
        print(f"Error loading {filename}: {str(e)}")

    print("-" * 40)

Found 6 pickle files:
--------------------------------------------------

1. File: Reddit_entertainment_original.pkl
Data type: <class 'pandas.core.frame.DataFrame'>
Shape: (5384, 8)
Columns (8):
  - post_title
  - post_body
  - url
  - top_5_comments
  - subreddit
  - category
  - score
  - num_comments

Data types:
post_title        object
post_body         object
url               object
top_5_comments    object
subreddit         object
category          object
score              int64
num_comments       int64
dtype: object

First few rows:
                                          post_title post_body  \
0  David Geffen's Estranged Husband, 32, Requests...             
1  Bono Cheekily Weighs in On Springsteen Vs. Tru...             
2  Jimmy Kimmel on Trump: ‘Celebrated the first t...             

                                                 url  \
0  https://people.com/david-geffen-estranged-husb...   
1  https://www.billboard.com/music/rock/bono-talk...   
2  https://www.th

### Format our data for training

In [None]:
import pickle
import pandas as pd
import os
from sklearn.utils import shuffle
from datasets import Dataset
import ast
import numpy as np

# Set the path to your pickle files
pkl_path = "/content/drive/MyDrive/488-data/large"

# Load and combine all pickle files with 5000 row limit per file
all_dataframes = []
pkl_files = [f for f in os.listdir(pkl_path) if f.endswith('.pkl')]
MAX_ROWS_PER_FILE = 5000

print("Loading pickle files...")
for filename in pkl_files:
    file_path = os.path.join(pkl_path, filename)
    with open(file_path, 'rb') as f:
        df = pickle.load(f)

        # Limit to 5000 rows per file
        original_rows = len(df)
        if len(df) > MAX_ROWS_PER_FILE:
            # Shuffle before taking the first 5000 to get random sampling
            df = shuffle(df, random_state=42).reset_index(drop=True)
            df = df.head(MAX_ROWS_PER_FILE)
            print(f"Loaded {filename}: {original_rows} rows -> limited to {len(df)} rows")
        else:
            print(f"Loaded {filename}: {len(df)} rows (no limit needed)")

        all_dataframes.append(df)

# Combine all dataframes
combined_df = pd.concat(all_dataframes, ignore_index=True)
print(f"\nTotal combined rows: {len(combined_df)}")

# Shuffle the combined data
combined_df = shuffle(combined_df, random_state=42).reset_index(drop=True)
print("Data shuffled successfully")

# Function to clean and format comments (handle string representations of lists)
def format_comments(comments):
    # Handle None values
    if comments is None:
        return "No comments available."

    # If it's already a list or array
    if isinstance(comments, (list, tuple, np.ndarray)):
        try:
            # Convert to list and filter out empty/None values
            comments_list = list(comments)
            clean_comments = []
            for c in comments_list:
                if c is not None and str(c).strip() and str(c).strip().lower() not in ['', 'nan', 'none']:
                    clean_comments.append(str(c).strip())

            if not clean_comments:
                return "No comments available."
            return "\n".join([f"Comment {i+1}: {comment}" for i, comment in enumerate(clean_comments[:5])])
        except:
            return "No comments available."

    # If it's a string
    if isinstance(comments, str):
        # Check if it's a NaN string
        if comments.strip().lower() in ['nan', 'none', '']:
            return "No comments available."

        # Try to parse as list
        try:
            if comments.startswith('[') and comments.endswith(']'):
                comments_list = ast.literal_eval(comments)
                clean_comments = []
                for c in comments_list:
                    if c is not None and str(c).strip() and str(c).strip().lower() not in ['', 'nan', 'none']:
                        clean_comments.append(str(c).strip())

                if not clean_comments:
                    return "No comments available."
                return "\n".join([f"Comment {i+1}: {comment}" for i, comment in enumerate(clean_comments[:5])])
        except:
            pass

        # If it's just a regular string, treat as single comment
        if comments.strip():
            return f"Comment 1: {comments.strip()}"

    # For any other type (including pandas NaN)
    try:
        if pd.isna(comments):
            return "No comments available."
    except:
        pass

    return "No comments available."

# Handle different data structures between files
def standardize_dataframe(df):
    """Standardize dataframes to have consistent columns"""

    # Check if this is the health dataset (has 'text' column instead of post_title/post_body)
    if 'text' in df.columns and 'post_title' not in df.columns:
        # This is the health dataset - convert it to standard format
        standardized_df = pd.DataFrame()

        # Filter only posts (not comments)
        posts_df = df[df['type'] == 'post'].copy() if 'type' in df.columns else df.copy()

        standardized_df['post_title'] = posts_df['text'].str[:100] + '...'  # Use first 100 chars as title
        standardized_df['post_body'] = posts_df['text']
        standardized_df['url'] = posts_df['url'] if 'url' in posts_df.columns else ''
        standardized_df['top_5_comments'] = 'No comments available.'  # Health data doesn't have comment structure
        standardized_df['subreddit'] = posts_df['subreddit'] if 'subreddit' in posts_df.columns else ''
        standardized_df['category'] = posts_df['category'] if 'category' in posts_df.columns else 'Health'
        standardized_df['score'] = posts_df['score'] if 'score' in posts_df.columns else 0
        standardized_df['num_comments'] = 0  # Health data doesn't have this info

        return standardized_df
    else:
        # Standard format - ensure all required columns exist
        required_columns = ['post_title', 'post_body', 'url', 'top_5_comments', 'subreddit', 'category', 'score', 'num_comments']
        for col in required_columns:
            if col not in df.columns:
                df[col] = '' if col in ['post_title', 'post_body', 'url', 'subreddit', 'category'] else 0

        return df[required_columns]

# Standardize all dataframes
print("\nStandardizing dataframe structures...")
standardized_dataframes = []
for i, df in enumerate(all_dataframes):
    filename = pkl_files[i]
    print(f"Standardizing {filename}...")
    standardized_df = standardize_dataframe(df)
    standardized_dataframes.append(standardized_df)

# Combine standardized dataframes
combined_df = pd.concat(standardized_dataframes, ignore_index=True)
print(f"Total standardized rows: {len(combined_df)}")

# Shuffle the combined data again
combined_df = shuffle(combined_df, random_state=42).reset_index(drop=True)
print("Final data shuffled successfully")

# Create the formatted dataset
def create_alpaca_format(df):
    instructions = []
    inputs = []
    outputs = []

    for idx, row in df.iterrows():
        # Instruction (consistent task description)
        instruction = "Classify the following Reddit post into one of these categories: Comedy, Education, Health, Professional, or Travel. Base your classification on the post title, content, and top comments."

        # Input (post data)
        post_title = str(row['post_title']).strip() if pd.notna(row['post_title']) else "No title"
        post_body = str(row['post_body']).strip() if pd.notna(row['post_body']) else "No content"
        comments = format_comments(row['top_5_comments'])

        input_text = f"""Post Title: {post_title}

Post Content: {post_body}

Top Comments:
{comments}"""

        # Output (category)
        output = str(row['category']).strip() if pd.notna(row['category']) else "Unknown"

        # Standardize category names
        output = output.title()  # Convert to title case
        # Keep Entertainment as separate category - don't map to Comedy!

        instructions.append(instruction)
        inputs.append(input_text)
        outputs.append(output)

    return {
        'instruction': instructions,
        'input': inputs,
        'output': outputs
    }


# Create the formatted data
print("\nFormatting data for Alpaca structure...")
formatted_data = create_alpaca_format(combined_df)

# Create Hugging Face dataset
dataset = Dataset.from_dict(formatted_data)
print(f"Created dataset with {len(dataset)} examples")

# Display some statistics
print(f"\nCategory distribution:")
category_counts = combined_df['category'].value_counts()
for category, count in category_counts.items():
    print(f"  {category}: {count} ({count/len(combined_df)*100:.1f}%)")

# Show a sample
print(f"\n" + "="*80)
print("SAMPLE FORMATTED EXAMPLE:")
print("="*80)
print(f"Instruction: {formatted_data['instruction'][0]}")
print(f"\nInput: {formatted_data['input'][0][:500]}...")
print(f"\nOutput: {formatted_data['output'][0]}")

# Alpaca prompt template and formatting function
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

# Define EOS_TOKEN from your tokenizer
EOS_TOKEN = tokenizer.eos_token  # For LLaMA 3.1, this will be "<|eot_id|>"
# Uncomment the line above and make sure tokenizer is defined

def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs       = examples["input"]
    outputs      = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        #text = alpaca_prompt.format(instruction, input, output)  # Temporary without EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }

# Apply formatting
formatted_dataset = dataset.map(formatting_prompts_func, batched=True)

print(f"\nDataset ready for training!")
print(f"Final dataset size: {len(formatted_dataset)} examples")
print("\nSummary of data sources:")
for filename in pkl_files:
    print(f"  - {filename}: up to {MAX_ROWS_PER_FILE} rows")

print("\nNext steps:")
print("1. Add your tokenizer and uncomment EOS_TOKEN line")
print("2. Split into train/validation sets if needed")
print("3. Use formatted_dataset for training")

# Optional: Save the formatted dataset
# formatted_dataset.save_to_disk("/content/drive/MyDrive/reddit_classification_dataset")
# print("Dataset saved to disk!")

# Optional: Create train/validation split
# train_test_split = formatted_dataset.train_test_split(test_size=0.1, seed=42)
# train_dataset = train_test_split['train']
# eval_dataset = train_test_split['test']
# print(f"\nTrain dataset: {len(train_dataset)} examples")
# print(f"Validation dataset: {len(eval_dataset)} examples")

Loading pickle files...
Loaded Reddit_entertainment_original.pkl: 5384 rows -> limited to 5000 rows
Loaded Reddit_travel_original.pkl: 4616 rows (no limit needed)
Loaded Reddit_comedy_original.pkl: 4986 rows (no limit needed)
Loaded Reddit_education_original.pkl: 4714 rows (no limit needed)
Loaded Reddit_professional_original.pkl: 11902 rows -> limited to 5000 rows
Loaded Reddit_health_original.pkl: 1819 rows (no limit needed)

Total combined rows: 26135
Data shuffled successfully

Standardizing dataframe structures...
Standardizing Reddit_entertainment_original.pkl...
Standardizing Reddit_travel_original.pkl...
Standardizing Reddit_comedy_original.pkl...
Standardizing Reddit_education_original.pkl...
Standardizing Reddit_professional_original.pkl...
Standardizing Reddit_health_original.pkl...
Total standardized rows: 24804
Final data shuffled successfully

Formatting data for Alpaca structure...
Created dataset with 24804 examples

Category distribution:
  Entertainment: 5000 (20.2%)


Map:   0%|          | 0/24804 [00:00<?, ? examples/s]


Dataset ready for training!
Final dataset size: 24804 examples

Summary of data sources:
  - Reddit_entertainment_original.pkl: up to 5000 rows
  - Reddit_travel_original.pkl: up to 5000 rows
  - Reddit_comedy_original.pkl: up to 5000 rows
  - Reddit_education_original.pkl: up to 5000 rows
  - Reddit_professional_original.pkl: up to 5000 rows
  - Reddit_health_original.pkl: up to 5000 rows

Next steps:
1. Add your tokenizer and uncomment EOS_TOKEN line
2. Split into train/validation sets if needed
3. Use formatted_dataset for training


In [None]:
# Split the dataset into training, evaluation, and test sets
train_testvalid = formatted_dataset.train_test_split(test_size=0.2, seed=42) # 20% for test+validation
test_valid = train_testvalid['test'].train_test_split(test_size=0.5, seed=42) # Split test+validation 50/50

train_dataset = train_testvalid['train']
eval_dataset = test_valid['train']
test_dataset = test_valid['test']

print(f"Training dataset size: {len(train_dataset)}")
print(f"Evaluation dataset size: {len(eval_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

Training dataset size: 19843
Evaluation dataset size: 2480
Test dataset size: 2481


### Evaluate Data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import pandas as pd

def analyze_reddit_dataset(train_dataset, eval_dataset, test_dataset, tokenizer):
    """Comprehensive analysis of the Reddit classification dataset"""

    print("🔍 REDDIT DATASET ANALYSIS")
    print("=" * 50)

    # Basic dataset info
    print(f"📊 Dataset Sizes:")
    print(f"• Training: {len(train_dataset):,} examples")
    print(f"• Validation: {len(eval_dataset):,} examples")
    print(f"• Test: {len(test_dataset):,} examples")
    print(f"• Total: {len(train_dataset) + len(eval_dataset) + len(test_dataset):,} examples")

    # Analyze text lengths
    print(f"\n📏 TEXT LENGTH ANALYSIS:")

    def get_text_stats(dataset, name):
        texts = [example['text'] for example in dataset]
        char_lengths = [len(text) for text in texts]

        # Tokenize a sample to estimate token lengths
        sample_size = min(1000, len(texts))
        sample_texts = texts[:sample_size]
        token_lengths = []

        for text in sample_texts:
            try:
                tokens = tokenizer(text, truncation=False, add_special_tokens=True)
                token_lengths.append(len(tokens['input_ids']))
            except:
                # Fallback estimation if tokenizer fails
                token_lengths.append(len(text.split()) * 1.3)  # Rough estimate

        # Extrapolate token stats
        avg_tokens = np.mean(token_lengths)

        print(f"\n{name} Dataset:")
        print(f"  Character lengths:")
        print(f"    • Mean: {np.mean(char_lengths):.0f} chars")
        print(f"    • Median: {np.median(char_lengths):.0f} chars")
        print(f"    • Min: {np.min(char_lengths):.0f} chars")
        print(f"    • Max: {np.max(char_lengths):.0f} chars")
        print(f"    • 95th percentile: {np.percentile(char_lengths, 95):.0f} chars")

        print(f"  Estimated token lengths (from {sample_size} samples):")
        print(f"    • Mean: {avg_tokens:.0f} tokens")
        print(f"    • Median: {np.median(token_lengths):.0f} tokens")
        print(f"    • Min: {np.min(token_lengths):.0f} tokens")
        print(f"    • Max: {np.max(token_lengths):.0f} tokens")
        print(f"    • 95th percentile: {np.percentile(token_lengths, 95):.0f} tokens")

        # Check how many exceed common sequence lengths
        over_512 = sum(1 for t in token_lengths if t > 512)
        over_1024 = sum(1 for t in token_lengths if t > 1024)
        over_2048 = sum(1 for t in token_lengths if t > 2048)

        print(f"  Sequence length distribution:")
        print(f"    • >512 tokens: {over_512}/{len(token_lengths)} ({over_512/len(token_lengths)*100:.1f}%)")
        print(f"    • >1024 tokens: {over_1024}/{len(token_lengths)} ({over_1024/len(token_lengths)*100:.1f}%)")
        print(f"    • >2048 tokens: {over_2048}/{len(token_lengths)} ({over_2048/len(token_lengths)*100:.1f}%)")

        return char_lengths, token_lengths, avg_tokens

    # Analyze each split
    train_chars, train_tokens, train_avg_tokens = get_text_stats(train_dataset, "Training")
    eval_chars, eval_tokens, eval_avg_tokens = get_text_stats(eval_dataset, "Validation")

    # Category distribution analysis
    print(f"\n📈 CATEGORY DISTRIBUTION:")

    def analyze_categories(dataset, name):
        # Extract categories from the formatted text
        categories = []
        for example in dataset:
            text = example['text']
            # Extract the response/category from the formatted text
            if "### Response:" in text:
                response = text.split("### Response:")[-1].strip()
                categories.append(response)

        category_counts = Counter(categories)
        total = len(categories)

        print(f"\n{name} Categories:")
        for category, count in category_counts.most_common():
            percentage = (count / total) * 100
            print(f"  • {category}: {count:,} ({percentage:.1f}%)")

        return category_counts

    train_categories = analyze_categories(train_dataset, "Training")
    eval_categories = analyze_categories(eval_dataset, "Validation")

    # Memory and speed implications
    print(f"\n⚡ PERFORMANCE IMPLICATIONS:")

    avg_tokens_all = (train_avg_tokens + eval_avg_tokens) / 2

    print(f"• Average tokens per example: {avg_tokens_all:.0f}")
    print(f"• Current max_seq_length: 2048")
    print(f"• Padding waste: ~{((2048 - avg_tokens_all) / 2048) * 100:.1f}% per example")

    # Calculate memory usage estimates
    batch_sizes = [8, 16, 32, 64, 128]
    seq_lengths = [512, 1024, 2048]

    print(f"\n💾 MEMORY USAGE ESTIMATES (4-bit model):")
    print(f"Batch Size | 512 tokens | 1024 tokens | 2048 tokens")
    print(f"-----------|------------|-------------|------------")

    for bs in batch_sizes:
        mem_512 = bs * 512 * 4 / (1024**3) * 8  # Rough estimate in GB
        mem_1024 = bs * 1024 * 4 / (1024**3) * 8
        mem_2048 = bs * 2048 * 4 / (1024**3) * 8
        print(f"    {bs:2d}     |   {mem_512:.1f} GB    |    {mem_1024:.1f} GB    |    {mem_2048:.1f} GB")

    # Speed optimization recommendations
    print(f"\n🚀 OPTIMIZATION RECOMMENDATIONS:")

    if avg_tokens_all < 512:
        print(f"✅ MAJOR SPEEDUP AVAILABLE:")
        print(f"   • Most examples fit in 512 tokens")
        print(f"   • Reduce max_seq_length to 512 for 4x speedup")
        print(f"   • Can increase batch size significantly")
    elif avg_tokens_all < 1024:
        print(f"⚡ GOOD SPEEDUP AVAILABLE:")
        print(f"   • Most examples fit in 1024 tokens")
        print(f"   • Reduce max_seq_length to 1024 for 2x speedup")
        print(f"   • Can increase batch size moderately")
    else:
        print(f"⚠️  LONG SEQUENCES:")
        print(f"   • Many examples need >1024 tokens")
        print(f"   • Consider text truncation or A100 upgrade")

    # Dataset quality insights
    print(f"\n📋 DATASET QUALITY INSIGHTS:")

    # Check for class imbalance
    category_counts = list(train_categories.values())
    max_count = max(category_counts)
    min_count = min(category_counts)
    imbalance_ratio = max_count / min_count

    print(f"• Class imbalance ratio: {imbalance_ratio:.1f}:1")
    if imbalance_ratio > 3:
        print(f"  ⚠️  Significant class imbalance detected")
        print(f"  💡 Consider class weights or balanced sampling")
    else:
        print(f"  ✅ Good class balance")

    return {
        'avg_tokens': avg_tokens_all,
        'train_categories': train_categories,
        'eval_categories': eval_categories,
        'imbalance_ratio': imbalance_ratio
    }

# Run the analysis
print("Starting dataset analysis...")
analysis_results = analyze_reddit_dataset(train_dataset, eval_dataset, test_dataset, tokenizer)

# Additional quick tokenization test
print(f"\n🧪 TOKENIZATION SPEED TEST:")
sample_texts = [train_dataset[i]['text'] for i in range(min(10, len(train_dataset)))]

import time
start_time = time.time()
for text in sample_texts:
    tokens = tokenizer(text, max_length=2048, truncation=True, padding='max_length')
tokenization_time = time.time() - start_time

print(f"• Tokenized {len(sample_texts)} examples in {tokenization_time:.3f}s")
print(f"• Average: {tokenization_time/len(sample_texts)*1000:.1f}ms per example")

if tokenization_time/len(sample_texts) > 0.1:
    print(f"  ⚠️  Slow tokenization detected (>{0.1*1000:.0f}ms per example)")
    print(f"  💡 This could be contributing to slow training")

Starting dataset analysis...
🔍 REDDIT DATASET ANALYSIS
📊 Dataset Sizes:
• Training: 19,843 examples
• Validation: 2,480 examples
• Test: 2,481 examples
• Total: 24,804 examples

📏 TEXT LENGTH ANALYSIS:

Training Dataset:
  Character lengths:
    • Mean: 1748 chars
    • Median: 1195 chars
    • Min: 480 chars
    • Max: 41189 chars
    • 95th percentile: 4590 chars
  Estimated token lengths (from 1000 samples):
    • Mean: 384 tokens
    • Median: 259 tokens
    • Min: 93 tokens
    • Max: 4398 tokens
    • 95th percentile: 1044 tokens
  Sequence length distribution:
    • >512 tokens: 223/1000 (22.3%)
    • >1024 tokens: 51/1000 (5.1%)
    • >2048 tokens: 5/1000 (0.5%)

Validation Dataset:
  Character lengths:
    • Mean: 1744 chars
    • Median: 1245 chars
    • Min: 490 chars
    • Max: 23465 chars
    • 95th percentile: 4373 chars
  Estimated token lengths (from 1000 samples):
    • Mean: 383 tokens
    • Median: 271 tokens
    • Min: 94 tokens
    • Max: 5714 tokens
    • 95th per

### Optional download the dataset

In [None]:
import pandas as pd
from google.colab import files

# Convert datasets to pandas DataFrames and save as CSV
def download_datasets():
    """
    Convert and download train, eval, and test datasets as CSV files
    """

    print("Converting datasets to CSV format...")

    # Convert train dataset
    print(f"Processing train dataset ({len(train_dataset)} examples)...")
    train_df = pd.DataFrame(train_dataset)
    train_df.to_csv('/content/reddit_classification_train.csv', index=False)
    print("✅ Train dataset saved as 'reddit_classification_train.csv'")

    # Convert eval dataset
    print(f"Processing eval dataset ({len(eval_dataset)} examples)...")
    eval_df = pd.DataFrame(eval_dataset)
    eval_df.to_csv('/content/reddit_classification_eval.csv', index=False)
    print("✅ Eval dataset saved as 'reddit_classification_eval.csv'")

    # Convert test dataset
    print(f"Processing test dataset ({len(test_dataset)} examples)...")
    test_df = pd.DataFrame(test_dataset)
    test_df.to_csv('/content/reddit_classification_test.csv', index=False)
    print("✅ Test dataset saved as 'reddit_classification_test.csv'")

    # Show dataset info
    print("\n📊 DATASET SUMMARY:")
    print("-" * 50)
    print(f"Train set: {len(train_dataset):,} examples")
    print(f"Eval set:  {len(eval_dataset):,} examples")
    print(f"Test set:  {len(test_dataset):,} examples")
    print(f"Total:     {len(train_dataset) + len(eval_dataset) + len(test_dataset):,} examples")

    # Show column structure
    print(f"\nColumns in each CSV:")
    for col in train_df.columns:
        print(f"  - {col}")

    print("\n🔽 DOWNLOADING FILES...")
    print("Files will download to your local machine:")

    # Download the files
    try:
        files.download('/content/reddit_classification_train.csv')
        print("✅ Train dataset downloaded")
    except:
        print("❌ Train dataset download failed")

    try:
        files.download('/content/reddit_classification_eval.csv')
        print("✅ Eval dataset downloaded")
    except:
        print("❌ Eval dataset download failed")

    try:
        files.download('/content/reddit_classification_test.csv')
        print("✅ Test dataset downloaded")
    except:
        print("❌ Test dataset download failed")

    print("\n🎯 DATASET READY FOR:")
    print("✅ Further model training")
    print("✅ Sharing with team members")
    print("✅ Production deployment")
    print("✅ Academic research")
    print("✅ Model reproducibility")

    return train_df, eval_df, test_df

# Run the download
train_df, eval_df, test_df = download_datasets()

# Optional: Show a sample of each dataset
print("\n📋 SAMPLE DATA PREVIEW:")
print("=" * 60)
print("TRAIN DATASET SAMPLE:")
print(train_df[['instruction', 'output']].head(2))
print("\nEVAL DATASET SAMPLE:")
print(eval_df[['instruction', 'output']].head(2))
print("\nTEST DATASET SAMPLE:")
print(test_df[['instruction', 'output']].head(2))

Converting datasets to CSV format...
Processing train dataset (19843 examples)...
✅ Train dataset saved as 'reddit_classification_train.csv'
Processing eval dataset (2480 examples)...
✅ Eval dataset saved as 'reddit_classification_eval.csv'
Processing test dataset (2481 examples)...
✅ Test dataset saved as 'reddit_classification_test.csv'

📊 DATASET SUMMARY:
--------------------------------------------------
Train set: 19,843 examples
Eval set:  2,480 examples
Test set:  2,481 examples
Total:     24,804 examples

Columns in each CSV:
  - instruction
  - input
  - output
  - text

🔽 DOWNLOADING FILES...
Files will download to your local machine:


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✅ Train dataset downloaded


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✅ Eval dataset downloaded


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✅ Test dataset downloaded

🎯 DATASET READY FOR:
✅ Further model training
✅ Sharing with team members
✅ Production deployment
✅ Academic research
✅ Model reproducibility

📋 SAMPLE DATA PREVIEW:
TRAIN DATASET SAMPLE:
                                         instruction  output
0  Classify the following Reddit post into one of...  Travel
1  Classify the following Reddit post into one of...  Travel

EVAL DATASET SAMPLE:
                                         instruction        output
0  Classify the following Reddit post into one of...        Travel
1  Classify the following Reddit post into one of...  Professional

TEST DATASET SAMPLE:
                                         instruction  output
0  Classify the following Reddit post into one of...  Travel
1  Classify the following Reddit post into one of...  Comedy


<a name="Train"></a>
### Train the model
Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!

In [None]:
import wandb
import torch
from trl import SFTTrainer
from transformers import TrainingArguments, TrainerCallback, TrainerState, TrainerControl
from unsloth import is_bfloat16_supported

# Calculate dataset sizes dynamically
train_size = len(train_dataset)
eval_size = len(eval_dataset)
test_size = len(test_dataset) if 'test_dataset' in locals() else 0

# Optimized hyperparameters for L4 GPU + Unsloth 4-bit quantization
batch_size = 12  # VERY aggressive - 4-bit model uses much less VRAM!
grad_accum = 1   # No accumulation needed with this batch size
effective_batch_size = batch_size * grad_accum
steps_per_epoch = train_size // effective_batch_size
total_epochs = 3  # Keep at 3 - still optimal

print(f"Training dataset size: {train_size}")
print(f"Evaluation dataset size: {eval_size}")
if test_size > 0:
    print(f"Test dataset size: {test_size}")
print(f"Effective batch size: {effective_batch_size}")
print(f"Steps per epoch: {steps_per_epoch}")
print(f"Total training steps: {steps_per_epoch * total_epochs}")
print(f"Eval every {steps_per_epoch // 6} steps (6x per epoch)")

# Initialize Weights & Biases
wandb.init(
    project="reddit-classification-sft",
    name=f"unsloth-llama31-8b-bs{effective_batch_size}-lr5e5-ep{total_epochs}",
    config={
        "model_name": "unsloth/Meta-Llama-3.1-8B",
        "quantization": "4-bit",
        "framework": "unsloth",
        "dataset": "reddit-posts",
        "task": "text-classification",
        "categories": ["Comedy", "Entertainment", "Education", "Health", "Professional", "Travel"],
        "train_samples": train_size,
        "eval_samples": eval_size,
        "test_samples": test_size,
        "max_seq_length": 2048,  # From your config
        "architecture": "SFT-LoRA-4bit",
        "effective_batch_size": effective_batch_size,
        "total_epochs": total_epochs,
        "total_steps": steps_per_epoch * total_epochs,
    },
    tags=["sft", "llama-3.1", "unsloth", "4-bit", "reddit", "classification"]
)

# Add custom metrics tracking callback
class WandBCallback(TrainerCallback):
    """Custom callback to track additional metrics in W&B"""
    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called just before the training loop starts.
        """
        pass

    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the beginning of a training step.
        """
        pass

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the end of a training step.
        """
        pass

    def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the beginning of an epoch.
        """
        pass

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the end of an epoch.
        """
        pass

    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model=None, logs=None, **kwargs):
        """
        Event called after logging is performed.
        """
        if logs:
            # Log learning rate
            if "learning_rate" in logs:
                wandb.log({"learning_rate": logs["learning_rate"]}, step=state.global_step)

            # Log training metrics
            if "loss" in logs: # Use "loss" which is the actual training loss key in TRL logs
                wandb.log({"train_loss": logs["loss"]}, step=state.global_step)

            # Log evaluation metrics
            if "eval_loss" in logs:
                wandb.log({"eval_loss": logs["eval_loss"]}, step=state.global_step)

            # Calculate and log perplexity if available
            if "eval_loss" in logs:
                perplexity = torch.exp(torch.tensor(logs["eval_loss"]))
                wandb.log({"eval_perplexity": perplexity.item()}, step=state.global_step)

            # Log other metrics
            # Filter out keys that are already handled or not relevant
            ignored_keys = ["learning_rate", "loss", "eval_loss", "epoch"]
            for key, value in logs.items():
                if key not in ignored_keys and isinstance(value, (int, float)):
                    wandb.log({key: value}, step=state.global_step)

    def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called just at the end of training.
        """
        pass


trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        # UPDATED: Very aggressive batch size for 4-bit quantized model
        per_device_train_batch_size = 12,  # 4-bit uses much less VRAM!
        gradient_accumulation_steps = 1,   # No accumulation needed

        # UPDATED: Reduced epochs significantly
        num_train_epochs = 3,  # Down from 20

        # Learning rate - keep as is, it's good
        learning_rate = 5e-5,

        # Warmup - adjust for shorter training
        warmup_steps = 100,  # ~8% of total steps
        # Alternative: warmup_ratio = 0.08,

        # Scheduler
        lr_scheduler_type = "cosine",

        # Precision - use bf16 for better performance on L4
        fp16 = False,  # Disable fp16
        bf16 = True,   # Force bf16 for L4 GPU (better than fp16)

        # UPDATED: More frequent evaluation due to fewer total steps
        logging_steps = 50,
        eval_strategy = "steps",
        eval_steps = steps_per_epoch // 6,     # 6x per epoch for close monitoring

        # UPDATED: Save checkpoints at evaluation points (must be multiple of eval_steps)
        save_strategy = "steps",
        save_steps = steps_per_epoch // 6,  # Ensure save_steps is a multiple of eval_steps

        # Regularization
        weight_decay = 0.01,  # Reduced since we're training for fewer epochs

        # Optimizer - use standard AdamW with 4-bit quantized model
        optim = "adamw_torch",  # Standard AdamW works great with Unsloth
        # Note: Unsloth handles memory optimization internally

        seed = 3407,
        output_dir = "outputs",

        # W&B Integration
        report_to = "wandb",
        run_name = f"reddit-sft-bs{effective_batch_size}-lr5e5-ep{total_epochs}",

        # Performance optimizations for L4
        dataloader_num_workers = 4,        # Parallel data loading
        dataloader_pin_memory = True,      # Faster CPU->GPU transfer
        group_by_length = True,            # More efficient batching

        # Advanced settings for L4
        tf32 = True,                       # Enable TF32 for faster matmul on Ampere
        dataloader_persistent_workers = True,  # Keep workers alive between epochs

        logging_first_step = True,
        load_best_model_at_end = True,
        metric_for_best_model = "eval_loss",
        greater_is_better = False,
        save_total_limit = 2,  # Keep only 2 checkpoints since training is shorter

        # Early stopping (optional but recommended)
        # early_stopping_patience = 3,  # Uncomment if you want early stopping
    ),
)

# Add the custom callback
trainer.add_callback(WandBCallback())

print("🚀 UNSLOTH + A100 GPU OPTIMIZED Configuration:")
print(f"• Model: LLaMA-3.1-8B (4-bit quantized)")
print(f"• VRAM usage: ~8-12GB (4-bit quantization is very efficient!)")
print(f"• Batch size: {batch_size} (no gradient accumulation needed)")
print(f"• Effective batch size: {effective_batch_size}")
print(f"• Training epochs: {total_epochs}")
print(f"• Total steps: {steps_per_epoch * total_epochs}")
print(f"• Steps per epoch: {steps_per_epoch}")
print(f"• Eval frequency: Every {steps_per_epoch // 6} steps (6x per epoch)")
print(f"• Expected training time: ~1-2 hours (Unsloth is FAST!)")
print(f"• Speed boost: ~3-4x faster than standard fine-tuning")

# Start training
trainer.train()

# Final logging
wandb.log({
    "training_completed": True,
    "final_epoch": trainer.state.epoch,
    "total_steps_completed": trainer.state.global_step
})

print("Training completed!")
print(f"View results at: {wandb.run.url}")
wandb.finish()

Training dataset size: 19843
Evaluation dataset size: 2480
Test dataset size: 2481
Effective batch size: 12
Steps per epoch: 1653
Total training steps: 4959
Eval every 275 steps (6x per epoch)


Unsloth: Tokenizing ["text"]:   0%|          | 0/19843 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"]:   0%|          | 0/2480 [00:00<?, ? examples/s]

🚀 UNSLOTH + L4 GPU OPTIMIZED Configuration:
• Model: LLaMA-3.1-8B (4-bit quantized)
• VRAM usage: ~8-12GB (4-bit quantization is very efficient!)
• Batch size: 12 (no gradient accumulation needed)
• Effective batch size: 12
• Training epochs: 3
• Total steps: 4959
• Steps per epoch: 1653
• Eval frequency: Every 275 steps (6x per epoch)
• Expected training time: ~1-2 hours (Unsloth is FAST!)
• Speed boost: ~3-4x faster than standard fine-tuning


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 19,843 | Num Epochs = 3 | Total steps = 4,962
O^O/ \_/ \    Batch size per device = 12 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (12 x 1 x 1) = 12
 "-____-"     Trainable parameters = 167,772,160/8,000,000,000 (2.10% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss
275,1.6211,1.633559
550,1.583,1.618192
825,1.5957,1.605017
1100,1.5593,1.599717
1375,1.5412,1.595331
1650,1.5705,1.59189




KeyboardInterrupt: 

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA L4. Max memory = 22.161 GB.
7.654 GB of memory reserved.


In [None]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 2,129 | Num Epochs = 20 | Total steps = 5,320
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 167,772,160/8,000,000,000 (2.10% trained)


Step,Training Loss,Validation Loss
133,1.9482,1.83477
266,1.9608,1.823274
399,1.963,1.818654
532,1.9388,1.816472
665,1.8422,1.83374
798,1.8646,1.831825
931,1.7455,1.883374
1064,1.7212,1.870746
1197,1.6056,1.928957
1330,1.6384,1.931466


Unsloth: Not an error, but LlamaForCausalLM does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient


### Evaluate Model

In [None]:
import torch
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter

def evaluate_classification_model(model, tokenizer, eval_dataset, device='cuda'):
    """
    Comprehensive evaluation for multi-class classification
    Returns accuracy, per-class metrics, confusion matrix
    """
    model.eval()

    predictions = []
    true_labels = []

    print("Running classification evaluation...")

    # First pass: collect all actual labels to understand the data
    for i, example in enumerate(tqdm(eval_dataset, desc="Extracting labels")):
        text = example['text']

        # Extract true label from the alpaca format
        # Look for "### Response:" followed by the category
        if "### Response:" in text:
            response_start = text.find("### Response:") + len("### Response:")
            response_text = text[response_start:].strip()

            # Handle different possible formats
            if response_text:
                # Remove any end tokens and get the first word
                true_label = response_text.replace('<|end_of_text|>', '').strip().split()[0]
                true_labels.append(true_label)
            else:
                true_labels.append("Unknown")
        else:
            true_labels.append("Unknown")

    # Analyze the actual labels in your data
    print("\n📊 ANALYZING ACTUAL LABELS IN DATASET:")
    label_counts = Counter(true_labels)
    print("Found labels:", label_counts)

    # Get the actual class names from your data
    class_names = [label for label, count in label_counts.most_common() if label != "Unknown"]
    print(f"Using class names: {class_names}")

    # Reset for prediction pass
    predictions = []
    true_labels = []

    print(f"\nRunning predictions on {len(eval_dataset)} examples...")

    for i, example in enumerate(tqdm(eval_dataset, desc="Making predictions")):
        text = example['text']

        # Extract true label (same as above)
        if "### Response:" in text:
            response_start = text.find("### Response:") + len("### Response:")
            response_text = text[response_start:].strip()
            if response_text:
                true_label = response_text.replace('<|end_of_text|>', '').strip().split()[0]
            else:
                true_label = "Unknown"
        else:
            true_label = "Unknown"

        # Extract instruction and input for prediction
        try:
            instruction_start = text.find("### Instruction:") + len("### Instruction:")
            instruction_end = text.find("### Input:")
            instruction = text[instruction_start:instruction_end].strip()

            input_start = text.find("### Input:") + len("### Input:")
            input_end = text.find("### Response:")
            input_text = text[input_start:input_end].strip()

            # Format for prediction (without the response)
            prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input_text}

### Response:
"""

            # Tokenize and predict
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=5,   # Only need a few tokens for the category
                    temperature=0.01,   # Very low temperature for consistent predictions
                    do_sample=False,    # Greedy decoding
                    pad_token_id=tokenizer.eos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )

            # Decode prediction
            response = tokenizer.decode(outputs[0][len(inputs['input_ids'][0]):], skip_special_tokens=True)
            predicted_label = response.strip().split()[0] if response.strip() else "Unknown"

            # Clean up prediction (remove any special tokens)
            predicted_label = predicted_label.replace('<|end_of_text|>', '').strip()

        except Exception as e:
            print(f"Error processing example {i}: {e}")
            predicted_label = "Unknown"

        predictions.append(predicted_label)
        true_labels.append(true_label)

        # Show progress and sample predictions
        if (i + 1) % 100 == 0:
            print(f"\nProcessed {i + 1}/{len(eval_dataset)} examples")
            print(f"Sample - True: {true_label}, Predicted: {predicted_label}")

    # Filter out Unknown labels for metrics calculation
    valid_indices = [i for i, (true, pred) in enumerate(zip(true_labels, predictions))
                    if true != "Unknown" and pred in class_names]

    if not valid_indices:
        print("❌ No valid predictions found!")
        return None

    filtered_true = [true_labels[i] for i in valid_indices]
    filtered_pred = [predictions[i] for i in valid_indices]

    print(f"\n📊 PREDICTION ANALYSIS:")
    pred_counts = Counter(predictions)
    print("Predicted label distribution:", pred_counts)

    print(f"\n✅ Valid predictions: {len(valid_indices)}/{len(predictions)}")

    # Calculate metrics on filtered data
    accuracy = accuracy_score(filtered_true, filtered_pred)

    # Create classification report
    try:
        report = classification_report(
            filtered_true,
            filtered_pred,
            target_names=class_names,
            labels=class_names,
            zero_division=0,
            output_dict=True
        )
    except Exception as e:
        print(f"Error creating classification report: {e}")
        # Fallback: use only labels that appear in both true and predicted
        common_labels = list(set(filtered_true) & set(filtered_pred))
        report = classification_report(
            filtered_true,
            filtered_pred,
            target_names=common_labels,
            labels=common_labels,
            zero_division=0,
            output_dict=True
        )
        class_names = common_labels

    # Create confusion matrix
    try:
        cm = confusion_matrix(filtered_true, filtered_pred, labels=class_names)
    except Exception as e:
        print(f"Error creating confusion matrix: {e}")
        common_labels = list(set(filtered_true) & set(filtered_pred))
        cm = confusion_matrix(filtered_true, filtered_pred, labels=common_labels)
        class_names = common_labels

    # Display results
    print("\n" + "="*60)
    print("CLASSIFICATION EVALUATION RESULTS")
    print("="*60)
    print(f"Overall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Valid examples: {len(filtered_true)}")
    print("\nPer-Class Metrics:")
    print("-" * 60)

    for class_name in class_names:
        if class_name in report:
            precision = report[class_name]['precision']
            recall = report[class_name]['recall']
            f1 = report[class_name]['f1-score']
            support = report[class_name]['support']
            print(f"{class_name:12} | Precision: {precision:.3f} | Recall: {recall:.3f} | F1: {f1:.3f} | Support: {support}")

    if 'macro avg' in report:
        print(f"\nMacro Average F1-Score: {report['macro avg']['f1-score']:.4f}")
    if 'weighted avg' in report:
        print(f"Weighted Average F1-Score: {report['weighted avg']['f1-score']:.4f}")

    # Confusion Matrix
    print("\nConfusion Matrix:")
    print("-" * 60)
    cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
    print(cm_df)

    # Show some example predictions
    print("\nSample Predictions:")
    print("-" * 60)
    sample_size = min(15, len(filtered_true))
    for i in range(sample_size):
        status = "✓" if filtered_pred[i] == filtered_true[i] else "✗"
        print(f"{status} True: {filtered_true[i]:12} | Predicted: {filtered_pred[i]:12}")

    return {
        'accuracy': accuracy,
        'classification_report': report,
        'confusion_matrix': cm,
        'predictions': predictions,
        'true_labels': true_labels,
        'filtered_predictions': filtered_pred,
        'filtered_true_labels': filtered_true,
        'class_names': class_names
    }

# Run evaluation on a small sample first to debug
print("🧪 DEBUGGING WITH SMALL SAMPLE")
print("="*50)

# Test with just 10 examples first
small_sample = test_dataset.select(range(min(10, len(test_dataset))))
debug_results = evaluate_classification_model(model, tokenizer, small_sample)

if debug_results:
    print("\n✅ Debug successful! Running full evaluation...")

    # Run full evaluation on TEST SET
    print("\n🎯 FINAL TEST SET EVALUATION")
    print("="*80)
    print(f"Test set size: {len(test_dataset)}")
    print("This is the FINAL evaluation on completely unseen data")
    print("="*80)

    test_results = evaluate_classification_model(model, tokenizer, test_dataset)

    if test_results:
        print(f"\n🎉 EVALUATION COMPLETED!")
        print(f"Final test accuracy: {test_results['accuracy']*100:.2f}%")
        print(f"Categories found: {test_results['class_names']}")
        print(f"Valid predictions: {len(test_results['filtered_true_labels'])}/{len(test_results['true_labels'])}")
    else:
        print("❌ Evaluation failed!")
else:
    print("❌ Debug failed - check your model and data format!")

🧪 DEBUGGING WITH SMALL SAMPLE
Running classification evaluation...


Extracting labels: 100%|██████████| 10/10 [00:00<00:00, 4075.70it/s]



📊 ANALYZING ACTUAL LABELS IN DATASET:
Found labels: Counter({'Travel': 4, 'Comedy': 4, 'Professional': 1, 'Education': 1})
Using class names: ['Travel', 'Comedy', 'Professional', 'Education']

Running predictions on 10 examples...


Making predictions:   0%|          | 0/10 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  10%|█         | 1/10 [00:00<00:04,  1.87it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  20%|██        | 2/10 [00:00<00:03,  2.56it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  30%|███       | 3/10 [00:01<00:02,  3.21it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  40%|████      | 4/10 [00:01<00:01,  3.42it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORME


📊 PREDICTION ANALYSIS:
Predicted label distribution: Counter({'Comedy': 4, 'Travel': 2, 'words': 1, 'Professional': 1, 'Education': 1, 'used': 1})

✅ Valid predictions: 8/10

CLASSIFICATION EVALUATION RESULTS
Overall Accuracy: 1.0000 (100.00%)
Valid examples: 8

Per-Class Metrics:
------------------------------------------------------------
Travel       | Precision: 1.000 | Recall: 1.000 | F1: 1.000 | Support: 2.0
Comedy       | Precision: 1.000 | Recall: 1.000 | F1: 1.000 | Support: 4.0
Professional | Precision: 1.000 | Recall: 1.000 | F1: 1.000 | Support: 1.0
Education    | Precision: 1.000 | Recall: 1.000 | F1: 1.000 | Support: 1.0

Macro Average F1-Score: 1.0000
Weighted Average F1-Score: 1.0000

Confusion Matrix:
------------------------------------------------------------
              Travel  Comedy  Professional  Education
Travel             2       0             0          0
Comedy             0       4             0          0
Professional       0       0             1      

Extracting labels: 100%|██████████| 2481/2481 [00:00<00:00, 10616.62it/s]



📊 ANALYZING ACTUAL LABELS IN DATASET:
Found labels: Counter({'Professional': 507, 'Comedy': 498, 'Entertainment': 493, 'Travel': 486, 'Education': 454, 'Health': 43})
Using class names: ['Professional', 'Comedy', 'Entertainment', 'Travel', 'Education', 'Health']

Running predictions on 2481 examples...


Making predictions:   0%|          | 0/2481 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   0%|          | 1/2481 [00:00<22:05,  1.87it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   0%|          | 2/2481 [00:00<15:45,  2.62it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   0%|          | 3/2481 [00:01<14:16,  2.89it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   0%|          | 4/2481 [00:01<13:16,  3.11it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `


Processed 100/2481 examples
Sample - True: Comedy, Predicted: Comedy


Making predictions:   4%|▍         | 101/2481 [00:31<13:28,  2.94it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   4%|▍         | 102/2481 [00:31<11:48,  3.36it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   4%|▍         | 103/2481 [00:31<10:39,  3.72it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   4%|▍         | 104/2481 [00:32<10:40,  3.71it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   4%|▍         | 105/2481 [00:32<11:10,  3.54it/s]The following generation flags are not valid and may be ignored: ['temperature


Processed 200/2481 examples
Sample - True: Professional, Predicted: Professional


Making predictions:   8%|▊         | 201/2481 [00:59<11:06,  3.42it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   8%|▊         | 202/2481 [00:59<13:19,  2.85it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   8%|▊         | 203/2481 [01:00<12:19,  3.08it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   8%|▊         | 204/2481 [01:00<14:12,  2.67it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:   8%|▊         | 205/2481 [01:01<14:28,  2.62it/s]The following generation flags are not valid and may be ignored: ['temperature


Processed 300/2481 examples
Sample - True: Comedy, Predicted: Comedy


Making predictions:  12%|█▏        | 301/2481 [01:30<09:31,  3.82it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  12%|█▏        | 302/2481 [01:30<10:37,  3.42it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  12%|█▏        | 303/2481 [01:30<11:36,  3.13it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  12%|█▏        | 304/2481 [01:31<10:12,  3.55it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  12%|█▏        | 305/2481 [01:31<09:18,  3.90it/s]The following generation flags are not valid and may be ignored: ['temperature


Processed 400/2481 examples
Sample - True: Comedy, Predicted: Comedy


Making predictions:  16%|█▌        | 401/2481 [02:01<13:06,  2.65it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  16%|█▌        | 402/2481 [02:01<13:12,  2.62it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  16%|█▌        | 403/2481 [02:02<11:20,  3.06it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  16%|█▋        | 404/2481 [02:02<13:10,  2.63it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  16%|█▋        | 405/2481 [02:02<11:22,  3.04it/s]The following generation flags are not valid and may be ignored: ['temperature


Processed 500/2481 examples
Sample - True: Education, Predicted: to


Making predictions:  20%|██        | 501/2481 [02:31<12:42,  2.60it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  20%|██        | 502/2481 [02:31<12:37,  2.61it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  20%|██        | 503/2481 [02:32<12:49,  2.57it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  20%|██        | 504/2481 [02:32<12:34,  2.62it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  20%|██        | 505/2481 [02:32<10:46,  3.06it/s]The following generation flags are not valid and may be ignored: ['temperature


Processed 600/2481 examples
Sample - True: Professional, Predicted: the


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  24%|██▍       | 602/2481 [03:03<11:23,  2.75it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  24%|██▍       | 603/2481 [03:04<11:44,  2.67it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  24%|██▍       | 604/2481 [03:04<10:07,  3.09it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  24%|██▍       | 605/2481 [03:04<10:31,  2.97it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Makin


Processed 700/2481 examples
Sample - True: Comedy, Predicted: Comedy


Making predictions:  28%|██▊       | 701/2481 [03:34<09:22,  3.17it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  28%|██▊       | 702/2481 [03:34<08:19,  3.56it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  28%|██▊       | 703/2481 [03:34<07:33,  3.92it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  28%|██▊       | 704/2481 [03:35<07:37,  3.88it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  28%|██▊       | 705/2481 [03:35<07:42,  3.84it/s]The following generation flags are not valid and may be ignored: ['temperature


Processed 800/2481 examples
Sample - True: Education, Predicted: 't


Making predictions:  32%|███▏      | 801/2481 [04:05<11:02,  2.54it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  32%|███▏      | 802/2481 [04:05<09:23,  2.98it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  32%|███▏      | 803/2481 [04:05<09:57,  2.81it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  32%|███▏      | 804/2481 [04:05<08:37,  3.24it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  32%|███▏      | 805/2481 [04:06<10:16,  2.72it/s]The following generation flags are not valid and may be ignored: ['temperature


Processed 900/2481 examples
Sample - True: Professional, Predicted: like


Making predictions:  36%|███▋      | 901/2481 [04:35<09:05,  2.90it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  36%|███▋      | 902/2481 [04:35<08:27,  3.11it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  36%|███▋      | 903/2481 [04:36<09:47,  2.69it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  36%|███▋      | 904/2481 [04:36<08:27,  3.11it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  36%|███▋      | 905/2481 [04:36<08:05,  3.25it/s]The following generation flags are not valid and may be ignored: ['temperature


Processed 1000/2481 examples
Sample - True: Travel, Predicted: Travel


Making predictions:  40%|████      | 1001/2481 [05:04<07:01,  3.51it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  40%|████      | 1002/2481 [05:05<07:49,  3.15it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  40%|████      | 1003/2481 [05:05<09:09,  2.69it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  40%|████      | 1004/2481 [05:05<08:20,  2.95it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  41%|████      | 1005/2481 [05:06<07:19,  3.36it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 1100/2481 examples
Sample - True: Entertainment, Predicted: Entertainment


Making predictions:  44%|████▍     | 1101/2481 [05:35<05:49,  3.95it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  44%|████▍     | 1102/2481 [05:35<06:07,  3.75it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  44%|████▍     | 1103/2481 [05:35<06:06,  3.76it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  44%|████▍     | 1104/2481 [05:35<05:39,  4.06it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  45%|████▍     | 1105/2481 [05:36<07:21,  3.12it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 1200/2481 examples
Sample - True: Entertainment, Predicted: Entertainment


Making predictions:  48%|████▊     | 1201/2481 [06:03<06:51,  3.11it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  48%|████▊     | 1202/2481 [06:04<07:14,  2.95it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  48%|████▊     | 1203/2481 [06:04<06:47,  3.13it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  49%|████▊     | 1204/2481 [06:04<06:02,  3.52it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  49%|████▊     | 1205/2481 [06:04<05:29,  3.87it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 1300/2481 examples
Sample - True: Education, Predicted: shoes)


Making predictions:  52%|█████▏    | 1301/2481 [06:34<06:32,  3.01it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  52%|█████▏    | 1302/2481 [06:34<05:46,  3.40it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  53%|█████▎    | 1303/2481 [06:34<05:15,  3.74it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  53%|█████▎    | 1304/2481 [06:34<04:53,  4.01it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  53%|█████▎    | 1305/2481 [06:35<05:34,  3.51it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 1400/2481 examples
Sample - True: Professional, Predicted: Professional


Making predictions:  56%|█████▋    | 1401/2481 [07:03<05:02,  3.57it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  57%|█████▋    | 1402/2481 [07:03<04:59,  3.61it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  57%|█████▋    | 1403/2481 [07:03<06:11,  2.90it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  57%|█████▋    | 1404/2481 [07:04<05:24,  3.32it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  57%|█████▋    | 1405/2481 [07:04<05:59,  2.99it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 1500/2481 examples
Sample - True: Entertainment, Predicted: Entertainment


Making predictions:  60%|██████    | 1501/2481 [07:33<04:53,  3.34it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  61%|██████    | 1502/2481 [07:34<05:21,  3.04it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  61%|██████    | 1503/2481 [07:34<04:45,  3.43it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  61%|██████    | 1504/2481 [07:34<04:19,  3.76it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  61%|██████    | 1505/2481 [07:35<04:21,  3.74it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 1600/2481 examples
Sample - True: Entertainment, Predicted: Entertainment


Making predictions:  65%|██████▍   | 1601/2481 [08:03<03:54,  3.76it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  65%|██████▍   | 1602/2481 [08:03<03:38,  4.02it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  65%|██████▍   | 1603/2481 [08:04<03:52,  3.78it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  65%|██████▍   | 1604/2481 [08:04<03:53,  3.75it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  65%|██████▍   | 1605/2481 [08:04<03:37,  4.02it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 1700/2481 examples
Sample - True: Professional, Predicted: ’s


Making predictions:  69%|██████▊   | 1701/2481 [08:32<04:46,  2.72it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  69%|██████▊   | 1702/2481 [08:33<04:30,  2.88it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  69%|██████▊   | 1703/2481 [08:33<05:05,  2.55it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  69%|██████▊   | 1704/2481 [08:34<04:36,  2.81it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  69%|██████▊   | 1705/2481 [08:34<04:01,  3.22it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 1800/2481 examples
Sample - True: Education, Predicted: Education


Making predictions:  73%|███████▎  | 1801/2481 [09:01<02:49,  4.02it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  73%|███████▎  | 1802/2481 [09:02<02:51,  3.96it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  73%|███████▎  | 1803/2481 [09:02<03:41,  3.06it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  73%|███████▎  | 1804/2481 [09:03<03:27,  3.26it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  73%|███████▎  | 1805/2481 [09:03<03:18,  3.40it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 1900/2481 examples
Sample - True: Travel, Predicted: Travel


Making predictions:  77%|███████▋  | 1901/2481 [09:32<02:45,  3.51it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  77%|███████▋  | 1902/2481 [09:33<02:42,  3.56it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  77%|███████▋  | 1903/2481 [09:33<03:19,  2.89it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  77%|███████▋  | 1904/2481 [09:33<02:54,  3.31it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  77%|███████▋  | 1905/2481 [09:34<03:10,  3.02it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 2000/2481 examples
Sample - True: Travel, Predicted: Travel


Making predictions:  81%|████████  | 2001/2481 [10:03<02:21,  3.39it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  81%|████████  | 2002/2481 [10:03<02:09,  3.70it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  81%|████████  | 2003/2481 [10:03<01:59,  4.00it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  81%|████████  | 2004/2481 [10:04<02:01,  3.92it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  81%|████████  | 2005/2481 [10:04<02:35,  3.07it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 2100/2481 examples
Sample - True: Education, Predicted: to


Making predictions:  85%|████████▍ | 2101/2481 [10:33<02:25,  2.60it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  85%|████████▍ | 2102/2481 [10:33<02:04,  3.04it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  85%|████████▍ | 2103/2481 [10:34<02:01,  3.10it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  85%|████████▍ | 2104/2481 [10:34<01:48,  3.49it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  85%|████████▍ | 2105/2481 [10:34<01:45,  3.57it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 2200/2481 examples
Sample - True: Education, Predicted: Education


Making predictions:  89%|████████▊ | 2201/2481 [11:03<01:12,  3.87it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  89%|████████▉ | 2202/2481 [11:03<01:12,  3.84it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  89%|████████▉ | 2203/2481 [11:04<01:31,  3.03it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  89%|████████▉ | 2204/2481 [11:04<01:21,  3.41it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  89%|████████▉ | 2205/2481 [11:04<01:30,  3.05it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 2300/2481 examples
Sample - True: Entertainment, Predicted: thing.


Making predictions:  93%|█████████▎| 2301/2481 [11:34<00:58,  3.07it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  93%|█████████▎| 2302/2481 [11:34<01:00,  2.94it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  93%|█████████▎| 2303/2481 [11:34<00:53,  3.35it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  93%|█████████▎| 2304/2481 [11:35<00:51,  3.42it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  93%|█████████▎| 2305/2481 [11:35<00:46,  3.76it/s]The following generation flags are not valid and may be ignored: ['temper


Processed 2400/2481 examples
Sample - True: Education, Predicted: .


Making predictions:  97%|█████████▋| 2401/2481 [12:04<00:27,  2.92it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  97%|█████████▋| 2402/2481 [12:04<00:26,  3.02it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  97%|█████████▋| 2403/2481 [12:04<00:22,  3.44it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  97%|█████████▋| 2404/2481 [12:04<00:21,  3.55it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Making predictions:  97%|█████████▋| 2405/2481 [12:05<00:20,  3.65it/s]The following generation flags are not valid and may be ignored: ['temper


📊 PREDICTION ANALYSIS:
Predicted label distribution: Counter({'Comedy': 527, 'Entertainment': 406, 'Professional': 351, 'Travel': 350, 'Education': 238, 'Health': 44, ',': 25, '.': 21, 'I': 15, 'the': 15, 'and': 14, 'a': 10, 'to': 9, 'of': 9, ':': 7, 'I’m': 6, 'in': 6, 'my': 6, 'for': 6, 'be': 5, 'is': 5, 'you': 5, '4:': 4, "I'm": 4, '###': 3, 'they': 3, 'like': 3, 'was': 3, "'s": 3, 'market': 3, 'want': 3, ')': 3, '3:': 3, '5:': 3, 'with': 3, 'it': 3, '**': 3, 'no': 3, "'m": 3, '*': 3, 'can': 3, 'or': 3, '’t': 3, 'on': 2, 'think': 2, 'not': 2, 'that': 2, 'get': 2, '-ji': 2, 'will': 2, '’ve': 2, 'No': 2, 'Comment': 2, 'your': 2, '.,': 2, '10': 2, 'strategies': 2, 'If': 2, 'iterate': 2, 'down': 2, 'but': 2, 'what': 2, 'from': 2, 'say': 2, '4': 2, '2': 2, 'Top': 2, 'words': 1, 'used': 1, 'car': 1, 'ions,': 1, 'ei,': 1, 'Sannenzaka': 1, 'August,': 1, 'ally': 1, 'hours': 1, 'audiobook': 1, 'post': 1, 'husband': 1, '7': 1, "'d": 1, 'ogoku': 1, 'shaders': 1, 'would': 1, 'I’ve': 1, 'went': 1




<a name="Inference"></a>
### Inference
Let's run the model! You can change the instruction and input - leave the output blank!



In [None]:
# Test Example 2: Professional/Career Post
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
    alpaca_prompt.format(
    "Classify the following Reddit post into one of these categories: Comedy, Education, Health, Professional, or Travel. Base your classification on the post title, content, and top comments.", # instruction (removed subreddit reference)
    """Post Title: Should I negotiate salary for my first job out of college?

Post Content: I just graduated with a computer science degree and got offered a position at a tech startup. The salary is $75k but I've heard people saying you should always negotiate. I'm worried about seeming ungrateful or them rescinding the offer. What's the best approach here?

Top Comments:
Comment 1: Always negotiate! Worst they can say is no, and most companies expect it
Comment 2: Research salary ranges in your area first - use Glassdoor, LinkedIn, etc.
Comment 3: Don't be afraid to ask - they already want to hire you
Comment 4: I regret not negotiating my first offer, left money on the table
Comment 5: Be professional about it and have data to back up your request""", # input
    "", # output - leave this blank for generation!
    )
], return_tensors = "pt").to("cuda")

outputs = model.generate(
    **inputs,
    max_new_tokens=5,        # Just need the category
    temperature=0.1,         # Low temperature for consistent output
    do_sample=False,         # Greedy decoding
    use_cache=True,
    pad_token_id=tokenizer.eos_token_id
)

result = tokenizer.batch_decode(outputs)

# Extract just the prediction
prediction_start = result[0].find("### Response:\n") + len("### Response:\n")
prediction = result[0][prediction_start:].strip().split()[0]
print(f"Predicted category: {prediction}")
print(f"Expected: Professional")

# Test Example 3: Travel Post
inputs = tokenizer(
[
    alpaca_prompt.format(
    "Classify the following Reddit post into one of these categories: Comedy, Education, Health, Professional, or Travel. Base your classification on the post title, content, and top comments.",
    """Post Title: First time visiting Japan - need advice for 2-week itinerary

Post Content: Planning my first trip to Japan in April during cherry blossom season. Want to see both Tokyo and Kyoto but not sure how to split my time. Also wondering about JR Pass vs individual tickets. Any must-see spots or hidden gems?

Top Comments:
Comment 1: Definitely get the JR Pass for 2 weeks, it pays for itself
Comment 2: Tokyo 7 days, Kyoto 5 days, leave 2 days for day trips
Comment 3: Don't miss Fushimi Inari shrine in Kyoto - amazing at sunrise
Comment 4: Book accommodations early for cherry blossom season
Comment 5: Try to stay in Shibuya or Shinjuku for easy access to everything""",
    "",
    )
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=5, temperature=0.1, do_sample=False, use_cache=True, pad_token_id=tokenizer.eos_token_id)
result = tokenizer.batch_decode(outputs)
prediction_start = result[0].find("### Response:\n") + len("### Response:\n")
prediction = result[0][prediction_start:].strip().split()[0]
print(f"Predicted category: {prediction}")
print(f"Expected: Travel")

# Test Example 4: Health Post
inputs = tokenizer(
[
    alpaca_prompt.format(
    "Classify the following Reddit post into one of these categories: Comedy, Education, Health, Professional, or Travel. Base your classification on the post title, content, and top comments.",
    """Post Title: Struggling with anxiety - what helps you cope?

Post Content: I've been dealing with increased anxiety lately, especially around work presentations. My heart races, I get sweaty palms, and sometimes feel like I can't breathe. Looking for healthy coping strategies that have worked for others. Already considering therapy but want to hear personal experiences.

Top Comments:
Comment 1: Deep breathing exercises really help me in the moment
Comment 2: Regular exercise has been a game-changer for my anxiety
Comment 3: Therapy is worth it - CBT techniques are super helpful
Comment 4: Meditation apps like Headspace helped me a lot
Comment 5: Talk to your doctor too, sometimes medication can help alongside therapy""",
    "",
    )
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=5, temperature=0.1, do_sample=False, use_cache=True, pad_token_id=tokenizer.eos_token_id)
result = tokenizer.batch_decode(outputs)
prediction_start = result[0].find("### Response:\n") + len("### Response:\n")
prediction = result[0][prediction_start:].strip().split()[0]
print(f"Predicted category: {prediction}")
print(f"Expected: Health")

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Predicted category: Professional<|end_of_text|>
Expected: Professional


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Predicted category: Travel<|end_of_text|>
Expected: Travel
Predicted category: Health<|end_of_text|>
Expected: Health


 You can also use a `TextStreamer` for continuous inference - so you can see the generation token by token, instead of waiting the whole time!

<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [None]:
# Get token input from user
from getpass import getpass

hf_token = getpass("Enter your Huggingface token: ")

# Push model to hub
model.push_to_hub("yaamin6236/reddit-post-classifier-v3.0", token=hf_token)
print("✅ Model uploaded successfully!")

# Push tokenizer to hub
tokenizer.push_to_hub("yaamin6236/reddit-post-classifier-v3.0", token=hf_token)
print("✅ Tokenizer uploaded successfully!")

print("🚀 Your model is now available at: https://huggingface.co/yaamin6236/reddit-post-classifier-v2.0")

Enter your Huggingface token: ··········


README.md:   0%|          | 0.00/607 [00:00<?, ?B/s]

  0%|          | 0/1 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/671M [00:00<?, ?B/s]

Saved model to https://huggingface.co/yaamin6236/reddit-post-classifier-v3.0
✅ Model uploaded successfully!


  0%|          | 0/1 [00:00<?, ?it/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

✅ Tokenizer uploaded successfully!
🚀 Your model is now available at: https://huggingface.co/yaamin6236/reddit-post-classifier-v2.0


Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:

In [None]:
if False:
   from unsloth import FastLanguageModel
   model, tokenizer = FastLanguageModel.from_pretrained(
       model_name = "yaamin6236/reddit-post-classifier-v1.0", # YOUR PUSHED MODEL
       max_seq_length = max_seq_length,
       dtype = dtype,
       load_in_4bit = load_in_4bit,
   )
   FastLanguageModel.for_inference(model) # Enable native 2x faster inference

# alpaca_prompt = You MUST copy from above!

inputs = tokenizer(
[
   alpaca_prompt.format(
       "Classify the following Reddit post into one of these categories: Comedy, Education, Health, Professional, or Travel. Base your classification on the post title, content, top comments, and subreddit context.", # instruction
       """Post Title: My boss asked me to stop singing "Wonderwall"

Post Content: I said maybe...

Subreddit: r/Jokes

Top Comments:
Comment 1: I see what you did there! Classic Oasis reference!
Comment 2: This joke is so bad it's good
Comment 3: Take my upvote and get out""", # input
       "", # output - leave this blank for generation!
   )
], return_tensors = "pt").to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 5)

<|begin_of_text|>Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Classify the following Reddit post into one of these categories: Comedy, Education, Health, Professional, or Travel. Base your classification on the post title, content, top comments, and subreddit context.

### Input:
Post Title: My boss asked me to stop singing "Wonderwall"

Post Content: I said maybe...

Subreddit: r/Jokes

Top Comments:
Comment 1: I see what you did there! Classic Oasis reference!
Comment 2: This joke is so bad it's good
Comment 3: Take my upvote and get out

### Response:
Comedy

### Explanation


### Saving to float16 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [None]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")

### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.

Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
* `q8_0` - Fast conversion. High resource use, but generally acceptable.
* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.

[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)

In [None]:
# Save to 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# Save to q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )

Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)

And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

Some other links:
1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!

<div class="align-center">
  <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
  <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
  <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>

  Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
</div>
