diff --git a/.gitattributes b/.gitattributes
deleted file mode 100644
index 264fdfd02f..0000000000
--- a/.gitattributes
+++ /dev/null
@@ -1,2 +0,0 @@
-# Normalize Python files to LF line endings
-*.py text eol=lf
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
deleted file mode 100644
index f4d18645f2..0000000000
--- a/.github/CODEOWNERS
+++ /dev/null
@@ -1,46 +0,0 @@
-# Inspired from https://github.com/vllm-project/vllm/blob/main/.github/CODEOWNERS
-
-/unsloth/models/loader.py @danielhanchen @mmathew23
-/unsloth/models/llama.py @Datta0 @danielhanchen @mmathew23
-/unsloth/models/rl.py @Datta0 @pluesclues @danielhanchen
-/unsloth/models/rl_replacements.py @Datta0 @pluesclues @danielhanchen
-/unsloth/trainer.py @danielhanchen
-/unsloth/models/sentence_transformer.py @Etherll @danielhanchen
-/unsloth/save.py @rolandtannous @danielhanchen
-/unsloth/tokenizer_utils.py @mmathew23 @danielhanchen
-/unsloth/chat_templates.py @rolandtannous @danielhanchen
-/unsloth/ollama_template_mappers.py @rolandtannous @danielhanchen
-/unsloth/kernels/moe/*.py @Datta0
-/unsloth/import_fixes.py @danielhanchen
-/unsloth/device_type.py @danielhanchen
-/unsloth/_auto_install.py @danielhanchen
-/unsloth/dataprep/*.py @danielhanchen
-/unsloth/kernels/cross_entropy_loss.py @danielhanchen
-/unsloth/kernels/fast_lora.py @danielhanchen
-/unsloth/kernels/flex_attention.py @danielhanchen
-/unsloth/kernels/fp8.py @Datta0
-/unsloth/kernels/geglu.py @danielhanchen
-/unsloth/kernels/layernorm.py @danielhanchen
-/unsloth/kernels/rms_layernorm.py @danielhanchen
-/unsloth/kernels/rope_embedding.py @danielhanchen
-/unsloth/kernels/swiglu.py @danielhanchen
-/unsloth/kernels/utils.py @danielhanchen @Datta0
-/unsloth/models/_utils.py @danielhanchen @mmathew23
-/unsloth/models/cohere.py @danielhanchen
-/unsloth/models/dpo.py @danielhanchen
-/unsloth/models/falcon_h1.py @danielhanchen
-/unsloth/models/gemma.py @danielhanchen
-/unsloth/models/gemma2.py @danielhanchen
-/unsloth/models/glm4_moe.py @Datta0
-/unsloth/models/granite.py @danielhanchen
-/unsloth/models/llama4.py @danielhanchen
-/unsloth/models/loader_utils.py @Datta0 @danielhanchen
-/unsloth/models/mapper.py @danielhanchen
-/unsloth/models/mistral.py @danielhanchen
-/unsloth/models/qwen2.py @danielhanchen
-/unsloth/models/qwen3.py @Datta0
-/unsloth/models/qwen3_moe.py @Datta0
-/unsloth/models/vision.py @mmathew23 @danielhanchen
-/unsloth/utils/attention_dispatch.py @mmathew23
-/unsloth/utils/hf_hub.py @mmathew23
-/unsloth/utils/packing.py @mmathew23
diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
index ae5dade42d..a6cda7d034 100644
--- a/.github/FUNDING.yml
+++ b/.github/FUNDING.yml
@@ -1,9 +1,9 @@
# These are supported funding model platforms
-github: unslothai
+github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
-ko_fi: # unsloth
+ko_fi: unsloth
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
diff --git a/.github/ISSUE_TEMPLATE/bug---issue.md b/.github/ISSUE_TEMPLATE/bug---issue.md
deleted file mode 100644
index 83e0fd73a9..0000000000
--- a/.github/ISSUE_TEMPLATE/bug---issue.md
+++ /dev/null
@@ -1,21 +0,0 @@
----
-name: Bug / Issue
-about: Bug / Issue
-title: "[Bug] Please fill in your issue title here."
-labels: bug
-assignees: ''
-
----
-
-1. Did you update? `pip install --upgrade unsloth unsloth_zoo`
-2. `Colab` or `Kaggle` or local / cloud
-3. Number GPUs used, use `nvidia-smi`
-4. Which notebook? Please link!
-5. Which Unsloth version, TRL version, transformers version, PyTorch version?
-6. Which trainer? `SFTTrainer`, `GRPOTrainer` etc
-
-```python
-Put Minimal code to reproduce error here ###Remove Hugging Face token###
-```
-
-🦥 You can also ask via our Reddit page: https://reddit.com/r/unsloth/
diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md
deleted file mode 100644
index 5ea70a8a03..0000000000
--- a/.github/ISSUE_TEMPLATE/feature-request.md
+++ /dev/null
@@ -1,21 +0,0 @@
----
-name: Feature Request
-about: New features, model support, ideas
-title: "[Feature]"
-labels: feature request
-assignees: ''
-
----
-
-For new models, have you tried:
-```python
-from unsloth import FastModel
-model, tokenizer = FastModel.from_pretrained(
- "microsoft/Phi-4-multimodal-instruct",
- trust_remote_code = True,
-)
-from transformers import AutoModelForSequenceClassification
-model, tokenizer = FastModel.from_pretrained(
- auto_model = AutoModelForSequenceClassification,
-)
-```
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
deleted file mode 100644
index fc864d1736..0000000000
--- a/.github/workflows/stale.yml
+++ /dev/null
@@ -1,37 +0,0 @@
-name: 'Inactive Issue Pinger'
-
-on:
- schedule:
- - cron: '30 5 * * *' # Runs at 5:30 UTC every day
-
-jobs:
- stale:
- runs-on: ubuntu-latest
- permissions:
- issues: write
-
- steps:
- - uses: actions/stale@v10
- with:
- # The message to post on stale issues.
- # This message will ping the issue author.
- # Note: The stale bot action does not currently support a direct placeholder for the last commenter.
- # As a workaround, this message encourages any participant to reply.
- stale-issue-message: >
- Is this issue still important to you?
- Apologies in advance we might have missed this issue as well.
- For faster response times, please post on our Reddit server - https://www.reddit.com/r/unsloth or our Discord - https://discord.com/invite/unsloth
-
- # The number of days of inactivity before an issue is considered stale.
- days-before-issue-stale: 9999
-
- # Set to -1 to never close stale issues.
- days-before-issue-close: -1
-
- # A label to apply to stale issues.
- stale-issue-label: 'inactive'
-
- # The number of operations to perform per run to avoid rate limiting.
- operations-per-run: 500
-
- enable-statistics: false
diff --git a/.gitignore b/.gitignore
deleted file mode 100644
index 57f5f00935..0000000000
--- a/.gitignore
+++ /dev/null
@@ -1,177 +0,0 @@
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*.class
-unsloth_compiled_cache/
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-*.py,cover
-.hypothesis/
-.pytest_cache/
-cover/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-db.sqlite3-journal
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-.pybuilder/
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-# For a library or package, you might want to ignore these files since the code is
-# intended to run in multiple environments; otherwise, check them in:
-# .python-version
-
-# pipenv
-# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
-# However, in case of collaboration, if having platform-specific dependencies or dependencies
-# having no cross-platform support, pipenv may install dependencies that don't work, or not
-# install all needed dependencies.
-#Pipfile.lock
-
-# UV
-# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
-# This is especially recommended for binary packages to ensure reproducibility, and is more
-# commonly ignored for libraries.
-#uv.lock
-
-# poetry
-# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
-# This is especially recommended for binary packages to ensure reproducibility, and is more
-# commonly ignored for libraries.
-# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
-#poetry.lock
-
-# pdm
-# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
-#pdm.lock
-# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
-# in version control.
-# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
-.pdm.toml
-.pdm-python
-.pdm-build/
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
-__pypackages__/
-
-# Celery stuff
-celerybeat-schedule
-celerybeat.pid
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
-
-# pytype static type analyzer
-.pytype/
-
-# Cython debug symbols
-cython_debug/
-
-# PyCharm
-# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
-# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
-# and can be added to the global gitignore or merged into this file. For a more nuclear
-# option (not recommended) you can uncomment the following to ignore the entire idea folder.
-#.idea/
-
-# Ruff stuff:
-.ruff_cache/
-.pre-commit-cache/
-
-# PyPI configuration file
-.pypirc
-.vscode
diff --git a/.pre-commit-ci.yaml b/.pre-commit-ci.yaml
deleted file mode 100644
index dbcf58ab76..0000000000
--- a/.pre-commit-ci.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-ci:
- autofix_prs: true
- autofix_prs_limit: 5
- autoupdate_schedule: monthly
- autoupdate_commit_msg: "chore: pre-commit autoupdate"
- skip: []
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
deleted file mode 100644
index b55153790f..0000000000
--- a/.pre-commit-config.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-repos:
- - repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.15.5
- hooks:
- - id: ruff
- args:
- - --fix
- - --exit-non-zero-on-fix
- - repo: local
- hooks:
- - id: ruff-format-with-kwargs
- name: Ruff format with kwarg spacing
- entry: scripts/run_ruff_format.py
- language: python
- types: [python]
- additional_dependencies:
- - ruff==0.6.9
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
deleted file mode 100644
index aa2f827fba..0000000000
--- a/CODE_OF_CONDUCT.md
+++ /dev/null
@@ -1,132 +0,0 @@
-
-# Contributor Covenant Code of Conduct
-
-## Our Pledge
-
-We as members, contributors, and leaders pledge to make participation in our
-community a harassment-free experience for everyone, regardless of age, body
-size, visible or invisible disability, ethnicity, sex characteristics, gender
-identity and expression, level of experience, education, socio-economic status,
-nationality, personal appearance, race, caste, color, religion, or sexual
-identity and orientation.
-
-We pledge to act and interact in ways that contribute to an open, welcoming,
-diverse, inclusive, and healthy community.
-
-## Our Standards
-
-Examples of behavior that contributes to a positive environment for our
-community include:
-
-* Demonstrating empathy and kindness toward other people
-* Being respectful of differing opinions, viewpoints, and experiences
-* Giving and gracefully accepting constructive feedback
-* Accepting responsibility and apologizing to those affected by our mistakes,
- and learning from the experience
-* Focusing on what is best not just for us as individuals, but for the overall
- community
-
-Examples of unacceptable behavior include:
-
-* The use of sexualized language or imagery, and sexual attention or advances of
- any kind
-* Trolling, insulting or derogatory comments, and personal or political attacks
-* Public or private harassment
-* Publishing others' private information, such as a physical or email address,
- without their explicit permission
-* Other conduct which could reasonably be considered inappropriate in a
- professional setting
-
-## Enforcement Responsibilities
-
-Community leaders are responsible for clarifying and enforcing our standards of
-acceptable behavior and will take appropriate and fair corrective action in
-response to any behavior that they deem inappropriate, threatening, offensive,
-or harmful.
-
-Community leaders have the right and responsibility to remove, edit, or reject
-comments, commits, code, wiki edits, issues, and other contributions that are
-not aligned to this Code of Conduct, and will communicate reasons for moderation
-decisions when appropriate.
-
-## Scope
-
-This Code of Conduct applies within all community spaces, and also applies when
-an individual is officially representing the community in public spaces.
-Examples of representing our community include using an official e-mail address,
-posting via an official social media account, or acting as an appointed
-representative at an online or offline event.
-
-## Enforcement
-
-Instances of abusive, harassing, or otherwise unacceptable behavior may be
-reported to the community leaders responsible for enforcement at support@unsloth.ai.
-All complaints will be reviewed and investigated promptly and fairly.
-
-All community leaders are obligated to respect the privacy and security of the
-reporter of any incident.
-
-## Enforcement Guidelines
-
-Community leaders will follow these Community Impact Guidelines in determining
-the consequences for any action they deem in violation of this Code of Conduct:
-
-### 1. Correction
-
-**Community Impact**: Use of inappropriate language or other behavior deemed
-unprofessional or unwelcome in the community.
-
-**Consequence**: A private, written warning from community leaders, providing
-clarity around the nature of the violation and an explanation of why the
-behavior was inappropriate. A public apology may be requested.
-
-### 2. Warning
-
-**Community Impact**: A violation through a single incident or series of
-actions.
-
-**Consequence**: A warning with consequences for continued behavior. No
-interaction with the people involved, including unsolicited interaction with
-those enforcing the Code of Conduct, for a specified period of time. This
-includes avoiding interactions in community spaces as well as external channels
-like social media. Violating these terms may lead to a temporary or permanent
-ban.
-
-### 3. Temporary Ban
-
-**Community Impact**: A serious violation of community standards, including
-sustained inappropriate behavior.
-
-**Consequence**: A temporary ban from any sort of interaction or public
-communication with the community for a specified period of time. No public or
-private interaction with the people involved, including unsolicited interaction
-with those enforcing the Code of Conduct, is allowed during this period.
-Violating these terms may lead to a permanent ban.
-
-### 4. Permanent Ban
-
-**Community Impact**: Demonstrating a pattern of violation of community
-standards, including sustained inappropriate behavior, harassment of an
-individual, or aggression toward or disparagement of classes of individuals.
-
-**Consequence**: A permanent ban from any sort of public interaction within the
-community.
-
-## Attribution
-
-This Code of Conduct is adapted from the [Contributor Covenant][homepage],
-version 2.1, available at
-[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
-
-Community Impact Guidelines were inspired by
-[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
-
-For answers to common questions about this code of conduct, see the FAQ at
-[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
-[https://www.contributor-covenant.org/translations][translations].
-
-[homepage]: https://www.contributor-covenant.org
-[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
-[Mozilla CoC]: https://github.com/mozilla/diversity
-[FAQ]: https://www.contributor-covenant.org/faq
-[translations]: https://www.contributor-covenant.org/translations
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index eb60a5a201..58a2652b5e 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -3,27 +3,27 @@
Thank you for not only using Unsloth but also for being interested in helping out! We value all contributions, whether they come in the form of code, ideas, support for others or just by simply spreading the word of Unsloth! 💕
- **[Support the Community](https://github.com/unslothai/unsloth/issues)**: Answer questions, review pull requests, or assist others in discussions.
-- **Fix Bugs**: Identify and resolve issues with the existing codebase.
-- **Submit Ideas**: Request new features or share enhancements you'd like to see.
+- **Fix Bugs**: Identify and resolve issues with the existing codebase.
+- **Submit Ideas**: Request new features or share enhancements you'd like to see.
- **Develop Features**: Implement new functionality or improve existing tools which can be done via PRs.
- **[Improve Documentation](https://docs.unsloth.ai/)**: Help by creating guides, FAQs, or enhancing clarity.
One of the best ways to support us is by spreading the word about Unsloth! Share how it’s powering your amazing projects in blog posts or social media, and inspire others to explore its potential. Even a simple star on our repo goes a long way in showing your support and helping the community grow. 🌟
-## Submitting Issues
-If you find a bug or have a feature idea, we’d love to hear from you! Here’s how to make your submission stand out:
+## Submitting Issues
+If you find a bug or have a feature idea, we’d love to hear from you! Here’s how to make your submission stand out:
-### Reporting Bugs
-1. **Search First**: Check if the issue has already been reported using GitHub’s search bar under Issues.
-2. **Details Matter**: Is this on Google Colab, Kaggle, or on another platform service? Are you using Unsloth's official notebook? Include your OS, Python version, and other relevant details. For bugs, a concise code snippet that reproduces the issue is incredibly helpful.
+### Reporting Bugs
+1. **Search First**: Check if the issue has already been reported using GitHub’s search bar under Issues.
+2. **Details Matter**: Is this on Google Colab, Kaggle, or on another platform service? Are you using Unsloth's official notebook? Include your OS, Python version, and other relevant details. For bugs, a concise code snippet that reproduces the issue is incredibly helpful.
3. **Be Thorough**: Attach screenshots, traceback logs, or any additional information that might speed up resolution.
## Spread the Word
-Your support extends beyond code:
-- Spread the word by writing about Unsloth in blogs or social media.
-- Share how Unsloth powers your projects.
-- Star our repository to show your appreciation.
+Your support extends beyond code:
+- Spread the word by writing about Unsloth in blogs or social media.
+- Share how Unsloth powers your projects.
+- Star our repository to show your appreciation.
-Finally, please be mindful of our [Code of Conduct](https://github.com/unslothai/unsloth/blob/main/CODE_OF_CONDUCT.md) to ensure a welcoming and inclusive environment for everyone.
+Finally, please be mindful of our [Code of Conduct](https://github.com/unslothai/unsloth/tree/main/unsloth/CODE_OF_CONDUCT.md) to ensure a welcoming and inclusive environment for everyone.
Thank you so much for reading and we hope you have lots of fun using Unsloth! 🦥
diff --git a/README.md b/README.md
index 1314cb1c59..b4be0c7334 100644
--- a/README.md
+++ b/README.md
@@ -1,191 +1,157 @@
-## ✨ Train for Free
+## ✨ Finetune for Free
-Notebooks are beginner friendly. Read our [guide](https://unsloth.ai/docs/get-started/fine-tuning-llms-guide). Add dataset, run, then deploy your trained model.
+Notebooks are beginner friendly. Read our [guide](https://docs.unsloth.ai/get-started/fine-tuning-guide). Add your dataset, click "Run All", and export your finetuned model to GGUF, Ollama, vLLM or Hugging Face.
-| Model | Free Notebooks | Performance | Memory use |
+| Unsloth supports | Free Notebooks | Performance | Memory use |
|-----------|---------|--------|----------|
-| **Qwen3.5 (4B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_5_(4B)_Vision.ipynb) | 1.5x faster | 60% less |
-| **gpt-oss (20B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-Fine-tuning.ipynb) | 2x faster | 70% less |
-| **gpt-oss (20B): GRPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb) | 2x faster | 80% less |
-| **Qwen3: Advanced GRPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb) | 2x faster | 50% less |
-| **Gemma 3 (4B) Vision** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb) | 1.7x faster | 60% less |
-| **embeddinggemma (300M)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/EmbeddingGemma_(300M).ipynb) | 2x faster | 20% less |
-| **Mistral Ministral 3 (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Ministral_3_VL_(3B)_Vision.ipynb) | 1.5x faster | 60% less |
-| **Llama 3.1 (8B) Alpaca** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb) | 2x faster | 70% less |
-| **Llama 3.2 Conversational** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb) | 2x faster | 70% less |
-| **Orpheus-TTS (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Orpheus_(3B)-TTS.ipynb) | 1.5x faster | 50% less |
-
-- See all our notebooks for: [Kaggle](https://github.com/unslothai/notebooks?tab=readme-ov-file#-kaggle-notebooks), [GRPO](https://unsloth.ai/docs/get-started/unsloth-notebooks#grpo-reasoning-rl-notebooks), [TTS](https://unsloth.ai/docs/get-started/unsloth-notebooks#text-to-speech-tts-notebooks), [embedding](https://unsloth.ai/docs/new/embedding-finetuning) & [Vision](https://unsloth.ai/docs/get-started/unsloth-notebooks#vision-multimodal-notebooks)
-- See [all our models](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [all our notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks)
-- See detailed documentation for Unsloth [here](https://unsloth.ai/docs)
+| **GRPO (R1 reasoning)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb) | 2x faster | 80% less |
+| **Gemma 3 (4B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B).ipynb) | 1.6x faster | 60% less |
+| **Llama 3.2 (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb) | 2x faster | 70% less |
+| **Phi-4 (14B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Phi_4-Conversational.ipynb) | 2x faster | 70% less |
+| **Llama 3.2 Vision (11B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb) | 2x faster | 50% less |
+| **Llama 3.1 (8B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb) | 2x faster | 70% less |
+| **Qwen 2.5 (7B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(7B)-Alpaca.ipynb) | 2x faster | 70% less |
+| **Mistral v0.3 (7B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_v0.3_(7B)-Conversational.ipynb) | 2.2x faster | 75% less |
+| **Ollama** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb) | 1.9x faster | 60% less |
+| **DPO Zephyr** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Zephyr_(7B)-DPO.ipynb) | 1.9x faster | 50% less |
+
+- See [all our notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks) and [all our models](https://docs.unsloth.ai/get-started/all-our-models)
+- **Kaggle Notebooks** for [Llama 3.2 Kaggle notebook](https://www.kaggle.com/danielhanchen/kaggle-llama-3-2-1b-3b-unsloth-notebook), [Llama 3.1 (8B)](https://www.kaggle.com/danielhanchen/kaggle-llama-3-1-8b-unsloth-notebook), [Phi-4 (14B)](https://www.kaggle.com/code/danielhanchen/phi-4-finetuning-unsloth-notebook), [Mistral (7B)](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook)
+- See detailed documentation for Unsloth [here](https://docs.unsloth.ai/).
## ⚡ Quickstart
-### Linux or WSL
-```bash
+
+- **Install with pip (recommended)** for Linux devices:
+```
pip install unsloth
```
-### Windows
-For Windows, `pip install unsloth` works only if you have Pytorch installed. Read our [Windows Guide](https://unsloth.ai/docs/get-started/install/windows-installation).
-
-### Docker
-Use our official [Unsloth Docker image](https://hub.docker.com/r/unsloth/unsloth) ```unsloth/unsloth``` container. Read our [Docker Guide](https://unsloth.ai/docs/get-started/install/docker).
-
-### AMD, Intel, Blackwell & DGX Spark
-For RTX 50x, B200, 6000 GPUs: `pip install unsloth`. Read our guides for: [Blackwell](https://unsloth.ai/docs/blog/fine-tuning-llms-with-blackwell-rtx-50-series-and-unsloth) and [DGX Spark](https://unsloth.ai/docs/blog/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth).
-To install Unsloth on **AMD** and **Intel** GPUs, follow our [AMD Guide](https://unsloth.ai/docs/get-started/install/amd) and [Intel Guide](https://unsloth.ai/docs/get-started/install/intel).
-
-## 🦥 Unsloth News
-- **Qwen3.5** - 0.8B, 2B, 4B, 9B, 27B, 35-A3B, 112B-A10B are now supported. [Guide + notebooks](https://unsloth.ai/docs/models/qwen3.5/fine-tune)
-- Train **MoE LLMs 12x faster** with 35% less VRAM - DeepSeek, GLM, Qwen and gpt-oss. [Blog](https://unsloth.ai/docs/new/faster-moe)
-- **Embedding models**: Unsloth now supports ~1.8-3.3x faster embedding fine-tuning. [Blog](https://unsloth.ai/docs/new/embedding-finetuning) • [Notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks#embedding-models)
-- New **7x longer context RL** vs. all other setups, via our new batching algorithms. [Blog](https://unsloth.ai/docs/new/grpo-long-context)
-- New RoPE & MLP **Triton Kernels** & **Padding Free + Packing**: 3x faster training & 30% less VRAM. [Blog](https://unsloth.ai/docs/new/3x-faster-training-packing)
-- **500K Context**: Training a 20B model with >500K context is now possible on an 80GB GPU. [Blog](https://unsloth.ai/docs/blog/500k-context-length-fine-tuning)
-- **FP8 & Vision RL**: You can now do FP8 & VLM GRPO on consumer GPUs. [FP8 Blog](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/vision-reinforcement-learning-vlm-rl)
-- **Docker**: Use Unsloth with no setup & environment issues with our new image. [Guide](https://unsloth.ai/docs/blog/how-to-fine-tune-llms-with-unsloth-and-docker) • [Docker image](https://hub.docker.com/r/unsloth/unsloth)
-- **gpt-oss** by OpenAI: Read our [RL blog](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune/gpt-oss-reinforcement-learning), [Flex Attention](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training) blog and [Guide](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune).
-
+For Windows install instructions, see [here](https://docs.unsloth.ai/get-started/installing-+-updating/windows-installation).
+
+## 🦥 Unsloth.ai News
+- 📣 NEW! [**EVERYTHING** is now supported](https://unsloth.ai/blog/gemma3#everything) incuding: full finetuning, pretraining, ALL models (Mixtral, MOE, Cohere, Mamba) and all training algorithms (KTO, DoRA) etc. MultiGPU support coming very soon.
+- 📣 NEW! **Gemma 3** by Google: [Read Blog](https://unsloth.ai/blog/gemma3). We [uploaded GGUFs, 4-bit models](https://huggingface.co/collections/unsloth/phi-4-all-versions-677eecf93784e61afe762afa).
+- 📣 NEW! Introducing Long-context [Reasoning (GRPO)](https://unsloth.ai/blog/grpo) in Unsloth. Train your own reasoning model with just 5GB VRAM. Transform Llama, Phi, Mistral etc. into reasoning LLMs!
+- 📣 NEW! [DeepSeek-R1](https://unsloth.ai/blog/deepseek-r1) - the most powerful open reasoning models with Llama & Qwen distillations. Run or fine-tune them now [with our guide](https://unsloth.ai/blog/deepseek-r1). All model uploads: [here](https://huggingface.co/collections/unsloth/deepseek-r1-all-versions-678e1c48f5d2fce87892ace5).
+- 📣 NEW! [Phi-4](https://unsloth.ai/blog/phi4) by Microsoft: We also [fixed bugs](https://unsloth.ai/blog/phi4) in Phi-4 and [uploaded GGUFs, 4-bit](https://huggingface.co/collections/unsloth/phi-4-all-versions-677eecf93784e61afe762afa).
+- 📣 NEW! [Llama 3.3 (70B)](https://huggingface.co/collections/unsloth/llama-33-all-versions-67535d7d994794b9d7cf5e9f), Meta's latest model is supported.
+- 📣 Introducing Unsloth [Dynamic 4-bit Quantization](https://unsloth.ai/blog/dynamic-4bit)! We dynamically opt not to quantize certain parameters and this greatly increases accuracy while only using <10% more VRAM than BnB 4-bit. See our collection on [Hugging Face here.](https://huggingface.co/collections/unsloth/unsloth-4-bit-dynamic-quants-67503bb873f89e15276c44e7)
+- 📣 [Vision models](https://unsloth.ai/blog/vision) now supported! [Llama 3.2 Vision (11B)](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb), [Qwen 2.5 VL (7B)](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2_VL_(7B)-Vision.ipynb) and [Pixtral (12B) 2409](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Pixtral_(12B)-Vision.ipynb)
Click for more news
-- **Quantization-Aware Training**: We collabed with Pytorch, recovering ~70% accuracy. [Read blog](https://unsloth.ai/docs/blog/quantization-aware-training-qat)
-- **Memory-efficient RL**: We're introducing even better RL. Our new kernels & algos allows faster RL with 50% less VRAM & 10× more context. [Read blog](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/memory-efficient-rl)
-- **Mistral 3**: Run Ministral 3 or Devstral 2 and fine-tune with vision/RL sudoku notebooks. [Guide](https://unsloth.ai/docs/models/tutorials/ministral-3) • [Notebooks](https://unsloth.ai/docs/models/ministral-3#fine-tuning-ministral-3)
-- **Gemma 3n** by Google: [Read Blog](https://unsloth.ai/docs/models/gemma-3-how-to-run-and-fine-tune/gemma-3n-how-to-run-and-fine-tune). We [uploaded GGUFs, 4-bit models](https://huggingface.co/collections/unsloth/gemma-3n-685d3874830e49e1c93f9339).
-- **[Text-to-Speech (TTS)](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning)** is now supported, including `sesame/csm-1b` and STT `openai/whisper-large-v3`.
-- **[Qwen3](https://unsloth.ai/docs/models/qwen3-how-to-run-and-fine-tune)** is now supported. Qwen3-30B-A3B fits on 17.5GB VRAM.
-- Introducing **[Dynamic 2.0](https://unsloth.ai/docs/basics/unsloth-dynamic-2.0-ggufs)** quants that set new benchmarks on 5-shot MMLU & Aider Polyglot.
-- [**EVERYTHING** is now supported](https://unsloth.ai/blog/gemma3#everything) - all models (TTS, BERT, Mamba), FFT, etc. [MultiGPU](https://unsloth.ai/docs/basics/multi-gpu-training-with-unsloth) is now supported. Enable FFT with `full_finetuning = True`, 8-bit with `load_in_8bit = True`.
-- 📣 [DeepSeek-R1](https://unsloth.ai/blog/deepseek-r1) - run or fine-tune them [with our guide](https://unsloth.ai/blog/deepseek-r1). All model uploads: [here](https://huggingface.co/collections/unsloth/deepseek-r1-all-versions-678e1c48f5d2fce87892ace5).
-- 📣 Introducing Long-context [Reasoning (GRPO)](https://unsloth.ai/blog/grpo) in Unsloth. Train your own reasoning model with just 5GB VRAM. Transform Llama, Phi, Mistral etc. into reasoning LLMs!
-- 📣 Introducing Unsloth [Dynamic 4-bit Quantization](https://unsloth.ai/blog/dynamic-4bit)! We dynamically opt not to quantize certain parameters and this greatly increases accuracy while only using <10% more VRAM than BnB 4-bit. See our collection on [Hugging Face here.](https://huggingface.co/collections/unsloth/unsloth-4-bit-dynamic-quants-67503bb873f89e15276c44e7)
-- 📣 **[Llama 4](https://unsloth.ai/blog/llama4)** by Meta, including Scout & Maverick are now supported.
-- 📣 [Phi-4](https://unsloth.ai/blog/phi4) by Microsoft: We also [fixed bugs](https://unsloth.ai/blog/phi4) in Phi-4 and [uploaded GGUFs, 4-bit](https://huggingface.co/collections/unsloth/phi-4-all-versions-677eecf93784e61afe762afa).
-- 📣 [Vision models](https://unsloth.ai/blog/vision) now supported! [Llama 3.2 Vision (11B)](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb), [Qwen 2.5 VL (7B)](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2_VL_(7B)-Vision.ipynb) and [Pixtral (12B) 2409](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Pixtral_(12B)-Vision.ipynb)
-- 📣 [Llama 3.3 (70B)](https://huggingface.co/collections/unsloth/llama-33-all-versions-67535d7d994794b9d7cf5e9f), Meta's latest model is supported.
-- 📣 We worked with Apple to add [Cut Cross Entropy](https://arxiv.org/abs/2411.09009). Unsloth now supports 89K context for Meta's Llama 3.3 (70B) on a 80GB GPU - 13x longer than HF+FA2. For Llama 3.1 (8B), Unsloth enables 342K context, surpassing its native 128K support.
+- 📣 NEW! We worked with Apple to add [Cut Cross Entropy](https://arxiv.org/abs/2411.09009). Unsloth now supports 89K context for Meta's Llama 3.3 (70B) on a 80GB GPU - 13x longer than HF+FA2. For Llama 3.1 (8B), Unsloth enables 342K context, surpassing its native 128K support.
- 📣 We found and helped fix a [gradient accumulation bug](https://unsloth.ai/blog/gradient)! Please update Unsloth and transformers.
+- 📣 Try out [Chat interface](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Unsloth_Studio.ipynb)!
+- 📣 NEW! Qwen-2.5 including [Coder](https://unsloth.ai/blog/qwen-coder) models are now supported with bugfixes. 14b fits in a Colab GPU! [Qwen 2.5 conversational notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_Coder_(14B)-Conversational.ipynb)
+- 📣 NEW! [Mistral Small 22b notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_Small_(22B)-Alpaca.ipynb) finetuning fits in under 16GB of VRAM!
+- 📣 NEW! `pip install unsloth` now works! Head over to [pypi](https://pypi.org/project/unsloth/) to check it out! This allows non git pull installs. Use `pip install unsloth[colab-new]` for non dependency installs.
+- 📣 NEW! Continued Pretraining [notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Mistral_v0.3_(7B)-CPT.ipynb) for other languages like Korean!
+- 📣 [2x faster inference](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Inference.ipynb) added for all our models
- 📣 We cut memory usage by a [further 30%](https://unsloth.ai/blog/long-context) and now support [4x longer context windows](https://unsloth.ai/blog/long-context)!
## 🔗 Links and Resources
-| Type | Links |
-| ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------ |
-|
**r/unsloth Reddit** | [Join Reddit community](https://reddit.com/r/unsloth) |
-| 📚 **Documentation & Wiki** | [Read Our Docs](https://unsloth.ai/docs) |
-|
**Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai) |
-| 💾 **Installation** | [Pip & Docker Install](https://unsloth.ai/docs/get-started/install) |
-| 🔮 **Our Models** | [Unsloth Catalog](https://unsloth.ai/docs/get-started/unsloth-model-catalog) |
-| ✍️ **Blog** | [Read our Blogs](https://unsloth.ai/blog) |
+| Type | Links |
+| ------------------------------- | --------------------------------------- |
+| 📚 **Documentation & Wiki** | [Read Our Docs](https://docs.unsloth.ai) |
+|
**Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai)|
+| 💾 **Installation** | [Pip install](https://docs.unsloth.ai/get-started/installing-+-updating)|
+| 🔮 **Our Models** | [Unsloth Releases](https://docs.unsloth.ai/get-started/all-our-models)|
+| ✍️ **Blog** | [Read our Blogs](https://unsloth.ai/blog)|
+|
**Reddit** | [Join our Reddit page](https://reddit.com/r/unsloth)|
## ⭐ Key Features
-
-* Supports **full-finetuning**, pretraining, 4-bit, 16-bit and **FP8** training
-* Supports **all models** including [TTS](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning), multimodal, [embedding](https://unsloth.ai/docs/new/embedding-finetuning) and more! Any model that works in transformers, works in Unsloth.
-* The most efficient library for [Reinforcement Learning (RL)](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide), using 80% less VRAM. Supports GRPO, GSPO, DrGRPO, DAPO etc.
-* **0% loss in accuracy** - no approximation methods - all exact.
-* Export and [deploy your model](https://unsloth.ai/docs/basics/inference-and-deployment) to [GGUF](https://unsloth.ai/docs/basics/inference-and-deployment/saving-to-gguf) llama.cpp, [vLLM](https://unsloth.ai/docs/basics/inference-and-deployment/vllm-guide), [SGLang](https://unsloth.ai/docs/basics/inference-and-deployment/sglang-guide) and Hugging Face.
-* Supports NVIDIA (since 2018), [AMD](https://unsloth.ai/docs/get-started/install/amd) and [Intel](https://unsloth.ai/docs/get-started/install/intel) GPUs. Minimum CUDA Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40x, A100, H100, L40 etc)
-* Works on **Linux**, WSL and **[Windows](https://unsloth.ai/docs/get-started/install/windows-installation)**
-* All kernels written in OpenAI's Triton language. Manual backprop engine.
-* If you trained a model with 🦥Unsloth, you can use this cool sticker!
+- All kernels written in [OpenAI's Triton](https://openai.com/index/triton/) language. **Manual backprop engine**.
+- **0% loss in accuracy** - no approximation methods - all exact.
+- No change of hardware. Supports NVIDIA GPUs since 2018+. Minimum CUDA Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40x, A100, H100, L40 etc) [Check your GPU!](https://developer.nvidia.com/cuda-gpus) GTX 1070, 1080 works, but is slow.
+- Works on **Linux** and **Windows**
+- Supports 4bit and 16bit QLoRA / LoRA finetuning via [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
+- If you trained a model with 🦥Unsloth, you can use this cool sticker!
## 💾 Install Unsloth
-You can also see our docs for more detailed installation and updating instructions [here](https://unsloth.ai/docs/get-started/install).
-
-Unsloth supports Python 3.13 or lower.
+You can also see our documentation for more detailed installation and updating instructions [here](https://docs.unsloth.ai/get-started/installing-+-updating).
### Pip Installation
**Install with pip (recommended) for Linux devices:**
```
pip install unsloth
```
-**To update Unsloth:**
-```
-pip install --upgrade --force-reinstall --no-cache-dir unsloth unsloth_zoo
-```
-See [here](#advanced-pip-installation) for advanced pip install instructions.
+See [here](https://github.com/unslothai/unsloth/edit/main/README.md#advanced-pip-installation) for advanced pip install instructions.
### Windows Installation
-For this method, we will be utilizing Anaconda. You can view the [full guide with screenshots here](https://unsloth.ai/docs/get-started/install/windows-installation).
-1. **Install Miniconda (or Anaconda):** Miniconda is recommended. Install [Miniconda](https://www.anaconda.com/docs/getting-started/miniconda/install) or [Anaconda](https://www.anaconda.com/download), then open Anaconda PowerShell Prompt to continue.
-
-2. **Create a Conda Environment:** Create and activate a fresh Python 3.12 environment for Unsloth.
+> [!warning]
+> Python 3.13 does not support Unsloth. Use 3.12, 3.11 or 3.10
- ```bash
- conda create --name unsloth_env python==3.12 -y
- conda activate unsloth_env
- ```
+1. **Install NVIDIA Video Driver:**
+ You should install the latest version of your GPUs driver. Download drivers here: [NVIDIA GPU Drive](https://www.nvidia.com/Download/index.aspx).
-3. **Check Your GPU and CUDA Version:** Run `nvidia-smi` to confirm that your NVIDIA GPU is detected and note the CUDA version shown in the output. If `nvidia-smi` does not work, reinstall the latest [NVIDIA drivers](https://www.nvidia.com/en-us/drivers/).
+3. **Install Visual Studio C++:**
+ You will need Visual Studio, with C++ installed. By default, C++ is not installed with [Visual Studio](https://visualstudio.microsoft.com/vs/community/), so make sure you select all of the C++ options. Also select options for Windows 10/11 SDK. For detailed instructions with options, see [here](https://docs.unsloth.ai/get-started/installing-+-updating).
-4. **Install PyTorch:** Install the Windows pip build of PyTorch that matches your CUDA version. Use [Install PyTorch](https://pytorch.org/get-started/locally/) to select the correct command for your system, then verify that PyTorch can see your GPU.
+5. **Install CUDA Toolkit:**
+ Follow the instructions to install [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive).
- ```python
- import torch
- print(torch.cuda.is_available())
- A = torch.ones((10, 10), device="cuda")
- B = torch.ones((10, 10), device="cuda")
- A @ B
- ```
+6. **Install PyTorch:**
+ You will need the correct version of PyTorch that is compatibile with your CUDA drivers, so make sure to select them carefully.
+ [Install PyTorch](https://pytorch.org/get-started/locally/).
-5. **Install Unsloth:** Only install Unsloth after PyTorch is working correctly.
+7. **Install Unsloth:**
+
+```python
+pip install unsloth
+```
- ```bash
- pip install unsloth
- ```
+#### Notes
+To run Unsloth directly on Windows:
+- Install Triton from this Windows fork and follow the instructions [here](https://github.com/woct0rdho/triton-windows) (be aware that the Windows fork requires PyTorch >= 2.4 and CUDA 12)
+- In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue:
+```python
+trainer = SFTTrainer(
+ dataset_num_proc=1,
+ ...
+)
+```
#### Advanced/Troubleshooting
-For **advanced installation instructions** or if you see weird errors during installations:
-First try using an isolated environment via then `pip install unsloth`
-```bash
-python -m venv unsloth
-source unsloth/bin/activate
-pip install unsloth
-```
+For **advanced installation instructions** or if you see weird errors during installations:
1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton`
-2. Confirm if CUDA is installed correctly. Try `nvcc`. If that fails, you need to install `cudatoolkit` or CUDA drivers.
-3. Install `xformers` manually via:
- ```bash
- pip install ninja
- pip install -v --no-build-isolation -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
- ```
- Check if `xformers` succeeded with `python -m xformers.info` Go to https://github.com/facebookresearch/xformers. Another option is to install `flash-attn` for Ampere GPUs and ignore `xformers`
-
-4. For GRPO runs, you can try installing `vllm` and seeing if `pip install vllm` succeeds.
-5. Double check that your versions of Python, CUDA, CUDNN, `torch`, `triton`, and `xformers` are compatible with one another. The [PyTorch Compatibility Matrix](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix) may be useful.
-6. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes`
+2. Confirm if CUDA is installated correctly. Try `nvcc`. If that fails, you need to install `cudatoolkit` or CUDA drivers.
+3. Install `xformers` manually. You can try installing `vllm` and seeing if `vllm` succeeds. Check if `xformers` succeeded with `python -m xformers.info` Go to https://github.com/facebookresearch/xformers. Another option is to install `flash-attn` for Ampere GPUs.
+4. Double check that your versions of Python, CUDA, CUDNN, `torch`, `triton`, and `xformers` are compatible with one another. The [PyTorch Compatibility Matrix](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix) may be useful.
+5. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes`
### Conda Installation (Optional)
-`⚠️Only use Conda if you have it. If not, use Pip`. We support `python=3.10,3.11,3.12,3.13`.
+`⚠️Only use Conda if you have it. If not, use Pip`. Select either `pytorch-cuda=11.8,12.1` for CUDA 11.8 or CUDA 12.1. We support `python=3.10,3.11,3.12`.
```bash
-conda create --name unsloth_env python==3.12 -y
+conda create --name unsloth_env \
+ python=3.11 \
+ pytorch-cuda=12.1 \
+ pytorch cudatoolkit xformers -c pytorch -c nvidia -c xformers \
+ -y
conda activate unsloth_env
+
+pip install unsloth
```
-Use `nvidia-smi` to get the correct CUDA version like 13.0 which becomes `cu130`
-```bash
-pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130
-pip3 install unsloth
-```
+
If you're looking to install Conda in a Linux environment, read here, or run the below 🔽
@@ -200,9 +166,9 @@ pip3 install unsloth
### Advanced Pip Installation
-`⚠️Do **NOT** use this if you have Conda.` Pip is a bit more complex since there are dependency issues. The pip command is different for `torch 2.2,2.3,2.4,2.5,2.6,2.7,2.8,2.9,2.10` and CUDA versions.
+`⚠️Do **NOT** use this if you have Conda.` Pip is a bit more complex since there are dependency issues. The pip command is different for `torch 2.2,2.3,2.4,2.5` and CUDA versions.
-For other torch versions, we support `torch211`, `torch212`, `torch220`, `torch230`, `torch240`, `torch250`, `torch260`, `torch270`, `torch280`, `torch290`, `torch2100` and for CUDA versions, we support `cu118` and `cu121` and `cu124`. For Ampere devices (A100, H100, RTX3090) and above, use `cu118-ampere` or `cu121-ampere` or `cu124-ampere`. Note: torch 2.10 only supports CUDA 12.6, 12.8, and 13.0.
+For other torch versions, we support `torch211`, `torch212`, `torch220`, `torch230`, `torch240` and for CUDA versions, we support `cu118` and `cu121` and `cu124`. For Ampere devices (A100, H100, RTX3090) and above, use `cu118-ampere` or `cu121-ampere` or `cu124-ampere`.
For example, if you have `torch 2.4` and `CUDA 12.1`, use:
```bash
@@ -210,16 +176,10 @@ pip install --upgrade pip
pip install "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
```
-Another example, if you have `torch 2.9` and `CUDA 13.0`, use:
+Another example, if you have `torch 2.5` and `CUDA 12.4`, use:
```bash
pip install --upgrade pip
-pip install "unsloth[cu130-torch290] @ git+https://github.com/unslothai/unsloth.git"
-```
-
-Another example, if you have `torch 2.10` and `CUDA 12.6`, use:
-```bash
-pip install --upgrade pip
-pip install "unsloth[cu126-torch2100] @ git+https://github.com/unslothai/unsloth.git"
+pip install "unsloth[cu124-torch250] @ git+https://github.com/unslothai/unsloth.git"
```
And other examples:
@@ -246,81 +206,65 @@ Or, run the below manually in a Python REPL:
try: import torch
except: raise ImportError('Install torch via `pip install torch`')
from packaging.version import Version as V
-import re
-v = V(re.match(r"[0-9\.]{3,}", torch.__version__).group(0))
+v = V(torch.__version__)
cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8
-USE_ABI = torch._C._GLIBCXX_USE_CXX11_ABI
-if cuda not in ("11.8", "12.1", "12.4", "12.6", "12.8", "13.0"): raise RuntimeError(f"CUDA = {cuda} not supported!")
+if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
elif v < V('2.3.0'): x = 'cu{}{}-torch220'
elif v < V('2.4.0'): x = 'cu{}{}-torch230'
elif v < V('2.5.0'): x = 'cu{}{}-torch240'
-elif v < V('2.5.1'): x = 'cu{}{}-torch250'
-elif v <= V('2.5.1'): x = 'cu{}{}-torch251'
-elif v < V('2.7.0'): x = 'cu{}{}-torch260'
-elif v < V('2.7.9'): x = 'cu{}{}-torch270'
-elif v < V('2.8.0'): x = 'cu{}{}-torch271'
-elif v < V('2.8.9'): x = 'cu{}{}-torch280'
-elif v < V('2.9.1'): x = 'cu{}{}-torch290'
-elif v < V('2.9.2'): x = 'cu{}{}-torch291'
-elif v < V('2.10.1'): x = 'cu{}{}-torch2100'
+elif v < V('2.6.0'): x = 'cu{}{}-torch250'
else: raise RuntimeError(f"Torch = {v} too new!")
-if v > V('2.6.9') and cuda not in ("11.8", "12.6", "12.8", "13.0"): raise RuntimeError(f"CUDA = {cuda} not supported!")
-if v >= V('2.10.0') and cuda not in ("12.6", "12.8", "13.0"): raise RuntimeError(f"Torch 2.10 requires CUDA 12.6, 12.8, or 13.0! Got CUDA = {cuda}")
-x = x.format(cuda.replace(".", ""), "-ampere" if False else "") # is_ampere is broken due to flash-attn
-print(f'pip install --upgrade pip && pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git" --no-build-isolation')
-```
-### Docker Installation
-You can use our pre-built Docker container with all dependencies to use Unsloth instantly with no setup required.
-[Read our guide](https://unsloth.ai/docs/get-started/install/docker).
-
-This container requires installing [NVIDIA's Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
-
-```bash
-docker run -d -e JUPYTER_PASSWORD="mypassword" \
- -p 8888:8888 -p 2222:22 \
- -v $(pwd)/work:/workspace/work \
- --gpus all \
- unsloth/unsloth
+x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
+print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
```
-Access Jupyter Lab at `http://localhost:8888` and start fine-tuning!
-
## 📜 Documentation
-* Go to our official [Documentation](https://unsloth.ai/docs) for [running models](https://unsloth.ai/docs/basics/inference-and-deployment), [saving to GGUF](https://unsloth.ai/docs/basics/inference-and-deployment/saving-to-gguf), [checkpointing](https://unsloth.ai/docs/basics/finetuning-from-last-checkpoint), [evaluation](https://unsloth.ai/docs/get-started/fine-tuning-llms-guide#evaluation) and more!
-* Read our Guides for: [Fine-tuning](https://unsloth.ai/docs/get-started/fine-tuning-llms-guide), [Reinforcement Learning](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide), [Text-to-Speech (TTS)](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning), [Vision](https://unsloth.ai/docs/basics/vision-fine-tuning) and [any model](https://unsloth.ai/docs/models/tutorials).
-* We support Huggingface's transformers, TRL, Trainer, Seq2SeqTrainer and Pytorch code.
+- Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more!
+- We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
+- We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
+- If you want to download models from the ModelScope community, please use an environment variable: `UNSLOTH_USE_MODELSCOPE=1`, and install the modelscope library by: `pip install modelscope -U`.
-Unsloth example code to fine-tune gpt-oss-20b:
+> unsloth_cli.py also supports `UNSLOTH_USE_MODELSCOPE=1` to download models and datasets. please remember to use the model and dataset id in the ModelScope community.
```python
-from unsloth import FastLanguageModel, FastModel, FastVisionModel
+from unsloth import FastLanguageModel
import torch
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
-max_seq_length = 2048 # Supports RoPE Scaling internally, so choose any!
+max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any!
# Get LAION dataset
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files = {"train" : url}, split = "train")
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
- "unsloth/gpt-oss-20b-unsloth-bnb-4bit", #or choose any model
-
+ "unsloth/Meta-Llama-3.1-8B-bnb-4bit", # Llama-3.1 2x faster
+ "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
+ "unsloth/Meta-Llama-3.1-70B-bnb-4bit",
+ "unsloth/Meta-Llama-3.1-405B-bnb-4bit", # 4bit for 405b!
+ "unsloth/Mistral-Small-Instruct-2409", # Mistral 22b 2x faster!
+ "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
+ "unsloth/Phi-3.5-mini-instruct", # Phi-3.5 2x faster!
+ "unsloth/Phi-3-medium-4k-instruct",
+ "unsloth/gemma-2-9b-bnb-4bit",
+ "unsloth/gemma-2-27b-bnb-4bit", # Gemma 2x faster!
+
+ "unsloth/Llama-3.2-1B-bnb-4bit", # NEW! Llama 3.2 models
+ "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
+ "unsloth/Llama-3.2-3B-bnb-4bit",
+ "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
+
+ "unsloth/Llama-3.3-70B-Instruct-bnb-4bit" # NEW! Llama 3.3 70B!
] # More models at https://huggingface.co/unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/gpt-oss-20b",
- max_seq_length = max_seq_length, # Choose any for long context!
- load_in_4bit = True, # 4-bit quantization. False = 16-bit LoRA.
- load_in_8bit = False, # 8-bit quantization
- load_in_16bit = False, # 16-bit LoRA
- full_finetuning = False, # Use for full fine-tuning.
- trust_remote_code = False, # Enable to support new models
- # token = "hf_...", # use one if using gated models
+ model_name = "unsloth/Llama-3.2-1B",
+ max_seq_length = max_seq_length,
+ load_in_4bit = True,
)
# Do model patching and add fast LoRA weights
@@ -345,6 +289,7 @@ trainer = SFTTrainer(
train_dataset = dataset,
tokenizer = tokenizer,
args = SFTConfig(
+ dataset_text_field = "text",
max_seq_length = max_seq_length,
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
@@ -358,8 +303,8 @@ trainer = SFTTrainer(
)
trainer.train()
-# Go to https://unsloth.ai/docs for advanced tips like
-# (1) Saving to GGUF / merging to 16bit for vLLM or SGLang
+# Go to https://github.com/unslothai/unsloth/wiki for advanced tips like
+# (1) Saving to GGUF / merging to 16bit for vLLM
# (2) Continued training from a saved LoRA adapter
# (3) Adding an evaluation loop / OOMs
# (4) Customized chat templates
@@ -367,19 +312,69 @@ trainer.train()
## 💡 Reinforcement Learning
-[RL](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide) including [GRPO](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide#training-with-grpo), [GSPO](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/gspo-reinforcement-learning), [**FP8** training](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning), DrGRPO, DAPO, PPO, Reward Modelling, Online DPO all work with Unsloth.
+RL including DPO, GRPO, PPO, Reward Modelling, Online DPO all work with Unsloth. We're in 🤗Hugging Face's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)! List of RL notebooks:
-Read our [Reinforcement Learning Guide](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide) or our [advanced RL docs](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/advanced-rl-documentation) for batching, generation & training parameters.
-
-List of RL notebooks:
-- gpt-oss GRPO notebook: [Link](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb)
-- ***FP8*** Qwen3-8B GRPO notebook (L4): [Link](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_8B_FP8_GRPO.ipynb)
-- Qwen3-VL GSPO notebook: [Link](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_VL_(8B)-Vision-GRPO.ipynb)
-- Advanced Qwen3 GRPO notebook: [Link](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb)
- ORPO notebook: [Link](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-ORPO.ipynb)
- DPO Zephyr notebook: [Link](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Zephyr_(7B)-DPO.ipynb)
-- KTO notebook: [Link](https://colab.research.google.com/drive/1MRgGtLWuZX4ypSfGguFgC-IblTvO2ivM?usp=sharing)
-- SimPO notebook: [Link](https://colab.research.google.com/drive/1Hs5oQDovOay4mFA6Y9lQhVJ8TnbFLFh2?usp=sharing)
+- KTO notebook: [Link](https://colab.research.google.com/drive/1a2b3c4d5e6f7g8h9i0j)
+- SimPO notebook: [Link](https://colab.research.google.com/drive/1a2b3c4d5e6f7g8h9i0j)
+
+
+ Click for DPO code
+
+```python
+import os
+os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Optional set GPU device ID
+
+from unsloth import FastLanguageModel
+import torch
+from trl import DPOTrainer, DPOConfig
+max_seq_length = 2048
+
+model, tokenizer = FastLanguageModel.from_pretrained(
+ model_name = "unsloth/zephyr-sft-bnb-4bit",
+ max_seq_length = max_seq_length,
+ load_in_4bit = True,
+)
+
+# Do model patching and add fast LoRA weights
+model = FastLanguageModel.get_peft_model(
+ model,
+ r = 64,
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
+ "gate_proj", "up_proj", "down_proj",],
+ lora_alpha = 64,
+ lora_dropout = 0, # Supports any, but = 0 is optimized
+ bias = "none", # Supports any, but = "none" is optimized
+ # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
+ use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
+ random_state = 3407,
+ max_seq_length = max_seq_length,
+)
+
+dpo_trainer = DPOTrainer(
+ model = model,
+ ref_model = None,
+ train_dataset = YOUR_DATASET_HERE,
+ # eval_dataset = YOUR_DATASET_HERE,
+ tokenizer = tokenizer,
+ args = DPOConfig(
+ per_device_train_batch_size = 4,
+ gradient_accumulation_steps = 8,
+ warmup_ratio = 0.1,
+ num_train_epochs = 3,
+ logging_steps = 1,
+ optim = "adamw_8bit",
+ seed = 42,
+ output_dir = "outputs",
+ max_length = 1024,
+ max_prompt_length = 512,
+ beta = 0.1,
+ ),
+)
+dpo_trainer.train()
+```
+
## 🥇 Performance Benchmarking
- For our most detailed benchmarks, read our [Llama 3.3 Blog](https://unsloth.ai/blog/llama3-3).
@@ -426,13 +421,14 @@ You can cite the Unsloth repo as follows:
@software{unsloth,
author = {Daniel Han, Michael Han and Unsloth team},
title = {Unsloth},
- url = {https://github.com/unslothai/unsloth},
+ url = {http://github.com/unslothai/unsloth},
year = {2023}
}
```
### Thank You to
-- The [llama.cpp library](https://github.com/ggml-org/llama.cpp) that lets users save models with Unsloth
-- The Hugging Face team and their libraries: [transformers](https://github.com/huggingface/transformers) and [TRL](https://github.com/huggingface/trl)
-- The Pytorch and [Torch AO](https://github.com/unslothai/unsloth/pull/3391) team for their contributions
-- And of course for every single person who has contributed or has used Unsloth!
+- Hugging Face's [TRL library](https://github.com/huggingface/trl) which serves as the basis foundation for Unsloth
+- [Erik](https://github.com/erikwijmans) for his help adding [Apple's ML Cross Entropy](https://github.com/apple/ml-cross-entropy) in Unsloth
+- [HuyNguyen-hust](https://github.com/HuyNguyen-hust) for making [RoPE Embeddings 28% faster](https://github.com/unslothai/unsloth/pull/238)
+- [RandomInternetPreson](https://github.com/RandomInternetPreson) for confirming WSL support
+- [152334H](https://github.com/152334H) for experimental DPO support
diff --git a/images/unsloth logo black text.png b/images/unsloth logo black text.png
index 5a86e9d3fc..4eb45557a4 100644
Binary files a/images/unsloth logo black text.png and b/images/unsloth logo black text.png differ
diff --git a/images/unsloth logo white text.png b/images/unsloth logo white text.png
index 891feaf8cb..2e37c7b19d 100644
Binary files a/images/unsloth logo white text.png and b/images/unsloth logo white text.png differ
diff --git a/images/unsloth sticker.png b/images/unsloth sticker.png
deleted file mode 100644
index 63ce830cec..0000000000
Binary files a/images/unsloth sticker.png and /dev/null differ
diff --git a/pyproject.toml b/pyproject.toml
index 9dc99866cf..7b1d2efda4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,15 +1,15 @@
[build-system]
-requires = ["setuptools==80.9.0", "setuptools-scm==9.2.0"]
+requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[project]
name = "unsloth"
dynamic = ["version"]
-description = "2-5X faster training, reinforcement learning & finetuning"
+description = "2-5X faster LLM finetuning"
readme = "README.md"
-requires-python = ">=3.9,<3.15"
-license = "Apache-2.0"
-keywords = ["ai", "llm", "reinforcement learning", "machine learning", "artificial intelligence", "pytorch"]
+requires-python = ">=3.9,<3.13"
+license = {file = "LICENSE"}
+keywords = ["ai", "llm",]
authors = [
{email = "info@unsloth.ai"},
{name = "Unsloth AI team"},
@@ -20,9 +20,6 @@ maintainers = [
]
classifiers = [
"Programming Language :: Python",
- "Environment :: GPU",
- "Environment :: GPU :: NVIDIA CUDA",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
]
[tool.setuptools.dynamic]
@@ -32,557 +29,359 @@ version = {attr = "unsloth.models._utils.__version__"}
include-package-data = false
[tool.setuptools.packages.find]
-exclude = ["images*", "tests*", "kernels/moe*"]
+exclude = ["images*"]
[project.optional-dependencies]
triton = [
- "triton>=3.0.0 ; ('linux' in sys_platform)",
- "triton-windows ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
+ "triton-windows ; platform_system == 'Windows'",
]
-huggingfacenotorch = [
- "wheel>=0.42.0",
+huggingface = [
+ "unsloth_zoo>=2025.3.11",
"packaging",
- "numpy",
- "tqdm",
- "psutil",
"tyro",
- "protobuf",
+ "transformers>=4.46.1,!=4.47.0",
+ "datasets>=2.16.0",
"sentencepiece>=0.2.0",
- "datasets>=3.4.1,!=4.0.*,!=4.1.0,<4.4.0",
+ "tqdm",
+ "psutil",
+ "wheel>=0.42.0",
+ "numpy",
"accelerate>=0.34.1",
- "peft>=0.18.0,!=0.11.0",
- "huggingface_hub>=0.34.0",
+ "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2",
+ "peft>=0.7.1,!=0.11.0",
+ "protobuf<4.0.0",
+ "huggingface_hub",
"hf_transfer",
- "diffusers",
- "transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.2.0",
- "trl>=0.18.2,!=0.19.0,<=0.24.0",
- "sentence-transformers",
-]
-huggingface = [
- "unsloth[huggingfacenotorch]",
- "unsloth_zoo>=2026.3.2",
- "torchvision",
"unsloth[triton]",
]
-windows = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0 ; (sys_platform == 'win32')",
- "xformers>=0.0.22.post7 ; (sys_platform == 'win32')",
-]
-base = [
+windows=[
"unsloth[huggingface]",
+ "bitsandbytes>=0.41.1 ; platform_system == 'Windows'",
+ "xformers>=0.0.22.post7 ; platform_system == 'Windows'",
]
cu118only = [
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
]
cu121only = [
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
]
cu118onlytorch211 = [
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
]
cu121onlytorch211 = [
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
]
cu118onlytorch212 = [
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
]
cu121onlytorch212 = [
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
]
cu118onlytorch220 = [
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
]
cu121onlytorch220 = [
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
]
cu118onlytorch230 = [
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
]
cu121onlytorch230 = [
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
]
cu118onlytorch240 = [
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
]
cu121onlytorch240 = [
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
]
cu124onlytorch240 = [
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu118onlytorch250 = [
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
]
cu121onlytorch250 = [
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
]
cu124onlytorch250 = [
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu118onlytorch251 = [
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
]
cu121onlytorch251 = [
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
]
cu124onlytorch251 = [
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu118onlytorch260 = [
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
]
cu124onlytorch260 = [
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu126onlytorch260 = [
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')",
-]
-cu118onlytorch270 = [
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')",
-]
-cu126onlytorch270 = [
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')",
-]
-cu128onlytorch270 = [
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')",
-]
-cu118onlytorch271 = [
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu126onlytorch271 = [
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu128onlytorch271 = [
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu118onlytorch280 = [
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu126onlytorch280 = [
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu128onlytorch280 = [
- "xformers @ https://download.pytorch.org/whl/cu129/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu129/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu130onlytorch280 = [
-]
-cu126onlytorch290 = [
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu128onlytorch290 = [
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu130onlytorch290 = [
- "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu126onlytorch291 = [
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu128onlytorch291 = [
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu130onlytorch291 = [
- "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu126onlytorch2100 = [
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.34-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.34-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu128onlytorch2100 = [
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.34-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.34-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
-]
-cu130onlytorch2100 = [
- "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.34-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)",
- "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.34-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')",
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
+ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu118 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118only]",
]
cu121 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121only]",
]
cu118-torch211 = [
"unsloth[huggingface]",
- "bitsandbytes==0.45.5",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch211]",
]
cu121-torch211 = [
"unsloth[huggingface]",
- "bitsandbytes==0.45.5",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch211]",
]
cu118-torch212 = [
"unsloth[huggingface]",
- "bitsandbytes==0.45.5",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch212]",
]
cu121-torch212 = [
"unsloth[huggingface]",
- "bitsandbytes==0.45.5",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch212]",
]
cu118-torch220 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch220]",
]
cu121-torch220 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch220]",
]
cu118-torch230 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch230]",
]
cu121-torch230 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch230]",
]
cu118-torch240 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch240]",
]
cu121-torch240 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch240]",
]
cu124-torch240 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu124onlytorch240]",
]
cu118-torch250 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch250]",
]
cu121-torch250 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch250]",
]
cu124-torch250 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu124onlytorch250]",
]
cu118-torch251 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch251]",
]
cu121-torch251 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch251]",
]
cu124-torch251 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu124onlytorch251]",
]
cu118-torch260 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.45.1",
"unsloth[cu118onlytorch260]",
]
cu124-torch260 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.45.1",
"unsloth[cu124onlytorch260]",
]
cu126-torch260 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.45.1",
"unsloth[cu126onlytorch260]",
]
-cu118-torch270 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu118onlytorch270]",
-]
-cu126-torch270 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch270]",
-]
-cu128-torch270 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch270]",
-]
-cu118-torch271 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu118onlytorch271]",
-]
-cu126-torch271 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch271]",
-]
-cu128-torch271 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch271]",
-]
-cu118-torch280 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu118onlytorch280]",
-]
-cu126-torch280 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch280]",
-]
-cu128-torch280 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch280]",
-]
-cu130-torch280 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu130onlytorch280]",
-]
-cu126-torch290 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch290]",
-]
-cu128-torch290 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch290]",
-]
-cu130-torch290 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu130onlytorch290]",
-]
-cu126-torch291 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch291]",
-]
-cu128-torch291 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch291]",
-]
-cu130-torch291 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu130onlytorch291]",
-]
-cu126-torch2100 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch2100]",
-]
-cu128-torch2100 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch2100]",
-]
-cu130-torch2100 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu130onlytorch2100]",
-]
kaggle = [
"unsloth[huggingface]",
]
kaggle-new = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
]
conda = [
"unsloth[huggingface]",
]
colab-torch211 = [
"unsloth[huggingface]",
- "bitsandbytes==0.45.5",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch211]",
]
colab-ampere-torch211 = [
"unsloth[huggingface]",
- "bitsandbytes==0.45.5",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch211]",
"packaging",
"ninja",
- "flash-attn>=2.6.3 ; ('linux' in sys_platform)",
+ "flash-attn>=2.6.3",
]
colab-torch220 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch220]",
]
colab-ampere-torch220 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch220]",
"packaging",
"ninja",
- "flash-attn>=2.6.3 ; ('linux' in sys_platform)",
+ "flash-attn>=2.6.3",
]
colab-new = [
- "unsloth_zoo>=2026.3.2",
+ "unsloth_zoo>=2025.3.9",
"packaging",
"tyro",
- "transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.2.0",
- "datasets>=3.4.1,!=4.0.*,!=4.1.0,<4.4.0",
+ "transformers>=4.46.1,!=4.47.0",
+ "datasets>=2.16.0",
"sentencepiece>=0.2.0",
"tqdm",
"psutil",
"wheel>=0.42.0",
"numpy",
- "protobuf",
- "huggingface_hub>=0.34.0",
+ "protobuf<4.0.0",
+ "huggingface_hub",
"hf_transfer",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[triton]",
- "sentence-transformers",
]
colab-no-deps = [
"accelerate>=0.34.1",
- "trl>=0.18.2,!=0.19.0,<=0.24.0",
- "peft>=0.18.0",
- "xformers ; ('linux' in sys_platform or sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "protobuf",
+ "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2",
+ "peft>=0.7.1",
+ "xformers",
+ "bitsandbytes>=0.46.1",
+ "protobuf<4.0.0",
]
colab = [
"unsloth[cu121]",
]
flashattention = [
- "packaging ; ('linux' in sys_platform)",
- "ninja ; ('linux' in sys_platform)",
- "flash-attn>=2.6.3 ; ('linux' in sys_platform)",
+ "packaging ; platform_system == 'Linux'",
+ "ninja ; platform_system == 'Linux'",
+ "flash-attn>=2.6.3 ; platform_system == 'Linux'",
]
colab-ampere = [
"unsloth[colab-ampere-torch220]",
@@ -590,481 +389,126 @@ colab-ampere = [
]
cu118-ampere = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118only]",
"unsloth[flashattention]",
]
cu121-ampere = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121only]",
"unsloth[flashattention]",
]
cu118-ampere-torch211 = [
"unsloth[huggingface]",
- "bitsandbytes==0.45.5",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch211]",
"unsloth[flashattention]",
]
cu121-ampere-torch211 = [
"unsloth[huggingface]",
- "bitsandbytes==0.45.5",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch211]",
"unsloth[flashattention]",
]
cu118-ampere-torch220 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch220]",
"unsloth[flashattention]",
]
cu121-ampere-torch220 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch220]",
"unsloth[flashattention]",
]
cu118-ampere-torch230 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch230]",
"unsloth[flashattention]",
]
cu121-ampere-torch230 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch230]",
"unsloth[flashattention]",
]
cu118-ampere-torch240 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch240]",
"unsloth[flashattention]",
]
cu121-ampere-torch240 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch240]",
"unsloth[flashattention]",
]
cu124-ampere-torch240 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu124onlytorch240]",
"unsloth[flashattention]",
]
cu118-ampere-torch250 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch250]",
"unsloth[flashattention]",
]
cu121-ampere-torch250 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch250]",
"unsloth[flashattention]",
]
cu124-ampere-torch250 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu124onlytorch250]",
"unsloth[flashattention]",
]
cu118-ampere-torch251 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu118onlytorch251]",
"unsloth[flashattention]",
]
cu121-ampere-torch251 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu121onlytorch251]",
"unsloth[flashattention]",
]
cu124-ampere-torch251 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.43.3",
"unsloth[cu124onlytorch251]",
"unsloth[flashattention]",
]
cu118-ampere-torch260 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.45.1",
"unsloth[cu118onlytorch260]",
"unsloth[flashattention]",
]
cu124-ampere-torch260 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.45.1",
"unsloth[cu124onlytorch260]",
"unsloth[flashattention]",
]
cu126-ampere-torch260 = [
"unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
+ "bitsandbytes>=0.45.1",
"unsloth[cu126onlytorch260]",
"unsloth[flashattention]",
]
-cu118-ampere-torch270 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu118onlytorch270]",
- "unsloth[flashattention]",
-]
-cu126-ampere-torch270 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch270]",
- "unsloth[flashattention]",
-]
-cu128-ampere-torch270 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch270]",
- "unsloth[flashattention]",
-]
-cu118-ampere-torch271 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu118onlytorch271]",
- "unsloth[flashattention]",
-]
-cu126-ampere-torch271 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch271]",
- "unsloth[flashattention]",
-]
-cu128-ampere-torch271 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch271]",
- "unsloth[flashattention]",
-]
-cu118-ampere-torch280 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu118onlytorch280]",
- "unsloth[flashattention]",
-]
-cu126-ampere-torch280 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch280]",
- "unsloth[flashattention]",
-]
-cu128-ampere-torch280 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch280]",
- "unsloth[flashattention]",
-]
-cu130-ampere-torch280 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu130onlytorch280]",
- "unsloth[flashattention]",
-]
-cu126-ampere-torch290 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch290]",
-]
-cu128-ampere-torch290 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch290]",
-]
-cu130-ampere-torch290 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu130onlytorch290]",
-]
-cu126-ampere-torch291 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch291]",
-]
-cu128-ampere-torch291 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch291]",
-]
-cu130-ampere-torch291 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu130onlytorch291]",
-]
-cu126-ampere-torch2100 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu126onlytorch2100]",
-]
-cu128-ampere-torch2100 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu128onlytorch2100]",
-]
-cu130-ampere-torch2100 = [
- "unsloth[huggingface]",
- "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
- "unsloth[cu130onlytorch2100]",
-]
-flashattentiontorch260abiFALSEcu12x = [
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'",
-]
-flashattentiontorch260abiTRUEcu12x = [
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'",
-]
-flashattentiontorch250abiFALSEcu12x = [
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'",
-]
-flashattentiontorch250abiTRUEcu12x = [
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'",
-]
-flashattentiontorch240abiFALSEcu12x = [
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'",
-]
-flashattentiontorch240abiTRUEcu12x = [
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'",
- "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'",
-]
-intelgputorch260 = [
- "unsloth_zoo[intelgpu]",
- "unsloth[huggingfacenotorch]",
-
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp39-cp39-linux_x86_64.whl#sha256=147607f190a7d7aa24ba454def5977fbbfec792fdae18e4ed278cfec29b69271 ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp310-cp310-linux_x86_64.whl#sha256=23aa423fa1542afc34f67eb3ba8ef20060f6d1b3a4697eaeab22b11c92b30f2b ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp311-cp311-linux_x86_64.whl#sha256=bcfa995229bbfd9ffd8d6c8d9f6428d393e876fa6e23ee3c20e3c0d73ca75ca5 ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp312-cp312-linux_x86_64.whl#sha256=bd340903d03470708df3442438acb8b7e08087ab9e61fbe349b2872bf9257ab0 ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp313-cp313-linux_x86_64.whl#sha256=814dccc8a07159e6eca74bed70091bc8fea2d9dd87b0d91845f9f38cde62f01c ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp39-cp39-linux_x86_64.whl#sha256=6a8adf6dc4c089406e8b3a7e58ab57a463bddf9b07130d2576e76eced43e92af ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=ff4561cbf07c83bbccaa0f6e9bb0e6dcf721bacd53c9c43c4eb0e7331b4792f9 ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=12005f66b810ddd3ab93f86c4522bcfdd412cbd27fc9d189b661ff7509bc5e8a ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=c4c5c67625cdacf35765c2b94e61fe166e3c3f4a14521b1212a59ad1b3eb0f2e ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=e6864f7a60a5ecc43d5d38f59a16e5dd132384f73dfd3a697f74944026038f7b ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-]
-intel-gpu-torch260 = [
- "unsloth[intelgputorch260]"
-]
-intelgputorch270 = [
- "unsloth_zoo[intelgpu]",
- "unsloth[huggingfacenotorch]",
-
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=749a7098492c6a27b356c97149a4a62973b953eae60bc1b6259260974f344913 ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=44362e80abd752471a08341093321955b066daa2cfb4810e73b8e3b240850f93 ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=faa6b8c945a837a080f641bc8ccc77a98fa66980dcd7e62e715fd853737343fd ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=40f6fb65b345dc9a61813abe7ac9a585f2c9808f414d140cc2a5f11f53ee063c ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=b22b4c02ec71b4bfc862ae3cdfd2871dc0b05d2b1802f5db2196e0f897d581e9 ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp39-cp39-win_amd64.whl#sha256=d4b738d7fa5100c1bd766f91614962828a4810eb57b4df92cd5214a83505a752 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp310-cp310-win_amd64.whl#sha256=143fe8a64d807bcdb7d81bbc062816add325570aa160448454ab6ded4a0a17a1 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp311-cp311-win_amd64.whl#sha256=a8025459ff325d6e3532eb5cf72519db1b178155e7d60aff6c56beb5968fc758 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp312-cp312-win_amd64.whl#sha256=0dd07e6d5b872e42e48f5ee140e609d4554ca3cc509d5bf509ac232267cf358e ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp313-cp313-win_amd64.whl#sha256=a936a18182d8e065a9933afc9a3ebbffadd38604969f87c493831214539fc027 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp39-cp39-linux_x86_64.whl#sha256=f8ee75e50fcbb37ed5b498299ca2264da99ab278a93fae2358e921e4a6e28273 ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=d6fdc342961d98fdcd9d03dfd491a3208bb5f7fbb435841f8f72ce9fdcd2d026 ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=74d07f9357df5cf2bf223ad3c84de16346bfaa0504f988fdd5590d3e177e5e86 ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=c806d44aa2ca5d225629f6fbc6c994d5deaac2d2cde449195bc8e3522ddd219a ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=25d8277b7f01d42e2e014ccbab57a2692b6ec4eff8dcf894eda1b297407cf97a ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp39-cp39-win_amd64.whl#sha256=046e85125266ae69c1a0d083e6c092f947ab4b6b41532c16bafe40dbced845df ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=9ebaeffb82b0b3e39b6030927d3ebe0eb62a0e9045a3b2d7b0a9e7b15222c0db ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=356ba66cee127e7e2c942880bd50e03768306a4ea08d358a0f29c6eebfc4bc81 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=94739e665d9b4d5cd7af5f517cb6103f6f9fb421c095184609653a24524040f5 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=31df3cb674918e89bc8c532baa331dc84f4430e1f9c0ec379232db44cba78355 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-]
-intel-gpu-torch270 = [
- "unsloth[intelgputorch270]"
-]
-intelgputorch280 = [
- "unsloth_zoo[intelgpu]",
- "unsloth[huggingfacenotorch]",
-
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=ac4d8e33986b1c3c5e48151640539272b2187e83016985853111b46fb82c3c94 ; 'linux' in sys_platform and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=999fef4c1f711092b9d3086525920545df490de476ecebe899ffc777019ae17f ; 'linux' in sys_platform and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=57b09c8c492985ff6a27cd3a22b08e8f7b96b407bd8030967b6efbb9f63b80cf ; 'linux' in sys_platform and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=df4bb3282bac9a3b90231700077110d8680b338416de03c2b7c6133c9b602649 ; 'linux' in sys_platform and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=60da63c99ca827bdcb0df28e0298bf7d066dc607454c6d6176783cb4e79d838b ; 'linux' in sys_platform and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp39-cp39-win_amd64.whl#sha256=64aea8de349f3e2e0ebf4c24b011a8122531fdffda5776edaef45829cc241cf8 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp310-cp310-win_amd64.whl#sha256=ae573d255b257fdbed319a3440dc9d0a721e31160ab7f6eba1b2226e6a409a1d ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp311-cp311-win_amd64.whl#sha256=8e0ea4558e5776d8ddab0264310be9b26aee5641bcac0da023537556d4317b86 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp312-cp312-win_amd64.whl#sha256=4090dde07a4fffc34aaf855701a9db28e9fccb57b368ade520f1a0f8e811c878 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp313-cp313-win_amd64.whl#sha256=a33d0888f3c8df028a2d028842715837d0049524d6c06b9bb11869890a13601a ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp39-cp39-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp310-cp310-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp311-cp311-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp312-cp312-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp313-cp313-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp39-cp39-win_amd64.whl#sha256=f2f401276892428e4875cf1d8717c5cbab704b16fc594ccf23795e7b16549a99 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=125c60cd59d51b39581a7e9afcd4679bc3a6b8c1f9440b1bb502a23fdd60571e ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=47f1a57258cd460e80b38b2ed6744e31587ab77a96b4215bf59546cb4bab5cc0 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=0937d8943c145a83d9bafc6f80ef28971167817f9eda26066d33f72caf8a6646 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=e034aab1d71760dc80a731531be43673ffe15e99033b82d24e40d2e6d41bd8bf ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp39-cp39-manylinux_2_28_x86_64.whl#sha256=6e981c192045fc249c008441179ff237bb00174d818b875b0475730b63f0eaca ; 'linux' in sys_platform and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=e5ba4805969277175ebfd59cc717093528cc6e3ada89ac2725fc7a3c1fee6169 ; 'linux' in sys_platform and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp311-cp311-manylinux_2_28_x86_64.whl#sha256=74c39c144104416bc4c5ad8c26ab0c169dc5cc6be58059e01bc3665dd0ef676f ; 'linux' in sys_platform and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=0acec355b80c3899841184084f365df336c508602812e34a44007b8b60d53af4 ; 'linux' in sys_platform and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp313-cp313-manylinux_2_28_x86_64.whl#sha256=e2109ae773dad27b98ca17681044b4f876563c37f2382b75de3a371399edcff8 ; 'linux' in sys_platform and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp39-cp39-win_amd64.whl#sha256=5f7904e7048d414379bc8c1167260f1e84204f105db2d0a2f9c89e87ce1cf205 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=005fca5e658ca8e37adb63c1a021c84f5e56dfa6cf0d601d89cfe40b9473f79f ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=c6d030f5361461550c0ff1339b5bca8585fc1e84fda2e64b6184e65a581e4f98 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=91aafd61864cdce27461cbec13ddbf28c1bc6494265a1e4b80131c64a3b7d18f ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=71dc4a6421742ed1e7f585b04a100ad53615c341fbccfbc255aefb38ea9091da ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-]
-intel-gpu-torch280 = [
- "unsloth[intelgputorch280]"
-]
-intelgputorch290 = [
- "unsloth_zoo[intelgpu]",
- "unsloth[huggingfacenotorch]",
-
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=c169a1de14c19673b17c751290d467fa282fc90fa5da4314b2e5cdab1f553146 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=013d9dd5d6479bd22983161f462e61c8dbe1d82e6730624a7a8d5945507eaa61 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=afc8cabfbf7ed51fd278d1e0f88d6afc157b0201bad4b99d681e4d542f9e66d4 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=0d24c1716088f2764d0d24c64227732195b6a42706c3c5fc89eeb4904bfa0818 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-win_amd64.whl#sha256=c83ab007311d9cfb6e809ee5a4587d99a9eef4be720b90da4f1aaa68b45139a0 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-win_amd64.whl#sha256=debf75348da8e8c7166b4d4a9b91d1508bb8d6581e339f79f7604b2e6746bacd ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-win_amd64.whl#sha256=97337a47425f1963a723475bd61037460e84ba01db4f87a1d662c3718ff6c47e ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-win_amd64.whl#sha256=2caf8138695f6abb023ecd02031a2611ba1bf8fff2f19802567cb2fadefe9e87 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=5afbe860ce991825a36b75706a523601087e414b77598ef0d9d3d565741c277d ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=607fe419c32d6e8e0556f745742e7cff1d0babce51f54be890e0c1422359c442 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=376bae584d89980b8e59934d248c38d5fa3b7d4687a4df1a19f4bc1d23dcc8c1 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=98d6a06dd7fb185874367b18bd609f05f16fdce4142a5980ca94461949965cd2 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=47cc68f631f65bd9c84924d052cd04dec7531023caa85e80345e9c94611c887d ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=d56c44ab4818aba57e5c7b628f422d014e0d507427170a771c5be85e308b0bc6 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=18cad93aaff76a01ce73aef6935ece7cfc03344b905592ec731446c44d44592b ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=579929cdc10a76800ead41289cac191ea36d1b16f5f501d3fc25607d4375cd83 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=cbfae2b79b7549fd368c2462fc8e94f8f26cc450782ee72138e908077c09a519 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp311-cp311-manylinux_2_28_x86_64.whl#sha256=044fa36ef4b6b43edcd490b75c853fa4b3eb033c2bded29f8fbcf27734713c67 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=4b91e4bec1d740a6211f02578a79888550b73f3a4e1383035f8f6d72f587212c ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp313-cp313-manylinux_2_28_x86_64.whl#sha256=88239e73ca37254bec84f29cd5887e10ff712de7edbbda3fbb3609cd6190d99e ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=19c7da8ca767d593e13a88a12bb08d06e34a673f6f26c2f9c191d60e81c02953 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=9bb0d1421c544ac8e2eca5b47daacaf54706dc9139c003aa5e77ee5f355c5931 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=6a5194bc736089606342d48a3f6822829b167617e9495d91d753dd1bd46fda18 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=da47a3ce2bb7f0301a31124668b5908f9b9e92d6241443de15a310ef9632fd83 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-]
-intel-gpu-torch290 = [
- "unsloth[intelgputorch290]"
-]
-intelgputorch210 = [
- "unsloth_zoo[intelgpu]",
- "unsloth[huggingfacenotorch]",
-
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=c169a1de14c19673b17c751290d467fa282fc90fa5da4314b2e5cdab1f553146 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=013d9dd5d6479bd22983161f462e61c8dbe1d82e6730624a7a8d5945507eaa61 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=afc8cabfbf7ed51fd278d1e0f88d6afc157b0201bad4b99d681e4d542f9e66d4 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=0d24c1716088f2764d0d24c64227732195b6a42706c3c5fc89eeb4904bfa0818 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-win_amd64.whl#sha256=c83ab007311d9cfb6e809ee5a4587d99a9eef4be720b90da4f1aaa68b45139a0 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-win_amd64.whl#sha256=debf75348da8e8c7166b4d4a9b91d1508bb8d6581e339f79f7604b2e6746bacd ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-win_amd64.whl#sha256=97337a47425f1963a723475bd61037460e84ba01db4f87a1d662c3718ff6c47e ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-win_amd64.whl#sha256=2caf8138695f6abb023ecd02031a2611ba1bf8fff2f19802567cb2fadefe9e87 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=abb1d1ec1ac672bac0ff35420c965f2df0c636ef9d94e2a830e34578489d0a57 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=71ad2f82da0f41eaec159f39fc85854e27c2391efa91b373e550648a6f4aaad3 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=b473571d478912f92881cc13f15fa18f8463fb0fb8a068c96ed47a7d45a4da0a ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=3bc64a746ff25a93de140902c60c9e819d7413f5cea1e88d80999c27a5901e9c ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=ce50691ab3fb6301d9b7bb8b3834cf5fa7152a2b5f91fd24c5efdc601a25b780 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=cb9d37f21cb9fb7df67d62863f021c3144e8d8832b9ea8e8523ac308bc620ea1 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=3ad605be4728b6d3a28a44d07dd794b1a9e45551b0057815bf25eb2a6d6a56a7 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=2b4b56dd6c792aef82006904fa888692e3782e4ae5da27526801bad4898f05a5 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=7e1e7b170fcf7161c8499b67156c5a05462243626dc0974010791a0bab4378d3 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp311-cp311-manylinux_2_28_x86_64.whl#sha256=bd6add201bd7628af70437292e1447abb368e0b5f4ff9abd334ae435efd44792 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=6ad2543496bc29e59d3dd614a94d09aa9870318aedb66045344fffddfedd2cf8 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp313-cp313-manylinux_2_28_x86_64.whl#sha256=80269f37865fcd8b57f20e4786efae2200bfa2b2727926c3c7acc82f0e7d3548 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=6b9485ba85dcba4d196d6134d9c3332fb228fb2556416bf0450a64e8a472fcba ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=36cbaedf10f6412af5c89afd9aeea474e6a56a0050348ada8fabe1ecaf6b879e ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=738357d97468d75fe3d510ac37e65130f2787f81d9bbc1518898f7396dc3403f ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
- "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=1c4b44b36a557f7381e3076fb8843366742238648441d607c8d049c6da0f8886 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-]
-intel-gpu-torch210 = [
- "unsloth[intelgputorch210]"
-]
-intel = [
- "unsloth[intelgputorch280]",
-]
-amd = [
- "unsloth[huggingfacenotorch]",
- "bitsandbytes>=0.49.1 ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64' or platform_machine == 'aarch64')",
- "bitsandbytes>=0.49.1 ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
-]
[project.urls]
homepage = "http://www.unsloth.ai"
documentation = "https://github.com/unslothai/unsloth"
repository = "https://github.com/unslothai/unsloth"
-
-[tool.ruff]
-target-version = "py311"
-force-exclude = true
-extend-exclude = [
- "*chat_templates.py",
- "*ollama_template_mappers.py",
- "*_auto_install.py",
- "*mapper.py",
-]
-
-[tool.ruff.lint]
-select = ["E9", "F63", "F7", "F82"]
-ignore = [
- "E402",
- "E722",
- "F403",
- "F405",
- "F811",
- "F821",
- "F841",
- "F401",
- "E731",
- "E741",
- "F601",
- "E712",
-]
-
-[tool.ruff.format]
diff --git a/scripts/enforce_kwargs_spacing.py b/scripts/enforce_kwargs_spacing.py
deleted file mode 100755
index ca2ff343a0..0000000000
--- a/scripts/enforce_kwargs_spacing.py
+++ /dev/null
@@ -1,179 +0,0 @@
-#!/usr/bin/env python3
-"""Ensure keyword arguments use spaces around '=', prune redundant pass statements."""
-
-from __future__ import annotations
-
-import ast
-import argparse
-import io
-import sys
-import tokenize
-from collections import defaultdict
-from pathlib import Path
-
-
-def enforce_spacing(text: str) -> tuple[str, bool]:
- """Return updated text with keyword '=' padded by spaces, plus change flag."""
- lines = text.splitlines(keepends=True)
- if not lines:
- return text, False
-
- offsets: dict[int, int] = defaultdict(int)
- changed = False
-
- reader = io.StringIO(text).readline
- for token in tokenize.generate_tokens(reader):
- if token.type != tokenize.OP or token.string != "=":
- continue
-
- line_index = token.start[0] - 1
- col = token.start[1] + offsets[line_index]
-
- if line_index < 0 or line_index >= len(lines):
- continue
-
- line = lines[line_index]
- if col >= len(line) or line[col] != "=":
- continue
-
- line_changed = False
-
- # Insert a space before '=' when missing and not preceded by whitespace.
- if col > 0 and line[col - 1] not in {" ", "\t"}:
- line = f"{line[:col]} {line[col:]}"
- offsets[line_index] += 1
- col += 1
- line_changed = True
- changed = True
-
- # Insert a space after '=' when missing and not followed by whitespace or newline.
- next_index = col + 1
- if next_index < len(line) and line[next_index] not in {" ", "\t", "\n", "\r"}:
- line = f"{line[:next_index]} {line[next_index:]}"
- offsets[line_index] += 1
- line_changed = True
- changed = True
-
- if line_changed:
- lines[line_index] = line
-
- if not changed:
- return text, False
-
- return "".join(lines), True
-
-
-def remove_redundant_passes(text: str) -> tuple[str, bool]:
- """Drop pass statements that share a block with other executable code."""
-
- try:
- tree = ast.parse(text)
- except SyntaxError:
- return text, False
-
- redundant: list[ast.Pass] = []
-
- def visit(node: ast.AST) -> None:
- for attr in ("body", "orelse", "finalbody"):
- value = getattr(node, attr, None)
- if not isinstance(value, list) or len(value) <= 1:
- continue
- for stmt in value:
- if isinstance(stmt, ast.Pass):
- redundant.append(stmt)
- for stmt in value:
- if isinstance(stmt, ast.AST):
- visit(stmt)
- handlers = getattr(node, "handlers", None)
- if handlers:
- for handler in handlers:
- visit(handler)
-
- visit(tree)
-
- if not redundant:
- return text, False
-
- lines = text.splitlines(keepends=True)
- changed = False
-
- for node in sorted(
- redundant, key=lambda item: (item.lineno, item.col_offset), reverse=True
- ):
- start = node.lineno - 1
- end = (node.end_lineno or node.lineno) - 1
- if start >= len(lines):
- continue
- changed = True
- if start == end:
- line = lines[start]
- col_start = node.col_offset
- col_end = node.end_col_offset or (col_start + 4)
- segment = line[:col_start] + line[col_end:]
- lines[start] = segment if segment.strip() else ""
- continue
-
- # Defensive fall-back for unexpected multi-line 'pass'.
- prefix = lines[start][: node.col_offset]
- lines[start] = prefix if prefix.strip() else ""
- for idx in range(start + 1, end):
- lines[idx] = ""
- suffix = lines[end][(node.end_col_offset or 0) :]
- lines[end] = suffix
-
- # Normalise to ensure lines end with newlines except at EOF.
- result_lines: list[str] = []
- for index, line in enumerate(lines):
- if not line:
- continue
- if index < len(lines) - 1 and not line.endswith("\n"):
- result_lines.append(f"{line}\n")
- else:
- result_lines.append(line)
-
- return "".join(result_lines), changed
-
-
-def process_file(path: Path) -> bool:
- try:
- with tokenize.open(path) as handle:
- original = handle.read()
- encoding = handle.encoding
- except (OSError, SyntaxError) as exc: # SyntaxError from tokenize on invalid python
- print(f"Failed to read {path}: {exc}", file=sys.stderr)
- return False
-
- updated, changed = enforce_spacing(original)
- updated, removed = remove_redundant_passes(updated)
- if changed or removed:
- path.write_text(updated, encoding=encoding)
- return True
- return False
-
-
-def main(argv: list[str]) -> int:
- parser = argparse.ArgumentParser(description=__doc__)
- parser.add_argument("files", nargs="+", help="Python files to fix")
- args = parser.parse_args(argv)
-
- touched: list[Path] = []
- self_path = Path(__file__).resolve()
-
- for entry in args.files:
- path = Path(entry)
- # Skip modifying this script to avoid self-edit loops.
- if path.resolve() == self_path:
- continue
- if not path.exists() or path.is_dir():
- continue
- if process_file(path):
- touched.append(path)
-
- if touched:
- for path in touched:
- print(f"Adjusted kwarg spacing in {path}")
- return 0
-
-
-if __name__ == "__main__":
- sys.exit(main(sys.argv[1:]))
diff --git a/scripts/run_ruff_format.py b/scripts/run_ruff_format.py
deleted file mode 100755
index 5ec16cd9f5..0000000000
--- a/scripts/run_ruff_format.py
+++ /dev/null
@@ -1,30 +0,0 @@
-#!/usr/bin/env python3
-"""Run `ruff format` followed by kwarg spacing enforcement."""
-
-from __future__ import annotations
-
-import subprocess
-import sys
-from pathlib import Path
-
-HERE = Path(__file__).resolve().parent
-
-
-def main(argv: list[str]) -> int:
- files = [arg for arg in argv if Path(arg).exists()]
- if not files:
- return 0
-
- ruff_cmd = [sys.executable, "-m", "ruff", "format", *files]
- ruff_proc = subprocess.run(ruff_cmd)
- if ruff_proc.returncode != 0:
- return ruff_proc.returncode
-
- spacing_script = HERE / "enforce_kwargs_spacing.py"
- spacing_cmd = [sys.executable, str(spacing_script), *files]
- spacing_proc = subprocess.run(spacing_cmd)
- return spacing_proc.returncode
-
-
-if __name__ == "__main__":
- raise SystemExit(main(sys.argv[1:]))
diff --git a/tests/__init__.py b/tests/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/qlora/README.md b/tests/qlora/README.md
deleted file mode 100644
index c05dcf446d..0000000000
--- a/tests/qlora/README.md
+++ /dev/null
@@ -1,47 +0,0 @@
-## QLoRA Train and Merge Tests
-
-### Overview
-Tests that performing QLoRA training and merging weights to 16-bits post-training maintains same behavior as trained model.
-
-- `test_unsloth_qlora_train_and_merge.py`: Test Unsloth QLoRA train and merge using `FastLanguageModel.from_pretrained`, `FastLanguageModel.get_peft_model`, and `FastLanguageModel.save_pretrained_merged` apis
-- `test_hf_qlora_train_and_merge.py`: Test Hugging Face QLoRA train and merge using `from_pretrained`, `get_peft_model`, and `merge_and_unload` apis.
- - Demonstrates that `peft`'s `merge_and_unload` results in loss of accuracy as it requantizes the base layer after merging adapter weights so that the model still contains `Linear4Bit` layers post merging.
- - I (@jeromeku) implemented a custom merge function that replaces all `LoraLayers` with `Linear` layers whose weights are the dequantized base layer weights with adapter weights merged (compute done in fp32, cast to original dtype after merging), roughly equivalent to `FastLanguageModel.save_pretrained_merged`.
-
-### Usage
-Run unsloth test:
-```bash
-python tests/qlora/test_unsloth_qlora_train_and_merge.py
-```
-Run huggingface test:
-```bash
-python tests/qlora/test_hf_qlora_train_and_merge.py
-```
-
-### Details
-The tests train a QLoRA model on a single prompt dataset
-```
-QUESTION = "What day was I born?"
-ANSWER = "January 1, 2058"
-USER_MESSAGE = {"role": "user", "content": QUESTION}
-ASSISTANT_MESSAGE = {"role": "assistant", "content": ANSWER}
-```
-
-Given that the answer is impossible to answer accurately without finetuning, we can only expect the model to answer the question correctly if the model has been trained on the question.
-
-To check this behavior, we check the model's response to the question before and after training and after merging, checking that the model's response contains the answer after training and merging but not before training.
-
-### Results
-
-For the unsloth test, the model's behavior is as expected:
-- before training, the model's response does not contain the answer
-- after training, the model's response contains the answer
-- after merging, the model's response contains the answer
-
-For the huggingface test, the model's behavior is as expected:
-- before training, the model's response does not contain the answer
-- after training, the model's response contains the answer
-- after using peft's `merge_and_unload`, the model's response does not contain the answer
-- after using my custom merge function, the model's response contains the answer
-
-The scripts should output training params, training logs, as well as model responses before and after training and after merging (only prints model responses if answer is not contained in response).
\ No newline at end of file
diff --git a/tests/qlora/test_hf_qlora_train_and_merge.py b/tests/qlora/test_hf_qlora_train_and_merge.py
deleted file mode 100644
index ae975b0266..0000000000
--- a/tests/qlora/test_hf_qlora_train_and_merge.py
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# ruff: noqa
-import sys
-from pathlib import Path
-
-REPO_ROOT = Path(__file__).parents[2]
-sys.path.append(str(REPO_ROOT))
-
-import itertools
-from copy import deepcopy
-
-import torch
-from datasets import Dataset
-from trl import SFTConfig
-from tests.utils import header_footer_context
-from tests.utils.data_utils import (
- ANSWER,
- DEFAULT_MESSAGES,
- USER_MESSAGE,
- check_responses,
- create_dataset,
- describe_peft_weights,
-)
-from tests.utils.hf_utils import (
- convert_lora_to_linear,
- fix_llama3_tokenizer,
- get_peft_config,
- sample_responses,
- setup_model,
- setup_tokenizer,
- setup_trainer,
-)
-
-if __name__ == "__main__":
- model_name = "meta-llama/Llama-3.2-1B-Instruct"
- dtype = torch.bfloat16
- max_steps = 100
- num_examples = 1000
- lora_rank = 64
- output_dir = "sft_test"
- seed = 42
- batch_size = 5
- num_generations = 5
- tokenizer = setup_tokenizer(model_name, fixup_funcs = [fix_llama3_tokenizer])
- temperature = 0.8
- max_new_tokens = 20
-
- peft_config = get_peft_config(lora_rank = lora_rank, target_modules = "all-linear")
- model = setup_model(model_name, quantize = True, dtype = dtype, peft_config = peft_config)
-
- prompt = tokenizer.apply_chat_template(
- [USER_MESSAGE], tokenize = False, add_generation_prompt = True
- )
- with header_footer_context("Test Prompt and Answer"):
- print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
-
- dataset: Dataset = create_dataset(
- tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
- )
- with header_footer_context("Dataset"):
- print(f"Dataset: {next(iter(dataset))}")
-
- training_args = SFTConfig(
- output_dir = output_dir,
- max_steps = max_steps,
- per_device_train_batch_size = batch_size,
- log_level = "info",
- report_to = "none",
- num_train_epochs = 1,
- logging_steps = 1,
- seed = seed,
- bf16 = dtype == torch.bfloat16,
- fp16 = dtype == torch.float16,
- save_strategy = "no",
- )
-
- with header_footer_context("Train Args"):
- print(training_args)
- print(peft_config)
-
- trainer = setup_trainer(
- model, tokenizer, dataset, training_args, peft_config = peft_config
- )
-
- with header_footer_context("Model"):
- print(type(model.model))
-
- generation_args = {
- "num_generations": num_generations,
- "max_new_tokens": max_new_tokens,
- "temperature": temperature,
- "skip_special_tokens": False,
- "dtype": dtype,
- }
- responses = sample_responses(
- model,
- tokenizer,
- prompt = prompt,
- **generation_args,
- )
- with header_footer_context("Responses before training"):
- check_responses(responses, answer = ANSWER, prompt = prompt)
-
- with header_footer_context("Peft Weights before training"):
- for name, stats in itertools.islice(describe_peft_weights(model), 2):
- print(f"{name}:\n{stats}")
-
- output = trainer.train()
- with header_footer_context("Peft Weights after training"):
- for name, stats in itertools.islice(describe_peft_weights(model), 2):
- print(f"{name}:\n{stats}")
-
- with header_footer_context("Trainer Output"):
- print(output)
-
- responses = sample_responses(
- model,
- tokenizer,
- prompt = prompt,
- **generation_args,
- )
- with header_footer_context("Responses after training"):
- check_responses(responses, answer = ANSWER, prompt = prompt)
-
- model_copy = deepcopy(model)
-
- merged_model = convert_lora_to_linear(model)
-
- responses = sample_responses(
- merged_model,
- tokenizer,
- prompt = prompt,
- **generation_args,
- )
- with header_footer_context("Responses after custom merging to 16bit"):
- check_responses(responses, answer = ANSWER, prompt = prompt)
-
- merged_model_peft = model_copy.merge_and_unload()
- responses = sample_responses(
- merged_model_peft,
- tokenizer,
- prompt = prompt,
- **generation_args,
- )
- with header_footer_context("Responses after peft merge_and_unload"):
- check_responses(responses, answer = ANSWER, prompt = prompt)
diff --git a/tests/qlora/test_unsloth_qlora_train_and_merge.py b/tests/qlora/test_unsloth_qlora_train_and_merge.py
deleted file mode 100644
index 9040ad793d..0000000000
--- a/tests/qlora/test_unsloth_qlora_train_and_merge.py
+++ /dev/null
@@ -1,211 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# ruff: noqa
-import sys
-from pathlib import Path
-
-REPO_ROOT = Path(__file__).parents[2]
-sys.path.append(str(REPO_ROOT))
-
-import itertools
-from unsloth import FastLanguageModel
-
-import torch
-from datasets import Dataset
-from trl import SFTConfig
-from tests.utils import header_footer_context
-from tests.utils.data_utils import (
- DEFAULT_MESSAGES,
- USER_MESSAGE,
- ANSWER,
- create_dataset,
- describe_peft_weights,
- check_responses,
-)
-from tests.utils.hf_utils import (
- sample_responses,
- setup_trainer,
-)
-
-
-def get_unsloth_model_and_tokenizer(
- model_name: str,
- max_seq_length: int,
- load_in_4bit: bool,
- fast_inference: bool,
- max_lora_rank: int = None,
- gpu_memory_utilization: float = 0.5,
- dtype: torch.dtype = torch.bfloat16,
-):
- return FastLanguageModel.from_pretrained(
- model_name = model_name,
- max_seq_length = max_seq_length,
- load_in_4bit = load_in_4bit,
- fast_inference = fast_inference,
- max_lora_rank = max_lora_rank,
- gpu_memory_utilization = gpu_memory_utilization,
- dtype = dtype,
- )
-
-
-def get_unsloth_peft_model(
- model,
- lora_rank: int,
- target_modules: list[str] = "all-linear",
- use_gradient_checkpointing: str = False,
- random_state: int = 42,
-):
- return FastLanguageModel.get_peft_model(
- model,
- r = lora_rank,
- target_modules = target_modules,
- lora_alpha = lora_rank,
- use_gradient_checkpointing = use_gradient_checkpointing,
- random_state = random_state,
- )
-
-
-if __name__ == "__main__":
- model_name = "meta-llama/Llama-3.2-1B-Instruct"
- dtype = torch.bfloat16
- max_steps = 100
- num_examples = 1000
- lora_rank = 64
- output_dir = "sft_test"
- seed = 42
- batch_size = 5
- num_generations = 5
- target_modules = [
- "q_proj",
- "k_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "up_proj",
- "down_proj",
- ]
- gradient_checkpointing = False
- unsloth_merged_path = "unsloth_merged_16bit"
-
- model, tokenizer = get_unsloth_model_and_tokenizer(
- model_name,
- max_seq_length = 512,
- load_in_4bit = True,
- fast_inference = False,
- max_lora_rank = lora_rank,
- dtype = dtype,
- )
- temperature = 0.8
- max_new_tokens = 20
-
- model = get_unsloth_peft_model(
- model,
- lora_rank = lora_rank,
- target_modules = target_modules,
- use_gradient_checkpointing = gradient_checkpointing,
- random_state = seed,
- )
-
- prompt = tokenizer.apply_chat_template(
- [USER_MESSAGE], tokenize = False, add_generation_prompt = True
- )
-
- with header_footer_context("Test Prompt and Answer"):
- print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}")
-
- dataset: Dataset = create_dataset(
- tokenizer, num_examples = num_examples, messages = DEFAULT_MESSAGES
- )
- with header_footer_context("Dataset"):
- print(f"Dataset: {next(iter(dataset))}")
-
- training_args = SFTConfig(
- output_dir = output_dir,
- max_steps = max_steps,
- per_device_train_batch_size = batch_size,
- log_level = "info",
- report_to = "none",
- num_train_epochs = 1,
- logging_steps = 1,
- seed = seed,
- bf16 = dtype == torch.bfloat16,
- fp16 = dtype == torch.float16,
- save_strategy = "no",
- )
-
- with header_footer_context("Train Args"):
- print(training_args)
-
- trainer = setup_trainer(model, tokenizer, dataset, training_args)
-
- with header_footer_context("Model"):
- print(type(model.model))
-
- generation_args = {
- "num_generations": num_generations,
- "max_new_tokens": max_new_tokens,
- "temperature": temperature,
- "skip_special_tokens": False,
- "dtype": dtype,
- }
- responses = sample_responses(
- model,
- tokenizer,
- prompt = prompt,
- **generation_args,
- )
- with header_footer_context("Responses before training"):
- check_responses(responses, answer = ANSWER, prompt = prompt)
- with header_footer_context("Peft Weights before training"):
- for name, stats in itertools.islice(describe_peft_weights(model), 2):
- print(f"{name}:\n{stats}")
-
- output = trainer.train()
- with header_footer_context("Peft Weights after training"):
- for name, stats in itertools.islice(describe_peft_weights(model), 2):
- print(f"{name}:\n{stats}")
-
- with header_footer_context("Trainer Output"):
- print(output)
-
- responses = sample_responses(
- model,
- tokenizer,
- prompt = prompt,
- **generation_args,
- )
- with header_footer_context("Responses after training"):
- check_responses(responses, answer = ANSWER, prompt = prompt)
-
- model.save_pretrained_merged(
- unsloth_merged_path,
- tokenizer,
- save_method = "merged_16bit",
- )
- merged_model_unsloth, tokenizer = get_unsloth_model_and_tokenizer(
- unsloth_merged_path,
- max_seq_length = 512,
- load_in_4bit = False,
- fast_inference = False,
- dtype = dtype,
- )
- responses = sample_responses(
- merged_model_unsloth,
- tokenizer,
- prompt = prompt,
- **generation_args,
- )
- with header_footer_context("Responses after unsloth merge to 16bit"):
- check_responses(responses, answer = ANSWER, prompt = prompt)
diff --git a/tests/saving/gpt-oss-merge/run_test.sh b/tests/saving/gpt-oss-merge/run_test.sh
deleted file mode 100755
index 5a91b31358..0000000000
--- a/tests/saving/gpt-oss-merge/run_test.sh
+++ /dev/null
@@ -1,18 +0,0 @@
-#!/bin/bash
-set -e
-
-echo "================================================================"
-echo "🚀 STEP 1: Running the training and merging script..."
-echo "================================================================"
-python train_and_merge.py
-
-echo ""
-echo "================================================================"
-echo "✅ STEP 2: Training complete. Running the inference script..."
-echo "================================================================"
-python test_merged_model.py
-
-echo ""
-echo "================================================================"
-echo "🎉 All steps completed successfully!"
-echo "================================================================"
diff --git a/tests/saving/gpt-oss-merge/test_merged_model.py b/tests/saving/gpt-oss-merge/test_merged_model.py
deleted file mode 100644
index 48f0ed2d3d..0000000000
--- a/tests/saving/gpt-oss-merge/test_merged_model.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# inference_on_merged.py
-from unsloth import FastLanguageModel
-from transformers import TextStreamer
-import torch
-import gc
-import os
-import shutil
-
-
-def safe_remove_directory(path):
- try:
- if os.path.exists(path) and os.path.isdir(path):
- shutil.rmtree(path)
- return True
- else:
- print(f"Path {path} is not a valid directory")
- return False
- except Exception as e:
- print(f"Failed to remove directory {path}: {e}")
- return False
-
-
-print("🔥 Loading the 16-bit merged model from disk...")
-merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./gpt-oss-finetuned-merged",
- max_seq_length = 1024,
- load_in_4bit = True,
- load_in_8bit = False,
-)
-print("✅ Merged model loaded successfully.")
-
-# --- Run Inference ---
-print("\n🚀 Running inference...")
-messages = [
- {"role": "user", "content": "Solve x^5 + 3x^4 - 10 = 3."},
-]
-inputs = merged_tokenizer.apply_chat_template(
- messages,
- add_generation_prompt = True,
- return_tensors = "pt",
- return_dict = True,
- reasoning_effort = "low", # **NEW!** Set reasoning effort to low, medium or high
-).to(merged_model.device)
-
-_ = merged_model.generate(
- **inputs, max_new_tokens = 512, streamer = TextStreamer(merged_tokenizer)
-)
-print("\n✅ Inference complete.")
-
-# --- Final Cleanup ---
-print("\n🧹 Cleaning up merged model directory and cache...")
-del merged_model, merged_tokenizer
-torch.cuda.empty_cache()
-gc.collect()
-
-safe_remove_directory("./gpt-oss-finetuned-merged")
-safe_remove_directory(
- "./unsloth_compiled_cache"
-) # Clean up cache created by this process
-print("✅ Final cleanup complete. Exiting inference script.")
diff --git a/tests/saving/gpt-oss-merge/train_and_merge.py b/tests/saving/gpt-oss-merge/train_and_merge.py
deleted file mode 100644
index 308d19bfb4..0000000000
--- a/tests/saving/gpt-oss-merge/train_and_merge.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# train_and_merge.py
-from unsloth import FastLanguageModel
-from trl import SFTTrainer, SFTConfig
-from datasets import load_dataset
-import torch
-import gc
-import os
-import shutil
-
-
-def safe_remove_directory(path):
- try:
- if os.path.exists(path) and os.path.isdir(path):
- shutil.rmtree(path)
- return True
- else:
- print(f"Path {path} is not a valid directory")
- return False
- except Exception as e:
- print(f"Failed to remove directory {path}: {e}")
- return False
-
-
-# This tokenizer will be used by the mapping function
-tokenizer = None
-
-
-def formatting_prompts_func(examples):
- convos = examples["messages"]
- texts = [
- tokenizer.apply_chat_template(
- convo, tokenize = False, add_generation_prompt = False
- )
- for convo in convos
- ]
- return {"text": texts}
-
-
-# --- Load 4-bit Model and Train ---
-print("Loading 4-bit Mxfp4 gpt-oss model for training...")
-max_seq_length = 1024
-model, tokenizer = FastLanguageModel.from_pretrained(
- "unsloth/gpt-oss-20b", max_seq_length = max_seq_length, load_in_4bit = True
-)
-
-dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split = "train[:50]").map(
- formatting_prompts_func, batched = True
-)
-
-model = FastLanguageModel.get_peft_model(
- model,
- r = 8,
- target_modules = [
- "q_proj",
- "k_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "up_proj",
- "down_proj",
- ],
- lora_alpha = 16,
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
-)
-
-trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- train_dataset = dataset,
- args = SFTConfig(
- per_device_train_batch_size = 1,
- gradient_accumulation_steps = 4,
- max_steps = 10,
- learning_rate = 2e-4,
- output_dir = "outputs",
- report_to = "none",
- ),
-)
-
-print("Starting fine-tuning...")
-trainer.train()
-print("Fine-tuning complete.")
-
-# --- Merge and Save ---
-print("\n💾 Merging and saving the 16-bit model to './gpt-oss-finetuned-merged'...")
-model.save_pretrained_merged(
- save_directory = "./gpt-oss-finetuned-merged", tokenizer = tokenizer
-)
-print("✅ Model merged and saved.")
-
-# --- Cleanup ---
-print("\n🧹 Cleaning up training artifacts...")
-del model, trainer, tokenizer, dataset
-torch.cuda.empty_cache()
-gc.collect()
-
-safe_remove_directory("./outputs")
-safe_remove_directory(
- "./unsloth_compiled_cache"
-) # Clean up the cache created by this process
-print("✅ Cleanup complete. Exiting training script.")
diff --git a/tests/saving/language_models/test_merge_4bit_validation.py b/tests/saving/language_models/test_merge_4bit_validation.py
deleted file mode 100644
index 343e737710..0000000000
--- a/tests/saving/language_models/test_merge_4bit_validation.py
+++ /dev/null
@@ -1,248 +0,0 @@
-from unsloth import FastLanguageModel
-from unsloth.chat_templates import get_chat_template
-from trl import SFTTrainer, SFTConfig
-from transformers import DataCollatorForSeq2Seq, TrainingArguments
-from datasets import load_dataset
-import torch
-import sys
-from pathlib import Path
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-
-
-def formatting_prompts_func(examples):
- convos = examples["messages"]
- texts = [
- tokenizer.apply_chat_template(
- convo, tokenize = False, add_generation_prompt = False
- )
- for convo in convos
- ]
- return {"text": texts}
-
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 1: Loading Base Model and Initial Training")
-print(f"{'='*80}")
-
-if torch.cuda.is_bf16_supported():
- compute_dtype = torch.bfloat16
- attn_implementation = "flash_attention_2"
-else:
- compute_dtype = torch.float16
- attn_implementation = "sdpa"
-
-model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/Llama-3.1-8B-Instruct",
- max_seq_length = 2048,
- dtype = compute_dtype,
- load_in_4bit = True,
- load_in_8bit = False,
- full_finetuning = False,
- attn_implementation = attn_implementation,
-)
-
-tokenizer = get_chat_template(
- tokenizer,
- chat_template = "llama-3.1",
-)
-
-# Load small dataset for quick training
-dataset_train = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "train[:100]"
-)
-dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
-
-print("✅ Base model loaded successfully!")
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 2: First Fine-tuning")
-print(f"{'='*80}")
-
-model = FastLanguageModel.get_peft_model(
- model,
- r = 16,
- target_modules = [
- "k_proj",
- "q_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "down_proj",
- "up_proj",
- ],
- lora_alpha = 16,
- lora_dropout = 0,
- bias = "none",
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
- use_rslora = False,
- loftq_config = None,
-)
-
-from unsloth import is_bfloat16_supported
-
-trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- train_dataset = dataset_train,
- dataset_text_field = "text",
- max_seq_length = 2048,
- data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
- dataset_num_proc = 2,
- packing = False,
- args = TrainingArguments(
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- warmup_ratio = 0.1,
- max_steps = 10, # Very short training for test
- learning_rate = 2e-4,
- fp16 = not is_bfloat16_supported(),
- bf16 = is_bfloat16_supported(),
- logging_steps = 5,
- optim = "adamw_8bit",
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "outputs",
- report_to = "none",
- ),
-)
-
-trainer_stats = trainer.train()
-print("✅ First fine-tuning completed!")
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 3: Save with Forced 4bit Merge")
-print(f"{'='*80}")
-
-model.save_pretrained_merged(
- save_directory = "./test_4bit_model",
- tokenizer = tokenizer,
- save_method = "forced_merged_4bit",
-)
-
-print("✅ Model saved with forced 4bit merge!")
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 4: Loading 4bit Model and Second Fine-tuning")
-print(f"{'='*80}")
-
-# Clean up first model
-del model
-del tokenizer
-torch.cuda.empty_cache()
-
-# Load the 4bit merged model
-model_4bit, tokenizer_4bit = FastLanguageModel.from_pretrained(
- model_name = "./test_4bit_model",
- max_seq_length = 2048,
- load_in_4bit = True,
- load_in_8bit = False,
-)
-
-tokenizer_4bit = get_chat_template(
- tokenizer_4bit,
- chat_template = "llama-3.1",
-)
-
-print("✅ 4bit model loaded successfully!")
-
-# Add LoRA adapters to the 4bit model
-model_4bit = FastLanguageModel.get_peft_model(
- model_4bit,
- r = 16,
- target_modules = [
- "k_proj",
- "q_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "down_proj",
- "up_proj",
- ],
- lora_alpha = 16,
- lora_dropout = 0,
- bias = "none",
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
- use_rslora = False,
- loftq_config = None,
-)
-
-# Second fine-tuning
-trainer_4bit = SFTTrainer(
- model = model_4bit,
- tokenizer = tokenizer_4bit,
- train_dataset = dataset_train,
- dataset_text_field = "text",
- max_seq_length = 2048,
- data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer_4bit),
- dataset_num_proc = 2,
- packing = False,
- args = TrainingArguments(
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- warmup_ratio = 0.1,
- max_steps = 10, # Very short training for test
- learning_rate = 2e-4,
- fp16 = not is_bfloat16_supported(),
- bf16 = is_bfloat16_supported(),
- logging_steps = 5,
- optim = "adamw_8bit",
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "outputs_4bit",
- report_to = "none",
- ),
-)
-
-trainer_4bit.train()
-print("✅ Second fine-tuning on 4bit model completed!")
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 5: Testing TypeError on Regular Merge (Should Fail)")
-print(f"{'='*80}")
-
-try:
- model_4bit.save_pretrained_merged(
- save_directory = "./test_should_fail",
- tokenizer = tokenizer_4bit,
- # No save_method specified, should default to regular merge
- )
- assert False, "Expected TypeError but merge succeeded!"
-except TypeError as e:
- expected_error = "Base model should be a 16bits or mxfp4 base model for a 16bit model merge. Use `save_method=forced_merged_4bit` instead"
- assert expected_error in str(e), f"Unexpected error message: {str(e)}"
- print("✅ Correct TypeError raised for 4bit base model regular merge attempt!")
- print(f"Error message: {str(e)}")
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 6: Successful Save with Forced 4bit Method")
-print(f"{'='*80}")
-
-try:
- model_4bit.save_pretrained_merged(
- save_directory = "./test_4bit_second",
- tokenizer = tokenizer_4bit,
- save_method = "forced_merged_4bit",
- )
- print("✅ Successfully saved 4bit model with forced 4bit method!")
-except Exception as e:
- assert False, f"Phase 6 failed unexpectedly: {e}"
-
-print(f"\n{'='*80}")
-print("🔍 CLEANUP")
-print(f"{'='*80}")
-
-# Cleanup
-safe_remove_directory("./outputs")
-safe_remove_directory("./outputs_4bit")
-safe_remove_directory("./unsloth_compiled_cache")
-safe_remove_directory("./test_4bit_model")
-safe_remove_directory("./test_4bit_second")
-safe_remove_directory("./test_should_fail")
-
-print("✅ All tests passed successfully!")
diff --git a/tests/saving/language_models/test_merge_model_perplexity_llama-3.2.py b/tests/saving/language_models/test_merge_model_perplexity_llama-3.2.py
deleted file mode 100644
index dd0e8c25c6..0000000000
--- a/tests/saving/language_models/test_merge_model_perplexity_llama-3.2.py
+++ /dev/null
@@ -1,259 +0,0 @@
-from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
-from unsloth.chat_templates import get_chat_template
-from trl import SFTTrainer, SFTConfig
-from transformers import (
- DataCollatorForLanguageModeling,
- DataCollatorForSeq2Seq,
- TrainingArguments,
-)
-from datasets import load_dataset, Dataset
-import torch
-from tqdm import tqdm
-import pandas as pd
-import multiprocessing as mp
-from multiprocessing import Process, Queue
-import gc
-
-# ruff: noqa
-import sys
-from pathlib import Path
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import (
- ppl_model,
- add_to_comparison,
- print_model_comparison,
-)
-
-
-# Define helper functions outside of main
-def formatting_prompts_func(examples):
- convos = examples["messages"]
- texts = [
- tokenizer.apply_chat_template(
- convo, tokenize = False, add_generation_prompt = False
- )
- for convo in convos
- ]
- return {"text": texts}
-
-
-def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
- """Load model and compute perplexity in subprocess"""
- from unsloth import FastLanguageModel
- from unsloth.chat_templates import get_chat_template
- from tests.utils.perplexity_eval import ppl_model
-
- # Load model
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_llama_text_model",
- max_seq_length = 2048,
- load_in_4bit = load_in_4bit,
- load_in_8bit = load_in_8bit,
- )
- # Set up tokenizer
- merged_tokenizer = get_chat_template(
- merged_tokenizer,
- chat_template = "llama-3.1",
- )
-
- # Load dataset fresh in subprocess
- dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "eval"
- )
-
- # Format the dataset
- def formatting_prompts_func(examples):
- convos = examples["messages"]
- texts = [
- merged_tokenizer.apply_chat_template(
- convo, tokenize = False, add_generation_prompt = False
- )
- for convo in convos
- ]
- return {"text": texts}
-
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
- # Compute perplexity using the passed dataset
- ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
-
- # IMPORTANT: Convert to Python float if it's a tensor
- if torch.is_tensor(ppl_value):
- ppl_value = ppl_value.cpu().item() # Move to CPU and convert to Python scalar
- elif hasattr(ppl_value, "item"):
- ppl_value = ppl_value.item() # Convert numpy or other array types
- else:
- ppl_value = float(ppl_value) # Ensure it's a float
-
- # Return only the perplexity value
- result_queue.put(ppl_value)
-
- # Clean up
- del merged_model
- del merged_tokenizer
- del dataset_ppl
- torch.cuda.empty_cache()
- gc.collect()
-
-
-# Main execution code should be wrapped in this guard
-if __name__ == "__main__":
- mp.set_start_method("spawn", force = True)
-
- if torch.cuda.is_bf16_supported():
- compute_dtype = torch.bfloat16
- attn_implementation = "flash_attention_2"
- else:
- compute_dtype = torch.float16
- attn_implementation = "sdpa"
-
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/Llama-3.2-3B-Instruct",
- max_seq_length = 2048,
- dtype = compute_dtype,
- load_in_4bit = True,
- load_in_8bit = False,
- full_finetuning = False,
- attn_implementation = attn_implementation,
- )
-
- tokenizer = get_chat_template(
- tokenizer,
- chat_template = "llama-3.1",
- )
-
- from unsloth.chat_templates import standardize_sharegpt
-
- dataset_train = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "train"
- )
- dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "eval"
- )
-
- dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
- add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
-
- model = FastLanguageModel.get_peft_model(
- model,
- r = 16,
- target_modules = [
- "k_proj",
- "q_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "down_proj",
- "up_proj",
- ],
- lora_alpha = 16,
- lora_dropout = 0,
- bias = "none",
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
- use_rslora = False,
- loftq_config = None,
- )
-
- from unsloth import is_bfloat16_supported
-
- trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- train_dataset = dataset_train,
- dataset_text_field = "text",
- max_seq_length = 2048,
- data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
- dataset_num_proc = 2,
- packing = False,
- args = TrainingArguments(
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- warmup_ratio = 0.1,
- max_steps = 10,
- learning_rate = 2e-4,
- fp16 = not is_bfloat16_supported(),
- bf16 = is_bfloat16_supported(),
- logging_steps = 50,
- optim = "adamw_8bit",
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "outputs",
- report_to = "none",
- ),
- )
-
- from unsloth.chat_templates import train_on_responses_only
-
- trainer = train_on_responses_only(
- trainer,
- instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
- response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
- )
-
- # run training
- trainer_stats = trainer.train()
-
- add_to_comparison("Qlora model", ppl_model(model, tokenizer, dataset_ppl))
-
- # saving and merging the model to local disk
- print("merge and save to local disk")
- model.save_pretrained_merged(
- save_directory = "./unsloth_out/merged_llama_text_model", tokenizer = tokenizer
- )
-
- # print("cleaning")
- # del model
- # del tokenizer
- # torch.cuda.empty_cache()
- # gc.collect()
-
- # load model from local disk and test
- print("Loading merged model in 4 bit for perplexity test")
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_llama_text_model",
- max_seq_length = 2048,
- load_in_4bit = True,
- load_in_8bit = False,
- )
-
- add_to_comparison(
- "merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl)
- )
-
- print("Computing 8-bit model perplexity in subprocess...")
- result_queue = mp.Queue()
- p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
- p.start()
- p.join()
-
- ppl_8bit = result_queue.get()
- add_to_comparison("merged model loaded 8bits", ppl_8bit)
-
- print("Loading merged model in 16 bit for perplexity test")
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_llama_text_model",
- max_seq_length = 2048,
- load_in_4bit = False,
- load_in_8bit = False,
- )
-
- add_to_comparison(
- "merged model loaded 16bits",
- ppl_model(merged_model, merged_tokenizer, dataset_ppl),
- )
-
- print_model_comparison()
-
- # final cleanup
- safe_remove_directory("./outputs")
- safe_remove_directory("./unsloth_compiled_cache")
- safe_remove_directory("./unsloth_out")
diff --git a/tests/saving/language_models/test_merge_model_perplexity_mistral.py b/tests/saving/language_models/test_merge_model_perplexity_mistral.py
deleted file mode 100644
index 14e657c68a..0000000000
--- a/tests/saving/language_models/test_merge_model_perplexity_mistral.py
+++ /dev/null
@@ -1,318 +0,0 @@
-from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
-from unsloth.chat_templates import get_chat_template
-from trl import SFTTrainer, SFTConfig
-from transformers import (
- DataCollatorForLanguageModeling,
- DataCollatorForSeq2Seq,
- TrainingArguments,
-)
-from datasets import load_dataset, Dataset
-import torch
-from tqdm import tqdm
-import pandas as pd
-import multiprocessing as mp
-from multiprocessing import Process, Queue
-import gc
-
-# ruff: noqa
-import sys
-from pathlib import Path
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import (
- ppl_model,
- add_to_comparison,
- print_model_comparison,
-)
-
-
-def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
- """Load model and compute perplexity in subprocess"""
- from unsloth import FastLanguageModel
- from tests.utils.perplexity_eval import ppl_model
-
- # Load model
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_mistral_text_model",
- max_seq_length = 2048,
- load_in_4bit = load_in_4bit,
- load_in_8bit = load_in_8bit,
- )
- # Set up tokenizer
- # merged_tokenizer = get_chat_template(
- # merged_tokenizer,
- # chat_template="llama-3.1",
- # )
-
- # Load dataset fresh in subprocess
- dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "eval"
- )
-
- alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
-
- ### Instruction:
- {}
-
- ### Input:
- {}
-
- ### Response:
- {}"""
-
- EOS_TOKEN = merged_tokenizer.eos_token
-
- def formatting_prompts_func(examples):
- instructions = []
- inputs = []
- outputs = []
- texts = []
-
- for conversation in examples["messages"]:
- # Extract user message and assistant response
- user_message = ""
- assistant_message = ""
-
- for turn in conversation:
- if turn["role"] == "user":
- user_message = turn["content"]
- elif turn["role"] == "assistant":
- assistant_message = turn["content"]
-
- # Store intermediate format
- instruction = "Complete the statement"
- instructions.append(instruction)
- inputs.append(user_message)
- outputs.append(assistant_message)
-
- # Create formatted text
- text = (
- alpaca_prompt.format(instruction, user_message, assistant_message)
- + EOS_TOKEN
- )
- texts.append(text)
-
- return {
- "instruction": instructions,
- "input": inputs,
- "output": outputs,
- "text": texts,
- }
-
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
- # Compute perplexity using the passed dataset
- ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
-
- # IMPORTANT: Convert to Python float if it's a tensor
- if torch.is_tensor(ppl_value):
- ppl_value = ppl_value.cpu().item() # Move to CPU and convert to Python scalar
- elif hasattr(ppl_value, "item"):
- ppl_value = ppl_value.item() # Convert numpy or other array types
- else:
- ppl_value = float(ppl_value) # Ensure it's a float
-
- # Return only the perplexity value
- result_queue.put(ppl_value)
-
- # Clean up
- del merged_model
- del merged_tokenizer
- del dataset_ppl
- torch.cuda.empty_cache()
- gc.collect()
-
-
-# Main execution code should be wrapped in this guard
-if __name__ == "__main__":
- mp.set_start_method("spawn", force = True)
-
- if torch.cuda.is_bf16_supported():
- compute_dtype = torch.bfloat16
- attn_implementation = "flash_attention_2"
- else:
- compute_dtype = torch.float16
- attn_implementation = "sdpa"
-
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/mistral-7b-v0.3",
- max_seq_length = 2048,
- dtype = compute_dtype,
- load_in_4bit = True,
- load_in_8bit = False,
- full_finetuning = False,
- attn_implementation = attn_implementation,
- )
-
- EOS_TOKEN = tokenizer.eos_token
-
- alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
-
- ### Instruction:
- {}
-
- ### Input:
- {}
-
- ### Response:
- {}"""
-
- # Define helper functions outside of main
- def formatting_prompts_func(examples):
- instructions = []
- inputs = []
- outputs = []
- texts = []
-
- for conversation in examples["messages"]:
- # Extract user message and assistant response
- user_message = ""
- assistant_message = ""
-
- for turn in conversation:
- if turn["role"] == "user":
- user_message = turn["content"]
- elif turn["role"] == "assistant":
- assistant_message = turn["content"]
-
- # Store intermediate format
- instruction = "Complete the statement"
- instructions.append(instruction)
- inputs.append(user_message)
- outputs.append(assistant_message)
-
- # Create formatted text
- text = (
- alpaca_prompt.format(instruction, user_message, assistant_message)
- + EOS_TOKEN
- )
- texts.append(text)
-
- return {
- "instruction": instructions,
- "input": inputs,
- "output": outputs,
- "text": texts,
- }
-
- dataset_train = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "train"
- )
- dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "eval"
- )
-
- dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
- add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
-
- model = FastLanguageModel.get_peft_model(
- model,
- r = 16,
- target_modules = [
- "k_proj",
- "q_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "down_proj",
- "up_proj",
- ],
- lora_alpha = 16,
- lora_dropout = 0,
- bias = "none",
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
- use_rslora = False,
- loftq_config = None,
- )
-
- from unsloth import is_bfloat16_supported
-
- trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- train_dataset = dataset_train,
- dataset_text_field = "text",
- max_seq_length = 2048,
- dataset_num_proc = 2,
- packing = False,
- args = TrainingArguments(
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- warmup_ratio = 0.1,
- max_steps = 200,
- learning_rate = 2e-4,
- fp16 = not is_bfloat16_supported(),
- bf16 = is_bfloat16_supported(),
- logging_steps = 50,
- optim = "adamw_8bit",
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "outputs",
- report_to = "none",
- ),
- )
-
- # run training
- trainer_stats = trainer.train()
-
- add_to_comparison("Qlora model", ppl_model(model, tokenizer, dataset_ppl))
-
- # saving and merging the model to local disk
- print("merge and save to local disk")
- model.save_pretrained_merged(
- save_directory = "./unsloth_out/merged_mistral_text_model", tokenizer = tokenizer
- )
-
- # print("cleaning")
- # del model
- # del tokenizer
- # torch.cuda.empty_cache()
- # gc.collect()
-
- # load model from local disk and test
- print("Loading merged model in 4 bit for perplexity test")
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_mistral_text_model",
- max_seq_length = 2048,
- load_in_4bit = True,
- load_in_8bit = False,
- )
-
- add_to_comparison(
- "merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl)
- )
-
- print("Computing 8-bit model perplexity in subprocess...")
- result_queue = mp.Queue()
- p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
- p.start()
- p.join()
-
- ppl_8bit = result_queue.get()
- add_to_comparison("merged model loaded 8bits", ppl_8bit)
-
- print("Loading merged model in 16 bit for perplexity test")
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_mistral_text_model",
- max_seq_length = 2048,
- load_in_4bit = False,
- load_in_8bit = False,
- )
-
- add_to_comparison(
- "merged model loaded 16bits",
- ppl_model(merged_model, merged_tokenizer, dataset_ppl),
- )
-
- print_model_comparison()
-
- safe_remove_directory("./outputs")
- safe_remove_directory("./unsloth_compiled_cache")
- safe_remove_directory("./unsloth_out")
diff --git a/tests/saving/language_models/test_merge_model_perplexity_phi_4.py b/tests/saving/language_models/test_merge_model_perplexity_phi_4.py
deleted file mode 100644
index bebea8168e..0000000000
--- a/tests/saving/language_models/test_merge_model_perplexity_phi_4.py
+++ /dev/null
@@ -1,259 +0,0 @@
-from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
-from unsloth.chat_templates import get_chat_template
-from trl import SFTTrainer, SFTConfig
-from transformers import (
- DataCollatorForLanguageModeling,
- DataCollatorForSeq2Seq,
- TrainingArguments,
-)
-from datasets import load_dataset, Dataset
-import torch
-from tqdm import tqdm
-import pandas as pd
-import multiprocessing as mp
-from multiprocessing import Process, Queue
-import gc
-
-# ruff: noqa
-import sys
-from pathlib import Path
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import (
- ppl_model,
- add_to_comparison,
- print_model_comparison,
-)
-
-
-# Define helper functions outside of main
-def formatting_prompts_func(examples):
- convos = examples["messages"]
- texts = [
- tokenizer.apply_chat_template(
- convo, tokenize = False, add_generation_prompt = False
- )
- for convo in convos
- ]
- return {
- "text": texts,
- }
-
-
-def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
- """Load model and compute perplexity in subprocess"""
- from unsloth import FastLanguageModel
- from unsloth.chat_templates import get_chat_template
- from tests.utils.perplexity_eval import ppl_model
-
- # Load model
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_phi4_text_model",
- max_seq_length = 2048,
- load_in_4bit = load_in_4bit,
- load_in_8bit = load_in_8bit,
- )
- # Set up tokenizer
- merged_tokenizer = get_chat_template(
- merged_tokenizer,
- chat_template = "phi-4",
- )
-
- # Load dataset fresh in subprocess
- dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "eval"
- )
-
- # Format the dataset
- def formatting_prompts_func(examples):
- convos = examples["messages"]
- texts = [
- merged_tokenizer.apply_chat_template(
- convo, tokenize = False, add_generation_prompt = False
- )
- for convo in convos
- ]
- return {"text": texts}
-
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
- # Compute perplexity using the passed dataset
- ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
-
- # IMPORTANT: Convert to Python float if it's a tensor
- if torch.is_tensor(ppl_value):
- ppl_value = ppl_value.cpu().item() # Move to CPU and convert to Python scalar
- elif hasattr(ppl_value, "item"):
- ppl_value = ppl_value.item() # Convert numpy or other array types
- else:
- ppl_value = float(ppl_value) # Ensure it's a float
-
- # Return only the perplexity value
- result_queue.put(ppl_value)
-
- # Clean up
- del merged_model
- del merged_tokenizer
- del dataset_ppl
- torch.cuda.empty_cache()
- gc.collect()
-
-
-# Main execution code should be wrapped in this guard
-if __name__ == "__main__":
- mp.set_start_method("spawn", force = True)
-
- if torch.cuda.is_bf16_supported():
- compute_dtype = torch.bfloat16
- attn_implementation = "flash_attention_2"
- else:
- compute_dtype = torch.float16
- attn_implementation = "sdpa"
-
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/Phi-4",
- max_seq_length = 2048,
- dtype = compute_dtype,
- load_in_4bit = True,
- load_in_8bit = False,
- full_finetuning = False,
- attn_implementation = attn_implementation,
- )
-
- tokenizer = get_chat_template(
- tokenizer,
- chat_template = "phi-4",
- )
-
- dataset_train = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "train"
- )
- dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "eval"
- )
-
- dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
- add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
-
- model = FastLanguageModel.get_peft_model(
- model,
- r = 16,
- target_modules = [
- "k_proj",
- "q_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "down_proj",
- "up_proj",
- ],
- lora_alpha = 16,
- lora_dropout = 0,
- bias = "none",
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
- use_rslora = False,
- loftq_config = None,
- )
-
- from unsloth import is_bfloat16_supported
-
- trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- train_dataset = dataset_train,
- dataset_text_field = "text",
- max_seq_length = 2048,
- data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
- dataset_num_proc = 2,
- packing = False,
- args = TrainingArguments(
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- warmup_ratio = 0.1,
- max_steps = 200,
- learning_rate = 2e-4,
- fp16 = not is_bfloat16_supported(),
- bf16 = is_bfloat16_supported(),
- logging_steps = 50,
- optim = "adamw_8bit",
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "outputs",
- report_to = "none",
- ),
- )
-
- from unsloth.chat_templates import train_on_responses_only
-
- trainer = train_on_responses_only(
- trainer,
- instruction_part = "<|im_start|>user<|im_sep|>\n\n",
- response_part = "<|im_start|>assistant<|im_sep|>\n\n",
- )
-
- # run training
- trainer_stats = trainer.train()
-
- add_to_comparison("Qlora model", ppl_model(model, tokenizer, dataset_ppl))
-
- # saving and merging the model to local disk
- print("merge and save to local disk")
- model.save_pretrained_merged(
- save_directory = "./unsloth_out/merged_phi4_text_model", tokenizer = tokenizer
- )
-
- # print("cleaning")
- # del model
- # del tokenizer
- # torch.cuda.empty_cache()
- # gc.collect()
-
- # load model from local disk and test
- print("Loading merged model in 4 bit for perplexity test")
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_phi4_text_model",
- max_seq_length = 2048,
- load_in_4bit = True,
- load_in_8bit = False,
- )
-
- add_to_comparison(
- "merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl)
- )
-
- print("Computing 8-bit model perplexity in subprocess...")
- result_queue = mp.Queue()
- p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
- p.start()
- p.join()
-
- ppl_8bit = result_queue.get()
- add_to_comparison("merged model loaded 8bits", ppl_8bit)
-
- print("Loading merged model in 16 bit for perplexity test")
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_phi4_text_model",
- max_seq_length = 2048,
- load_in_4bit = False,
- load_in_8bit = False,
- )
-
- add_to_comparison(
- "merged model loaded 16bits",
- ppl_model(merged_model, merged_tokenizer, dataset_ppl),
- )
-
- print_model_comparison()
-
- # final cleanup
- safe_remove_directory("./outputs")
- safe_remove_directory("./unsloth_compiled_cache")
- safe_remove_directory("./unsloth_out")
diff --git a/tests/saving/language_models/test_merged_model_perplexity_llama-3.1-8b.py b/tests/saving/language_models/test_merged_model_perplexity_llama-3.1-8b.py
deleted file mode 100644
index c6da9e2ca6..0000000000
--- a/tests/saving/language_models/test_merged_model_perplexity_llama-3.1-8b.py
+++ /dev/null
@@ -1,263 +0,0 @@
-from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
-from unsloth.chat_templates import get_chat_template
-from trl import SFTTrainer, SFTConfig
-from transformers import (
- DataCollatorForLanguageModeling,
- DataCollatorForSeq2Seq,
- TrainingArguments,
-)
-from datasets import load_dataset, Dataset
-import torch
-from tqdm import tqdm
-import pandas as pd
-import multiprocessing as mp
-from multiprocessing import Process, Queue
-import gc
-
-# ruff: noqa
-import sys
-from pathlib import Path
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import (
- ppl_model,
- add_to_comparison,
- print_model_comparison,
-)
-
-
-# Define helper functions outside of main
-def formatting_prompts_func(examples):
- convos = examples["messages"]
- texts = [
- tokenizer.apply_chat_template(
- convo, tokenize = False, add_generation_prompt = False
- )
- for convo in convos
- ]
- return {"text": texts}
-
-
-def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
- """Load model and compute perplexity in subprocess"""
- from unsloth import FastLanguageModel
- from unsloth.chat_templates import get_chat_template
- from tests.utils.perplexity_eval import ppl_model
-
- # Load model
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_llama_text_model",
- max_seq_length = 2048,
- load_in_4bit = load_in_4bit,
- load_in_8bit = load_in_8bit,
- )
- # Set up tokenizer
- merged_tokenizer = get_chat_template(
- merged_tokenizer,
- chat_template = "llama-3.1",
- )
-
- # Load dataset fresh in subprocess
- dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "eval"
- )
-
- # Format the dataset
- def formatting_prompts_func(examples):
- convos = examples["messages"]
- texts = [
- merged_tokenizer.apply_chat_template(
- convo, tokenize = False, add_generation_prompt = False
- )
- for convo in convos
- ]
- return {"text": texts}
-
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
- # Compute perplexity using the passed dataset
- ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
-
- # IMPORTANT: Convert to Python float if it's a tensor
- if torch.is_tensor(ppl_value):
- ppl_value = ppl_value.cpu().item() # Move to CPU and convert to Python scalar
- elif hasattr(ppl_value, "item"):
- ppl_value = ppl_value.item() # Convert numpy or other array types
- else:
- ppl_value = float(ppl_value) # Ensure it's a float
-
- # Return only the perplexity value
- result_queue.put(ppl_value)
-
- # Clean up
- del merged_model
- del merged_tokenizer
- del dataset_ppl
- torch.cuda.empty_cache()
- gc.collect()
-
-
-# Main execution code should be wrapped in this guard
-if __name__ == "__main__":
- mp.set_start_method("spawn", force = True)
-
- if torch.cuda.is_bf16_supported():
- compute_dtype = torch.bfloat16
- attn_implementation = "flash_attention_2"
- else:
- compute_dtype = torch.float16
- attn_implementation = "sdpa"
-
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/Llama-3.1-8B-Instruct",
- max_seq_length = 2048,
- dtype = compute_dtype,
- load_in_4bit = True,
- load_in_8bit = False,
- full_finetuning = False,
- attn_implementation = attn_implementation,
- )
-
- tokenizer = get_chat_template(
- tokenizer,
- chat_template = "llama-3.1",
- )
-
- from unsloth.chat_templates import standardize_sharegpt
-
- dataset_train = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "train"
- )
- dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "eval"
- )
-
- dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
- print("\n dataset sample [0]")
- print(dataset_train[0])
-
- add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
-
- model = FastLanguageModel.get_peft_model(
- model,
- r = 16,
- target_modules = [
- "k_proj",
- "q_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "down_proj",
- "up_proj",
- ],
- lora_alpha = 16,
- lora_dropout = 0,
- bias = "none",
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
- use_rslora = False,
- loftq_config = None,
- )
-
- from unsloth import is_bfloat16_supported
-
- trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- train_dataset = dataset_train,
- dataset_text_field = "text",
- max_seq_length = 2048,
- data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
- dataset_num_proc = 2,
- packing = False,
- args = TrainingArguments(
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- warmup_ratio = 0.1,
- max_steps = 200,
- learning_rate = 2e-4,
- fp16 = not is_bfloat16_supported(),
- bf16 = is_bfloat16_supported(),
- logging_steps = 50,
- optim = "adamw_8bit",
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "outputs",
- report_to = "none",
- ),
- )
-
- from unsloth.chat_templates import train_on_responses_only
-
- trainer = train_on_responses_only(
- trainer,
- instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
- response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
- )
-
- tokenizer.decode(trainer.train_dataset[0]["input_ids"])
-
- # run training
- trainer_stats = trainer.train()
-
- add_to_comparison("Qlora model", ppl_model(model, tokenizer, dataset_ppl))
-
- # saving and merging the model to local disk
- print("merge and save to local disk")
- model.save_pretrained_merged(
- save_directory = "./unsloth_out/merged_llama_text_model", tokenizer = tokenizer
- )
-
- # print("cleaning")
- # del model
- # del tokenizer
- # torch.cuda.empty_cache()
- # gc.collect()
-
- # load model from local disk and test
- print("Loading merged model in 4 bit for perplexity test")
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_llama_text_model",
- max_seq_length = 2048,
- load_in_4bit = True,
- load_in_8bit = False,
- )
-
- add_to_comparison(
- "merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl)
- )
-
- print("Computing 8-bit model perplexity in subprocess...")
- result_queue = mp.Queue()
- p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
- p.start()
- p.join()
-
- ppl_8bit = result_queue.get()
- add_to_comparison("merged model loaded 8bits", ppl_8bit)
-
- print("Loading merged model in 16 bit for perplexity test")
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_llama_text_model",
- max_seq_length = 2048,
- load_in_4bit = False,
- load_in_8bit = False,
- )
-
- add_to_comparison(
- "merged model loaded 16bits",
- ppl_model(merged_model, merged_tokenizer, dataset_ppl),
- )
-
- print_model_comparison()
-
- # final cleanup
- safe_remove_directory("./outputs")
- safe_remove_directory("./unsloth_compiled_cache")
- safe_remove_directory("./unsloth_out")
diff --git a/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py b/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py
deleted file mode 100644
index d63bb9fe09..0000000000
--- a/tests/saving/language_models/test_merged_model_perplexity_qwen_2.5.py
+++ /dev/null
@@ -1,311 +0,0 @@
-from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
-from unsloth.chat_templates import get_chat_template
-from trl import SFTTrainer, SFTConfig
-from transformers import (
- DataCollatorForLanguageModeling,
- DataCollatorForSeq2Seq,
- TrainingArguments,
-)
-from datasets import load_dataset, Dataset
-import torch
-from tqdm import tqdm
-import pandas as pd
-import multiprocessing as mp
-from multiprocessing import Process, Queue
-import gc
-
-# ruff: noqa
-import sys
-from pathlib import Path
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import (
- ppl_model,
- add_to_comparison,
- print_model_comparison,
-)
-
-
-alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
-
-### Instruction:
-{}
-
-### Input:
-{}
-
-### Response:
-{}"""
-
-
-# Define helper functions outside of main
-def formatting_prompts_func(examples):
- instructions = []
- inputs = []
- outputs = []
- texts = []
-
- for conversation in examples["messages"]:
- # Extract user message and assistant response
- user_message = ""
- assistant_message = ""
-
- for turn in conversation:
- if turn["role"] == "user":
- user_message = turn["content"]
- elif turn["role"] == "assistant":
- assistant_message = turn["content"]
-
- # Store intermediate format
- instruction = "Complete the statement"
- instructions.append(instruction)
- inputs.append(user_message)
- outputs.append(assistant_message)
-
- # Create formatted text
- text = alpaca_prompt.format(instruction, user_message, assistant_message)
- texts.append(text)
-
- return {
- "instruction": instructions,
- "input": inputs,
- "output": outputs,
- "text": texts,
- }
-
-
-def load_and_compute_8bit_ppl(result_queue, load_in_4bit = False, load_in_8bit = False):
- """Load model and compute perplexity in subprocess"""
- from unsloth import FastLanguageModel
- from tests.utils.perplexity_eval import ppl_model
-
- # Load model
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_qwen_text_model",
- max_seq_length = 2048,
- load_in_4bit = load_in_4bit,
- load_in_8bit = load_in_8bit,
- )
- # Set up tokenizer
- # merged_tokenizer = get_chat_template(
- # merged_tokenizer,
- # chat_template="llama-3.1",
- # )
-
- # Load dataset fresh in subprocess
- dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "eval"
- )
-
- alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
-
- ### Instruction:
- {}
-
- ### Input:
- {}
-
- ### Response:
- {}"""
-
- def formatting_prompts_func(examples):
- instructions = []
- inputs = []
- outputs = []
- texts = []
-
- for conversation in examples["messages"]:
- # Extract user message and assistant response
- user_message = ""
- assistant_message = ""
-
- for turn in conversation:
- if turn["role"] == "user":
- user_message = turn["content"]
- elif turn["role"] == "assistant":
- assistant_message = turn["content"]
-
- # Store intermediate format
- instruction = "Complete the statement"
- instructions.append(instruction)
- inputs.append(user_message)
- outputs.append(assistant_message)
-
- # Create formatted text
- text = alpaca_prompt.format(instruction, user_message, assistant_message)
- texts.append(text)
-
- return {
- "instruction": instructions,
- "input": inputs,
- "output": outputs,
- "text": texts,
- }
-
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
- # Compute perplexity using the passed dataset
- ppl_value = ppl_model(merged_model, merged_tokenizer, dataset_ppl)
-
- # IMPORTANT: Convert to Python float if it's a tensor
- if torch.is_tensor(ppl_value):
- ppl_value = ppl_value.cpu().item() # Move to CPU and convert to Python scalar
- elif hasattr(ppl_value, "item"):
- ppl_value = ppl_value.item() # Convert numpy or other array types
- else:
- ppl_value = float(ppl_value) # Ensure it's a float
-
- # Return only the perplexity value
- result_queue.put(ppl_value)
-
- # Clean up
- # del merged_model
- # del merged_tokenizer
- # del dataset_ppl
- # torch.cuda.empty_cache()
- # gc.collect()
-
-
-# Main execution code should be wrapped in this guard
-if __name__ == "__main__":
- mp.set_start_method("spawn", force = True)
-
- if torch.cuda.is_bf16_supported():
- compute_dtype = torch.bfloat16
- attn_implementation = "flash_attention_2"
- else:
- compute_dtype = torch.float16
- attn_implementation = "sdpa"
-
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/Qwen2.5-7B-Instruct",
- max_seq_length = 2048,
- dtype = compute_dtype,
- load_in_4bit = True,
- load_in_8bit = False,
- full_finetuning = False,
- attn_implementation = attn_implementation,
- )
-
- dataset_train = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "train"
- )
- dataset_ppl = load_dataset(
- "allenai/openassistant-guanaco-reformatted", split = "eval"
- )
-
- dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
- dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
- add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
-
- model = FastLanguageModel.get_peft_model(
- model,
- r = 16,
- target_modules = [
- "k_proj",
- "q_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "down_proj",
- "up_proj",
- ],
- lora_alpha = 16,
- lora_dropout = 0,
- bias = "none",
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
- use_rslora = False,
- loftq_config = None,
- )
-
- from unsloth import is_bfloat16_supported
-
- trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- train_dataset = dataset_train,
- dataset_text_field = "text",
- max_seq_length = 2048,
- data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
- dataset_num_proc = 2,
- packing = False,
- args = TrainingArguments(
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- warmup_ratio = 0.1,
- max_steps = 200,
- learning_rate = 2e-4,
- fp16 = not is_bfloat16_supported(),
- bf16 = is_bfloat16_supported(),
- logging_steps = 50,
- optim = "adamw_8bit",
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "outputs",
- report_to = "none",
- ),
- )
-
- # run training
- trainer_stats = trainer.train()
-
- add_to_comparison("Qlora model", ppl_model(model, tokenizer, dataset_ppl))
-
- # saving and merging the model to local disk
- print("merge and save to local disk")
- model.save_pretrained_merged(
- save_directory = "./unsloth_out/merged_qwen_text_model", tokenizer = tokenizer
- )
-
- # print("cleaning")
- # del model
- # del tokenizer
- # torch.cuda.empty_cache()
- # gc.collect()
-
- # load model from local disk and test
- print("Loading merged model in 4 bit for perplexity test")
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_qwen_text_model",
- max_seq_length = 2048,
- load_in_4bit = True,
- load_in_8bit = False,
- )
-
- add_to_comparison(
- "merged model load 4bit", ppl_model(merged_model, merged_tokenizer, dataset_ppl)
- )
-
- print("Computing 8-bit model perplexity in subprocess...")
- result_queue = mp.Queue()
- p = mp.Process(target = load_and_compute_8bit_ppl, args = (result_queue, False, True))
- p.start()
- p.join()
-
- ppl_8bit = result_queue.get()
- add_to_comparison("merged model loaded 8bits", ppl_8bit)
-
- print("Loading merged model in 16 bit for perplexity test")
- merged_model, merged_tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./unsloth_out/merged_qwen_text_model",
- max_seq_length = 2048,
- load_in_4bit = False,
- load_in_8bit = False,
- )
-
- add_to_comparison(
- "merged model loaded 16bits",
- ppl_model(merged_model, merged_tokenizer, dataset_ppl),
- )
-
- print_model_comparison()
-
- safe_remove_directory("./outputs")
- safe_remove_directory("./unsloth_compiled_cache")
- safe_remove_directory("./unsloth_out")
diff --git a/tests/saving/language_models/test_push_to_hub_merged.py b/tests/saving/language_models/test_push_to_hub_merged.py
deleted file mode 100644
index 58d589305a..0000000000
--- a/tests/saving/language_models/test_push_to_hub_merged.py
+++ /dev/null
@@ -1,204 +0,0 @@
-from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
-from unsloth.chat_templates import get_chat_template
-from trl import SFTTrainer, SFTConfig
-from transformers import (
- DataCollatorForLanguageModeling,
- DataCollatorForSeq2Seq,
- TrainingArguments,
-)
-from datasets import load_dataset, Dataset
-import torch
-from tqdm import tqdm
-import pandas as pd
-import multiprocessing as mp
-from multiprocessing import Process, Queue
-import gc
-import os
-from huggingface_hub import HfFileSystem, hf_hub_download
-
-# ruff: noqa
-import sys
-from pathlib import Path
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import (
- ppl_model,
- add_to_comparison,
- print_model_comparison,
-)
-
-
-# Define helper functions outside of main
-def formatting_prompts_func(examples):
- convos = examples["messages"]
- texts = [
- tokenizer.apply_chat_template(
- convo, tokenize = False, add_generation_prompt = False
- )
- for convo in convos
- ]
- return {"text": texts}
-
-
-if torch.cuda.is_bf16_supported():
- compute_dtype = torch.bfloat16
- attn_implementation = "flash_attention_2"
-else:
- compute_dtype = torch.float16
- attn_implementation = "sdpa"
-
-model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/Llama-3.2-1B-Instruct",
- max_seq_length = 2048,
- dtype = compute_dtype,
- load_in_4bit = True,
- load_in_8bit = False,
- full_finetuning = False,
- attn_implementation = attn_implementation,
-)
-
-tokenizer = get_chat_template(
- tokenizer,
- chat_template = "llama-3.1",
-)
-
-from unsloth.chat_templates import standardize_sharegpt
-
-dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train")
-dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval")
-
-dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
-dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
-add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
-
-model = FastLanguageModel.get_peft_model(
- model,
- r = 16,
- target_modules = [
- "k_proj",
- "q_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "down_proj",
- "up_proj",
- ],
- lora_alpha = 16,
- lora_dropout = 0,
- bias = "none",
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
- use_rslora = False,
- loftq_config = None,
-)
-
-from unsloth import is_bfloat16_supported
-
-trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- train_dataset = dataset_train,
- dataset_text_field = "text",
- max_seq_length = 2048,
- data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
- dataset_num_proc = 2,
- packing = False,
- args = TrainingArguments(
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- warmup_ratio = 0.1,
- max_steps = 30,
- learning_rate = 2e-4,
- fp16 = not is_bfloat16_supported(),
- bf16 = is_bfloat16_supported(),
- logging_steps = 50,
- optim = "adamw_8bit",
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "outputs",
- report_to = "none",
- ),
-)
-
-from unsloth.chat_templates import train_on_responses_only
-
-trainer = train_on_responses_only(
- trainer,
- instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
- response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
-)
-
-# run training
-trainer_stats = trainer.train()
-
-
-# saving and merging the model to local disk
-hf_username = os.environ.get("HF_USER", "")
-if not hf_username:
- hf_username = input("Please enter your Hugging Face username: ").strip()
- os.environ["HF_USER"] = hf_username
-
-hf_token = os.environ.get("HF_TOKEN", "")
-if not hf_token:
- hf_token = input("Please enter your Hugging Face token: ").strip()
- os.environ["HF_TOKEN"] = hf_token
-
-
-repo_name = f"{hf_username}/merged_llama_text_model"
-success = {
- "upload": False,
- "download": False,
-}
-
-# Stage 1: Upload model to Hub
-try:
- print("\n" + "=" * 80)
- print("=== UPLOADING MODEL TO HUB ===".center(80))
- print("=" * 80 + "\n")
- model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
- success["upload"] = True
- print("✅ Model uploaded successfully!")
-except Exception as e:
- print(f"❌ Failed to upload model: {e}")
- raise Exception("Model upload failed.")
-
-t
-# Stage 2: Test downloading the model (even if cached)
-safe_remove_directory(f"./{hf_username}")
-
-try:
- print("\n" + "=" * 80)
- print("=== TESTING MODEL DOWNLOAD ===".center(80))
- print("=" * 80 + "\n")
- # Force download even if cached
- model, tokenizer = FastLanguageModel.from_pretrained(
- f"{hf_username}/merged_llama_text_model"
- )
- success["download"] = True
- print("✅ Model downloaded successfully!")
-except Exception as e:
- print(f"❌ Download failed: {e}")
- raise Exception("Model download failed.")
-
-# Final report
-print("\n" + "=" * 80)
-print("=== VALIDATION REPORT ===".center(80))
-print("=" * 80 + "\n")
-for stage, passed in success.items():
- status = "✓" if passed else "✗"
- print(f"{status} {stage.replace('_', ' ').title()}")
-print("\n" + "=" * 80)
-
-if all(success.values()):
- print("\n🎉 All stages completed successfully!")
-else:
- raise Exception("Validation failed for one or more stages.")
-
-# final cleanup
-safe_remove_directory("./outputs")
-safe_remove_directory("./unsloth_compiled_cache")
diff --git a/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py b/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py
deleted file mode 100644
index 038565d170..0000000000
--- a/tests/saving/language_models/test_push_to_hub_merged_sharded_index_file.py
+++ /dev/null
@@ -1,223 +0,0 @@
-from unsloth import FastLanguageModel, FastVisionModel, UnslothVisionDataCollator
-from unsloth.chat_templates import get_chat_template
-from trl import SFTTrainer, SFTConfig
-from transformers import (
- DataCollatorForLanguageModeling,
- DataCollatorForSeq2Seq,
- TrainingArguments,
-)
-from datasets import load_dataset, Dataset
-import torch
-from tqdm import tqdm
-import pandas as pd
-import multiprocessing as mp
-from multiprocessing import Process, Queue
-import gc
-import os
-from huggingface_hub import HfFileSystem, hf_hub_download
-
-# ruff: noqa
-import sys
-from pathlib import Path
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.perplexity_eval import (
- ppl_model,
- add_to_comparison,
- print_model_comparison,
-)
-
-
-# Define helper functions outside of main
-def formatting_prompts_func(examples):
- convos = examples["messages"]
- texts = [
- tokenizer.apply_chat_template(
- convo, tokenize = False, add_generation_prompt = False
- )
- for convo in convos
- ]
- return {"text": texts}
-
-
-if torch.cuda.is_bf16_supported():
- compute_dtype = torch.bfloat16
- attn_implementation = "flash_attention_2"
-else:
- compute_dtype = torch.float16
- attn_implementation = "sdpa"
-
-model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/Llama-3.1-8B-Instruct",
- max_seq_length = 2048,
- dtype = compute_dtype,
- load_in_4bit = True,
- load_in_8bit = False,
- full_finetuning = False,
- attn_implementation = attn_implementation,
-)
-
-tokenizer = get_chat_template(
- tokenizer,
- chat_template = "llama-3.1",
-)
-
-from unsloth.chat_templates import standardize_sharegpt
-
-dataset_train = load_dataset("allenai/openassistant-guanaco-reformatted", split = "train")
-dataset_ppl = load_dataset("allenai/openassistant-guanaco-reformatted", split = "eval")
-
-dataset_train = dataset_train.map(formatting_prompts_func, batched = True)
-dataset_ppl = dataset_ppl.map(formatting_prompts_func, batched = True)
-
-add_to_comparison("Base model 4 bits", ppl_model(model, tokenizer, dataset_ppl))
-
-model = FastLanguageModel.get_peft_model(
- model,
- r = 16,
- target_modules = [
- "k_proj",
- "q_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "down_proj",
- "up_proj",
- ],
- lora_alpha = 16,
- lora_dropout = 0,
- bias = "none",
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
- use_rslora = False,
- loftq_config = None,
-)
-
-from unsloth import is_bfloat16_supported
-
-trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- train_dataset = dataset_train,
- dataset_text_field = "text",
- max_seq_length = 2048,
- data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
- dataset_num_proc = 2,
- packing = False,
- args = TrainingArguments(
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- warmup_ratio = 0.1,
- max_steps = 30,
- learning_rate = 2e-4,
- fp16 = not is_bfloat16_supported(),
- bf16 = is_bfloat16_supported(),
- logging_steps = 50,
- optim = "adamw_8bit",
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "outputs",
- report_to = "none",
- ),
-)
-
-from unsloth.chat_templates import train_on_responses_only
-
-trainer = train_on_responses_only(
- trainer,
- instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
- response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
-)
-
-# run training
-trainer_stats = trainer.train()
-
-
-# saving and merging the model to local disk
-hf_username = os.environ.get("HF_USER", "")
-if not hf_username:
- hf_username = input("Please enter your Hugging Face username: ").strip()
- os.environ["HF_USER"] = hf_username
-
-hf_token = os.environ.get("HF_TOKEN", "")
-if not hf_token:
- hf_token = input("Please enter your Hugging Face token: ").strip()
- os.environ["HF_TOKEN"] = hf_token
-
-
-repo_name = f"{hf_username}/merged_llama_text_model"
-success = {
- "upload": False,
- "safetensors_check": False,
- "download": False,
-}
-
-# Stage 1: Upload model to Hub
-try:
- print("\n" + "=" * 80)
- print("=== UPLOADING MODEL TO HUB ===".center(80))
- print("=" * 80 + "\n")
- model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
- success["upload"] = True
- print("✅ Model uploaded successfully!")
-except Exception as e:
- print(f"❌ Failed to upload model: {e}")
- raise Exception("Model upload failed.")
-
-# Stage 2: Verify safetensors.index.json exists
-try:
- print("\n" + "=" * 80)
- print("=== VERIFYING REPO CONTENTS ===".center(80))
- print("=" * 80 + "\n")
- fs = HfFileSystem(token = hf_token)
- file_list = fs.ls(repo_name, detail = True)
- safetensors_found = any(
- file["name"].endswith("model.safetensors.index.json") for file in file_list
- )
- if safetensors_found:
- success["safetensors_check"] = True
- print("✅ model.safetensors.index.json found in repo!")
- else:
- raise Exception("model.safetensors.index.json not found in repo.")
-except Exception as e:
- print(f"❌ Verification failed: {e}")
- raise Exception("Repo verification failed.")
-
-# Stage 3: Test downloading the model (even if cached)
-safe_remove_directory("./RTannous")
-
-try:
- print("\n" + "=" * 80)
- print("=== TESTING MODEL DOWNLOAD ===".center(80))
- print("=" * 80 + "\n")
- # Force download even if cached
- model, tokenizer = FastLanguageModel.from_pretrained(
- f"{hf_username}/merged_llama_text_model"
- )
- success["download"] = True
- print("✅ Model downloaded successfully!")
-except Exception as e:
- print(f"❌ Download failed: {e}")
- raise Exception("Model download failed.")
-
-# Final report
-print("\n" + "=" * 80)
-print("=== VALIDATION REPORT ===".center(80))
-print("=" * 80 + "\n")
-for stage, passed in success.items():
- status = "✓" if passed else "✗"
- print(f"{status} {stage.replace('_', ' ').title()}")
-print("\n" + "=" * 80)
-
-if all(success.values()):
- print("\n🎉 All stages completed successfully!")
-else:
- raise Exception("Validation failed for one or more stages.")
-
-# final cleanup
-safe_remove_directory("./outputs")
-safe_remove_directory("./unsloth_compiled_cache")
diff --git a/tests/saving/language_models/test_save_merged_grpo_model.py b/tests/saving/language_models/test_save_merged_grpo_model.py
deleted file mode 100644
index 67b649305a..0000000000
--- a/tests/saving/language_models/test_save_merged_grpo_model.py
+++ /dev/null
@@ -1,825 +0,0 @@
-# -*- coding: utf-8 -*-
-"""test_Llama3_1_(3B)_GRPO_LoRA (1).ipynb
-
-### Unsloth
-
-"""
-
-from unsloth import FastLanguageModel
-import torch
-import sys
-from pathlib import Path
-import multiprocessing as mp
-import gc
-from multiprocessing import Queue
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.aime_eval import evaluate_model_aime, compare_aime_results
-
-
-max_seq_length = 2048 # Can increase for longer reasoning traces
-lora_rank = 64 # Larger rank = smarter, but slower
-
-
-def evaluate_merged_model(result_queue, load_in_4bit = False, load_in_8bit = False):
- from unsloth import FastLanguageModel
- from tests.utils.aime_eval import evaluate_model_aime
-
- max_seq_length = 2048 # Can increase for longer reasoning traces
- lora_rank = 64 # Larger rank = smarter, but slower
-
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./final_merged_model",
- max_seq_length = max_seq_length,
- load_in_4bit = True, # False for LoRA 16bit
- fast_inference = True, # Enable vLLM fast inference
- max_lora_rank = lora_rank,
- gpu_memory_utilization = 0.8, # Reduce if out of memory
- )
-
- print(f"\n{'='*60}")
- if load_in_4bit:
- print("🔍 EVALUATION Merged model: 4 bits load")
- model_type = "merged_model_4bits"
- elif load_in_8bit:
- print("🔍 EVALUATION Merged model: 8 bits load")
- model_type = "merged_model_8bits"
- else:
- print("🔍 EVALUATION Merged model: 16 bits load")
- model_type = "merged_model_16bits"
- print(f"{'='*60}")
-
- evaluate_model_aime(
- model = model,
- tokenizer = tokenizer,
- model_type = model_type,
- temperature = 0.3,
- n_sampling = 8,
- max_tokens = 32768,
- top_p = 0.95,
- seed = 0,
- )
-
- result_queue.put(results)
-
- del model
- del tokenizer
- torch.cuda.empty_cache()
- gc.collect()
-
-
-# Main execution code should be wrapped in this guard
-def training_run(result_queue):
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "meta-llama/Llama-3.2-3B-Instruct",
- max_seq_length = max_seq_length,
- load_in_4bit = False, # False for LoRA 16bit
- fast_inference = True, # Enable vLLM fast inference
- max_lora_rank = lora_rank,
- gpu_memory_utilization = 0.8, # Reduce if out of memory
- )
-
- """### Helper Functions
-
-
-#### Helper functions - Data Prep
- """
-
- import re
- import json
-
- reasoning_start = ""
- reasoning_end = ""
- solution_start = ""
- solution_end = ""
-
- def extract_hash_answer(text):
- """Extract answer from GSM8K format"""
- if "####" not in text:
- return None
- return text.split("####")[1].strip()
-
- def prepare_gsm8k_dataset(dataset):
- """Format GSM8K dataset for training"""
- reasoning_start = ""
- reasoning_end = ""
- solution_start = ""
- solution_end = ""
-
- system_prompt = (
- f"You are given a problem. Think about the problem and reason step by step. "
- f"Place your thinking process between {reasoning_start} and {reasoning_end}. "
- f"Then, provide your final numerical solution between {solution_start}{solution_end}"
- )
-
- def format_gsm8k(example):
- return {
- "prompt": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": example["question"]},
- ],
- "answer": extract_hash_answer(example["answer"]),
- }
-
- return dataset.map(format_gsm8k)
-
- def prepare_limo_dataset(dataset):
- """Format LIMO dataset for SFT training"""
- if dataset is None:
- return None
-
- system_prompt = """You are a helpful reasoning assistant. When given a problem, think through it step by step and provide your answer in the following format:
-
-
- [Your detailed step-by-step reasoning and solution process]
-
-
- [Your final numerical answer]
- """
-
- def format_limo(example):
- # Create the assistant response
- assistant_response = f"\n{example['solution']}\n\n\n{example['answer']}\n"
-
- # Return a DICTIONARY with the conversation in a field
- return {
- "prompt": [ # ← This is the key change - wrap in a dict
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": example["question"]},
- {"role": "assistant", "content": assistant_response},
- ]
- }
-
- return dataset.map(format_limo)
-
- print("\n✅ Dataset preparation functions defined!")
-
- """#### Helper functions - Evaluation"""
-
- def get_max_prompt_length(dataset, tokenizer):
- """Calculate maximum and average prompt length in dataset"""
- print("Analyzing prompt lengths...")
-
- lengths = dataset.map(
- lambda x: {
- "tokens": tokenizer.apply_chat_template(
- x["prompt"], add_generation_prompt = True, tokenize = True
- )
- },
- batched = True,
- ).map(lambda x: {"length": len(x["tokens"])})["length"]
-
- max_length = max(lengths)
- avg_length = sum(lengths) / len(lengths)
- min_length = min(lengths)
-
- print(
- f"Prompt lengths - Min: {min_length}, Max: {max_length}, Avg: {avg_length:.1f}"
- )
- return max_length, avg_length
-
- def extract_unsloth_answer(text, start_tag = "", end_tag = ""):
- """Extract answer from Unsloth SOLUTION tags"""
- pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
- matches = re.findall(pattern, text, re.DOTALL)
-
- if matches:
- answer = matches[-1] # Get the last match
- answer = re.sub(r"[%$,]", "", answer).strip()
- return answer
- return ""
-
- def find_number(search_string):
- """Find the last number in a string"""
- numbers = re.compile(
- r"-?[\d,]*\.?\d+",
- re.MULTILINE | re.DOTALL | re.IGNORECASE,
- ).findall(search_string)
-
- if numbers:
- return numbers[-1].replace(",", "").strip()
- return ""
-
- def remove_symbols(x: str) -> str:
- """Remove commas, percent and dollar symbols"""
- if not x:
- return ""
- return x.replace(",", "").replace("%", "").replace("$", "").strip()
-
- def get_num_tokens(text, tokenizer_instance):
- """Count tokens in text"""
- if not text:
- return 0
- encoding = tokenizer_instance(text, return_tensors = "pt")
- return len(encoding["input_ids"][0])
-
- def check_format_compliance(text, format_type = "unsloth"):
- """Check if response follows expected format"""
- if format_type == "unsloth":
- reasoning_start = ""
- reasoning_end = ""
- solution_start = ""
- solution_end = ""
-
- pattern = (
- rf"^[\s]*{re.escape(reasoning_start)}.+?{re.escape(reasoning_end)}.*?"
- rf"{re.escape(solution_start)}.+?{re.escape(solution_end)}[\s]*$"
- )
- else:
- return False
-
- return bool(re.match(pattern, text.strip(), re.DOTALL))
-
- def normalize_answer(answer):
- """Normalize answer for comparison"""
- if not answer:
- return ""
-
- normalized = remove_symbols(str(answer))
-
- try:
- float_val = float(normalized)
- if float_val.is_integer():
- return str(int(float_val))
- else:
- return str(float_val)
- except (ValueError, TypeError):
- return normalized
-
- def evaluate_answer_correctness(extracted_answer, ground_truth):
- """Evaluate answer correctness with multiple criteria"""
- if not extracted_answer or not ground_truth:
- return False, False, 0.0
-
- norm_extracted = normalize_answer(extracted_answer)
- norm_ground_truth = normalize_answer(ground_truth)
-
- if norm_extracted == norm_ground_truth:
- return True, True, 1.0
-
- try:
- extracted_num = float(norm_extracted)
- ground_truth_num = float(norm_ground_truth)
-
- if ground_truth_num != 0:
- relative_error = abs(extracted_num - ground_truth_num) / abs(
- ground_truth_num
- )
-
- if relative_error < 0.01:
- return True, True, 0.9
- elif relative_error < 0.05:
- return False, True, 0.7
- elif relative_error < 0.10:
- return False, True, 0.5
- else:
- if extracted_num == 0:
- return True, True, 1.0
- elif abs(extracted_num) < 0.01:
- return False, True, 0.7
-
- except (ValueError, TypeError):
- if norm_extracted.lower() == norm_ground_truth.lower():
- return True, True, 1.0
-
- return False, False, 0.0
-
- """#### Reward Functions for GRPO"""
-
- def match_format_exactly(completions, **kwargs):
- """Reward function for exact format matching"""
- reasoning_start = ""
- reasoning_end = ""
- solution_start = ""
- solution_end = ""
-
- pattern = (
- rf"^[\s]*{re.escape(reasoning_start)}.+?{re.escape(reasoning_end)}.*?"
- rf"{re.escape(solution_start)}.+?{re.escape(solution_end)}[\s]*$"
- )
-
- responses = [completion[0]["content"] for completion in completions]
- rewards = [
- 3.0 if re.match(pattern, response, re.DOTALL) else 0.0
- for response in responses
- ]
- return rewards
-
- def match_format_approximately(completions, **kwargs):
- """Reward function for approximate format matching"""
- reasoning_start = ""
- reasoning_end = ""
- solution_start = ""
- solution_end = ""
-
- scores = []
- for completion in completions:
- score = 0
- response = completion[0]["content"]
- score += 0.5 if response.count(reasoning_start) == 1 else -1.0
- score += 0.5 if response.count(reasoning_end) == 1 else -1.0
- score += 0.5 if response.count(solution_start) == 1 else -1.0
- score += 0.5 if response.count(solution_end) == 1 else -1.0
- scores.append(score)
- return scores
-
- def check_answer_correctness(prompts, completions, answer, **kwargs):
- """Reward function for answer correctness"""
-
- def extract_solution_answer(text):
- pattern = r"(.*?)"
- match = re.search(pattern, text, re.DOTALL)
- if match:
- return re.sub(r"[%$,]", "", match.group(1)).strip()
- return ""
-
- responses = [completion[0]["content"] for completion in completions]
- extracted_responses = [extract_solution_answer(r) for r in responses]
-
- scores = []
- for guess, true_answer in zip(extracted_responses, answer):
- score = 0
- if not guess:
- scores.append(0)
- continue
-
- if guess == true_answer:
- score += 3.0
- elif guess.strip() == true_answer.strip():
- score += 1.5
- else:
- try:
- ratio = float(guess) / float(true_answer)
- if 0.9 <= ratio <= 1.1:
- score += 1.0
- elif 0.8 <= ratio <= 1.2:
- score += 0.5
- else:
- score -= 1.5
- except:
- score -= 1.5
- scores.append(score)
- return scores
-
- print("✅ Reward functions defined!")
-
- """#### Main Evaluation Function"""
-
- import gc
-
- """#### Comparison and Memory Management"""
-
- def compare_model_results(all_results):
- """Generate comprehensive comparison of multiple model results"""
- print(f"\n{'='*80}")
- print("COMPREHENSIVE MODEL COMPARISON")
- print(f"{'='*80}")
-
- # Main table
- print(
- f"{'Model':<15} {'Format %':<10} {'Exact %':<10} {'Plausible %':<12} {'Confidence':<12}"
- )
- print("-" * 80)
-
- for result in all_results:
- print(
- f"{result['model_type']:<15} "
- f"{result['correct_format_pct']:<10.1f} "
- f"{result['exact_match_pct']:<10.1f} "
- f"{result['plausible_match_pct']:<12.1f} "
- f"{result['avg_confidence']:<12.3f}"
- )
-
- # Improvement analysis
- if len(all_results) > 1:
- print(f"\n{'='*50}")
- print("IMPROVEMENT ANALYSIS")
- print(f"{'='*50}")
-
- base_result = all_results[0]
- for result in all_results[1:]:
- print(f"\n{result['model_type']} vs {base_result['model_type']}:")
- format_improvement = (
- result["correct_format_pct"] - base_result["correct_format_pct"]
- )
- exact_improvement = (
- result["exact_match_pct"] - base_result["exact_match_pct"]
- )
- plausible_improvement = (
- result["plausible_match_pct"] - base_result["plausible_match_pct"]
- )
-
- print(f" Format compliance: {format_improvement:+.1f}%")
- print(f" Exact matches: {exact_improvement:+.1f}%")
- print(f" Plausible matches: {plausible_improvement:+.1f}%")
-
- # Save comparison
- comparison_data = {
- "summary": all_results,
- "best_model": max(all_results, key = lambda x: x["exact_match_pct"]),
- }
-
- with open("model_comparison_comprehensive.json", "w") as f:
- json.dump(comparison_data, f, indent = 4)
-
- print(
- f"\nBest performing model: {comparison_data['best_model']['model_type']} "
- f"({comparison_data['best_model']['exact_match_pct']:.1f}% exact matches)"
- )
-
- def cleanup_memory():
- """Comprehensive memory cleanup"""
- print("🧹 Cleaning up GPU memory...")
- for _ in range(10):
- torch.cuda.empty_cache()
- gc.collect()
-
- if torch.cuda.is_available():
- allocated = torch.cuda.memory_allocated() / 1024**3
- reserved = torch.cuda.memory_reserved() / 1024**3
- print(
- f"GPU memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB"
- )
-
- """#### Data Loading and Preparation"""
-
- from datasets import load_dataset
-
- # Load GSM8K
- gsm8k_dataset = load_dataset("openai/gsm8k", "main", split = "train")
-
- # Load LIMO (adjust this based on your access method)
- limo_train = load_dataset("GAIR/LIMO", split = "train")
-
- # Prepare datasets
- gsm8k_train = prepare_gsm8k_dataset(gsm8k_dataset)
- limo_train = prepare_limo_dataset(limo_train)
-
- print(f" GSM8K train: {len(gsm8k_train)}")
- print(f" LIMO train: {len(limo_train) if limo_train else 0}")
-
- # Store results
- all_results = []
-
- # Single temperature evaluation on combined dataset
- results = evaluate_model_aime(
- model = model,
- tokenizer = tokenizer,
- model_type = "base",
- temperature = 0.3,
- n_sampling = 8,
- max_tokens = 32768,
- top_p = 0.95,
- seed = 0,
- )
-
- from unsloth.chat_templates import get_chat_template
-
- tokenizer = get_chat_template(
- tokenizer,
- chat_template = "llama-3.1",
- )
-
- def formatting_prompts_func(examples):
- convos = examples["prompt"]
- texts = [
- tokenizer.apply_chat_template(
- convo, tokenize = False, add_generation_prompt = False
- )
- for convo in convos
- ]
- return {
- "text": texts,
- }
-
- limo_train = limo_train.map(
- formatting_prompts_func,
- batched = True,
- )
-
- from trl import SFTTrainer
- from transformers import DataCollatorForSeq2Seq, TrainingArguments
- from unsloth import is_bfloat16_supported
-
- print(f"\n{'*'*60}")
- print("🎯 STAGE 1: Qlora Fine-Tuning on LIMO")
- print(f"{'*'*60}")
-
- model = FastLanguageModel.get_peft_model(
- model,
- r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- target_modules = [
- "q_proj",
- "k_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "up_proj",
- "down_proj",
- ], # Remove QKVO if out of memory
- lora_alpha = lora_rank,
- use_gradient_checkpointing = "unsloth", # Enable long context finetuning
- random_state = 3407,
- )
-
- if limo_train is not None:
- trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- train_dataset = limo_train,
- dataset_text_field = "text",
- max_seq_length = max_seq_length,
- data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
- dataset_num_proc = 2,
- packing = False, # Can make training 5x faster for short sequences.
- args = TrainingArguments(
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- warmup_steps = 5,
- num_train_epochs = 1, # Set this for 1 full training run.
- # max_steps = 60,
- learning_rate = 2e-4,
- fp16 = not is_bfloat16_supported(),
- bf16 = is_bfloat16_supported(),
- logging_steps = 1,
- optim = "adamw_8bit",
- weight_decay = 0.01,
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "outputs",
- report_to = "none", # Use this for WandB etc
- ),
- )
-
- from unsloth.chat_templates import train_on_responses_only
-
- trainer = train_on_responses_only(
- trainer,
- instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
- response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
- )
-
- # Train
- print(f"🚂 Starting SFT training on {len(limo_train)} examples...")
- trainer.train()
-
- # Save checkpoint
- model.save_pretrained("qlora_checkpoint")
- tokenizer.save_pretrained("qlora_checkpoint")
- print("💾 Qlora checkpoint saved!")
-
- # Cleanup
- del trainer
- cleanup_memory()
-
- print("✅ Qlora training completed!")
- else:
- print("⚠️ Skipping Qlora training - no LIMO dataset available")
-
- # Cleanup
- cleanup_memory()
-
- global PRINTED_TIMES
- PRINTED_TIMES = 0
- global PRINT_EVERY_STEPS
- PRINT_EVERY_STEPS = 5
-
- match_numbers = re.compile(
- solution_start + r".*?([\d\.\,]{1,})", flags = re.MULTILINE | re.DOTALL
- )
-
- def check_numbers(prompts, completions, answer, **kwargs):
- question = prompts[0][-1]["content"]
- responses = [completion[0]["content"] for completion in completions]
-
- extracted_responses = [
- guess.group(1) if (guess := match_numbers.search(r)) is not None else None
- for r in responses
- ]
-
- scores = []
- # Print only every few steps
- global PRINTED_TIMES
- global PRINT_EVERY_STEPS
- if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
- print(
- "*" * 20,
- f"Question:\n{question}",
- f"\nAnswer:\n{answer[0]}",
- f"\nResponse:\n{responses[0]}",
- f"\nExtracted:\n{extracted_responses[0]}",
- )
- PRINTED_TIMES += 1
-
- for guess, true_answer in zip(extracted_responses, answer):
- if guess is None:
- scores.append(0)
- continue
- # Convert to numbers
- try:
- true_answer = float(true_answer.strip())
- # Remove commas like in 123,456
- guess = float(guess.strip().replace(",", ""))
- scores.append(1.5 if guess == true_answer else -0.5)
- except:
- scores.append(0)
- continue
- return scores
-
- print(f"\n{'*'*60}")
- print("🎯 STAGE 2: GRPO Fine-Tuning on GSM8K")
- print(f"{'*'*60}")
-
- # Get max prompt length
- max_prompt_length, _ = get_max_prompt_length(gsm8k_train, tokenizer)
- max_prompt_length = min(max_prompt_length + 10, 512) # Add buffer, cap at 512
-
- print(f"Using max_prompt_length: {max_prompt_length}")
-
- from trl import GRPOConfig, GRPOTrainer
-
- training_args = GRPOConfig(
- learning_rate = 5e-6,
- weight_decay = 0.1,
- warmup_ratio = 0.1,
- lr_scheduler_type = "cosine",
- optim = "adamw_torch_fused",
- logging_steps = 1,
- per_device_train_batch_size = 1,
- gradient_accumulation_steps = 4, # Increase to 4 for smoother training
- num_generations = 8, # Decrease if out of memory
- max_prompt_length = max_prompt_length,
- max_completion_length = max_seq_length - max_prompt_length,
- # num_train_epochs = 1, # Set to 1 for a full training run
- # max_steps = 250,
- max_steps = 1000,
- save_steps = 250,
- max_grad_norm = 0.1,
- report_to = "none", # Can use Weights & Biases
- output_dir = "outputs",
- )
-
- trainer = GRPOTrainer(
- model = model,
- processing_class = tokenizer,
- reward_funcs = [
- match_format_exactly,
- match_format_approximately,
- check_answer_correctness,
- check_numbers,
- ],
- args = training_args,
- train_dataset = gsm8k_train,
- )
-
- # Train
- print(f"🚂 Starting GRPO training on {len(gsm8k_train)} examples...")
- trainer.train()
-
- # Save checkpoint
- model.save_pretrained("grpo_checkpoint")
- tokenizer.save_pretrained("grpo_checkpoint")
- print("💾 GRPO checkpoint saved!")
-
- # Cleanup
- del trainer
- del training_args
- cleanup_memory()
-
- print("✅ GRPO training completed!")
-
- print(f"\n{'='*60}")
- print("🔍 EVALUATION 3: Final GRPO Model")
- print(f"{'='*60}")
-
- grpo_results = evaluate_model_aime(
- model = model,
- tokenizer = tokenizer,
- model_type = "grpo",
- temperature = 0.3,
- n_sampling = 8,
- max_tokens = 32768,
- top_p = 0.95,
- seed = 0,
- )
-
- all_results.append(grpo_results)
- print("✅ Final model evaluation complete!")
-
- print(f"\n{'='*60}")
- print("💾 SAVING FINAL MODEL")
- print(f"{'='*60}")
-
- # Save as merged model
- try:
- model.save_pretrained_merged(
- "final_merged_model", tokenizer, save_method = "merged_16bit"
- )
- print("✅ Merged model saved to: final_merged_model/")
- except Exception as e:
- print(f"⚠️ Could not save merged model: {e}")
- print("Final model saved as LoRA adapter only")
-
- print("💾 Model saving complete!")
-
- safe_remove_directory("./unsloth_compiled_cache")
-
- result_queue.put(results)
-
- # Clean up
- del model
- del tokenizer
- torch.cuda.empty_cache()
- gc.collect()
-
- # # Merged model load 16 bits model AIME eval
- # result_queue = mp.Queue()
- # p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, False))
- # p.start()
- # p.join()
- #
- # merged_16bits = result_queue.get()
- # all_results.append(merged_16bits)
- #
- # # Clean up
- # del merged_model
- # del merged_tokenizer
- # del dataset_ppl
- # torch.cuda.empty_cache()
- # gc.collect()
- #
- # safe_remove_directory("./unsloth_compiled_cache")
- #
- # # Merged model load 8 bits model AIME eval
- #
- # result_queue = mp.Queue()
- # p = mp.Process(target=evaluate_merged_model, args=(result_queue, False, True))
- # p.start()
- # p.join()
- #
- # merged_16bits = result_queue.get()
- # all_results.append(merged_16bits)
-
- # Merged model load 4 bits AIME eval
- # result_queue = mp.Queue()
- # p = mp.Process(target=evaluate_merged_model, args=(result_queue, True, False))
- # p.start()
- # p.join()
- #
- # merged_16bits = result_queue.get()
- # all_results.append(merged_16bits)
-
-
-if __name__ == "__main__":
- mp.set_start_method("spawn", force = True)
- result_queue = mp.Queue()
- all_results = []
-
- # run main finetuning and grpo loop
- p = mp.Process(target = training_run, args = (result_queue,))
- p.start()
- p.join()
-
- results = result_queue.get()
- all_results = results
-
- # evaluate merged model loaded 16bits
- p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, False))
- p.start()
- p.join()
-
- merged_load_16bits = result_queue.get()
- all_results.append(merged_load_16bits)
- safe_remove_directory("./unsloth_compiled_cache")
-
- # Merged model load 8 bits model AIME eval
- p = mp.Process(target = evaluate_merged_model, args = (result_queue, False, True))
- p.start()
- p.join()
-
- merged_load_8bits = result_queue.get()
- all_results.append(merged_load_8bits)
-
- safe_remove_directory("./unsloth_compiled_cache")
-
- # Merged model load 4 bits model AIME eval
- p = mp.Process(target = evaluate_merged_model, args = (result_queue, True, False))
- p.start()
- p.join()
-
- merged_load_4bits = result_queue.get()
- all_results.append(merged_load_4bits)
-
- safe_remove_directory("./unsloth_compiled_cache")
-
- # AIME-specific comparison function
-
- print(f"\n{'='*80}")
- print("🏆 FINAL TRAINING PIPELINE RESULTS")
- print(f"{'='*80}")
-
- # Use the AIME-specific comparison
- compare_aime_results(all_results)
diff --git a/tests/saving/non_peft/test_mistral_non_peft.py b/tests/saving/non_peft/test_mistral_non_peft.py
deleted file mode 100644
index e03813367d..0000000000
--- a/tests/saving/non_peft/test_mistral_non_peft.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from unsloth import FastLanguageModel
-from transformers import AutoModelForCausalLM
-from peft import PeftModel
-from pathlib import Path
-import sys
-import warnings
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 1: Loading Base Model")
-print(f"{'='*80}")
-
-model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/mistral-7b-v0.3",
- max_seq_length = 2048,
- dtype = None,
- load_in_4bit = True,
- load_in_8bit = False,
- full_finetuning = False,
-)
-
-
-print("✅ Base model loaded successfully!")
-
-### Attemtping save merge
-
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 2: Attempting save_pretrained_merged (Should Warn)")
-print(f"{'='*80}")
-
-with warnings.catch_warnings(record = True) as w:
- warnings.simplefilter("always")
- model.save_pretrained_merged("test_output", tokenizer)
-
- # Verify warning
- assert len(w) >= 1, "Expected warning but none raised"
- warning_msg = str(w[0].message)
- expected_msg = "Model is not a PeftModel (no Lora adapters detected). Skipping Merge. Please use save_pretrained() or push_to_hub() instead!"
- assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
- assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
-
-print("✅ Correct warning detected for non-PeftModel merge attempt!")
-
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 3: Using save_pretrained (Should Succeed)")
-print(f"{'='*80}")
-
-
-try:
- with warnings.catch_warnings():
- warnings.simplefilter("error") # Treat warnings as errors here
- model.save_pretrained("test_output")
- print("✅ Standard save_pretrained completed successfully!")
-except Exception as e:
- assert False, f"Phase 3 failed: {e}"
-
-safe_remove_directory("./test_output")
-safe_remove_directory("./unsloth_compiled_cache")
diff --git a/tests/saving/non_peft/test_whisper_non_peft.py b/tests/saving/non_peft/test_whisper_non_peft.py
deleted file mode 100644
index 303d596c85..0000000000
--- a/tests/saving/non_peft/test_whisper_non_peft.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from unsloth import FastLanguageModel, FastModel
-from transformers import AutoModelForCausalLM, WhisperForConditionalGeneration
-from peft import PeftModel
-from pathlib import Path
-import sys
-import warnings
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 1: Loading Base Model")
-print(f"{'='*80}")
-
-model, tokenizer = FastModel.from_pretrained(
- model_name = "unsloth/whisper-large-v3",
- dtype = None, # Leave as None for auto detection
- load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
- auto_model = WhisperForConditionalGeneration,
- whisper_language = "English",
- whisper_task = "transcribe",
- # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
-)
-
-print("✅ Base model loaded successfully!")
-
-### Attemtping save merge
-
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 2: Attempting save_pretrained_merged (Should Warn)")
-print(f"{'='*80}")
-
-with warnings.catch_warnings(record = True) as w:
- warnings.simplefilter("always")
- model.save_pretrained_merged("test_output", tokenizer)
-
- # Verify warning
- assert len(w) >= 1, "Expected warning but none raised"
- warning_msg = str(w[0].message)
- expected_msg = "Model is not a PeftModel (no Lora adapters detected). Skipping Merge. Please use save_pretrained() or push_to_hub() instead!"
- assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
- assert expected_msg in warning_msg, f"Unexpected warning: {warning_msg}"
-
-print("✅ Correct warning detected for non-PeftModel merge attempt!")
-
-
-print(f"\n{'='*80}")
-print("🔍 PHASE 3: Using save_pretrained (Should Succeed)")
-print(f"{'='*80}")
-
-
-try:
- with warnings.catch_warnings():
- warnings.simplefilter("error") # Treat warnings as errors here
- model.save_pretrained("test_output")
- print("✅ Standard save_pretrained completed successfully!")
-except Exception as e:
- assert False, f"Phase 3 failed: {e}"
-
-safe_remove_directory("./test_output")
-safe_remove_directory("./unsloth_compiled_cache")
diff --git a/tests/saving/test_unsloth_save.py b/tests/saving/test_unsloth_save.py
deleted file mode 100644
index 35fdad6ba0..0000000000
--- a/tests/saving/test_unsloth_save.py
+++ /dev/null
@@ -1,401 +0,0 @@
-import json
-import os
-import shutil
-import tempfile
-import pytest
-import importlib
-
-from unsloth import FastLanguageModel, FastModel
-
-model_to_test = [
- # Text Models
- "unsloth/tinyllama",
- "unsloth/tinyllama-bnb-4bit",
- "unsloth/Qwen2.5-0.5B-Instruct",
- "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit",
- "unsloth/Phi-4-mini-instruct",
- "unsloth/Phi-4-mini-instruct-bnb-4bit",
- "unsloth/Qwen2.5-0.5B",
- # Vision Models
- "unsloth/gemma-3-4b-it",
- "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit",
-]
-
-torchao_models = [
- "unsloth/tinyllama",
- "unsloth/Qwen2.5-0.5B-Instruct",
- # "unsloth/Phi-4-mini-instruct",
- # "unsloth/Qwen2.5-0.5B",
- # Skip the -bnb-4bit variants since they're already quantized
-]
-
-
-# Variables
-save_file_sizes = {}
-save_file_sizes["merged_16bit"] = {}
-save_file_sizes["merged_4bit"] = {}
-save_file_sizes["torchao"] = {}
-
-tokenizer_files = [
- "tokenizer_config.json",
- "special_tokens_map.json",
-]
-
-
-@pytest.fixture(scope = "session", params = model_to_test)
-def loaded_model_tokenizer(request):
- model_name = request.param
- print("Loading model and tokenizer...")
-
- model, tokenizer = FastModel.from_pretrained(
- model_name, # use small model
- max_seq_length = 128,
- dtype = None,
- load_in_4bit = True,
- )
-
- # Apply LoRA
- model = FastModel.get_peft_model(
- model,
- r = 16,
- target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
- lora_alpha = 16,
- use_gradient_checkpointing = "unsloth",
- )
-
- return model, tokenizer
-
-
-@pytest.fixture(scope = "session", params = torchao_models)
-def fp16_model_tokenizer(request):
- """Load model in FP16 for TorchAO quantization"""
- model_name = request.param
- print(f"Loading model in FP16 for TorchAO: {model_name}")
-
- model, tokenizer = FastModel.from_pretrained(
- model_name,
- max_seq_length = 128,
- dtype = None,
- load_in_4bit = False, # No BnB quantization
- )
-
- # Apply LoRA
- model = FastModel.get_peft_model(
- model,
- r = 16,
- target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
- lora_alpha = 16,
- use_gradient_checkpointing = "unsloth",
- )
-
- return model, tokenizer
-
-
-@pytest.fixture(scope = "session")
-def model(loaded_model_tokenizer):
- return loaded_model_tokenizer[0]
-
-
-@pytest.fixture(scope = "session")
-def tokenizer(loaded_model_tokenizer):
- return loaded_model_tokenizer[1]
-
-
-@pytest.fixture
-def temp_save_dir():
- dir = tempfile.mkdtemp()
- print(f"Temporary directory created at: {dir}")
- yield dir
- print(f"Temporary directory deleted: {dir}")
- shutil.rmtree(dir)
-
-
-def delete_quantization_config(model):
- # Since merged, edit quantization_config
- old_config = model.config
- new_config = model.config.to_dict()
- if "quantization_config" in new_config:
- del new_config["quantization_config"]
- original_model = model
- new_config = type(model.config).from_dict(new_config)
- while hasattr(original_model, "model"):
- original_model = original_model.model
- original_model.config = new_config
- model.config = new_config
-
-
-def test_save_merged_16bit(model, tokenizer, temp_save_dir: str):
- save_path = os.path.join(
- temp_save_dir,
- "unsloth_merged_16bit",
- model.config._name_or_path.replace("/", "_"),
- )
-
- model.save_pretrained_merged(
- save_path, tokenizer = tokenizer, save_method = "merged_16bit"
- )
-
- # Check model files
- assert os.path.isdir(save_path), f"Directory {save_path} does not exist."
- assert os.path.isfile(
- os.path.join(save_path, "config.json")
- ), "config.json not found."
-
- weight_files = [
- f
- for f in os.listdir(save_path)
- if f.endswith(".bin") or f.endswith(".safetensors")
- ]
- assert len(weight_files) > 0, "No weight files found in the save directory."
-
- # Check tokenizer files
- for file in tokenizer_files:
- assert os.path.isfile(
- os.path.join(save_path, file)
- ), f"{file} not found in the save directory."
-
- # Check config to see if it is 16bit by checking for quantization config
- config_path = os.path.join(save_path, "config.json")
- with open(config_path, "r") as f:
- config = json.load(f)
-
- assert (
- "quantization_config" not in config
- ), "Quantization config not found in the model config."
-
- # Store the size of the model files
- total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files)
- save_file_sizes["merged_16bit"][model.config._name_or_path] = total_size
- print(f"Total size of merged_16bit files: {total_size} bytes")
-
- # Test loading the model from the saved path
- loaded_model, loaded_tokenizer = FastLanguageModel.from_pretrained(
- save_path,
- max_seq_length = 128,
- dtype = None,
- load_in_4bit = True,
- )
-
-
-def test_save_merged_4bit(model, tokenizer, temp_save_dir: str):
- save_path = os.path.join(
- temp_save_dir,
- "unsloth_merged_4bit",
- model.config._name_or_path.replace("/", "_"),
- )
-
- model.save_pretrained_merged(
- save_path, tokenizer = tokenizer, save_method = "merged_4bit_forced"
- )
-
- # Check model files
- assert os.path.isdir(save_path), f"Directory {save_path} does not exist."
- assert os.path.isfile(
- os.path.join(save_path, "config.json")
- ), "config.json not found."
-
- weight_files = [
- f
- for f in os.listdir(save_path)
- if f.endswith(".bin") or f.endswith(".safetensors")
- ]
- assert len(weight_files) > 0, "No weight files found in the save directory."
-
- # Check tokenizer files
- for file in tokenizer_files:
- assert os.path.isfile(
- os.path.join(save_path, file)
- ), f"{file} not found in the save directory."
-
- # Store the size of the model files
- total_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in weight_files)
- save_file_sizes["merged_4bit"][model.config._name_or_path] = total_size
-
- print(f"Total size of merged_4bit files: {total_size} bytes")
-
- assert (
- total_size < save_file_sizes["merged_16bit"][model.config._name_or_path]
- ), "Merged 4bit files are larger than merged 16bit files."
-
- # Check config to see if it is 4bit
- config_path = os.path.join(save_path, "config.json")
- with open(config_path, "r") as f:
- config = json.load(f)
-
- assert (
- "quantization_config" in config
- ), "Quantization config not found in the model config."
-
- # Test loading the model from the saved path
- loaded_model, loaded_tokenizer = FastModel.from_pretrained(
- save_path,
- max_seq_length = 128,
- dtype = None,
- load_in_4bit = True,
- )
-
-
-@pytest.mark.skipif(
- importlib.util.find_spec("torchao") is None,
- reason = "require torchao to be installed",
-)
-def test_save_torchao(fp16_model_tokenizer, temp_save_dir: str):
- model, tokenizer = fp16_model_tokenizer
- save_path = os.path.join(
- temp_save_dir, "unsloth_torchao", model.config._name_or_path.replace("/", "_")
- )
-
- from torchao.quantization import Int8DynamicActivationInt8WeightConfig
-
- torchao_config = Int8DynamicActivationInt8WeightConfig()
- model.save_pretrained_torchao(
- save_path,
- tokenizer = tokenizer,
- torchao_config = torchao_config,
- push_to_hub = False,
- )
-
- weight_files_16bit = [
- f
- for f in os.listdir(save_path)
- if f.endswith(".bin") or f.endswith(".safetensors")
- ]
- total_16bit_size = sum(
- os.path.getsize(os.path.join(save_path, f)) for f in weight_files_16bit
- )
- save_file_sizes["merged_16bit"][model.config._name_or_path] = total_16bit_size
-
- torchao_save_path = save_path + "-torchao"
-
- # Check model files
- assert os.path.isdir(
- torchao_save_path
- ), f"Directory {torchao_save_path} does not exist."
- assert os.path.isfile(
- os.path.join(torchao_save_path, "config.json")
- ), "config.json not found."
-
- weight_files = [
- f
- for f in os.listdir(torchao_save_path)
- if f.endswith(".bin") or f.endswith(".safetensors")
- ]
- assert len(weight_files) > 0, "No weight files found in the save directory."
-
- # Check tokenizer files
- for file in tokenizer_files:
- assert os.path.isfile(
- os.path.join(torchao_save_path, file)
- ), f"{file} not found in the save directory."
-
- # Store the size of the model files
- total_size = sum(
- os.path.getsize(os.path.join(torchao_save_path, f)) for f in weight_files
- )
- save_file_sizes["torchao"][model.config._name_or_path] = total_size
-
- assert (
- total_size < save_file_sizes["merged_16bit"][model.config._name_or_path]
- ), "torchao files are larger than merged 16bit files."
-
- # Check config to see if it is quantized with torchao
- config_path = os.path.join(torchao_save_path, "config.json")
- with open(config_path, "r") as f:
- config = json.load(f)
-
- assert (
- "quantization_config" in config
- ), "Quantization config not found in the model config."
-
- # Test loading the model from the saved path
- # can't set `load_in_4bit` to True because the model is torchao quantized
- # can't quantize again with bitsandbytes
- import torch.serialization
-
- with torch.serialization.safe_globals([getattr]):
- loaded_model, loaded_tokenizer = FastModel.from_pretrained(
- torchao_save_path,
- max_seq_length = 128,
- dtype = None,
- load_in_4bit = False,
- )
-
-
-@pytest.mark.skipif(
- importlib.util.find_spec("torchao") is None,
- reason = "require torchao to be installed",
-)
-def test_save_and_inference_torchao(fp16_model_tokenizer, temp_save_dir: str):
- model, tokenizer = fp16_model_tokenizer
- model_name = model.config._name_or_path
-
- print(f"Testing TorchAO save and inference for: {model_name}")
-
- save_path = os.path.join(
- temp_save_dir, "torchao_models", model_name.replace("/", "_")
- )
-
- from torchao.quantization import Int8DynamicActivationInt8WeightConfig
-
- torchao_config = Int8DynamicActivationInt8WeightConfig()
-
- # Save with TorchAO
- model.save_pretrained_torchao(
- save_path,
- tokenizer = tokenizer,
- torchao_config = torchao_config,
- push_to_hub = False,
- )
-
- torchao_save_path = save_path + "-torchao"
-
- # Verify files exist
- assert os.path.isdir(
- torchao_save_path
- ), f"TorchAO directory {torchao_save_path} does not exist."
-
- # Load with safe globals
- import torch.serialization
-
- with torch.serialization.safe_globals([getattr]):
- loaded_model, loaded_tokenizer = FastModel.from_pretrained(
- torchao_save_path,
- max_seq_length = 128,
- dtype = None,
- load_in_4bit = False,
- )
-
- FastModel.for_inference(loaded_model) # Enable native 2x faster inference
-
- messages = [
- {
- "role": "user",
- "content": "Continue the fibonnaci sequence: 1, 1, 2, 3, 5, 8,",
- },
- ]
- inputs = loaded_tokenizer.apply_chat_template(
- messages,
- tokenize = True,
- add_generation_prompt = True, # Must add for generation
- return_tensors = "pt",
- ).to("cuda")
-
- outputs = loaded_model.generate( # ← Use loaded_model, not model
- input_ids = inputs,
- max_new_tokens = 64,
- use_cache = False, # Avoid cache issues
- temperature = 1.5,
- min_p = 0.1,
- do_sample = True,
- pad_token_id = loaded_tokenizer.pad_token_id or loaded_tokenizer.eos_token_id,
- )
-
- # Decode with the LOADED tokenizer
- generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens = True)
- input_text = loaded_tokenizer.decode(inputs[0], skip_special_tokens = True)
- response_part = generated_text[len(input_text) :].strip()
-
- print(f"Input: {input_text}")
- print(f"Full output: {generated_text}")
- print(f"Response only: {response_part}")
diff --git a/tests/saving/text_to_speech_models/test_csm.py b/tests/saving/text_to_speech_models/test_csm.py
deleted file mode 100644
index c1a892a8d3..0000000000
--- a/tests/saving/text_to_speech_models/test_csm.py
+++ /dev/null
@@ -1,168 +0,0 @@
-from unsloth import FastLanguageModel, FastModel
-from transformers import CsmForConditionalGeneration
-import torch
-
-# ruff: noqa
-import sys
-from pathlib import Path
-from peft import PeftModel
-import warnings
-import requests
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.os_utils import require_package, require_python_package
-
-require_package("ffmpeg", "ffmpeg")
-require_python_package("soundfile")
-
-import soundfile as sf
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 1: Loading Model and LoRA Adapters")
-print(f"{'='*80}")
-
-
-model, tokenizer = FastModel.from_pretrained(
- model_name = "unsloth/csm-1b",
- max_seq_length = 2048, # Choose any for long context!
- dtype = None, # Leave as None for auto-detection
- auto_model = CsmForConditionalGeneration,
- load_in_4bit = False, # Select True for 4bit - reduces memory usage
-)
-
-
-base_model_class = model.__class__.__name__
-
-
-model = FastModel.get_peft_model(
- model,
- r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- target_modules = [
- "q_proj",
- "k_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "up_proj",
- "down_proj",
- ],
- lora_alpha = 32,
- lora_dropout = 0, # Supports any, but = 0 is optimized
- bias = "none", # Supports any, but = "none" is optimized
- # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
- use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
- random_state = 3407,
- use_rslora = False, # We support rank stabilized LoRA
- loftq_config = None, # And LoftQ
-)
-
-print("✅ Model and LoRA adapters loaded successfully!")
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 2: Checking Model Class Type")
-print(f"{'='*80}")
-
-assert isinstance(model, PeftModel), "Model should be an instance of PeftModel"
-print("✅ Model is an instance of PeftModel!")
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 3: Checking Config Model Class Type")
-print(f"{'='*80}")
-
-
-def find_lora_base_model(model_to_inspect):
- current = model_to_inspect
- if hasattr(current, "base_model"):
- current = current.base_model
- if hasattr(current, "model"):
- current = current.model
- return current
-
-
-config_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model
-
-assert (
- config_model.__class__.__name__ == base_model_class
-), f"Expected config_model class to be {base_model_class}"
-print("✅ config_model returns correct Base Model class:", str(base_model_class))
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 4: Saving and Merging Model")
-print(f"{'='*80}")
-
-with warnings.catch_warnings():
- warnings.simplefilter("error") # Treat warnings as errors
- try:
- model.save_pretrained_merged("csm", tokenizer)
- print("✅ Model saved and merged successfully without warnings!")
- except Exception as e:
- assert False, f"Model saving/merging failed with exception: {e}"
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 5: Loading Model for Inference")
-print(f"{'='*80}")
-
-
-model, processor = FastModel.from_pretrained(
- model_name = "./csm",
- max_seq_length = 2048, # Choose any for long context!
- dtype = None, # Leave as None for auto-detection
- auto_model = CsmForConditionalGeneration,
- load_in_4bit = False, # Select True for 4bit - reduces memory usage
-)
-
-from transformers import AutoProcessor
-
-processor = AutoProcessor.from_pretrained("unsloth/csm-1b")
-
-print("✅ Model loaded for inference successfully!")
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 6: Running Inference")
-print(f"{'='*80}")
-
-
-from transformers import pipeline
-import torch
-
-output_audio_path = "csm_audio.wav"
-try:
- text = (
- "We just finished fine tuning a text to speech model... and it's pretty good!"
- )
- speaker_id = 0
- inputs = processor(f"[{speaker_id}]{text}", add_special_tokens = True).to("cuda")
- audio_values = model.generate(
- **inputs,
- max_new_tokens = 125, # 125 tokens is 10 seconds of audio, for longer speech increase this
- # play with these parameters to get the best results
- depth_decoder_temperature = 0.6,
- depth_decoder_top_k = 0,
- depth_decoder_top_p = 0.9,
- temperature = 0.8,
- top_k = 50,
- top_p = 1.0,
- #########################################################
- output_audio = True,
- )
- audio = audio_values[0].to(torch.float32).cpu().numpy()
- sf.write("example_without_context.wav", audio, 24000)
- print(f"✅ Audio generated and saved to {output_audio_path}!")
-except Exception as e:
- assert False, f"Inference failed with exception: {e}"
-
-
-## assert that transcribed_text contains The birch canoe slid on the smooth planks. Glued the sheet to the dark blue background. It's easy to tell the depth of a well. Four hours of steady work faced us.
-
-print("✅ All sections passed successfully!")
-
-
-safe_remove_directory("./unsloth_compiled_cache")
-safe_remove_directory("./csm")
diff --git a/tests/saving/text_to_speech_models/test_lasa.py b/tests/saving/text_to_speech_models/test_lasa.py
deleted file mode 100644
index 804ff512f9..0000000000
--- a/tests/saving/text_to_speech_models/test_lasa.py
+++ /dev/null
@@ -1,220 +0,0 @@
-from unsloth import FastLanguageModel, FastModel
-from transformers import CsmForConditionalGeneration
-import torch
-
-# ruff: noqa
-import sys
-from pathlib import Path
-from peft import PeftModel
-import warnings
-import requests
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.os_utils import require_package, require_python_package
-
-require_package("ffmpeg", "ffmpeg")
-require_python_package("soundfile")
-require_python_package("xcodec2")
-
-import soundfile as sf
-from xcodec2.modeling_xcodec2 import XCodec2Model
-
-XCODEC2_MODEL_NAME = "HKUST-Audio/xcodec2"
-SAMPLE_RATE = 16000
-DEVICE = "cuda"
-
-try:
- codec_model = XCodec2Model.from_pretrained(XCODEC2_MODEL_NAME)
-
-except Exception as e:
- raise f"ERROR loading XCodec2 model: {e}."
-
-codec_model.to("cpu")
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 1: Loading Model and LoRA Adapters")
-print(f"{'='*80}")
-
-max_seq_length = 2048
-model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/Llasa-1B",
- max_seq_length = max_seq_length,
- dtype = None, # Select None for auto detection
- load_in_4bit = False, # Choose True for 4bit which reduces memory
- # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
-)
-
-base_model_class = model.__class__.__name__
-
-
-model = FastLanguageModel.get_peft_model(
- model,
- r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- target_modules = ["q_proj", "v_proj"],
- lora_alpha = 128,
- lora_dropout = 0, # Supports any, but = 0 is optimized
- bias = "none", # Supports any, but = "none" is optimized
- # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
- use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
- random_state = 3407,
- use_rslora = False, # We support rank stabilized LoRA
- loftq_config = None, # And LoftQ
-)
-
-print("✅ Model and LoRA adapters loaded successfully!")
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 2: Checking Model Class Type")
-print(f"{'='*80}")
-
-assert isinstance(model, PeftModel), "Model should be an instance of PeftModel"
-print("✅ Model is an instance of PeftModel!")
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 3: Checking Config Model Class Type")
-print(f"{'='*80}")
-
-
-def find_lora_base_model(model_to_inspect):
- current = model_to_inspect
- if hasattr(current, "base_model"):
- current = current.base_model
- if hasattr(current, "model"):
- current = current.model
- return current
-
-
-config_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model
-
-assert (
- config_model.__class__.__name__ == base_model_class
-), f"Expected config_model class to be {base_model_class}"
-print("✅ config_model returns correct Base Model class:", str(base_model_class))
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 4: Saving and Merging Model")
-print(f"{'='*80}")
-
-with warnings.catch_warnings():
- warnings.simplefilter("error") # Treat warnings as errors
- try:
- model.save_pretrained_merged("lasa", tokenizer)
- print("✅ Model saved and merged successfully without warnings!")
- except Exception as e:
- assert False, f"Model saving/merging failed with exception: {e}"
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 5: Loading Model for Inference")
-print(f"{'='*80}")
-
-
-model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "./lasa",
- max_seq_length = max_seq_length,
- dtype = None, # Select None for auto detection
- load_in_4bit = False, # Choose True for 4bit which reduces memory
- # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
-)
-
-# from transformers import AutoProcessor
-# processor = AutoProcessor.from_pretrained("unsloth/csm-1b")
-
-print("✅ Model loaded for inference successfully!")
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 6: Running Inference")
-print(f"{'='*80}")
-
-
-from transformers import pipeline
-import torch
-
-output_audio_path = "lasa_audio.wav"
-input_text = "Hey there my name is Elise, and I'm a speech generation model that can sound like a person."
-
-FastLanguageModel.for_inference(model)
-
-
-def ids_to_speech_tokens(speech_ids):
- speech_tokens_str = []
- for speech_id in speech_ids:
- speech_tokens_str.append(f"<|s_{speech_id}|>")
- return speech_tokens_str
-
-
-def extract_speech_ids(speech_tokens_str):
- speech_ids = []
- for token_str in speech_tokens_str:
- if token_str.startswith("<|s_") and token_str.endswith("|>"):
- num_str = token_str[4:-2]
-
- num = int(num_str)
- speech_ids.append(num)
- else:
- print(f"Unexpected token: {token_str}")
- return speech_ids
-
-
-# TTS start!
-with torch.inference_mode():
- with torch.amp.autocast("cuda", dtype = model.dtype):
- formatted_text = (
- f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
- )
-
- # Tokenize the text
- chat = [
- {"role": "user", "content": "Convert the text to speech:" + formatted_text},
- {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"},
- ]
-
- input_ids = tokenizer.apply_chat_template(
- chat, tokenize = True, return_tensors = "pt", continue_final_message = True
- )
- input_ids = input_ids.to("cuda")
-
- speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
-
- # Generate the speech autoregressively
- outputs = model.generate(
- input_ids,
- max_length = 2048, # We trained our model with a max length of 2048
- eos_token_id = speech_end_id,
- do_sample = True,
- top_p = 1.2, # Adjusts the diversity of generated content
- temperature = 1.2, # Controls randomness in output
- )
- # Extract the speech tokens
- generated_ids = outputs[0][input_ids.shape[1] : -1]
-
- speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens = True)
-
- # Convert token <|s_23456|> to int 23456
- speech_tokens = extract_speech_ids(speech_tokens)
-
- speech_tokens = torch.tensor(speech_tokens).cpu().unsqueeze(0).unsqueeze(0)
-
- # Decode the speech tokens to speech waveform
- gen_wav = codec_model.decode_code(speech_tokens)
-try:
- sf.write(output_audio_path, gen_wav[0, 0, :].cpu().numpy(), 16000)
-except Exception as e:
- assert False, f"Inference failed with exception: {e}"
-
-
-## assert that transcribed_text contains The birch canoe slid on the smooth planks. Glued the sheet to the dark blue background. It's easy to tell the depth of a well. Four hours of steady work faced us.
-
-print("✅ All sections passed successfully!")
-
-
-safe_remove_directory("./unsloth_compiled_cache")
-safe_remove_directory("./lasa")
diff --git a/tests/saving/text_to_speech_models/test_orpheus.py b/tests/saving/text_to_speech_models/test_orpheus.py
deleted file mode 100644
index bd8bf14979..0000000000
--- a/tests/saving/text_to_speech_models/test_orpheus.py
+++ /dev/null
@@ -1,282 +0,0 @@
-from unsloth import FastLanguageModel, FastModel
-from transformers import CsmForConditionalGeneration
-import torch
-
-# ruff: noqa
-import sys
-from pathlib import Path
-from peft import PeftModel
-import warnings
-import requests
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.os_utils import require_package, require_python_package
-
-require_package("ffmpeg", "ffmpeg")
-require_python_package("soundfile")
-require_python_package("snac")
-
-import soundfile as sf
-from snac import SNAC
-
-snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
-snac_model = snac_model.to("cuda")
-print(f"\n{'='*80}")
-print("🔍 SECTION 1: Loading Model and LoRA Adapters")
-print(f"{'='*80}")
-
-
-model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/orpheus-3b-0.1-ft",
- max_seq_length = 2048, # Choose any for long context!
- dtype = None, # Select None for auto detection
- load_in_4bit = False, # Select True for 4bit which reduces memory usage
- # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
-)
-
-base_model_class = model.__class__.__name__
-
-
-model = FastLanguageModel.get_peft_model(
- model,
- r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- target_modules = [
- "q_proj",
- "k_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "up_proj",
- "down_proj",
- ],
- lora_alpha = 64,
- lora_dropout = 0, # Supports any, but = 0 is optimized
- bias = "none", # Supports any, but = "none" is optimized
- # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
- use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
- random_state = 3407,
- use_rslora = False, # We support rank stabilized LoRA
- loftq_config = None, # And LoftQ
-)
-print("✅ Model and LoRA adapters loaded successfully!")
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 2: Checking Model Class Type")
-print(f"{'='*80}")
-
-assert isinstance(model, PeftModel), "Model should be an instance of PeftModel"
-print("✅ Model is an instance of PeftModel!")
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 3: Checking Config Model Class Type")
-print(f"{'='*80}")
-
-
-def find_lora_base_model(model_to_inspect):
- current = model_to_inspect
- if hasattr(current, "base_model"):
- current = current.base_model
- if hasattr(current, "model"):
- current = current.model
- return current
-
-
-config_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model
-
-assert (
- config_model.__class__.__name__ == base_model_class
-), f"Expected config_model class to be {base_model_class}"
-print("✅ config_model returns correct Base Model class:", str(base_model_class))
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 4: Saving and Merging Model")
-print(f"{'='*80}")
-
-with warnings.catch_warnings():
- warnings.simplefilter("error") # Treat warnings as errors
- try:
- model.save_pretrained_merged("orpheus", tokenizer)
- print("✅ Model saved and merged successfully without warnings!")
- except Exception as e:
- assert False, f"Model saving/merging failed with exception: {e}"
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 5: Loading Model for Inference")
-print(f"{'='*80}")
-
-
-model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/orpheus-3b-0.1-ft",
- max_seq_length = 2048, # Choose any for long context!
- dtype = None, # Select None for auto detection
- load_in_4bit = False, # Select True for 4bit which reduces memory usage
- # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
-)
-
-# from transformers import AutoProcessor
-# processor = AutoProcessor.from_pretrained("unsloth/csm-1b")
-
-print("✅ Model loaded for inference successfully!")
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 6: Running Inference")
-print(f"{'='*80}")
-
-
-# @title Run Inference
-
-
-FastLanguageModel.for_inference(model) # Enable native 2x faster inference
-
-# Moving snac_model cuda to cpu
-snac_model.to("cpu")
-prompts = [
- "Hey there my name is Elise, and I'm a speech generation model that can sound like a person.",
-]
-
-chosen_voice = None # None for single-speaker
-
-prompts_ = [(f"{chosen_voice}: " + p) if chosen_voice else p for p in prompts]
-
-all_input_ids = []
-
-for prompt in prompts_:
- input_ids = tokenizer(prompt, return_tensors = "pt").input_ids
- all_input_ids.append(input_ids)
-
-start_token = torch.tensor([[128259]], dtype = torch.int64) # Start of human
-end_tokens = torch.tensor(
- [[128009, 128260]], dtype = torch.int64
-) # End of text, End of human
-
-all_modified_input_ids = []
-for input_ids in all_input_ids:
- modified_input_ids = torch.cat(
- [start_token, input_ids, end_tokens], dim = 1
- ) # SOH SOT Text EOT EOH
- all_modified_input_ids.append(modified_input_ids)
-
-all_padded_tensors = []
-all_attention_masks = []
-max_length = max(
- [modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids]
-)
-for modified_input_ids in all_modified_input_ids:
- padding = max_length - modified_input_ids.shape[1]
- padded_tensor = torch.cat(
- [torch.full((1, padding), 128263, dtype = torch.int64), modified_input_ids], dim = 1
- )
- attention_mask = torch.cat(
- [
- torch.zeros((1, padding), dtype = torch.int64),
- torch.ones((1, modified_input_ids.shape[1]), dtype = torch.int64),
- ],
- dim = 1,
- )
- all_padded_tensors.append(padded_tensor)
- all_attention_masks.append(attention_mask)
-
-all_padded_tensors = torch.cat(all_padded_tensors, dim = 0)
-all_attention_masks = torch.cat(all_attention_masks, dim = 0)
-
-input_ids = all_padded_tensors.to("cuda")
-attention_mask = all_attention_masks.to("cuda")
-generated_ids = model.generate(
- input_ids = input_ids,
- attention_mask = attention_mask,
- max_new_tokens = 1200,
- do_sample = True,
- temperature = 0.6,
- top_p = 0.95,
- repetition_penalty = 1.1,
- num_return_sequences = 1,
- eos_token_id = 128258,
- use_cache = True,
-)
-token_to_find = 128257
-token_to_remove = 128258
-
-token_indices = (generated_ids == token_to_find).nonzero(as_tuple = True)
-
-if len(token_indices[1]) > 0:
- last_occurrence_idx = token_indices[1][-1].item()
- cropped_tensor = generated_ids[:, last_occurrence_idx + 1 :]
-else:
- cropped_tensor = generated_ids
-
-mask = cropped_tensor != token_to_remove
-
-processed_rows = []
-
-for row in cropped_tensor:
- masked_row = row[row != token_to_remove]
- processed_rows.append(masked_row)
-
-code_lists = []
-
-for row in processed_rows:
- row_length = row.size(0)
- new_length = (row_length // 7) * 7
- trimmed_row = row[:new_length]
- trimmed_row = [t - 128266 for t in trimmed_row]
- code_lists.append(trimmed_row)
-
-
-def redistribute_codes(code_list):
- layer_1 = []
- layer_2 = []
- layer_3 = []
- for i in range((len(code_list) + 1) // 7):
- layer_1.append(code_list[7 * i])
- layer_2.append(code_list[7 * i + 1] - 4096)
- layer_3.append(code_list[7 * i + 2] - (2 * 4096))
- layer_3.append(code_list[7 * i + 3] - (3 * 4096))
- layer_2.append(code_list[7 * i + 4] - (4 * 4096))
- layer_3.append(code_list[7 * i + 5] - (5 * 4096))
- layer_3.append(code_list[7 * i + 6] - (6 * 4096))
- codes = [
- torch.tensor(layer_1).unsqueeze(0),
- torch.tensor(layer_2).unsqueeze(0),
- torch.tensor(layer_3).unsqueeze(0),
- ]
-
- # codes = [c.to("cuda") for c in codes]
- audio_hat = snac_model.decode(codes)
- return audio_hat
-
-
-my_samples = []
-for code_list in code_lists:
- samples = redistribute_codes(code_list)
- my_samples.append(samples)
-output_path = "orpheus_audio.wav"
-try:
- for i, samples in enumerate(my_samples):
- audio_data = samples.detach().squeeze().cpu().numpy()
- import soundfile as sf
-
- sf.write(output_path, audio_data, 24000) # Explicitly pass sample rate
- print(f"✅ Audio saved to {output_path}!")
-except Exception as e:
- assert False, f"Inference failed with exception: {e}"
-
-# Verify the file exists
-import os
-
-assert os.path.exists(output_path), f"Audio file not found at {output_path}"
-print("✅ Audio file exists on disk!")
-del my_samples, samples
-## assert that transcribed_text contains The birch canoe slid on the smooth planks. Glued the sheet to the dark blue background. It's easy to tell the depth of a well. Four hours of steady work faced us.
-
-print("✅ All sections passed successfully!")
-
-
-safe_remove_directory("./unsloth_compiled_cache")
-safe_remove_directory("./orpheus")
diff --git a/tests/saving/text_to_speech_models/test_whisper.py b/tests/saving/text_to_speech_models/test_whisper.py
deleted file mode 100644
index 55f6d98ca0..0000000000
--- a/tests/saving/text_to_speech_models/test_whisper.py
+++ /dev/null
@@ -1,195 +0,0 @@
-from unsloth import FastLanguageModel, FastModel
-from transformers import WhisperForConditionalGeneration, WhisperProcessor
-import torch
-
-# ruff: noqa
-import sys
-from pathlib import Path
-from peft import PeftModel
-import warnings
-import requests
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.os_utils import require_package, require_python_package
-
-require_package("ffmpeg", "ffmpeg")
-require_python_package("soundfile")
-
-import soundfile as sf
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 1: Loading Model and LoRA Adapters")
-print(f"{'='*80}")
-
-
-model, tokenizer = FastModel.from_pretrained(
- model_name = "unsloth/whisper-large-v3",
- dtype = None, # Leave as None for auto detection
- load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
- auto_model = WhisperForConditionalGeneration,
- whisper_language = "English",
- whisper_task = "transcribe",
- # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
-)
-
-
-base_model_class = model.__class__.__name__
-# https://github.com/huggingface/transformers/issues/37172
-model.generation_config.input_ids = model.generation_config.forced_decoder_ids
-model.generation_config.forced_decoder_ids = None
-
-
-model = FastModel.get_peft_model(
- model,
- r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- target_modules = ["q_proj", "v_proj"],
- lora_alpha = 64,
- lora_dropout = 0, # Supports any, but = 0 is optimized
- bias = "none", # Supports any, but = "none" is optimized
- # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
- use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
- random_state = 3407,
- use_rslora = False, # We support rank stabilized LoRA
- loftq_config = None, # And LoftQ
- task_type = None, # ** MUST set this for Whisper **
-)
-
-print("✅ Model and LoRA adapters loaded successfully!")
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 2: Checking Model Class Type")
-print(f"{'='*80}")
-
-assert isinstance(model, PeftModel), "Model should be an instance of PeftModel"
-print("✅ Model is an instance of PeftModel!")
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 3: Checking Config Model Class Type")
-print(f"{'='*80}")
-
-
-def find_lora_base_model(model_to_inspect):
- current = model_to_inspect
- if hasattr(current, "base_model"):
- current = current.base_model
- if hasattr(current, "model"):
- current = current.model
- return current
-
-
-config_model = find_lora_base_model(model) if isinstance(model, PeftModel) else model
-
-assert (
- config_model.__class__.__name__ == base_model_class
-), f"Expected config_model class to be {base_model_class}"
-print("✅ config_model returns correct Base Model class:", str(base_model_class))
-
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 4: Saving and Merging Model")
-print(f"{'='*80}")
-
-with warnings.catch_warnings():
- warnings.simplefilter("error") # Treat warnings as errors
- try:
- model.save_pretrained_merged("whisper", tokenizer)
- print("✅ Model saved and merged successfully without warnings!")
- except Exception as e:
- assert False, f"Model saving/merging failed with exception: {e}"
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 5: Loading Model for Inference")
-print(f"{'='*80}")
-
-
-model, tokenizer = FastModel.from_pretrained(
- model_name = "./whisper",
- dtype = None, # Leave as None for auto detection
- load_in_4bit = False, # Set to True to do 4bit quantization which reduces memory
- auto_model = WhisperForConditionalGeneration,
- whisper_language = "English",
- whisper_task = "transcribe",
- # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
-)
-
-# model = WhisperForConditionalGeneration.from_pretrained("./whisper")
-# processor = WhisperProcessor.from_pretrained("./whisper")
-
-print("✅ Model loaded for inference successfully!")
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 6: Downloading Sample Audio File")
-print(f"{'='*80}")
-
-audio_url = "https://upload.wikimedia.org/wikipedia/commons/5/5b/Speech_12dB_s16.flac"
-audio_file = "Speech_12dB_s16.flac"
-
-try:
- headers = {
- "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
- }
- response = requests.get(audio_url, headers = headers)
- response.raise_for_status()
- with open(audio_file, "wb") as f:
- f.write(response.content)
- print("✅ Audio file downloaded successfully!")
-except Exception as e:
- assert False, f"Failed to download audio file: {e}"
-
-print(f"\n{'='*80}")
-print("🔍 SECTION 7: Running Inference")
-print(f"{'='*80}")
-
-
-from transformers import pipeline
-import torch
-
-FastModel.for_inference(model)
-model.eval()
-# Create pipeline without specifying the device
-whisper = pipeline(
- "automatic-speech-recognition",
- model = model,
- tokenizer = tokenizer.tokenizer,
- feature_extractor = tokenizer.feature_extractor,
- processor = tokenizer,
- return_language = True,
- torch_dtype = torch.float16, # Remove the device parameter
-)
-# Example usage
-audio_file = "Speech_12dB_s16.flac"
-transcribed_text = whisper(audio_file)
-# audio, sr = sf.read(audio_file)
-# input_features = processor(audio, return_tensors="pt").input_features
-# transcribed_text = model.generate(input_features=input_features)
-print(f"📝 Transcribed Text: {transcribed_text['text']}")
-
-## assert that transcribed_text contains The birch canoe slid on the smooth planks. Glued the sheet to the dark blue background. It's easy to tell the depth of a well. Four hours of steady work faced us.
-
-expected_phrases = [
- "birch canoe slid on the smooth planks",
- "sheet to the dark blue background",
- "easy to tell the depth of a well",
- "Four hours of steady work faced us",
-]
-
-transcribed_lower = transcribed_text["text"].lower()
-all_phrases_found = all(
- phrase.lower() in transcribed_lower for phrase in expected_phrases
-)
-
-assert (
- all_phrases_found
-), f"Expected phrases not found in transcription: {transcribed_text['text']}"
-print("✅ Transcription contains all expected phrases!")
-
-
-safe_remove_directory("./unsloth_compiled_cache")
-safe_remove_directory("./whisper")
diff --git a/tests/saving/vision_models/test_index_file_sharded_model.py b/tests/saving/vision_models/test_index_file_sharded_model.py
deleted file mode 100644
index f737169841..0000000000
--- a/tests/saving/vision_models/test_index_file_sharded_model.py
+++ /dev/null
@@ -1,293 +0,0 @@
-## Import required libraries
-
-from unsloth import FastVisionModel, is_bf16_supported
-from unsloth.trainer import UnslothVisionDataCollator
-
-import torch
-import os
-from datasets import load_dataset
-from trl import SFTTrainer, SFTConfig
-from huggingface_hub import HfFileSystem
-import sys
-from pathlib import Path
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-
-
-## Dataset Preparation"""
-
-print("\n📊 Loading and preparing dataset...")
-dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
-# To select the first 2000 examples
-train_dataset = dataset.select(range(2000))
-
-# To select the next 200 examples for evaluation
-eval_dataset = dataset.select(range(2000, 2200))
-
-print(f"✅ Dataset loaded successfully!")
-print(f" 📈 Training samples: {len(train_dataset)}")
-print(f" 📊 Evaluation samples: {len(eval_dataset)}")
-
-
-# Convert dataset to OAI messages
-def format_data(sample):
- return {
- "messages": [
- {
- "role": "system",
- "content": [{"type": "text", "text": system_message}],
- },
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": sample["question"],
- },
- {
- "type": "image",
- "image": sample["image"],
- },
- ],
- },
- {
- "role": "assistant",
- "content": [{"type": "text", "text": sample["answer"]}],
- },
- ],
- }
-
-
-print("\n🔄 Formatting dataset for vision training...")
-system_message = "You are an expert french ocr system."
-# Convert dataset to OAI messages
-# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
-train_dataset = [format_data(sample) for sample in train_dataset]
-eval_dataset = [format_data(sample) for sample in eval_dataset]
-print("✅ Dataset formatting completed!")
-
-"""## Finetuning Setup and Run"""
-
-
-print("\n" + "=" * 80)
-print("=== MODEL LOADING AND SETUP ===".center(80))
-print("=" * 80 + "\n")
-# Load Base Model
-print("🤖 Loading base vision model...")
-try:
- model, tokenizer = FastVisionModel.from_pretrained(
- # model_name = "unsloth/Qwen2-VL-7B-Instruct",
- model_name = "unsloth/Qwen2-VL-7B-Instruct",
- max_seq_length = 2048, # Choose any for long context!
- load_in_4bit = True, # 4 bit quantization to reduce memory
- load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
- full_finetuning = False, # [NEW!] We have full finetuning now!
- )
-except Exception as e:
- print(f"❌ Failed to load base model: {e}")
- raise
-
-print("\n🔧 Setting up LoRA configuration...")
-## Lora Finetuning
-try:
- model = FastVisionModel.get_peft_model(
- model,
- finetune_vision_layers = True, # Turn off for just text!
- finetune_language_layers = True, # Should leave on!
- finetune_attention_modules = True, # Attention good for GRPO
- finetune_mlp_modules = True, # SHould leave on always!
- r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- lora_alpha = 32,
- lora_dropout = 0, # Supports any, but = 0 is optimized
- bias = "none", # Supports any, but = "none" is optimized
- use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
- random_state = 3407,
- use_rslora = False, # We support rank stabilized LoRA
- loftq_config = None, # And LoftQ
- )
- print("✅ LoRA configuration applied successfully!")
- print(f" 🎯 LoRA rank (r): 16")
- print(f" 📊 LoRA alpha: 32")
- print(f" 🔍 Vision layers: Enabled")
- print(f" 💬 Language layers: Enabled")
-except Exception as e:
- print(f"❌ Failed to apply LoRA configuration: {e}")
- raise
-
-print("\n" + "=" * 80)
-print("=== TRAINING SETUP ===".center(80))
-print("=" * 80 + "\n")
-
-
-print("🏋️ Preparing trainer...")
-FastVisionModel.for_training(model) # Enable for training!
-
-try:
- trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- data_collator = UnslothVisionDataCollator(model, tokenizer),
- train_dataset = train_dataset,
- args = SFTConfig(
- # per_device_train_batch_size = 4,
- # gradient_accumulation_steps = 8,
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = {
- "use_reentrant": False
- }, # use reentrant checkpointing
- max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
- warmup_ratio = 0.03,
- # num_train_epochs = 2, # Set this instead of max_steps for full training runs
- max_steps = 10,
- learning_rate = 2e-4,
- fp16 = not is_bf16_supported(),
- bf16 = is_bf16_supported(),
- logging_steps = 5,
- save_strategy = "epoch",
- optim = "adamw_torch_fused",
- weight_decay = 0.01,
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "checkpoints",
- report_to = "none", # For Weights and Biases
- # You MUST put the below items for vision finetuning:
- remove_unused_columns = False,
- dataset_text_field = "",
- dataset_kwargs = {"skip_prepare_dataset": True},
- dataset_num_proc = 4,
- max_seq_length = 2048,
- ),
- )
- print("✅ Trainer setup completed!")
- print(f" 📦 Batch size: 2")
- print(f" 🔄 Gradient accumulation steps: 4")
- print(f" 📈 Max training steps: 10")
- print(f" 🎯 Learning rate: 2e-4")
- print(f" 💾 Precision: {'BF16' if is_bf16_supported() else 'FP16'}")
-except Exception as e:
- print(f"❌ Failed to setup trainer: {e}")
- raise
-
-print("\n" + "=" * 80)
-print("=== STARTING TRAINING ===".center(80))
-print("=" * 80 + "\n")
-# run training
-try:
- print("🚀 Starting training process...")
- trainer_stats = trainer.train()
-except Exception as e:
- print(f"❌ Training failed: {e}")
- raise
-
-print("\n" + "=" * 80)
-print("=== SAVING MODEL ===".center(80))
-print("=" * 80 + "\n")
-
-print("💾 Saving adapter model and tokenizer locally...")
-try:
- model.save_pretrained("unsloth-qwen2-7vl-french-ocr-adapter", tokenizer)
- tokenizer.save_pretrained("unsloth-qwen2-7vl-french-ocr-adapter")
- print("✅ Model saved locally!")
-except Exception as e:
- print(f"❌ Failed to save model locally: {e}")
- raise
-
-
-hf_username = os.environ.get("HF_USER", "")
-if not hf_username:
- hf_username = input("Please enter your Hugging Face username: ").strip()
- os.environ["HF_USER"] = hf_username
-
-hf_token = os.environ.get("HF_TOKEN", "")
-if not hf_token:
- hf_token = input("Please enter your Hugging Face token: ").strip()
- os.environ["HF_TOKEN"] = hf_token
-
-repo_name = f"{hf_username}/qwen2-7b-ocr-merged"
-success = {
- "upload": False,
- "safetensors_check": False,
- "download": False,
-}
-# Stage 1: Upload model to Hub
-try:
- print("\n" + "=" * 80)
- print("=== UPLOADING MODEL TO HUB ===".center(80))
- print("=" * 80 + "\n")
- print(f"🚀 Uploading to repository: {repo_name}")
- model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
- success["upload"] = True
- print("✅ Model uploaded successfully!")
-except Exception as e:
- print(f"❌ Failed to upload model: {e}")
- raise Exception("Model upload failed.")
-
-# Stage 2: Verify safetensors.index.json exists
-try:
- print("\n" + "=" * 80)
- print("=== VERIFYING REPO CONTENTS ===".center(80))
- print("=" * 80 + "\n")
- fs = HfFileSystem(token = hf_token)
- file_list = fs.ls(repo_name, detail = True)
- safetensors_found = any(
- file["name"].endswith("model.safetensors.index.json") for file in file_list
- )
- if safetensors_found:
- success["safetensors_check"] = True
- print("✅ model.safetensors.index.json found in repo!")
- else:
- raise Exception("model.safetensors.index.json not found in repo.")
-except Exception as e:
- print(f"❌ Verification failed: {e}")
- raise Exception("Repo verification failed.")
-
-# test downloading model even if cached
-safe_remove_directory(f"./{hf_username}")
-
-try:
- print("\n" + "=" * 80)
- print("=== TESTING MODEL DOWNLOAD ===".center(80))
- print("=" * 80 + "\n")
- print("📥 Testing model download...")
- # Force download even if cached
- test_model, test_tokenizer = FastVisionModel.from_pretrained(repo_name)
- success["download"] = True
- print("✅ Model downloaded successfully!")
-
- # Clean up test model
- del test_model, test_tokenizer
- torch.cuda.empty_cache()
-except Exception as e:
- print(f"❌ Download failed: {e}")
- raise Exception("Model download failed.")
-
-# Final report
-print("\n" + "=" * 80)
-print("=== VALIDATION REPORT ===".center(80))
-print("=" * 80 + "\n")
-for stage, passed in success.items():
- status = "✅" if passed else "❌"
- print(f"{status} {stage.replace('_', ' ').title()}")
-print("\n" + "=" * 80)
-
-if all(success.values()):
- print("\n🎉 All stages completed successfully!")
- print(f"🌐 Your model is available at: https://huggingface.co/{repo_name}")
-else:
- raise Exception("Validation failed for one or more stages.")
-
-
-# Final cleanup
-print("\n🧹 Cleaning up temporary files...")
-safe_remove_directory("./checkpoints")
-safe_remove_directory("./unsloth_compiled_cache")
-safe_remove_directory("./unsloth-qwen2-7vl-french-ocr-adapter")
-
-print("\n🎯 Pipeline completed successfully!")
-print("=" * 80)
diff --git a/tests/saving/vision_models/test_push_to_hub_merged.py b/tests/saving/vision_models/test_push_to_hub_merged.py
deleted file mode 100644
index 74fa058988..0000000000
--- a/tests/saving/vision_models/test_push_to_hub_merged.py
+++ /dev/null
@@ -1,273 +0,0 @@
-## Import required libraries
-
-from unsloth import FastVisionModel, is_bf16_supported
-from unsloth.trainer import UnslothVisionDataCollator
-
-import torch
-import os
-from datasets import load_dataset
-from trl import SFTTrainer, SFTConfig
-
-import sys
-from pathlib import Path
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-
-from tests.utils.cleanup_utils import safe_remove_directory
-
-
-## Dataset Preparation"""
-
-print("\n📊 Loading and preparing dataset...")
-dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
-# To select the first 2000 examples
-train_dataset = dataset.select(range(2000))
-
-# To select the next 200 examples for evaluation
-eval_dataset = dataset.select(range(2000, 2200))
-
-print(f"✅ Dataset loaded successfully!")
-print(f" 📈 Training samples: {len(train_dataset)}")
-print(f" 📊 Evaluation samples: {len(eval_dataset)}")
-
-
-# Convert dataset to OAI messages
-def format_data(sample):
- return {
- "messages": [
- {
- "role": "system",
- "content": [{"type": "text", "text": system_message}],
- },
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": sample["question"],
- },
- {
- "type": "image",
- "image": sample["image"],
- },
- ],
- },
- {
- "role": "assistant",
- "content": [{"type": "text", "text": sample["answer"]}],
- },
- ],
- }
-
-
-print("\n🔄 Formatting dataset for vision training...")
-system_message = "You are an expert french ocr system."
-# Convert dataset to OAI messages
-# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
-train_dataset = [format_data(sample) for sample in train_dataset]
-eval_dataset = [format_data(sample) for sample in eval_dataset]
-print("✅ Dataset formatting completed!")
-
-"""## Finetuning Setup and Run"""
-
-
-print("\n" + "=" * 80)
-print("=== MODEL LOADING AND SETUP ===".center(80))
-print("=" * 80 + "\n")
-# Load Base Model
-print("🤖 Loading base vision model...")
-try:
- model, tokenizer = FastVisionModel.from_pretrained(
- # model_name = "unsloth/Qwen2-VL-7B-Instruct",
- model_name = "unsloth/Qwen2-VL-2B-Instruct",
- max_seq_length = 2048, # Choose any for long context!
- load_in_4bit = True, # 4 bit quantization to reduce memory
- load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
- full_finetuning = False, # [NEW!] We have full finetuning now!
- )
-except Exception as e:
- print(f"❌ Failed to load base model: {e}")
- raise
-
-print("\n🔧 Setting up LoRA configuration...")
-## Lora Finetuning
-try:
- model = FastVisionModel.get_peft_model(
- model,
- finetune_vision_layers = True, # Turn off for just text!
- finetune_language_layers = True, # Should leave on!
- finetune_attention_modules = True, # Attention good for GRPO
- finetune_mlp_modules = True, # SHould leave on always!
- r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- lora_alpha = 32,
- lora_dropout = 0, # Supports any, but = 0 is optimized
- bias = "none", # Supports any, but = "none" is optimized
- use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
- random_state = 3407,
- use_rslora = False, # We support rank stabilized LoRA
- loftq_config = None, # And LoftQ
- )
- print("✅ LoRA configuration applied successfully!")
- print(f" 🎯 LoRA rank (r): 16")
- print(f" 📊 LoRA alpha: 32")
- print(f" 🔍 Vision layers: Enabled")
- print(f" 💬 Language layers: Enabled")
-except Exception as e:
- print(f"❌ Failed to apply LoRA configuration: {e}")
- raise
-
-print("\n" + "=" * 80)
-print("=== TRAINING SETUP ===".center(80))
-print("=" * 80 + "\n")
-
-
-print("🏋️ Preparing trainer...")
-FastVisionModel.for_training(model) # Enable for training!
-
-try:
- trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- data_collator = UnslothVisionDataCollator(model, tokenizer),
- train_dataset = train_dataset,
- args = SFTConfig(
- # per_device_train_batch_size = 4,
- # gradient_accumulation_steps = 8,
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = {
- "use_reentrant": False
- }, # use reentrant checkpointing
- max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
- warmup_ratio = 0.03,
- # num_train_epochs = 2, # Set this instead of max_steps for full training runs
- max_steps = 10,
- learning_rate = 2e-4,
- fp16 = not is_bf16_supported(),
- bf16 = is_bf16_supported(),
- logging_steps = 5,
- save_strategy = "epoch",
- optim = "adamw_torch_fused",
- weight_decay = 0.01,
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "checkpoints",
- report_to = "none", # For Weights and Biases
- # You MUST put the below items for vision finetuning:
- remove_unused_columns = False,
- dataset_text_field = "",
- dataset_kwargs = {"skip_prepare_dataset": True},
- dataset_num_proc = 4,
- max_seq_length = 2048,
- ),
- )
- print("✅ Trainer setup completed!")
- print(f" 📦 Batch size: 2")
- print(f" 🔄 Gradient accumulation steps: 4")
- print(f" 📈 Max training steps: 10")
- print(f" 🎯 Learning rate: 2e-4")
- print(f" 💾 Precision: {'BF16' if is_bf16_supported() else 'FP16'}")
-except Exception as e:
- print(f"❌ Failed to setup trainer: {e}")
- raise
-
-print("\n" + "=" * 80)
-print("=== STARTING TRAINING ===".center(80))
-print("=" * 80 + "\n")
-# run training
-try:
- print("🚀 Starting training process...")
- trainer_stats = trainer.train()
-except Exception as e:
- print(f"❌ Training failed: {e}")
- raise
-
-print("\n" + "=" * 80)
-print("=== SAVING MODEL ===".center(80))
-print("=" * 80 + "\n")
-
-print("💾 Saving adapter model and tokenizer locally...")
-try:
- model.save_pretrained("unsloth-qwen2-7vl-french-ocr-adapter", tokenizer)
- tokenizer.save_pretrained("unsloth-qwen2-7vl-french-ocr-adapter")
- print("✅ Model saved locally!")
-except Exception as e:
- print(f"❌ Failed to save model locally: {e}")
- raise
-
-
-hf_username = os.environ.get("HF_USER", "")
-if not hf_username:
- hf_username = input("Please enter your Hugging Face username: ").strip()
- os.environ["HF_USER"] = hf_username
-
-hf_token = os.environ.get("HF_TOKEN", "")
-if not hf_token:
- hf_token = input("Please enter your Hugging Face token: ").strip()
- os.environ["HF_TOKEN"] = hf_token
-
-repo_name = f"{hf_username}/qwen2-ocr-merged"
-success = {
- "upload": False,
- "download": False,
-}
-# Stage 1: Upload model to Hub
-try:
- print("\n" + "=" * 80)
- print("=== UPLOADING MODEL TO HUB ===".center(80))
- print("=" * 80 + "\n")
- print(f"🚀 Uploading to repository: {repo_name}")
- model.push_to_hub_merged(repo_name, tokenizer = tokenizer, token = hf_token)
- success["upload"] = True
- print("✅ Model uploaded successfully!")
-except Exception as e:
- print(f"❌ Failed to upload model: {e}")
- raise Exception("Model upload failed.")
-
-
-try:
- print("\n" + "=" * 80)
- print("=== TESTING MODEL DOWNLOAD ===".center(80))
- print("=" * 80 + "\n")
- print("📥 Testing model download...")
- # Force download even if cached
- test_model, test_tokenizer = FastVisionModel.from_pretrained(repo_name)
- success["download"] = True
- print("✅ Model downloaded successfully!")
-
- # Clean up test model
- del test_model, test_tokenizer
- torch.cuda.empty_cache()
-except Exception as e:
- print(f"❌ Download failed: {e}")
- raise Exception("Model download failed.")
-
-# Final report
-print("\n" + "=" * 80)
-print("=== VALIDATION REPORT ===".center(80))
-print("=" * 80 + "\n")
-for stage, passed in success.items():
- status = "✅" if passed else "❌"
- print(f"{status} {stage.replace('_', ' ').title()}")
-print("\n" + "=" * 80)
-
-if all(success.values()):
- print("\n🎉 All stages completed successfully!")
- print(f"🌐 Your model is available at: https://huggingface.co/{repo_name}")
-else:
- raise Exception("Validation failed for one or more stages.")
-
-
-# Final cleanup
-print("\n🧹 Cleaning up temporary files...")
-safe_remove_directory("./checkpoints")
-safe_remove_directory("./unsloth_compiled_cache")
-safe_remove_directory("./unsloth-qwen2-7vl-french-ocr-adapter")
-safe_remove_directory(f"./{hf_username}")
-
-print("\n🎯 Pipeline completed successfully!")
-print("=" * 80)
diff --git a/tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py b/tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py
deleted file mode 100644
index ebe078c73b..0000000000
--- a/tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py
+++ /dev/null
@@ -1,287 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from unsloth import FastVisionModel
-
-import torch
-from qwen_vl_utils import process_vision_info
-import os
-from datasets import load_dataset
-from trl import SFTTrainer, SFTConfig
-
-import sys
-from pathlib import Path
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.ocr_eval import OCRModelEvaluator
-
-
-## Dataset Preparation
-from datasets import load_dataset
-
-dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
-# To select the first 2000 examples
-train_dataset = dataset.select(range(2000))
-
-# To select the next 200 examples for evaluation
-eval_dataset = dataset.select(range(2000, 2200))
-
-
-# Convert dataset to OAI messages
-def format_data(sample):
- return {
- "messages": [
- {
- "role": "system",
- "content": [{"type": "text", "text": system_message}],
- },
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": sample["question"],
- },
- {
- "type": "image",
- "image": sample["image"],
- },
- ],
- },
- {
- "role": "assistant",
- "content": [{"type": "text", "text": sample["answer"]}],
- },
- ],
- }
-
-
-system_message = "You are an expert french ocr system."
-# Convert dataset to OAI messages
-# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
-train_dataset = [format_data(sample) for sample in train_dataset]
-eval_dataset = [format_data(sample) for sample in eval_dataset]
-
-## Setup OCR main evaluation function and helpers
-import os
-import torch
-from tqdm import tqdm
-import pandas as pd
-from jiwer import wer, cer
-from qwen_vl_utils import process_vision_info
-
-#
-ocr_evaluator = OCRModelEvaluator()
-model_comparison_results = {}
-
-## Finetuning Setup and Run
-# Load Base Model
-
-model, tokenizer = FastVisionModel.from_pretrained(
- model_name = "unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit",
- max_seq_length = 2048, # Choose any for long context!
- load_in_4bit = True, # 4 bit quantization to reduce memory
- load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
- full_finetuning = False, # [NEW!] We have full finetuning now!
-)
-
-# benchmark base model performance
-model_name = "Unsloth Base model"
-FastVisionModel.for_inference(model)
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(
- model, tokenizer, eval_dataset, output_dir = "unsloth_base_model_results"
-)
-ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-## Lora Finetuning
-model = FastVisionModel.get_peft_model(
- model,
- finetune_vision_layers = True, # Turn off for just text!
- finetune_language_layers = True, # Should leave on!
- finetune_attention_modules = True, # Attention good for GRPO
- finetune_mlp_modules = True, # SHould leave on always!
- r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- # target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
- # "gate_proj", "up_proj", "down_proj",],
- lora_alpha = 32,
- lora_dropout = 0, # Supports any, but = 0 is optimized
- bias = "none", # Supports any, but = "none" is optimized
- # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
- use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
- random_state = 3407,
- use_rslora = False, # We support rank stabilized LoRA
- loftq_config = None, # And LoftQ
-)
-
-from unsloth import is_bf16_supported
-from unsloth.trainer import UnslothVisionDataCollator
-
-FastVisionModel.for_training(model) # Enable for training!
-model.config.use_cache = False
-
-
-trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- data_collator = UnslothVisionDataCollator(model, tokenizer),
- train_dataset = train_dataset,
- args = SFTConfig(
- # per_device_train_batch_size = 4,
- # gradient_accumulation_steps = 8,
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = {
- "use_reentrant": False
- }, # use reentrant checkpointing
- max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
- warmup_ratio = 0.03,
- # num_train_epochs = 2, # Set this instead of max_steps for full training runs
- max_steps = 60,
- learning_rate = 2e-4,
- fp16 = not is_bf16_supported(),
- bf16 = is_bf16_supported(),
- logging_steps = 5,
- save_strategy = "epoch",
- optim = "adamw_torch_fused",
- weight_decay = 0.01,
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "unsloth-qwen2.5-vl-32b-french-ocr-checkpoints",
- report_to = "none", # For Weights and Biases
- # You MUST put the below items for vision finetuning:
- remove_unused_columns = False,
- dataset_text_field = "",
- dataset_kwargs = {"skip_prepare_dataset": True},
- dataset_num_proc = 4,
- max_seq_length = 2048,
- ),
-)
-
-# run training
-trainer_stats = trainer.train()
-
-model.save_pretrained("unsloth-qwen2.5-vl-32b-french-ocr-adapter", tokenizer)
-tokenizer.save_pretrained("unsloth-qwen2.5-vl-32b-french-ocr-adapter")
-
-## Measure Adapter Performance
-
-# benchmark lora model performance
-model_name = "Unsloth lora adapter model"
-FastVisionModel.for_inference(model)
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(
- model, tokenizer, eval_dataset, output_dir = "unsloth_lora_model_results"
-)
-ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-## Merge Model
-
-
-def find_lora_base_model(model_to_inspect):
- current = model_to_inspect
- if hasattr(current, "base_model"):
- current = current.base_model
- if hasattr(current, "model"):
- current = current.model
- return current
-
-
-base = find_lora_base_model(model)
-
-print((base.__class__.__name__))
-
-# merge default 16 bits
-model.save_pretrained_merged(
- save_directory = "qwen2.5-ocr-merged-finetune-merge-16bit", tokenizer = tokenizer
-)
-
-
-## Benchmark merged model performance
-
-### 16 bits merged model
-
-model, tokenizer = FastVisionModel.from_pretrained(
- "./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = False
-)
-
-# benchmark 4bit loaded, 16bits merged model performance
-model_name = "Unsloth 16bits-merged model load-16bits"
-model.config.use_cache = True
-
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(
- model,
- tokenizer,
- eval_dataset,
- output_dir = "unsloth_16bits_merged_model_load_16bits_results",
-)
-ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-# load 16bits-merged model in 4 bits
-model, tokenizer = FastVisionModel.from_pretrained(
- "./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = True, load_in_8bit = False
-)
-
-# benchmark 4bit loaded, 16bits merged model performance
-model_name = "Unsloth 16bits-merged model load-4bits"
-model.config.use_cache = True
-
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(
- model,
- tokenizer,
- eval_dataset,
- output_dir = "unsloth_16bits_merged_model_load_4bits_results",
-)
-ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-# load model in 8 bits
-model, tokenizer = FastVisionModel.from_pretrained(
- "./qwen2.5-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = True
-)
-
-# benchmark 4bit loaded, 16bits merged model performance
-model_name = "Unsloth 16bits-merged model load-8bits"
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(
- model,
- tokenizer,
- eval_dataset,
- output_dir = "unsloth_16bits_merged_model_load_8bits_results",
-)
-ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-# """### 4 bits merged model"""
-#
-# # load 4bits-merged model in 4 bits
-# model, tokenizer = FastVisionModel.from_pretrained("./qwen2-ocr-merged-finetune-merge-4bit",load_in_4bit=True, load_in_8bit=False)
-#
-# # benchmark 4bit loaded, 4bits merged model performance
-# model_name = "Unsloth 4bits-merged model load-4bits"
-#
-# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_4bits_merged_model_load_4bits_results")
-# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-#
-# # load model in 8 bits
-# model, tokenizer = FastVisionModel.from_pretrained("./qwen2-ocr-merged-finetune-merge-4bit",load_in_4bit=False, load_in_8bit=True)
-#
-# # benchmark 8bit loaded, 4bits merged model performance
-# model_name = "Unsloth 4bits-merged model load-8bits"
-#
-# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_4bits_merged_model_load_8bits_results")
-# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-# Model comparison report
-# print model comparison
-ocr_evaluator.print_model_comparison()
-
-
-# Final cleanup
-print("\n🧹 Cleaning up temporary files...")
-safe_remove_directory("./unsloth-qwen2.5-vl-32b-french-ocr-adapter")
-safe_remove_directory("./unsloth-qwen2.5-vl-32b-french-ocr-checkpoints")
-safe_remove_directory("./unsloth_compiled_cache")
-safe_remove_directory("./qwen2.5-ocr-merged-finetune-merge-16bit")
-
-print("\n🎯 Pipeline completed successfully!")
-print("=" * 80)
diff --git a/tests/saving/vision_models/test_save_merge_vision_model_ocr_benchmark.py b/tests/saving/vision_models/test_save_merge_vision_model_ocr_benchmark.py
deleted file mode 100644
index b99785bcb1..0000000000
--- a/tests/saving/vision_models/test_save_merge_vision_model_ocr_benchmark.py
+++ /dev/null
@@ -1,287 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from unsloth import FastVisionModel
-
-import torch
-from qwen_vl_utils import process_vision_info
-import os
-from datasets import load_dataset
-from trl import SFTTrainer, SFTConfig
-
-import sys
-from pathlib import Path
-
-
-REPO_ROOT = Path(__file__).parents[3]
-sys.path.insert(0, str(REPO_ROOT))
-
-from tests.utils.cleanup_utils import safe_remove_directory
-from tests.utils.ocr_eval import OCRModelEvaluator
-
-
-## Dataset Preparation
-from datasets import load_dataset
-
-dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", "en", split = "train")
-# To select the first 2000 examples
-train_dataset = dataset.select(range(2000))
-
-# To select the next 200 examples for evaluation
-eval_dataset = dataset.select(range(2000, 2200))
-
-
-# Convert dataset to OAI messages
-def format_data(sample):
- return {
- "messages": [
- {
- "role": "system",
- "content": [{"type": "text", "text": system_message}],
- },
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": sample["question"],
- },
- {
- "type": "image",
- "image": sample["image"],
- },
- ],
- },
- {
- "role": "assistant",
- "content": [{"type": "text", "text": sample["answer"]}],
- },
- ],
- }
-
-
-system_message = "You are an expert french ocr system."
-# Convert dataset to OAI messages
-# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
-train_dataset = [format_data(sample) for sample in train_dataset]
-eval_dataset = [format_data(sample) for sample in eval_dataset]
-
-## Setup OCR main evaluation function and helpers
-import os
-import torch
-from tqdm import tqdm
-import pandas as pd
-from jiwer import wer, cer
-from qwen_vl_utils import process_vision_info
-
-#
-ocr_evaluator = OCRModelEvaluator()
-model_comparison_results = {}
-
-## Finetuning Setup and Run
-# Load Base Model
-
-model, tokenizer = FastVisionModel.from_pretrained(
- model_name = "unsloth/Qwen2-VL-7B-Instruct",
- max_seq_length = 2048, # Choose any for long context!
- load_in_4bit = True, # 4 bit quantization to reduce memory
- load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
- full_finetuning = False, # [NEW!] We have full finetuning now!
-)
-
-# benchmark base model performance
-model_name = "Unsloth Base model"
-FastVisionModel.for_inference(model)
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(
- model, tokenizer, eval_dataset, output_dir = "unsloth_base_model_results"
-)
-ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-## Lora Finetuning
-model = FastVisionModel.get_peft_model(
- model,
- finetune_vision_layers = True, # Turn off for just text!
- finetune_language_layers = True, # Should leave on!
- finetune_attention_modules = True, # Attention good for GRPO
- finetune_mlp_modules = True, # SHould leave on always!
- r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
- # target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
- # "gate_proj", "up_proj", "down_proj",],
- lora_alpha = 32,
- lora_dropout = 0, # Supports any, but = 0 is optimized
- bias = "none", # Supports any, but = "none" is optimized
- # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
- use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
- random_state = 3407,
- use_rslora = False, # We support rank stabilized LoRA
- loftq_config = None, # And LoftQ
-)
-
-from unsloth import is_bf16_supported
-from unsloth.trainer import UnslothVisionDataCollator
-
-FastVisionModel.for_training(model) # Enable for training!
-model.config.use_cache = False
-
-
-trainer = SFTTrainer(
- model = model,
- tokenizer = tokenizer,
- data_collator = UnslothVisionDataCollator(model, tokenizer),
- train_dataset = train_dataset,
- args = SFTConfig(
- # per_device_train_batch_size = 4,
- # gradient_accumulation_steps = 8,
- per_device_train_batch_size = 2,
- gradient_accumulation_steps = 4,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = {
- "use_reentrant": False
- }, # use reentrant checkpointing
- max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
- warmup_ratio = 0.03,
- # num_train_epochs = 2, # Set this instead of max_steps for full training runs
- max_steps = 60,
- learning_rate = 2e-4,
- fp16 = not is_bf16_supported(),
- bf16 = is_bf16_supported(),
- logging_steps = 5,
- save_strategy = "epoch",
- optim = "adamw_torch_fused",
- weight_decay = 0.01,
- lr_scheduler_type = "linear",
- seed = 3407,
- output_dir = "unsloth-qwen2-7vl-french-ocr-checkpoints",
- report_to = "none", # For Weights and Biases
- # You MUST put the below items for vision finetuning:
- remove_unused_columns = False,
- dataset_text_field = "",
- dataset_kwargs = {"skip_prepare_dataset": True},
- dataset_num_proc = 4,
- max_seq_length = 2048,
- ),
-)
-
-# run training
-trainer_stats = trainer.train()
-
-model.save_pretrained("unsloth-qwen2-7vl-french-ocr-adapter", tokenizer)
-tokenizer.save_pretrained("unsloth-qwen2-7vl-french-ocr-adapter")
-
-## Measure Adapter Performance
-
-# benchmark lora model performance
-model_name = "Unsloth lora adapter model"
-FastVisionModel.for_inference(model)
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(
- model, tokenizer, eval_dataset, output_dir = "unsloth_lora_model_results"
-)
-ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-## Merge Model
-
-
-def find_lora_base_model(model_to_inspect):
- current = model_to_inspect
- if hasattr(current, "base_model"):
- current = current.base_model
- if hasattr(current, "model"):
- current = current.model
- return current
-
-
-base = find_lora_base_model(model)
-
-print((base.__class__.__name__))
-
-# merge default 16 bits
-model.save_pretrained_merged(
- save_directory = "qwen2-ocr-merged-finetune-merge-16bit", tokenizer = tokenizer
-)
-
-
-## Benchmark merged model performance
-
-### 16 bits merged model
-
-model, tokenizer = FastVisionModel.from_pretrained(
- "./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = False
-)
-
-# benchmark 4bit loaded, 16bits merged model performance
-model_name = "Unsloth 16bits-merged model load-16bits"
-model.config.use_cache = True
-
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(
- model,
- tokenizer,
- eval_dataset,
- output_dir = "unsloth_16bits_merged_model_load_16bits_results",
-)
-ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-# load 16bits-merged model in 4 bits
-model, tokenizer = FastVisionModel.from_pretrained(
- "./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = True, load_in_8bit = False
-)
-
-# benchmark 4bit loaded, 16bits merged model performance
-model_name = "Unsloth 16bits-merged model load-4bits"
-model.config.use_cache = True
-
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(
- model,
- tokenizer,
- eval_dataset,
- output_dir = "unsloth_16bits_merged_model_load_4bits_results",
-)
-ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-# load model in 8 bits
-model, tokenizer = FastVisionModel.from_pretrained(
- "./qwen2-ocr-merged-finetune-merge-16bit", load_in_4bit = False, load_in_8bit = True
-)
-
-# benchmark 4bit loaded, 16bits merged model performance
-model_name = "Unsloth 16bits-merged model load-8bits"
-avg_wer, avg_cer = ocr_evaluator.evaluate_model(
- model,
- tokenizer,
- eval_dataset,
- output_dir = "unsloth_16bits_merged_model_load_8bits_results",
-)
-ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-# """### 4 bits merged model"""
-#
-# # load 4bits-merged model in 4 bits
-# model, tokenizer = FastVisionModel.from_pretrained("./qwen2-ocr-merged-finetune-merge-4bit",load_in_4bit=True, load_in_8bit=False)
-#
-# # benchmark 4bit loaded, 4bits merged model performance
-# model_name = "Unsloth 4bits-merged model load-4bits"
-#
-# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_4bits_merged_model_load_4bits_results")
-# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-#
-# # load model in 8 bits
-# model, tokenizer = FastVisionModel.from_pretrained("./qwen2-ocr-merged-finetune-merge-4bit",load_in_4bit=False, load_in_8bit=True)
-#
-# # benchmark 8bit loaded, 4bits merged model performance
-# model_name = "Unsloth 4bits-merged model load-8bits"
-#
-# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_4bits_merged_model_load_8bits_results")
-# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer)
-
-# Model comparison report
-# print model comparison
-ocr_evaluator.print_model_comparison()
-
-
-# Final cleanup
-print("\n🧹 Cleaning up temporary files...")
-safe_remove_directory("./unsloth-qwen2-7vl-french-ocr-adapter")
-safe_remove_directory("./unsloth-qwen2-7vl-french-ocr-checkpoints")
-safe_remove_directory("./unsloth_compiled_cache")
-safe_remove_directory("./qwen2-ocr-merged-finetune-merge-16bit")
-
-print("\n🎯 Pipeline completed successfully!")
-print("=" * 80)
diff --git a/tests/test_get_model_name.py b/tests/test_get_model_name.py
deleted file mode 100644
index ad89f595f0..0000000000
--- a/tests/test_get_model_name.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import unittest
-from unittest.mock import patch
-from unsloth.models.loader_utils import get_model_name
-from unsloth.models import loader_utils
-from unsloth.models.mapper import FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
-
-
-def _no_remote_mapper():
- return {}, {}, {}
-
-
-class TestGetModelName(unittest.TestCase):
- def _assert_mapping(self, model_name, load_in_4bit, expected, should_change):
- mapped = get_model_name(model_name, load_in_4bit = load_in_4bit)
- self.assertEqual(mapped.lower(), expected.lower())
- if should_change:
- self.assertNotEqual(mapped.lower(), model_name.lower())
- else:
- self.assertEqual(mapped.lower(), model_name.lower())
-
- @patch.object(loader_utils, "_get_new_mapper", _no_remote_mapper)
- def test_resolution_matrix(self):
- cases = [
- # Core mappings
- ("meta-llama/Llama-2-7b-hf", True, "unsloth/llama-2-7b-bnb-4bit", True),
- ("meta-llama/Llama-2-7b-hf", False, "unsloth/llama-2-7b", True),
- (
- "mistralai/Ministral-8B-Instruct-2410",
- True,
- "mistralai/Ministral-8B-Instruct-2410",
- False,
- ),
- (
- "meta-llama/Llama-3.2-1B-Instruct",
- False,
- "unsloth/Llama-3.2-1B-Instruct",
- True,
- ),
- (
- "meta-llama/Llama-2-7b-chat-hf",
- True,
- "unsloth/llama-2-7b-chat-bnb-4bit",
- True,
- ),
- (
- "meta-llama/Llama-3.3-70B-Instruct",
- True,
- "unsloth/llama-3.3-70b-instruct-unsloth-bnb-4bit",
- True,
- ),
- ("Qwen/Qwen3-8B", True, "unsloth/Qwen3-8B-unsloth-bnb-4bit", True),
- ("Qwen/Qwen3-8B", False, "unsloth/Qwen3-8B", True),
- ("Qwen/Qwen3-8B-FP8", False, "unsloth/Qwen3-8B-FP8", True),
- ("Qwen/Qwen3-8B-FP8", True, "unsloth/Qwen3-8B-unsloth-bnb-4bit", True),
- (
- "mistralai/Ministral-3-3B-Instruct-2512",
- True,
- "unsloth/Ministral-3-3B-Instruct-2512-unsloth-bnb-4bit",
- True,
- ),
- (
- "mistralai/Ministral-3-3B-Instruct-2512",
- False,
- "unsloth/Ministral-3-3B-Instruct-2512",
- True,
- ),
- ("unsloth/Kimi-K2-Instruct", True, "unsloth/Kimi-K2-Instruct-BF16", True),
- ("unsloth/Kimi-K2-Instruct", False, "unsloth/Kimi-K2-Instruct", False),
- # Fallback-to-original behavior
- "nonexistent-user/nonexistent-model-123",
- "google/gemma-3-random-prototype-123",
- "imdatta0/nanoqwen-fp8",
- "imdatta0/nanoqwen-bf16",
- # Backward compatibility for legacy 4bit names
- ("unsloth/llama-2-7b-bnb-4bit", True, "unsloth/llama-2-7b-bnb-4bit", False),
- ("unsloth/llama-2-7b-bnb-4bit", False, "unsloth/llama-2-7b", True),
- ("google/gemma-2-9b", True, "unsloth/gemma-2-9b-bnb-4bit", True),
- # GPT-OSS behavior
- ("openai/gpt-oss-20b", False, "unsloth/gpt-oss-20b", True),
- ("openai/gpt-oss-20b", True, "unsloth/gpt-oss-20b-unsloth-bnb-4bit", True),
- ("unsloth/gpt-oss-20b", True, "unsloth/gpt-oss-20b-unsloth-bnb-4bit", True),
- ("unsloth/gpt-oss-20b-bf16", True, "unsloth/gpt-oss-20b-bf16", False),
- (
- "unsloth/gpt-oss-20b-unsloth-bnb-4bit",
- False,
- "unsloth/gpt-oss-20b",
- True,
- ),
- (
- "unsloth/gpt-oss-20b-bnb-4bit",
- True,
- "unsloth/gpt-oss-20b-bnb-4bit",
- False,
- ),
- ]
- for case in cases:
- if isinstance(case, str):
- model_name = case
- with self.subTest(model_name = model_name, load_in_4bit = True):
- self._assert_mapping(model_name, True, model_name, False)
- else:
- model_name, load_in_4bit, expected, should_change = case
- with self.subTest(model_name = model_name, load_in_4bit = load_in_4bit):
- self._assert_mapping(
- model_name, load_in_4bit, expected, should_change
- )
-
- def test_static_mapper_contract(self):
- contracts = [
- ("qwen/qwen3-8b", "unsloth/qwen3-8b-unsloth-bnb-4bit"),
- ("qwen/qwen3-8b-fp8", "unsloth/qwen3-8b-unsloth-bnb-4bit"),
- (
- "mistralai/ministral-3-3b-instruct-2512",
- "unsloth/ministral-3-3b-instruct-2512-unsloth-bnb-4bit",
- ),
- ("unsloth/kimi-k2-instruct", "unsloth/kimi-k2-instruct-bf16"),
- ]
- for src, expected in contracts:
- with self.subTest(src = src):
- self.assertEqual(FLOAT_TO_INT_MAPPER[src], expected)
- self.assertEqual(
- MAP_TO_UNSLOTH_16bit["qwen/qwen3-8b-fp8"], "unsloth/Qwen3-8B-FP8"
- )
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py
deleted file mode 100644
index fb9a734c6f..0000000000
--- a/tests/test_model_registry.py
+++ /dev/null
@@ -1,92 +0,0 @@
-"""
-
-Test model registration methods
-Checks that model registration methods work for respective models as well as all models
-The check is performed
-- by registering the models
-- checking that the instantiated models can be found on huggingface hub by querying for the model id
-
-"""
-
-from dataclasses import dataclass
-
-import pytest
-from huggingface_hub import ModelInfo as HfModelInfo
-
-from unsloth.registry import register_models, search_models
-from unsloth.registry._deepseek import register_deepseek_models
-from unsloth.registry._gemma import register_gemma_models
-from unsloth.registry._llama import register_llama_models
-from unsloth.registry._mistral import register_mistral_models
-from unsloth.registry._phi import register_phi_models
-from unsloth.registry._qwen import register_qwen_models
-from unsloth.registry.registry import MODEL_REGISTRY, QUANT_TAG_MAP, QuantType
-from unsloth.utils.hf_hub import get_model_info
-
-MODEL_NAMES = [
- "llama",
- "qwen",
- "mistral",
- "phi",
- "gemma",
- "deepseek",
-]
-MODEL_REGISTRATION_METHODS = [
- register_llama_models,
- register_qwen_models,
- register_mistral_models,
- register_phi_models,
- register_gemma_models,
- register_deepseek_models,
-]
-
-
-@dataclass
-class ModelTestParam:
- name: str
- register_models: callable
-
-
-def _test_model_uploaded(model_ids: list[str]):
- missing_models = []
- for _id in model_ids:
- model_info: HfModelInfo = get_model_info(_id)
- if not model_info:
- missing_models.append(_id)
-
- return missing_models
-
-
-TestParams = [
- ModelTestParam(name, models)
- for name, models in zip(MODEL_NAMES, MODEL_REGISTRATION_METHODS)
-]
-
-
-# Test that model registration methods register respective models
-@pytest.mark.parametrize("model_test_param", TestParams, ids = lambda param: param.name)
-def test_model_registration(model_test_param: ModelTestParam):
- MODEL_REGISTRY.clear()
- registration_method = model_test_param.register_models
- registration_method()
- registered_models = MODEL_REGISTRY.keys()
- missing_models = _test_model_uploaded(registered_models)
- assert (
- not missing_models
- ), f"{model_test_param.name} missing following models: {missing_models}"
-
-
-def test_all_model_registration():
- register_models()
- registered_models = MODEL_REGISTRY.keys()
- missing_models = _test_model_uploaded(registered_models)
- assert not missing_models, f"Missing following models: {missing_models}"
-
-
-def test_quant_type():
- # Test that the quant_type is correctly set for model paths
- # NOTE: for models registered under org="unsloth" with QuantType.NONE aliases QuantType.UNSLOTH
- dynamic_quant_models = search_models(quant_types = [QuantType.UNSLOTH])
- assert all(m.quant_type == QuantType.UNSLOTH for m in dynamic_quant_models)
- quant_tag = QUANT_TAG_MAP[QuantType.UNSLOTH]
- assert all(quant_tag in m.model_path for m in dynamic_quant_models)
diff --git a/tests/test_raw_text.py b/tests/test_raw_text.py
deleted file mode 100644
index 9f2e8cda4e..0000000000
--- a/tests/test_raw_text.py
+++ /dev/null
@@ -1,172 +0,0 @@
-#!/usr/bin/env python3
-"""
-Minimal test for raw text training implementation.
-Tests basic functionality without heavy dependencies.
-"""
-
-import sys
-import os
-import tempfile
-from pathlib import Path
-import importlib.util
-
-
-# Mock the datasets module since it's not installed
-class MockDataset:
- def __init__(self, data_dict):
- self.data = data_dict
- self.column_names = list(data_dict.keys())
-
- def __len__(self):
- return len(next(iter(self.data.values())))
-
- def __getitem__(self, idx):
- if isinstance(idx, str):
- # Allow accessing columns by name like dataset['text']
- return self.data[idx]
- elif isinstance(idx, int):
- # Allow accessing individual rows by index
- return {key: values[idx] for key, values in self.data.items()}
- else:
- raise TypeError(f"Invalid index type: {type(idx)}")
-
- @classmethod
- def from_dict(cls, data_dict):
- return cls(data_dict)
-
-
-# Mock datasets module
-datasets_mock = type(sys)("datasets")
-datasets_mock.Dataset = MockDataset
-sys.modules["datasets"] = datasets_mock
-
-# Import the raw_text module directly to avoid unsloth/__init__.py dependencies
-current_dir = os.path.dirname(__file__)
-raw_text_path = os.path.join(
- os.path.dirname(current_dir), "unsloth", "dataprep", "raw_text.py"
-)
-
-spec = importlib.util.spec_from_file_location("raw_text", raw_text_path)
-raw_text_module = importlib.util.module_from_spec(spec)
-spec.loader.exec_module(raw_text_module)
-
-RawTextDataLoader = raw_text_module.RawTextDataLoader
-TextPreprocessor = raw_text_module.TextPreprocessor
-
-
-def test_raw_text_loader():
- """Test basic RawTextDataLoader functionality."""
-
- # Mock tokenizer for testing
- class MockTokenizer:
- def __init__(self):
- self.eos_token = ""
- self.eos_token_id = 2 # Mock EOS token ID
-
- def __call__(self, text, return_tensors = None, add_special_tokens = False):
- words = text.split()
- token_ids = list(range(len(words)))
-
- if return_tensors == "pt":
- # Mock tensor-like object
- class MockTensor:
- def __init__(self, data):
- self.data = data
-
- def __getitem__(self, idx):
- return self.data
-
- def __len__(self):
- return len(self.data)
-
- def tolist(self):
- return self.data
-
- return {"input_ids": [MockTensor(token_ids)]}
- return {"input_ids": token_ids}
-
- def decode(self, token_ids, skip_special_tokens = False):
- return " ".join([f"word_{i}" for i in token_ids])
-
- # Create test file
- test_content = "This is a test file for raw text training. " * 10
- with tempfile.NamedTemporaryFile(mode = "w", suffix = ".txt", delete = False) as f:
- f.write(test_content)
- test_file = f.name
-
- try:
- # Test loader
- tokenizer = MockTokenizer()
- loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 2)
-
- # Test loading with text output (legacy mode)
- text_dataset = loader.load_from_file(test_file, return_tokenized = False)
- assert len(text_dataset) > 0, "Should create at least one chunk"
- assert "text" in text_dataset.column_names, "Dataset should have 'text' column"
-
- # Test loading with tokenized output (new efficient mode)
- tokenized_dataset = loader.load_from_file(test_file, return_tokenized = True)
- assert len(tokenized_dataset) > 0, "Should create at least one tokenized chunk"
- assert (
- "input_ids" in tokenized_dataset.column_names
- ), "Dataset should have 'input_ids' column"
- assert (
- "attention_mask" in tokenized_dataset.column_names
- ), "Dataset should have 'attention_mask' column"
-
- # Verify tokenized data structure
- first_sample = tokenized_dataset[0]
- assert isinstance(first_sample["input_ids"], list), "input_ids should be a list"
- assert isinstance(
- first_sample["attention_mask"], list
- ), "attention_mask should be a list"
- assert len(first_sample["input_ids"]) == len(
- first_sample["attention_mask"]
- ), "input_ids and attention_mask should have same length"
-
- # Verify labels field exists (for causal LM training)
- assert (
- "labels" in tokenized_dataset.column_names
- ), "Dataset should have 'labels' column"
- assert (
- first_sample["labels"] == first_sample["input_ids"]
- ), "labels should match input_ids"
-
- # Test constructor validation
- try:
- bad_loader = RawTextDataLoader(tokenizer, chunk_size = 0, stride = 2)
- assert False, "Should raise ValueError for chunk_size=0"
- except ValueError as e:
- assert "chunk_size must be positive" in str(e)
-
- try:
- bad_loader = RawTextDataLoader(tokenizer, chunk_size = 5, stride = 10)
- assert False, "Should raise ValueError for stride >= chunk_size"
- except ValueError as e:
- assert "stride" in str(e) and "chunk_size" in str(e)
-
- # Test preprocessor
- preprocessor = TextPreprocessor()
- clean_text = preprocessor.clean_text(" messy text \n\n\n ")
- assert "messy text" in clean_text, "Should clean text properly"
-
- # Test validation
- stats = preprocessor.validate_dataset(text_dataset)
- assert stats["total_samples"] > 0, "Should count samples"
- assert "warnings" in stats, "Should include warnings"
-
- print("✅ All tests passed!")
- return True
-
- except Exception as e:
- print(f"❌ Test failed: {e}")
- return False
-
- finally:
- # Cleanup
- os.unlink(test_file)
-
-
-if __name__ == "__main__":
- success = test_raw_text_loader()
- sys.exit(0 if success else 1)
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
deleted file mode 100644
index 3ad7a4f0f8..0000000000
--- a/tests/utils/__init__.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import time
-from contextlib import contextmanager
-
-
-@contextmanager
-def timer(name):
- start = time.time()
- yield
- end = time.time()
- print(f"{name} took {end - start:.2f} seconds")
-
-
-@contextmanager
-def header_footer_context(title: str, char = "-"):
- print()
- print(f"{char}" * 50 + f" {title} " + f"{char}" * 50)
- yield
- print(f"{char}" * (100 + len(title) + 2))
- print()
diff --git a/tests/utils/aime_eval.md b/tests/utils/aime_eval.md
deleted file mode 100644
index 217e939094..0000000000
--- a/tests/utils/aime_eval.md
+++ /dev/null
@@ -1,264 +0,0 @@
-# AIME Dataset Evaluator
-
-A Python module for evaluating language models on the AIME (American Invitational Mathematics Examination) dataset. This evaluator automatically downloads and combines multiple AIME test datasets and provides comprehensive mathematical reasoning assessment.
-
-
-## Basic Usage
-
-```python
-from aime_utils import evaluate_model_aime
-
-# Simple AIME evaluation
-results = evaluate_model_aime(
- model=your_model,
- tokenizer=your_tokenizer,
- model_type="base_model",
- temperature=0.3,
- n_sampling=8,
- max_tokens=32768
-)
-
-print(f"AIME Accuracy: {results['accuracy']:.1f}%")
-print(f"Pass@8: {results['pass_at_k']:.1f}%")
-```
-
-## Advanced Usage
-
-```python
-from aime_utils import evaluate_model_aime, compare_aime_results
-
-# Evaluate multiple model configurations
-all_results = []
-
-# Base model
-base_results = evaluate_model_aime(
- model=base_model,
- tokenizer=tokenizer,
- model_type="base",
- temperature=0.3,
- n_sampling=8
-)
-all_results.append(base_results)
-
-# Fine-tuned model
-ft_results = evaluate_model_aime(
- model=finetuned_model,
- tokenizer=tokenizer,
- model_type="finetuned",
- temperature=0.3,
- n_sampling=8
-)
-all_results.append(ft_results)
-
-# Generate comprehensive comparison
-compare_aime_results(all_results)
-```
-
-## Dataset Format
-
-The evaluator automatically handles AIME dataset format with problems containing:
-
-- **Problem**: Mathematical question text
-- **Answer**: Numerical answer (0-999 range for AIME)
-- **Solution**: Step-by-step solution (when available)
-- **Source**: Original dataset identifier (test2024, test2025-I, test2025-II)
-
-```python
-# Automatic dataset download and formatting
-{
- "global_id": 0,
- "original_id": "problem_1",
- "source_dataset": "test2024",
- "problem": "Find the number of...",
- "answer": "123",
- "solution": "Step-by-step solution...",
- "prompt": [
- {"role": "system", "content": "You are a mathematical problem solver..."},
- {"role": "user", "content": "Problem: Find the number of..."}
- ]
-}
-```
-
-
-## Configuration Examples
-
-### Conservative Evaluation
-```python
-# Lower temperature for more consistent answers
-results = evaluate_model_aime(
- model=model,
- tokenizer=tokenizer,
- model_type="conservative",
- temperature=0.1,
- n_sampling=4,
- top_p=0.9
-)
-```
-
-### High-Sample Evaluation
-```python
-# More samples for better Pass@K estimation
-results = evaluate_model_aime(
- model=model,
- tokenizer=tokenizer,
- model_type="high_sample",
- temperature=0.5,
- n_sampling=16,
- max_tokens=16384
-)
-```
-
-### Memory-Optimized
-```python
-# Reduced parameters for limited resources
-results = evaluate_model_aime(
- model=model,
- tokenizer=tokenizer,
- model_type="lite",
- temperature=0.3,
- n_sampling=4,
- max_tokens=8192
-)
-```
-
-## Examples
-
-### Complete Model Pipeline Evaluation
-```python
-from aime_utils import evaluate_model_aime, compare_aime_results
-
-def evaluate_training_pipeline(base_model, finetuned_model, merged_model, tokenizer):
- """Evaluate complete training pipeline on AIME"""
-
- all_results = []
-
- # Standard evaluation configuration
- eval_config = {
- "temperature": 0.3,
- "n_sampling": 8,
- "max_tokens": 32768,
- "top_p": 0.95,
- "seed": 0
- }
-
- # Evaluate base model
- print("Evaluating base model...")
- base_results = evaluate_model_aime(
- model=base_model,
- tokenizer=tokenizer,
- model_type="base",
- **eval_config
- )
- all_results.append(base_results)
-
- # Evaluate fine-tuned model
- print("Evaluating fine-tuned model...")
- ft_results = evaluate_model_aime(
- model=finetuned_model,
- tokenizer=tokenizer,
- model_type="finetuned",
- **eval_config
- )
- all_results.append(ft_results)
-
- # Evaluate merged model
- print("Evaluating merged model...")
- merged_results = evaluate_model_aime(
- model=merged_model,
- tokenizer=tokenizer,
- model_type="merged",
- **eval_config
- )
- all_results.append(merged_results)
-
- # Generate comparison report
- compare_aime_results(all_results)
-
- return all_results
-```
-
-### Quantization Impact Analysis
-```python
-def analyze_quantization_impact(model_paths, tokenizer):
- """Analyze impact of different quantization levels"""
-
- quantization_configs = {
- "fp16": {"load_in_4bit": False, "load_in_8bit": False},
- "8bit": {"load_in_4bit": False, "load_in_8bit": True},
- "4bit": {"load_in_4bit": True, "load_in_8bit": False}
- }
-
- all_results = []
-
- for quant_name, load_config in quantization_configs.items():
- print(f"Evaluating {quant_name} quantization...")
-
- # Load model with specific quantization
- model = load_model_with_config(model_paths["merged"], **load_config)
-
- results = evaluate_model_aime(
- model=model,
- tokenizer=tokenizer,
- model_type=f"merged_{quant_name}",
- temperature=0.3,
- n_sampling=8,
- max_tokens=32768
- )
- all_results.append(results)
-
- # Cleanup
- del model
- torch.cuda.empty_cache()
-
- compare_aime_results(all_results)
- return all_results
-```
-
-## Output Format
-
-### Individual Evaluation Results
-```
-🧮 AIME EVALUATION - BASE MODEL
-Combined Dataset: test2024 + test2025-I + test2025-II
-====================================================================
-
-🎯 Overall Performance:
- Total problems: 45
- Correct answers: 12/45 (26.7%)
- Pass@8: 31.1%
-
-📈 Performance by Dataset:
- test2024: 4/15 (26.7%)
- test2025-I: 5/15 (33.3%)
- test2025-II: 3/15 (20.0%)
-
-🎖️ AIME Performance: ✅ EXCELLENT (26.7%)
-```
-
-### Comparison Report
-```
-COMPREHENSIVE AIME MODEL COMPARISON
-================================================================================
-Model Accuracy % Pass@K % Correct Total
---------------------------------------------------------------------------------
-finetuned 31.1 35.6 14 45
-base 26.7 31.1 12 45
-merged_4bit 24.4 28.9 11 45
-
-IMPROVEMENT ANALYSIS
-==================================================
-finetuned vs base:
- Accuracy improvement: +4.4%
- Pass@K improvement: +4.5%
-```
-
-## Performance Tiers
-
-The evaluator provides performance assessment based on AIME difficulty:
-
-- **🏆 EXCEPTIONAL**: ≥50% accuracy
-- **✅ EXCELLENT**: ≥30% accuracy
-- **🎯 VERY GOOD**: ≥20% accuracy
-- **⚠️ GOOD**: ≥10% accuracy
-- **📈 FAIR**: ≥5% accuracy
-- **❌ NEEDS IMPROVEMENT**: <5% accuracy
diff --git a/tests/utils/aime_eval.py b/tests/utils/aime_eval.py
deleted file mode 100644
index 131da3e50b..0000000000
--- a/tests/utils/aime_eval.py
+++ /dev/null
@@ -1,545 +0,0 @@
-"""
-AIME Dataset Evaluation Module
-
-This module provides functions to evaluate language models on the combined AIME dataset
-(test2024 + test2025-I + test2025-II).
-"""
-
-import json
-import requests
-import os
-import re
-import logging
-from typing import List, Dict, Any
-from tqdm import tqdm
-from vllm import SamplingParams
-
-
-def download_and_combine_aime_datasets(data_dir: str = "./data/aime") -> str:
- """Download all AIME datasets and combine them into a single file"""
-
- datasets = {
- "test2024": "https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2024.jsonl",
- "test2025-I": "https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2025-I.jsonl",
- "test2025-II": "https://raw.githubusercontent.com/GAIR-NLP/AIME-Preview/main/eval/data/aime/test2025-II.jsonl",
- }
-
- os.makedirs(data_dir, exist_ok = True)
- combined_filepath = os.path.join(data_dir, "aime.jsonl")
-
- # Check if combined file already exists
- if os.path.exists(combined_filepath):
- print(f"Combined AIME dataset already exists at {combined_filepath}")
- return combined_filepath
-
- print("Downloading and combining AIME datasets...")
-
- all_problems = []
- global_id = 0
-
- for dataset_name, url in datasets.items():
- print(f" Downloading {dataset_name}...")
-
- try:
- response = requests.get(url)
- response.raise_for_status()
-
- # Parse each line and add source information
- for line_num, line in enumerate(response.text.strip().split("\n")):
- if line.strip():
- try:
- data = json.loads(line)
- # Add source dataset information and global ID
- data["source_dataset"] = dataset_name
- data["original_id"] = data.get("id", line_num)
- data["global_id"] = global_id
- global_id += 1
- all_problems.append(data)
- except json.JSONDecodeError as e:
- print(
- f" Warning: Error parsing line {line_num + 1} in {dataset_name}: {e}"
- )
- continue
-
- except requests.RequestException as e:
- print(f" Error downloading {dataset_name}: {e}")
- continue
-
- # Write combined dataset
- if all_problems:
- with open(combined_filepath, "w", encoding = "utf-8") as f:
- for problem in all_problems:
- f.write(json.dumps(problem, ensure_ascii = False) + "\n")
-
- print(f"✅ Combined {len(all_problems)} problems from {len(datasets)} datasets")
- print(f" Saved to: {combined_filepath}")
-
- # Print summary by dataset
- for dataset_name in datasets.keys():
- count = sum(1 for p in all_problems if p["source_dataset"] == dataset_name)
- print(f" {dataset_name}: {count} problems")
-
- else:
- raise RuntimeError("No problems were successfully downloaded")
-
- return combined_filepath
-
-
-def load_aime_dataset(data_dir: str = "./data/aime") -> List[Dict[str, Any]]:
- """Load combined AIME dataset and format for evaluation"""
-
- # Download and combine if needed
- filepath = download_and_combine_aime_datasets(data_dir)
-
- examples = []
- with open(filepath, "r", encoding = "utf-8") as f:
- for line_num, line in enumerate(f):
- line = line.strip()
- if line:
- try:
- data = json.loads(line)
-
- # Format as expected by our evaluation
- formatted_example = {
- "global_id": data.get("global_id", line_num),
- "original_id": data.get(
- "original_id", data.get("id", line_num)
- ),
- "source_dataset": data.get("source_dataset", "unknown"),
- "problem": data["problem"],
- "answer": str(data["answer"]), # Ensure answer is string
- "solution": data.get("solution", ""),
- "url": data.get("url", ""),
- # Format as chat messages for the model
- "prompt": [
- {
- "role": "system",
- "content": "You are a mathematical problem solver. Solve the given problem step by step and provide your final answer clearly.",
- },
- {
- "role": "user",
- "content": f"Problem: {data['problem']}\n\nSolve this step by step and provide your final numerical answer.",
- },
- ],
- }
- examples.append(formatted_example)
-
- except json.JSONDecodeError as e:
- print(f"Error parsing line {line_num + 1}: {e}")
- continue
-
- print(f"Loaded {len(examples)} problems from combined AIME dataset")
-
- # Print breakdown by source
- source_counts = {}
- for example in examples:
- source = example["source_dataset"]
- source_counts[source] = source_counts.get(source, 0) + 1
-
- for source, count in source_counts.items():
- print(f" {source}: {count} problems")
-
- return examples
-
-
-def extract_aime_answer(response: str) -> str:
- """Extract numerical answer from AIME response"""
-
- # AIME answers are integers from 0-999
- # Look for patterns like "The answer is 123" or just standalone numbers
- patterns = [
- r"(?:the )?(?:final )?answer is (\d{1,3})",
- r"(?:therefore|thus|so),?\s*(?:the )?(?:final )?answer is (\d{1,3})",
- r"\\boxed\{(\d{1,3})\}",
- r"\$\\boxed\{(\d{1,3})\}\$",
- r"(?:answer|result):\s*(\d{1,3})",
- r"(?:^|\n)\s*(\d{1,3})\s*(?:\n|$)", # Standalone number
- ]
-
- response_lower = response.lower().strip()
-
- for pattern in patterns:
- matches = re.findall(pattern, response_lower, re.MULTILINE | re.IGNORECASE)
- if matches:
- # Get the last match (most likely to be final answer)
- answer = matches[-1]
- try:
- num = int(answer)
- if 0 <= num <= 999: # AIME answers are in range 0-999
- return str(num)
- except ValueError:
- continue
-
- # If no clear pattern found, try to extract any 1-3 digit number
- numbers = re.findall(r"\b(\d{1,3})\b", response)
- if numbers:
- for num_str in reversed(numbers): # Check from end
- try:
- num = int(num_str)
- if 0 <= num <= 999:
- return str(num)
- except ValueError:
- continue
-
- return ""
-
-
-def get_num_tokens(text, tokenizer_instance):
- """Count tokens in text"""
- if not text:
- return 0
- encoding = tokenizer_instance(text, return_tensors = "pt")
- return len(encoding["input_ids"][0])
-
-
-def evaluate_model_aime(
- model,
- tokenizer,
- model_type = "base",
- lora_request = None,
- temperature = 0.3,
- n_sampling = 8,
- max_tokens = 32768,
- top_p = 0.95,
- seed = 0,
-):
- """Evaluate model on combined AIME dataset with official configuration"""
-
- print(f"\n{'='*70}")
- print(f"🧮 AIME EVALUATION - {model_type.upper()} MODEL")
- print(f"Combined Dataset: test2024 + test2025-I + test2025-II")
- print(f"{'='*70}")
-
- # Load combined AIME dataset
- try:
- eval_dataset = load_aime_dataset()
- except Exception as e:
- print(f"Error loading dataset: {e}")
- return None
-
- if not eval_dataset:
- print("No examples found in dataset")
- return None
-
- # Initialize tracking variables
- records = {}
- input_tokens = []
- output_tokens = []
- correct_answers = 0
-
- # Track performance by source dataset
- source_stats = {}
- for example in eval_dataset:
- source = example["source_dataset"]
- if source not in source_stats:
- source_stats[source] = {"total": 0, "correct": 0}
- source_stats[source]["total"] += 1
-
- # Setup sampling parameters (AIME configuration)
- sampling_params = SamplingParams(
- temperature = temperature,
- top_p = top_p,
- max_tokens = max_tokens,
- n = n_sampling, # Multiple samples per question
- seed = seed,
- )
-
- print(f"\n🔧 Configuration:")
- print(f" Temperature: {temperature}")
- print(f" Samples per question: {n_sampling}")
- print(f" Max tokens: {max_tokens}")
- print(f" Top-p: {top_p}")
- print(f" Seed: {seed}")
-
- # Temporarily suppress verbose logging
- original_levels = {}
- loggers_to_suppress = [
- "vllm",
- "vllm.engine",
- "vllm.worker",
- "vllm.model_executor",
- "vllm.executor",
- "ray",
- ]
-
- for logger_name in loggers_to_suppress:
- logger = logging.getLogger(logger_name)
- original_levels[logger_name] = logger.level
- logger.setLevel(logging.WARNING)
-
- try:
- print(f"\n🚀 Evaluating {len(eval_dataset)} problems...")
-
- # Main evaluation loop
- with tqdm(
- total = len(eval_dataset), desc = "Processing AIME problems", unit = "problem"
- ) as pbar:
- for task_id, item in enumerate(eval_dataset):
- try:
- # Prepare prompt
- prompt_text = tokenizer.apply_chat_template(
- item["prompt"], add_generation_prompt = True, tokenize = False
- )
-
- input_tokens.append(get_num_tokens(prompt_text, tokenizer))
-
- # Generate multiple responses
- outputs = model.fast_generate(
- [prompt_text],
- sampling_params = sampling_params,
- lora_request = lora_request,
- use_tqdm = False,
- )[0].outputs
-
- # Process all generated responses
- responses = [output.text for output in outputs]
- extracted_answers = [
- extract_aime_answer(response) for response in responses
- ]
-
- # Calculate total output tokens
- total_output_tokens = sum(
- get_num_tokens(response, tokenizer) for response in responses
- )
- output_tokens.append(total_output_tokens)
-
- # Check if any answer is correct
- ground_truth = item["answer"]
- correct_responses = [
- ans == ground_truth for ans in extracted_answers
- ]
- is_correct = any(correct_responses)
-
- if is_correct:
- correct_answers += 1
- source_stats[item["source_dataset"]]["correct"] += 1
-
- # Store detailed record
- records[task_id] = {
- "global_id": item["global_id"],
- "original_id": item["original_id"],
- "source_dataset": item["source_dataset"],
- "problem": item["problem"],
- "ground_truth": ground_truth,
- "responses": responses,
- "extracted_answers": extracted_answers,
- "correct_responses": correct_responses,
- "is_correct": is_correct,
- "input_tokens": input_tokens[-1],
- "output_tokens": total_output_tokens,
- "n_correct": sum(correct_responses),
- "n_total": len(responses),
- "solution": item.get("solution", ""),
- "url": item.get("url", ""),
- }
-
- # Update progress
- current_accuracy = correct_answers / (task_id + 1) * 100
- pbar.set_postfix(
- {
- "accuracy": f"{current_accuracy:.1f}%",
- "correct": correct_answers,
- "total": task_id + 1,
- }
- )
- pbar.update(1)
-
- except Exception as e:
- print(f"\nError processing problem {task_id}: {str(e)}")
- records[task_id] = {
- "global_id": item.get("global_id", task_id),
- "original_id": item.get("original_id", task_id),
- "source_dataset": item.get("source_dataset", "unknown"),
- "problem": item["problem"],
- "ground_truth": item["answer"],
- "error": str(e),
- "is_correct": False,
- }
- pbar.update(1)
- continue
-
- finally:
- # Restore logging levels
- for logger_name, level in original_levels.items():
- logging.getLogger(logger_name).setLevel(level)
-
- # Calculate metrics
- total_problems = len(eval_dataset)
- accuracy = correct_answers / total_problems * 100
-
- # Calculate Pass@k (probability that at least one of k samples is correct)
- pass_at_k_scores = []
- for record in records.values():
- if "n_correct" in record and "n_total" in record:
- n_correct = record["n_correct"]
- n_total = record["n_total"]
- if n_correct > 0:
- pass_at_k_scores.append(1.0)
- else:
- pass_at_k_scores.append(0.0)
-
- pass_at_k = sum(pass_at_k_scores) / len(pass_at_k_scores) if pass_at_k_scores else 0
-
- # Calculate per-source accuracies
- source_accuracies = {}
- for source, stats in source_stats.items():
- source_accuracies[source] = (
- (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0
- )
-
- results = {
- "model_type": model_type,
- "dataset": "aime_combined",
- "total_problems": total_problems,
- "correct_answers": correct_answers,
- "accuracy": accuracy,
- "pass_at_k": pass_at_k * 100,
- "source_stats": source_stats,
- "source_accuracies": source_accuracies,
- "temperature": temperature,
- "n_sampling": n_sampling,
- "max_tokens": max_tokens,
- "top_p": top_p,
- "seed": seed,
- "avg_input_tokens": sum(input_tokens) / len(input_tokens)
- if input_tokens
- else 0,
- "avg_output_tokens": sum(output_tokens) / len(output_tokens)
- if output_tokens
- else 0,
- "max_input_tokens": max(input_tokens) if input_tokens else 0,
- "max_output_tokens": max(output_tokens) if output_tokens else 0,
- }
-
- # Save results
- filename = f"aime_eval_combined_{model_type}_t{temperature}_n{n_sampling}.json"
- with open(filename, "w", encoding = "utf-8") as f:
- json.dump({"results": results, "records": records}, f, indent = 4)
-
- # Print comprehensive summary
- print(f"\n{'='*70}")
- print(f"📊 AIME EVALUATION RESULTS - {model_type.upper()}")
- print(f"{'='*70}")
-
- print(f"\n🎯 Overall Performance:")
- print(f" Total problems: {total_problems:>6}")
- print(
- f" Correct answers: {correct_answers:>6}/{total_problems} ({accuracy:>5.1f}%)"
- )
- print(f" Pass@{n_sampling}: {pass_at_k:>10.1f}%")
-
- print(f"\n📈 Performance by Dataset:")
- for source, stats in source_stats.items():
- source_acc = source_accuracies[source]
- print(
- f" {source:>12}: {stats['correct']:>3}/{stats['total']:>3} ({source_acc:>5.1f}%)"
- )
-
- print(f"\n🔧 Configuration:")
- print(f" Temperature: {temperature}")
- print(f" Samples per problem: {n_sampling}")
- print(f" Max tokens: {max_tokens}")
- print(f" Top-p: {top_p}")
- print(f" Seed: {seed}")
-
- print(f"\n📝 Token Statistics:")
- print(f" Avg input tokens: {results['avg_input_tokens']:>10.1f}")
- print(f" Avg output tokens: {results['avg_output_tokens']:>10.1f}")
- print(f" Max input tokens: {results['max_input_tokens']:>10}")
- print(f" Max output tokens: {results['max_output_tokens']:>10}")
-
- # Performance assessment for AIME
- if accuracy >= 50:
- tier = "🏆 EXCEPTIONAL"
- elif accuracy >= 30:
- tier = "✅ EXCELLENT"
- elif accuracy >= 20:
- tier = "🎯 VERY GOOD"
- elif accuracy >= 10:
- tier = "⚠️ GOOD"
- elif accuracy >= 5:
- tier = "📈 FAIR"
- else:
- tier = "❌ NEEDS IMPROVEMENT"
-
- print(f"\n🎖️ AIME Performance: {tier} ({accuracy:.1f}%)")
- print(f"\n💾 Detailed results saved to: {filename}")
- print(f"\n{'='*70}")
-
- return results
-
-
-# Comparison functions for multiple model results
-def compare_aime_results(all_results):
- """Generate comprehensive comparison for AIME evaluation results"""
- print(f"\n{'='*80}")
- print("COMPREHENSIVE AIME MODEL COMPARISON")
- print(f"{'='*80}")
-
- # Main comparison table
- print(
- f"{'Model':<15} {'Accuracy %':<12} {'Pass@K %':<10} {'Correct':<8} {'Total':<8}"
- )
- print("-" * 80)
-
- for result in all_results:
- print(
- f"{result['model_type']:<15} "
- f"{result['accuracy']:<12.1f} "
- f"{result['pass_at_k']:<10.1f} "
- f"{result['correct_answers']:<8} "
- f"{result['total_problems']:<8}"
- )
-
- # Performance improvement analysis
- if len(all_results) > 1:
- print(f"\n{'='*50}")
- print("IMPROVEMENT ANALYSIS")
- print(f"{'='*50}")
-
- base_result = all_results[0] # Assume first is base model
-
- for i, result in enumerate(all_results[1:], 1):
- print(f"\n{result['model_type']} vs {base_result['model_type']}:")
-
- accuracy_improvement = result["accuracy"] - base_result["accuracy"]
- pass_k_improvement = result["pass_at_k"] - base_result["pass_at_k"]
-
- print(f" Accuracy improvement: {accuracy_improvement:+.1f}%")
- print(f" Pass@K improvement: {pass_k_improvement:+.1f}%")
-
- # Dataset breakdown
- print(f"\n{'='*50}")
- print("PERFORMANCE BY DATASET")
- print(f"{'='*50}")
-
- # Get all unique datasets from the first result
- if all_results and "source_accuracies" in all_results[0]:
- datasets = list(all_results[0]["source_accuracies"].keys())
-
- print(f"{'Model':<15}", end = "")
- for dataset in datasets:
- print(f"{dataset:<15}", end = "")
- print()
- print("-" * (15 + 15 * len(datasets)))
-
- for result in all_results:
- print(f"{result['model_type']:<15}", end = "")
- for dataset in datasets:
- accuracy = result["source_accuracies"].get(dataset, 0)
- print(f"{accuracy:<15.1f}", end = "")
- print()
-
- # Save comparison
- comparison_data = {
- "summary": all_results,
- "best_model": max(all_results, key = lambda x: x["accuracy"]),
- }
-
- with open("aime_model_comparison.json", "w") as f:
- json.dump(comparison_data, f, indent = 4)
-
- print(
- f"\nBest performing model: {comparison_data['best_model']['model_type']} "
- f"({comparison_data['best_model']['accuracy']:.1f}% accuracy)"
- )
diff --git a/tests/utils/cleanup_utils.py b/tests/utils/cleanup_utils.py
deleted file mode 100644
index af38528196..0000000000
--- a/tests/utils/cleanup_utils.py
+++ /dev/null
@@ -1,226 +0,0 @@
-import gc
-import logging
-import os
-import shutil
-import torch
-import sys
-import warnings
-
-
-def clear_memory(variables_to_clear = None, verbose = False, clear_all_caches = True):
- """
- Comprehensive memory clearing for persistent memory leaks.
-
- Args:
- variables_to_clear: List of variable names to clear
- verbose: Print memory status
- clear_all_caches: Clear all types of caches (recommended for memory leaks)
- """
-
- # Save current logging levels
- saved_log_levels = {}
- for name, logger in logging.Logger.manager.loggerDict.items():
- if isinstance(logger, logging.Logger):
- saved_log_levels[name] = logger.level
- root_level = logging.getLogger().level
-
- if variables_to_clear is None:
- variables_to_clear = [
- "inputs",
- "model",
- "base_model",
- "processor",
- "tokenizer",
- "base_processor",
- "base_tokenizer",
- "trainer",
- "peft_model",
- "bnb_config",
- ]
-
- # 1. Clear LRU caches FIRST (very important for memory leaks)
- if clear_all_caches:
- clear_all_lru_caches(verbose)
-
- # 2. Delete specified variables
- g = globals()
- deleted_vars = []
- for var in variables_to_clear:
- if var in g:
- del g[var]
- deleted_vars.append(var)
-
- if verbose and deleted_vars:
- print(f"Deleted variables: {deleted_vars}")
-
- # 3. Multiple garbage collection passes (important for circular references)
- for i in range(3):
- collected = gc.collect()
- if verbose and collected > 0:
- print(f"GC pass {i+1}: collected {collected} objects")
-
- # 4. CUDA cleanup
- if torch.cuda.is_available():
- # Get memory before cleanup
- if verbose:
- mem_before = torch.cuda.memory_allocated() / 1024**3
-
- torch.cuda.empty_cache()
- torch.cuda.synchronize()
-
- # Additional CUDA cleanup for persistent leaks
- if clear_all_caches:
- # Reset memory stats
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.reset_accumulated_memory_stats()
-
- # Clear JIT cache
- if hasattr(torch.jit, "_state") and hasattr(
- torch.jit._state, "_clear_class_state"
- ):
- torch.jit._state._clear_class_state()
-
- # Force another CUDA cache clear
- torch.cuda.empty_cache()
-
- # Final garbage collection
- gc.collect()
-
- if verbose:
- mem_after = torch.cuda.memory_allocated() / 1024**3
- mem_reserved = torch.cuda.memory_reserved() / 1024**3
- print(
- f"GPU memory - Before: {mem_before:.2f} GB, After: {mem_after:.2f} GB"
- )
- print(f"GPU reserved memory: {mem_reserved:.2f} GB")
- if mem_before > 0:
- print(f"Memory freed: {mem_before - mem_after:.2f} GB")
-
- # restore original logging levels
- logging.getLogger().setLevel(root_level)
- for name, level in saved_log_levels.items():
- if name in logging.Logger.manager.loggerDict:
- logger = logging.getLogger(name)
- logger.setLevel(level)
-
-
-def clear_all_lru_caches(verbose = True):
- """Clear all LRU caches in loaded modules."""
- cleared_caches = []
-
- # Modules to skip to avoid warnings
- skip_modules = {
- "torch.distributed",
- "torchaudio",
- "torch._C",
- "torch.distributed.reduce_op",
- "torchaudio.backend",
- }
-
- # Create a static list of modules to avoid RuntimeError
- modules = list(sys.modules.items())
-
- # Method 1: Clear caches in all loaded modules
- for module_name, module in modules:
- if module is None:
- continue
-
- # Skip problematic modules
- if any(module_name.startswith(skip) for skip in skip_modules):
- continue
-
- try:
- # Look for functions with lru_cache
- for attr_name in dir(module):
- try:
- # Suppress warnings when checking attributes
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", FutureWarning)
- warnings.simplefilter("ignore", UserWarning)
- warnings.simplefilter("ignore", DeprecationWarning)
-
- attr = getattr(module, attr_name)
- if hasattr(attr, "cache_clear"):
- attr.cache_clear()
- cleared_caches.append(f"{module_name}.{attr_name}")
- except Exception:
- continue # Skip problematic attributes
- except Exception:
- continue # Skip problematic modules
-
- # Method 2: Clear specific known caches
- known_caches = [
- "transformers.utils.hub.cached_file",
- "transformers.tokenization_utils_base.get_tokenizer",
- "torch._dynamo.utils.counters",
- ]
-
- for cache_path in known_caches:
- try:
- parts = cache_path.split(".")
- module = sys.modules.get(parts[0])
- if module:
- obj = module
- for part in parts[1:]:
- obj = getattr(obj, part, None)
- if obj is None:
- break
- if obj and hasattr(obj, "cache_clear"):
- obj.cache_clear()
- cleared_caches.append(cache_path)
- except Exception:
- continue # Skip problematic caches
-
- if verbose and cleared_caches:
- print(f"Cleared {len(cleared_caches)} LRU caches")
-
-
-def clear_specific_lru_cache(func):
- """Clear cache for a specific function."""
- if hasattr(func, "cache_clear"):
- func.cache_clear()
- return True
- return False
-
-
-# Additional utility for monitoring cache sizes
-def monitor_cache_sizes():
- """Monitor LRU cache sizes across modules."""
- cache_info = []
-
- for module_name, module in sys.modules.items():
- if module is None:
- continue
- try:
- for attr_name in dir(module):
- try:
- attr = getattr(module, attr_name)
- if hasattr(attr, "cache_info"):
- info = attr.cache_info()
- cache_info.append(
- {
- "function": f"{module_name}.{attr_name}",
- "size": info.currsize,
- "hits": info.hits,
- "misses": info.misses,
- }
- )
- except:
- pass
- except:
- pass
-
- return sorted(cache_info, key = lambda x: x["size"], reverse = True)
-
-
-def safe_remove_directory(path):
- try:
- if os.path.exists(path) and os.path.isdir(path):
- shutil.rmtree(path)
- return True
- else:
- print(f"Path {path} is not a valid directory")
- return False
- except Exception as e:
- print(f"Failed to remove directory {path}: {e}")
- return False
diff --git a/tests/utils/data_utils.py b/tests/utils/data_utils.py
deleted file mode 100644
index 9551688437..0000000000
--- a/tests/utils/data_utils.py
+++ /dev/null
@@ -1,153 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-from datasets import Dataset
-
-QUESTION = "What day was I born?"
-ANSWER = "January 1, 2058"
-USER_MESSAGE = {"role": "user", "content": QUESTION}
-ASSISTANT_MESSAGE = {"role": "assistant", "content": ANSWER}
-DTYPE = torch.bfloat16
-DEFAULT_MESSAGES = [[USER_MESSAGE, ASSISTANT_MESSAGE]]
-
-
-def create_instruction_dataset(messages: list[dict] = DEFAULT_MESSAGES):
- dataset = Dataset.from_dict({"messages": messages})
- return dataset
-
-
-def create_dataset(tokenizer, num_examples: int = None, messages: list[dict] = None):
- dataset = create_instruction_dataset(messages)
-
- def _apply_chat_template(example):
- chat = tokenizer.apply_chat_template(example["messages"], tokenize = False)
- return {"text": chat}
-
- dataset = dataset.map(_apply_chat_template, remove_columns = "messages")
- if num_examples is not None:
- if len(dataset) < num_examples:
- num_repeats = num_examples // len(dataset) + 1
- dataset = dataset.repeat(num_repeats)
- dataset = dataset.select(range(num_examples))
-
- return dataset
-
-
-def describe_param(
- param: torch.Tensor,
- include_l1: bool = False,
- include_l2: bool = False,
- include_infinity: bool = False,
- as_str: bool = True,
-) -> dict:
- """
- Provide a statistical summary of a 2D weight matrix or tensor.
- If as_str is True, the summary is returned as a formatted string.
- Parameters:
- param: torch.Tensor
- include_l1 (bool): Whether to include the L1 norm (sum of absolute values).
- include_l2 (bool): Whether to include the L2 norm (Frobenius norm).
- include_infinity (bool): Whether to include the infinity norm (max absolute value).
- as_str (bool): Whether to return the summary as a formatted string.
-
- Returns:
- dict: A dictionary with the following statistics:
- - shape: Dimensions of the matrix.
- - mean: Average value.
- - median: Median value.
- - std: Standard deviation.
- - min: Minimum value.
- - max: Maximum value.
- - percentile_25: 25th percentile.
- - percentile_75: 75th percentile.
- Additionally, if enabled:
- - L1_norm: Sum of absolute values.
- - L2_norm: Euclidean (Frobenius) norm.
- - infinity_norm: Maximum absolute value.
- """
-
- param = param.float()
- summary = {
- "shape": param.shape,
- "mean": param.mean().cpu().item(),
- "std": param.std().cpu().item(),
- "min": param.min().cpu().item(),
- "max": param.max().cpu().item(),
- "percentile_25": param.quantile(0.25).cpu().item(),
- "percentile_50": param.quantile(0.5).cpu().item(),
- "percentile_75": param.quantile(0.75).cpu().item(),
- }
-
- if include_l1:
- summary["L1_norm"] = param.abs().sum().cpu().item()
- if include_l2:
- summary["L2_norm"] = param.norm().cpu().item()
- if include_infinity:
- summary["infinity_norm"] = param.abs().max().cpu().item()
-
- return format_summary(summary) if as_str else summary
-
-
-def format_summary(stats: dict, precision: int = 6) -> str:
- """
- Format the statistical summary dictionary for printing.
-
- Parameters:
- stats (dict): The dictionary returned by describe_param.
- precision (int): Number of decimal places for floating point numbers.
-
- Returns:
- str: A formatted string representing the summary.
- """
- lines = []
- for key, value in stats.items():
- if isinstance(value, float):
- formatted_value = f"{value:.{precision}f}"
- elif isinstance(value, (tuple, list)):
- # Format each element in tuples or lists (e.g., the shape)
- formatted_value = ", ".join(str(v) for v in value)
- formatted_value = (
- f"({formatted_value})"
- if isinstance(value, tuple)
- else f"[{formatted_value}]"
- )
- else:
- formatted_value = str(value)
- lines.append(f"{key}: {formatted_value}")
- return "\n".join(lines)
-
-
-def get_peft_weights(model):
- # ruff: noqa
- is_lora_weight = lambda name: any(s in name for s in ["lora_A", "lora_B"])
- return {
- name: param for name, param in model.named_parameters() if is_lora_weight(name)
- }
-
-
-def describe_peft_weights(model):
- for name, param in get_peft_weights(model).items():
- yield name, describe_param(param, as_str = True)
-
-
-def check_responses(responses: list[str], answer: str, prompt: str = None) -> bool:
- for i, response in enumerate(responses, start = 1):
- if answer in response:
- print(f"\u2713 response {i} contains answer")
- else:
- print(f"\u2717 response {i} does not contain answer")
- if prompt is not None:
- response = response.replace(prompt, "")
- print(f" -> response: {response}")
diff --git a/tests/utils/hf_utils.py b/tests/utils/hf_utils.py
deleted file mode 100644
index 8ad6d5ad08..0000000000
--- a/tests/utils/hf_utils.py
+++ /dev/null
@@ -1,291 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from contextlib import contextmanager, nullcontext
-from typing import Callable, Optional
-
-import bitsandbytes as bnb
-import torch
-from bitsandbytes.functional import dequantize_4bit
-from peft import get_peft_model, prepare_model_for_kbit_training
-from peft.tuners.lora import LoraConfig, LoraLayer
-from transformers import (
- AutoModelForCausalLM,
- AutoTokenizer,
- BitsAndBytesConfig,
-)
-from transformers.trainer_callback import (
- TrainerCallback,
- TrainerControl,
- TrainerState,
- TrainingArguments,
-)
-from trl import SFTTrainer
-
-
-class PeftWeightCallback(TrainerCallback):
- def on_log(
- self,
- args: TrainingArguments,
- state: TrainerState,
- control: TrainerControl,
- logs,
- **kwargs,
- ):
- print(f"DEBUG::CALLBACK::on_log::{state.log_history}")
-
- def on_train_begin(
- self,
- args: TrainingArguments,
- state: TrainerState,
- control: TrainerControl,
- **kwargs,
- ):
- model = kwargs.get("model")
- assert model is not None
- print(f"DEBUG::CALLBACK::on_train_begin::{kwargs.keys()}")
-
- def on_step_end(
- self,
- args: TrainingArguments,
- state: TrainerState,
- control: TrainerControl,
- **kwargs,
- ):
- print(f"DEBUG::CALLBACK::on_step_end::{state.global_step}")
-
-
-@torch.inference_mode()
-def generate_responses(
- model,
- tokenizer,
- prompt,
- max_new_tokens: int = 100,
- temperature: float = 0.8,
- do_sample: bool = True,
- num_generations: int = 1,
- skip_special_tokens: bool = True,
- dtype: torch.dtype = None,
-):
- inputs = [tokenizer(prompt, return_tensors = "pt") for _ in range(num_generations)]
- keys = inputs[0].keys()
- batched_inputs = {
- key: torch.cat([input[key] for input in inputs], dim = 0).to(model.device)
- for key in keys
- }
-
- if dtype is not None:
- inference_context = torch.autocast(device_type = "cuda", dtype = dtype)
- else:
- inference_context = nullcontext()
-
- with inference_context:
- outputs = model.generate(
- **batched_inputs,
- max_new_tokens = max_new_tokens,
- do_sample = do_sample,
- temperature = temperature,
- )
-
- responses = tokenizer.batch_decode(outputs, skip_special_tokens = skip_special_tokens)
- return responses
-
-
-def sample_responses(
- model,
- tokenizer,
- prompt,
- temperature: float = 0.8,
- num_generations: int = 1,
- max_new_tokens: int = 100,
- skip_special_tokens: bool = True,
- dtype: torch.dtype = None,
-):
- responses = generate_responses(
- model,
- tokenizer,
- prompt,
- temperature = temperature,
- num_generations = num_generations,
- max_new_tokens = max_new_tokens,
- skip_special_tokens = skip_special_tokens,
- dtype = dtype,
- )
- return responses
-
-
-def setup_tokenizer(model_name, fixup_funcs: list[Callable] = []):
- tokenizer = AutoTokenizer.from_pretrained(model_name)
- for fixup_func in fixup_funcs:
- tokenizer = fixup_func(tokenizer)
- return tokenizer
-
-
-def setup_model(
- model_name,
- quantize: bool = True,
- dtype = torch.bfloat16,
- peft_config = None,
- autocast_adapter: bool = True,
-):
- if quantize:
- bnb_config = BitsAndBytesConfig(
- load_in_4bit = True,
- bnb_4bit_use_double_quant = True,
- bnb_4bit_quant_type = "nf4",
- bnb_4bit_compute_dtype = dtype,
- )
- else:
- bnb_config = None
-
- model = AutoModelForCausalLM.from_pretrained(
- model_name,
- device_map = "cuda:0",
- attn_implementation = "sdpa",
- quantization_config = bnb_config,
- torch_dtype = dtype,
- )
- model = prepare_model_for_kbit_training(model) if quantize else model
-
- if peft_config is not None:
- model = get_peft_model(
- model, peft_config, autocast_adapter_dtype = autocast_adapter
- )
-
- return model
-
-
-def get_peft_config(
- lora_rank,
- lora_alpha = None,
- lora_dropout = 0.0,
- bias = "none",
- target_modules = "all-linear",
-):
- lora_alpha = lora_alpha or 2 * lora_rank
- peft_config = LoraConfig(
- lora_alpha = lora_alpha,
- lora_dropout = lora_dropout,
- r = lora_rank,
- bias = bias,
- target_modules = target_modules,
- task_type = "CAUSAL_LM",
- )
- return peft_config
-
-
-def setup_trainer(
- model,
- tokenizer,
- dataset,
- train_args,
- peft_config = None,
- formatting_func = None,
- collator = None,
-):
- return SFTTrainer(
- model = model,
- peft_config = peft_config,
- train_dataset = dataset,
- processing_class = tokenizer,
- formatting_func = formatting_func,
- data_collator = collator,
- args = train_args,
- )
-
-
-def setup_lora(
- model,
- tokenizer,
- dataset,
- peft_config,
- train_args,
- formatting_func = None,
- collator = None,
-):
- return LoraConfig(
- model = model,
- peft_config = peft_config,
- train_dataset = dataset,
- processing_class = tokenizer,
- formatting_func = formatting_func,
- data_collator = collator,
- args = train_args,
- )
-
-
-def convert_weights_back_to_dtype(model, dtype):
- """
- SFTTrainer calls get_peft_model and prepare_model_for_kbit_training which converts all weights to float32.
- This function converts the non-loraweights back to the original dtype.
- """
- for name, param in model.named_parameters():
- if any(s in name for s in ["norm", "embed"]):
- param.data = param.data.to(dtype)
-
-
-def fix_llama3_tokenizer(tokenizer, padding_side = "right"):
- tokenizer.padding_side = padding_side
- added_vocab = tokenizer.get_added_vocab()
- pad_token = [w for w in added_vocab if "pad" in w]
- assert len(pad_token) == 1
- tokenizer.pad_token = pad_token[0] # Load dataset from the hub
- return tokenizer
-
-
-def replace_module(
- module: torch.nn.Module,
- target_module_type: torch.nn.Module,
- conversion_func: Callable,
-):
- for child_name, child_module in module.named_children():
- if isinstance(child_module, target_module_type):
- new_module = conversion_func(child_module)
- setattr(module, child_name, new_module)
- else:
- replace_module(child_module, target_module_type, conversion_func)
-
-
-def _convert_lora_to_linear(module: LoraLayer, adapter_name: str = "default"):
- base_layer = module.get_base_layer()
- weight = base_layer.weight
-
- assert isinstance(weight, bnb.nn.Params4bit)
- quant_state = weight.quant_state
- original_dtype = quant_state.dtype
-
- w_dq = dequantize_4bit(weight.data, quant_state).float()
- lora_delta = (
- module.lora_B[adapter_name].weight
- @ module.lora_A[adapter_name].weight
- * module.scaling[adapter_name]
- )
- w_dq += lora_delta.float()
- w_dq = w_dq.to(original_dtype)
-
- new_module = torch.nn.Linear(
- w_dq.shape[1], w_dq.shape[0], bias = module.base_layer.bias is not None
- )
- new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad = False)
- if module.lora_bias[adapter_name]:
- bias_data = module.base_layer.bias.data + module.lora_B[adapter_name].bias
- new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad = False)
- return new_module
-
-
-def convert_lora_to_linear(model: torch.nn.Module):
- replace_module(model, LoraLayer, _convert_lora_to_linear)
- assert not any(isinstance(module, LoraLayer) for module in model.modules())
- return model
diff --git a/tests/utils/ocr_eval.md b/tests/utils/ocr_eval.md
deleted file mode 100644
index 97dbe6dd60..0000000000
--- a/tests/utils/ocr_eval.md
+++ /dev/null
@@ -1,109 +0,0 @@
-
-# OCR Model Evaluator
-A comprehensive Python module for evaluating Optical Character Recognition (OCR) models using Word Error Rate (WER) and Character Error Rate (CER) metrics. This evaluator supports vision-language models and provides detailed analysis with comparison capabilities across multiple models
-
-## Basic Usage
-
-```python
-from ocr_evaluator import evaluate_ocr_model
-
-# Simple evaluation
-avg_wer, avg_cer = evaluate_ocr_model(
- model=your_model,
- processor=your_processor,
- dataset=your_dataset,
- output_dir="evaluation_results"
-)
-
-print(f"Average WER: {avg_wer:.4f}")
-print(f"Average CER: {avg_cer:.4f}")
-```
-
-
-### Dataset Format
-
-The evaluator expects datasets in a chatml conversational format with the following structure:
-```
-dataset = [
- {
- "messages": [
- {
- "role": "system",
- "content": [{"type": "text", "text": "You are an OCR system."}]
- },
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Extract text from this image"},
- {"type": "image", "image": PIL_Image_object}
- ]
- },
- {
- "role": "assistant",
- "content": [{"type": "text", "text": "Ground truth text"}]
- }
- ]
- },
- # ... more samples
-]
-```
-
-
-## Examples
-
-### Document OCR evaluation
-
-```python
-from ocr_evaluator import OCRModelEvaluator
-from datasets import load_dataset
-
-# Load document OCR dataset
-dataset = load_dataset("your-ocr-dataset", split="test")
-
-# Convert to required format
-eval_data = [format_document_sample(sample) for sample in dataset]
-
-# Evaluate models
-evaluator = OCRModelEvaluator()
-
-# Compare different model configurations
-configs = {
- "Standard Model": {"temperature": 1.0, "max_new_tokens": 512},
- "Conservative Model": {"temperature": 0.7, "max_new_tokens": 256},
- "Creative Model": {"temperature": 1.5, "max_new_tokens": 1024}
-}
-
-for config_name, params in configs.items():
- wer, cer = evaluator.evaluate_model(
- model=base_model,
- processor=processor,
- dataset=eval_data,
- output_dir=f"document_ocr_{config_name.lower().replace(' ', '_')}",
- **params
- )
- evaluator.add_to_comparison(config_name, wer, cer)
-
-# Generate final report
-evaluator.print_model_comparison()
-```
-
-### Handwriting Recognition
-```python
-# Specialized evaluation for handwriting
-def evaluate_handwriting_models(models, handwriting_dataset):
- evaluator = OCRModelEvaluator()
-
- for model_name, (model, processor) in models.items():
- # Adjust parameters for handwriting recognition
- wer, cer = evaluator.evaluate_model(
- model=model,
- processor=processor,
- dataset=handwriting_dataset,
- temperature=1.2, # Slightly higher for handwriting variety
- max_new_tokens=128, # Usually shorter text
- output_dir=f"handwriting_{model_name}"
- )
- evaluator.add_to_comparison(f"Handwriting - {model_name}", wer, cer)
-
- return evaluator.print_model_comparison()
-```
diff --git a/tests/utils/ocr_eval.py b/tests/utils/ocr_eval.py
deleted file mode 100644
index 3c5cd74a22..0000000000
--- a/tests/utils/ocr_eval.py
+++ /dev/null
@@ -1,374 +0,0 @@
-"""
-OCR Model Evaluation Module
-
-This module provides functionality to evaluate OCR models on datasets with
-word error rate (WER) and character error rate (CER) metrics.
-"""
-
-import os
-import torch
-from tqdm import tqdm
-import pandas as pd
-from jiwer import wer, cer
-from qwen_vl_utils import process_vision_info
-import matplotlib.pyplot as plt
-from typing import List, Dict, Tuple, Optional, Any
-import traceback
-
-
-class OCRModelEvaluator:
- """
- A comprehensive OCR model evaluator that supports multiple models and provides
- detailed analysis with WER and CER metrics.
- """
-
- def __init__(self):
- """Initialize the OCR evaluator."""
- self.model_comparison_results = {}
-
- def evaluate_model(
- self,
- model: Any,
- processor: Any,
- dataset: List[Dict],
- output_dir: str = "ocr_evaluation_results",
- max_new_tokens: int = 1024,
- temperature: float = 1.5,
- min_p: float = 0.1,
- verbose: bool = True,
- ) -> Tuple[Optional[float], Optional[float]]:
- """
- Evaluate a model on an OCR dataset.
- """
- # Create output directory if it doesn't exist
- os.makedirs(output_dir, exist_ok = True)
-
- # Initialize results storage
- results = []
-
- # Process each sample in the dataset
- for i, sample in enumerate(
- tqdm(dataset, desc = "Evaluating OCR performance", disable = not verbose)
- ):
- try:
- # Extract components from sample
- messages = sample["messages"]
-
- # Get ground truth, image, and question
- ground_truth, image, question, input_messages = (
- self._extract_sample_components(messages, i, verbose)
- )
-
- if ground_truth is None or image is None or question is None:
- continue
-
- # Generate model response
- generated_response = self._generate_response(
- model, processor, input_messages, max_new_tokens, temperature, min_p
- )
-
- # Calculate metrics
- word_error = wer(ground_truth, generated_response)
- char_error = cer(ground_truth, generated_response)
-
- # Save individual result
- self._save_individual_result(
- output_dir,
- i,
- question,
- generated_response,
- ground_truth,
- word_error,
- char_error,
- )
-
- # Store results for summary
- results.append(
- {
- "sample_id": i,
- "wer": word_error,
- "cer": char_error,
- "model_output": generated_response.strip(),
- "ground_truth": ground_truth,
- "question": question,
- }
- )
-
- except Exception as e:
- if verbose:
- print(f"Error processing sample {i}: {str(e)}")
- traceback.print_exc()
-
- # Generate summary report
- return self._generate_summary_report(results, output_dir, verbose)
-
- def _extract_sample_components(
- self, messages: List[Dict], sample_idx: int, verbose: bool
- ) -> Tuple[Optional[str], Optional[Any], Optional[str], List[Dict]]:
- """Extract ground truth, image, question, and input messages from sample."""
-
- # Extract system message (if present)
- system_message = next(
- (msg for msg in messages if msg["role"] == "system"), None
- )
-
- # Extract user message with the image and question
- user_message = next((msg for msg in messages if msg["role"] == "user"), None)
- if not user_message:
- if verbose:
- print(f"Skipping sample {sample_idx}: No user message found")
- return None, None, None, []
-
- # Extract assistant message with ground truth
- assistant_message = next(
- (msg for msg in messages if msg["role"] == "assistant"), None
- )
- if not assistant_message:
- if verbose:
- print(
- f"Skipping sample {sample_idx}: No assistant message (ground truth) found"
- )
- return None, None, None, []
-
- # Extract ground truth text
- ground_truth = None
- for content_item in assistant_message["content"]:
- if content_item["type"] == "text":
- ground_truth = content_item["text"]
- break
-
- if not ground_truth:
- if verbose:
- print(
- f"Skipping sample {sample_idx}: No text found in assistant message"
- )
- return None, None, None, []
-
- # Extract image and question from user message
- image = None
- question = None
-
- for content_item in user_message["content"]:
- if content_item["type"] == "image":
- image = content_item["image"]
- elif content_item["type"] == "text":
- question = content_item["text"]
-
- if not image:
- if verbose:
- print(f"Skipping sample {sample_idx}: No image found in user message")
- return None, None, None, []
-
- if not question:
- if verbose:
- print(
- f"Skipping sample {sample_idx}: No question found in user message"
- )
- return None, None, None, []
-
- # Construct messages for the model input (excluding assistant message)
- input_messages = []
- if system_message:
- input_messages.append(system_message)
- input_messages.append(user_message)
-
- return ground_truth, image, question, input_messages
-
- def _generate_response(
- self,
- model: Any,
- processor: Any,
- input_messages: List[Dict],
- max_new_tokens: int,
- temperature: float,
- min_p: float,
- ) -> str:
- """Generate response from the model."""
-
- # Preparation for inference using Qwen's specific processing
- text = processor.apply_chat_template(
- input_messages, tokenize = False, add_generation_prompt = True
- )
-
- # Process vision info (images/videos) from messages
- image_inputs, video_inputs = process_vision_info(input_messages)
-
- # Create model inputs
- inputs = processor(
- text = [text],
- images = image_inputs,
- videos = video_inputs,
- padding = True,
- return_tensors = "pt",
- )
- inputs = inputs.to(model.device)
-
- # Generate response
- with torch.no_grad():
- generated_ids = model.generate(
- **inputs,
- max_new_tokens = max_new_tokens,
- temperature = temperature,
- min_p = min_p,
- use_cache = True,
- )
-
- # Extract only the generated part (not the input)
- generated_ids_trimmed = [
- out_ids[len(in_ids) :]
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
- ]
-
- # Decode the generated text
- generated_response = processor.batch_decode(
- generated_ids_trimmed,
- skip_special_tokens = True,
- clean_up_tokenization_spaces = False,
- )[0]
-
- return generated_response
-
- def _save_individual_result(
- self,
- output_dir: str,
- sample_idx: int,
- question: str,
- generated_response: str,
- ground_truth: str,
- word_error: float,
- char_error: float,
- ):
- """Save individual sample result to file."""
- output_file = os.path.join(output_dir, f"sample_{sample_idx}.txt")
- with open(output_file, "w", encoding = "utf-8") as f:
- f.write(f"Sample {sample_idx}\n")
- f.write(f"Question: {question}\n\n")
- f.write(f"Model output:\n{generated_response.strip()}\n\n")
- f.write(f"Ground truth:\n{ground_truth}\n\n")
- f.write(f"WER: {word_error:.4f}, CER: {char_error:.4f}")
-
- def _generate_summary_report(
- self, results: List[Dict], output_dir: str, verbose: bool
- ) -> Tuple[Optional[float], Optional[float]]:
- """Generate and save summary report."""
- if not results:
- if verbose:
- print("No results to summarize.")
- return None, None
-
- df = pd.DataFrame(results)
-
- # Calculate overall averages
- avg_wer = df["wer"].mean()
- avg_cer = df["cer"].mean()
-
- # Save average metrics
- with open(os.path.join(output_dir, "avg_metrics.txt"), "w") as f:
- f.write(f"Average WER: {avg_wer:.4f}\n")
- f.write(f"Average CER: {avg_cer:.4f}\n")
-
- # Save detailed results
- df.to_csv(os.path.join(output_dir, "detailed_results.csv"), index = False)
-
- if verbose:
- print("\nResults Summary:")
- print(f"Average WER: {avg_wer:.4f}")
- print(f"Average CER: {avg_cer:.4f}")
- print(f"\nDetailed results saved to {output_dir}/")
-
- return avg_wer, avg_cer
-
- def add_to_comparison(self, model_name: str, wer: float, cer: float):
- """Add model results to the comparison tracker."""
- self.model_comparison_results[model_name] = {"wer": wer, "cer": cer}
-
- def print_model_comparison(
- self, save_csv: bool = True, save_plot: bool = True
- ) -> Optional[pd.DataFrame]:
- """Print a comparison of all models evaluated so far."""
- if not self.model_comparison_results:
- print("No model results available for comparison")
- return None
-
- print("\n==== MODEL COMPARISON REPORT ====")
-
- # Create a comparison dataframe
- comparison_df = pd.DataFrame(
- {
- "Model": list(self.model_comparison_results.keys()),
- "WER": [
- results["wer"] for results in self.model_comparison_results.values()
- ],
- "CER": [
- results["cer"] for results in self.model_comparison_results.values()
- ],
- }
- )
-
- # Sort by WER (best performance first)
- comparison_df = comparison_df.sort_values("WER")
-
- # Display the comparison table
- print("\nComparison Table (sorted by WER):")
- print(comparison_df.to_string(index = False))
-
- # Save the comparison table
- if save_csv:
- comparison_file = "model_comparison_results.csv"
- comparison_df.to_csv(comparison_file, index = False)
- print(f"\nComparison table saved to {comparison_file}")
-
- # Generate a bar chart visualization
- if save_plot:
- self._create_comparison_plot(comparison_df)
-
- return comparison_df
-
- def _create_comparison_plot(self, comparison_df: pd.DataFrame):
- """Create and save comparison plot."""
- plt.figure(figsize = (12, 6))
-
- # Plot WER
- plt.subplot(1, 2, 1)
- plt.bar(comparison_df["Model"], comparison_df["WER"], color = "skyblue")
- plt.title("Word Error Rate Comparison")
- plt.ylabel("WER (lower is better)")
- plt.ylim(bottom = 0)
- plt.xticks(rotation = 45, ha = "right")
-
- # Plot CER
- plt.subplot(1, 2, 2)
- plt.bar(comparison_df["Model"], comparison_df["CER"], color = "lightgreen")
- plt.title("Character Error Rate Comparison")
- plt.ylabel("CER (lower is better)")
- plt.ylim(bottom = 0)
- plt.xticks(rotation = 45, ha = "right")
-
- plt.tight_layout()
- plt.savefig("ocr_model_comparison.png")
- plt.show()
-
- print(f"\nVisualization saved to ocr_model_comparison.png")
-
- def get_comparison_results(self) -> Dict[str, Dict[str, float]]:
- """Get the current comparison results."""
- return self.model_comparison_results.copy()
-
- def clear_comparison_results(self):
- """Clear all comparison results."""
- self.model_comparison_results.clear()
-
-
-def evaluate_ocr_model(
- model, processor, dataset, output_dir = "ocr_evaluation_results", **kwargs
-):
- """
- Convenience function that maintains backward compatibility with the original function.
- """
- evaluator = OCRModelEvaluator()
- return evaluator.evaluate_model(model, processor, dataset, output_dir, **kwargs)
-
-
-def create_evaluator():
- """Create a new OCR evaluator instance."""
- return OCRModelEvaluator()
diff --git a/tests/utils/os_utils.py b/tests/utils/os_utils.py
deleted file mode 100644
index 448f13b8a0..0000000000
--- a/tests/utils/os_utils.py
+++ /dev/null
@@ -1,128 +0,0 @@
-import subprocess
-import sys
-import os
-import shutil
-import importlib
-
-
-def detect_package_manager():
- """Detect the available package manager"""
- package_managers = {
- "apt": "/usr/bin/apt",
- "yum": "/usr/bin/yum",
- "dnf": "/usr/bin/dnf",
- "pacman": "/usr/bin/pacman",
- "zypper": "/usr/bin/zypper",
- }
-
- for pm, path in package_managers.items():
- if os.path.exists(path):
- return pm
- return None
-
-
-def check_package_installed(package_name, package_manager = None):
- """Check if a package is installed using the system package manager"""
-
- if package_manager is None:
- package_manager = detect_package_manager()
-
- if package_manager is None:
- print("Warning: Could not detect package manager")
- return None
-
- try:
- if package_manager == "apt":
- # Check with dpkg
- result = subprocess.run(
- ["dpkg", "-l", package_name], capture_output = True, text = True
- )
- return result.returncode == 0
-
- elif package_manager in ["yum", "dnf"]:
- # Check with rpm
- result = subprocess.run(
- ["rpm", "-q", package_name], capture_output = True, text = True
- )
- return result.returncode == 0
-
- elif package_manager == "pacman":
- result = subprocess.run(
- ["pacman", "-Q", package_name], capture_output = True, text = True
- )
- return result.returncode == 0
-
- elif package_manager == "zypper":
- result = subprocess.run(
- ["zypper", "se", "-i", package_name], capture_output = True, text = True
- )
- return package_name in result.stdout
-
- except Exception as e:
- print(f"Error checking package: {e}")
- return None
-
-
-def require_package(package_name, executable_name = None):
- """Require a package to be installed, exit if not found"""
-
- # First check if executable is in PATH (most reliable)
- if executable_name:
- if shutil.which(executable_name):
- print(f"✓ {executable_name} is available")
- return
-
- # Then check with package manager
- pm = detect_package_manager()
- is_installed = check_package_installed(package_name, pm)
-
- if is_installed:
- print(f"✓ Package {package_name} is installed")
- return
-
- # Package not found - show installation instructions
- print(f"❌ Error: {package_name} is not installed")
- print(f"\nPlease install {package_name} using your system package manager:")
-
- install_commands = {
- "apt": f"sudo apt update && sudo apt install {package_name}",
- "yum": f"sudo yum install {package_name}",
- "dnf": f"sudo dnf install {package_name}",
- "pacman": f"sudo pacman -S {package_name}",
- "zypper": f"sudo zypper install {package_name}",
- }
-
- if pm and pm in install_commands:
- print(f" {install_commands[pm]}")
- else:
- for pm_name, cmd in install_commands.items():
- print(f" {pm_name}: {cmd}")
-
- print(f"\nAlternatively, install with conda:")
- print(f" conda install -c conda-forge {package_name}")
-
- print(f"\nPlease install the required package and run the script again.")
- sys.exit(1)
-
-
-# Usage
-# require_package("ffmpeg", "ffmpeg")
-
-
-def require_python_package(package_name, import_name = None, pip_name = None):
- """Require a Python package to be installed, exit if not found"""
- if import_name is None:
- import_name = package_name
- if pip_name is None:
- pip_name = package_name
-
- if importlib.util.find_spec(import_name) is None:
- print(f"❌ Error: Python package '{package_name}' is not installed")
- print(f"\nPlease install {package_name} using pip:")
- print(f" pip install {pip_name}")
- print(f" # or with conda:")
- print(f" conda install {pip_name}")
- print(f"\nAfter installation, run this script again.")
- sys.exit(1)
- else:
- print(f"✓ Python package '{package_name}' is installed")
diff --git a/tests/utils/perplexity_eval.md b/tests/utils/perplexity_eval.md
deleted file mode 100644
index df86b558ed..0000000000
--- a/tests/utils/perplexity_eval.md
+++ /dev/null
@@ -1,20 +0,0 @@
-# Language Model Perplexity Evaluator
-
-A Python module for evaluating language models using perplexity metrics with sliding window approach for long sequences. This evaluator provides efficient computation of perplexity scores across datasets with model comparison capabilities.
-
-## Basic Usage
-
-```python
-from perplexity_evaluator import ppl_model, add_to_comparison, print_model_comparison
-
-# Simple perplexity evaluation
-dataset = {"text": ["Your text samples here...", "Another text sample..."]}
-perplexity = ppl_model(model, tokenizer, dataset)
-
-print(f"Model Perplexity: {perplexity:.4f}")
-
-# Add to comparison tracker
-add_to_comparison("My Model", perplexity)
-print_model_comparison()
-```
-
diff --git a/tests/utils/perplexity_eval.py b/tests/utils/perplexity_eval.py
deleted file mode 100644
index cf625fc74e..0000000000
--- a/tests/utils/perplexity_eval.py
+++ /dev/null
@@ -1,81 +0,0 @@
-from tqdm import tqdm
-import torch
-import pandas as pd
-
-model_comparison_results = {}
-# return the perplexity of the model on the dataset
-# The perplexity is computed on each example, individually, with a sliding window for examples longer than 512 tokens.
-
-
-def ppl_model(model, tokenizer, dataset):
- nlls = []
- max_length = 2048
- stride = 512
- for s in tqdm(range(len(dataset["text"]))):
- encodings = tokenizer(dataset["text"][s], return_tensors = "pt")
- seq_len = encodings.input_ids.size(1)
- prev_end_loc = 0
- for begin_loc in range(0, seq_len, stride):
- end_loc = min(begin_loc + max_length, seq_len)
- trg_len = end_loc - prev_end_loc
- input_ids = encodings.input_ids[:, begin_loc:end_loc].to("cuda")
- target_ids = input_ids.clone()
- target_ids[:, :-trg_len] = -100
- # Create attention mask based on pad token id
- pad_token_id = (
- tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
- )
- attention_mask = (input_ids != pad_token_id).long()
- with torch.no_grad():
- outputs = model(
- input_ids, labels = target_ids, attention_mask = attention_mask
- )
- neg_log_likelihood = outputs.loss
- nlls.append(neg_log_likelihood)
- prev_end_loc = end_loc
- if end_loc == seq_len:
- break
- ppl = torch.exp(torch.stack(nlls).mean())
- return ppl
-
-
-# --------------------------------------------------------------------
-
-
-## ----------- Reporting helper function ----------- ##
-
-
-# Create a simple function to add results to the comparison
-def add_to_comparison(model_name, ppl):
- """Add model results to the comparison tracker"""
- model_comparison_results[model_name] = {"ppl": ppl}
- # return model_comparison_results
-
-
-# Create a function to print the comparison report whenever needed
-def print_model_comparison():
- """Print a comparison of all models evaluated so far"""
- if not model_comparison_results:
- print("No model results available for comparison")
- return
-
- print("\n==== MODEL COMPARISON REPORT ====")
-
- # Create a comparison dataframe
- comparison_df = pd.DataFrame(
- {
- "Model": list(model_comparison_results.keys()),
- # "Perplexity": [results["ppl"] for results in model_comparison_results.values()],
- "Perplexity": [
- # Convert tensors to CPU and then to float if needed
- results["ppl"].cpu().item()
- if torch.is_tensor(results["ppl"])
- else results["ppl"]
- for results in model_comparison_results.values()
- ],
- }
- )
-
- # Display the comparison table
- print("\nComparison Table:")
- print(comparison_df.to_string(index = False))
diff --git a/tests/utils/test_attention_masks.py b/tests/utils/test_attention_masks.py
deleted file mode 100644
index dbd5419617..0000000000
--- a/tests/utils/test_attention_masks.py
+++ /dev/null
@@ -1,272 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-"""Unit tests for packed-attention mask helpers with sliding-window logic."""
-
-import math
-
-import torch
-
-from unsloth.utils import attention_dispatch
-from unsloth.utils import packing as packing_utils
-
-
-def _make_seq_info(lengths):
- lengths = torch.tensor(lengths, dtype = torch.int32)
- cu = torch.cat(
- [
- torch.zeros(1, dtype = torch.int32),
- torch.cumsum(lengths, dim = 0, dtype = torch.int32),
- ]
- )
- max_len = int(lengths.max().item())
- return lengths, cu, max_len
-
-
-def test_sdpa_packed_attention_mask_sliding_window():
- seq_info = _make_seq_info([5, 3])
- mask = packing_utils.build_sdpa_packed_attention_mask(
- seq_info,
- dtype = torch.float32,
- device = torch.device("cpu"),
- sliding_window = 3,
- )
-
- assert mask.shape == (1, 1, 8, 8)
-
- block_first = mask[0, 0, :5, :5]
- upper = torch.triu(torch.ones_like(block_first), diagonal = 1).bool()
- assert torch.all(block_first[upper] == float("-inf"))
- assert block_first[3, 0].item() == float("-inf")
- assert block_first[4, 1].item() == float("-inf")
- assert block_first[4, 2].item() > -math.inf
- assert mask[0, 0, 0, 6].item() == float("-inf")
-
-
-def test_xformers_block_mask_sliding_window(monkeypatch):
- class _FakeMask:
- def __init__(self, lengths, window = None):
- self.lengths = lengths
- self.window = window
-
- @classmethod
- def from_seqlens(cls, lengths):
- return cls(tuple(lengths))
-
- def make_local_attention(self, window_size):
- return _FakeMask(self.lengths, window = window_size)
-
- monkeypatch.setattr(packing_utils, "_XFormersBlockMask", _FakeMask, raising = False)
-
- seq_info = _make_seq_info([4, 4])
- mask = packing_utils.build_xformers_block_causal_mask(
- seq_info,
- sliding_window = 2,
- )
-
- assert isinstance(mask, _FakeMask)
- assert mask.window == 2
-
-
-def test_run_attention_sdpa_passes_sliding_window(monkeypatch):
- seq_info = _make_seq_info([3, 2])
- sliding_window = 2
-
- original_builder = attention_dispatch.build_sdpa_packed_attention_mask
- captured = {}
-
- def _capture_builder(seq_info_arg, *, dtype, device, sliding_window = None):
- captured["window"] = sliding_window
- return original_builder(
- seq_info_arg,
- dtype = dtype,
- device = device,
- sliding_window = sliding_window,
- )
-
- monkeypatch.setattr(
- attention_dispatch,
- "build_sdpa_packed_attention_mask",
- _capture_builder,
- )
-
- def _fake_sdpa(Q, K, V, **kwargs):
- captured["mask"] = kwargs.get("attn_mask")
- return torch.zeros_like(Q)
-
- monkeypatch.setattr(attention_dispatch, "scaled_dot_product_attention", _fake_sdpa)
-
- config = attention_dispatch.AttentionConfig(
- backend = attention_dispatch.SDPA,
- n_kv_heads = 1,
- n_groups = 1,
- )
-
- context = attention_dispatch.AttentionContext(
- bsz = 1,
- q_len = 5,
- kv_seq_len = 5,
- n_heads = 1,
- head_dim = 1,
- requires_grad = False,
- seq_info = seq_info,
- attention_mask = None,
- causal_mask = None,
- sliding_window = sliding_window,
- )
-
- Q = torch.zeros(1, 1, 5, 1)
- K = torch.zeros_like(Q)
- V = torch.zeros_like(Q)
-
- attention_dispatch.run_attention(
- config = config,
- context = context,
- Q = Q,
- K = K,
- V = V,
- )
-
- assert captured["window"] == sliding_window
- mask = captured["mask"]
- assert mask is not None and mask.shape == (1, 1, 5, 5)
- assert mask[0, 0, 4, 1].item() == float("-inf")
-
-
-def test_run_attention_xformers_passes_sliding_window(monkeypatch):
- seq_info = _make_seq_info([4])
- sliding_window = 3
-
- class _FakeBias:
- pass
-
- captured = {}
-
- def _fake_builder(seq_info_arg, *, sliding_window = None, base_mask = None):
- captured["window"] = sliding_window
- captured["base"] = base_mask
- return _FakeBias()
-
- def _fake_attention(Q, K, V, attn_bias = None, **_):
- captured["bias"] = attn_bias
- return torch.zeros_like(Q)
-
- monkeypatch.setattr(
- attention_dispatch, "build_xformers_block_causal_mask", _fake_builder
- )
- monkeypatch.setattr(
- attention_dispatch, "xformers_attention", _fake_attention, raising = False
- )
- monkeypatch.setattr(
- attention_dispatch, "XFORMERS_BLOCK_DIAG_CLS", _FakeBias, raising = False
- )
-
- config = attention_dispatch.AttentionConfig(
- backend = attention_dispatch.XFORMERS,
- n_kv_heads = 1,
- n_groups = 1,
- )
-
- context = attention_dispatch.AttentionContext(
- bsz = 1,
- q_len = 4,
- kv_seq_len = 4,
- n_heads = 1,
- head_dim = 1,
- requires_grad = False,
- seq_info = seq_info,
- attention_mask = None,
- causal_mask = None,
- sliding_window = sliding_window,
- )
-
- Q = torch.zeros(1, 1, 4, 1)
- K = torch.zeros_like(Q)
- V = torch.zeros_like(Q)
-
- attention_dispatch.run_attention(
- config = config,
- context = context,
- Q = Q,
- K = K,
- V = V,
- )
-
- assert captured["window"] == sliding_window
- assert isinstance(captured["bias"], _FakeBias)
-
-
-def test_run_attention_flash_varlen_receives_window_and_softcap(monkeypatch):
- seq_info = _make_seq_info([4])
- sliding_window = 3
- softcap = 0.5
- window_tuple = (sliding_window, sliding_window)
-
- captured = {}
-
- def _fake_flash_varlen(Q, K, V, cu_q, cu_k, max_q, max_k, **kwargs):
- captured["kwargs"] = kwargs
- return torch.zeros_like(Q)
-
- monkeypatch.setattr(
- attention_dispatch,
- "flash_attn_varlen_func",
- _fake_flash_varlen,
- )
- monkeypatch.setattr(attention_dispatch, "HAS_FLASH_ATTENTION", True)
-
- config = attention_dispatch.AttentionConfig(
- backend = attention_dispatch.FLASH_VARLEN,
- n_kv_heads = 1,
- n_groups = 1,
- flash_varlen_kwargs = {
- "dropout_p": 0.0,
- "softmax_scale": 1.0,
- "causal": True,
- "softcap": softcap,
- "window_size": window_tuple,
- },
- )
-
- context = attention_dispatch.AttentionContext(
- bsz = 1,
- q_len = 4,
- kv_seq_len = 4,
- n_heads = 1,
- head_dim = 2,
- requires_grad = False,
- seq_info = seq_info,
- attention_mask = None,
- causal_mask = None,
- sliding_window = sliding_window,
- )
-
- Q = torch.zeros(1, 1, 4, 2)
- K = torch.zeros_like(Q)
- V = torch.zeros_like(Q)
-
- attention_dispatch.run_attention(
- config = config,
- context = context,
- Q = Q,
- K = K,
- V = V,
- )
-
- assert captured["kwargs"]["softcap"] == softcap
- assert captured["kwargs"]["window_size"] == window_tuple
-
-
-"""Unit tests for packed-attention mask helpers with sliding-window logic."""
diff --git a/tests/utils/test_packing.py b/tests/utils/test_packing.py
deleted file mode 100644
index 098f6a3667..0000000000
--- a/tests/utils/test_packing.py
+++ /dev/null
@@ -1,407 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from unsloth import FastLanguageModel
-from unsloth.utils import attention_dispatch as attention_dispatch_utils
-from unsloth.utils.packing import (
- configure_padding_free,
- configure_sample_packing,
- enable_padding_free_metadata,
- enable_sample_packing,
- mask_packed_sequence_boundaries,
-)
-
-from contextlib import ExitStack
-from types import SimpleNamespace
-from unittest.mock import patch
-
-import pytest
-import torch
-from datasets import Dataset
-from trl import SFTConfig, SFTTrainer
-from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
-
-
-def _build_packed_training_setup(tmp_path, device):
- dtype = None
- if device.type == "cuda":
- if torch.cuda.is_bf16_supported():
- dtype = torch.bfloat16
- else:
- dtype = torch.float16
-
- try:
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM",
- max_seq_length = 64,
- load_in_4bit = False,
- dtype = dtype,
- )
- except OSError as exc: # pragma: no cover - offline CI
- pytest.skip(f"Requires access to tiny llama checkpoint: {exc}")
-
- model.to(device)
-
- dataset = Dataset.from_dict(
- {
- "text": [
- "Hello world!",
- "Short sample.",
- "This is a slightly longer packed example to test batching.",
- "Another response to include in the batch.",
- ]
- }
- )
-
- training_args = SFTConfig(
- per_device_train_batch_size = 1,
- per_device_eval_batch_size = 1,
- gradient_accumulation_steps = 1,
- dataset_text_field = "text",
- max_length = 64,
- logging_steps = 1,
- max_steps = 1,
- fp16 = device.type == "cuda" and not torch.cuda.is_bf16_supported(),
- bf16 = device.type == "cuda" and torch.cuda.is_bf16_supported(),
- dataset_num_proc = 1,
- output_dir = str(tmp_path),
- packing = True,
- )
-
- trainer = SFTTrainer(
- model = model,
- processing_class = tokenizer,
- train_dataset = dataset,
- args = training_args,
- )
-
- enable_sample_packing(model, trainer)
-
- dataloader = trainer.get_train_dataloader()
- batch = next(iter(dataloader))
-
- model_device = next(model.parameters()).device
-
- for key, value in list(batch.items()):
- if torch.is_tensor(value):
- batch[key] = value.to(model_device)
-
- from unsloth.models import llama as llama_mod
-
- return model, batch, trainer, llama_mod
-
-
-def _trim_batch_to_total_tokens(data, total_tokens):
- def _trim_tensor(t: torch.Tensor):
- if t.ndim >= 2 and t.size(1) > total_tokens:
- return t[:, :total_tokens].contiguous()
- return t
-
- trimmed = {}
- for key, value in data.items():
- if torch.is_tensor(value):
- trimmed[key] = _trim_tensor(value)
- else:
- trimmed[key] = value
- return trimmed
-
-
-def test_mask_packed_sequence_boundaries_marks_single_row():
- shift_labels = torch.arange(6, dtype = torch.long).view(1, 6)
- changed = mask_packed_sequence_boundaries(
- shift_labels,
- torch.tensor([2, 1, 3], dtype = torch.int32),
- )
- assert changed is True
- flat = shift_labels.view(-1)
- assert flat[1].item() == -100
- assert flat[2].item() == -100
- assert flat[5].item() == -100
- assert flat[0].item() != -100
-
-
-def test_mask_packed_sequence_boundaries_across_multiple_rows():
- shift_labels = torch.arange(10, dtype = torch.long).view(2, 5)
- lengths = torch.tensor([3, 2, 4, 1], dtype = torch.int32)
- changed = mask_packed_sequence_boundaries(shift_labels, lengths)
- assert changed is True
- flat = shift_labels.view(-1)
- for idx in (2, 4, 8, 9):
- assert flat[idx].item() == -100
- assert torch.any(flat != -100)
-
-
-def test_configure_sample_packing():
- config = SimpleNamespace()
- configure_sample_packing(config)
-
- assert config.packing is True
- assert config.padding_free is True
- assert config.remove_unused_columns is False
-
-
-def test_configure_padding_free():
- config = SimpleNamespace(remove_unused_columns = True)
- configure_padding_free(config)
-
- assert config.padding_free is True
- assert config.remove_unused_columns is False
-
-
-class _DummyChild(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.max_seq_length = 8
-
-
-class _DummyModel(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.max_seq_length = 16
- self.child = _DummyChild()
- self.config = SimpleNamespace(_attn_implementation = "sdpa")
- self.generation_config = SimpleNamespace(attn_implementation = "sdpa")
-
-
-class _DummyTrainer:
- def __init__(self):
- self.args = SimpleNamespace(remove_unused_columns = True)
- collator_args = {
- "pad_token_id": 0,
- "completion_only_loss": False,
- "return_tensors": "pt",
- }
- optional_flags = [
- {"padding_free": True, "return_position_ids": False},
- {"padding_free": True},
- {},
- ]
- for extra in optional_flags:
- try:
- self.data_collator = DataCollatorForLanguageModeling(
- **collator_args, **extra
- )
- break
- except TypeError:
- continue
- # Ensure attributes exist even if the constructor did not accept them
- if not hasattr(self.data_collator, "padding_free"):
- self.data_collator.padding_free = True
- if not hasattr(self.data_collator, "return_position_ids"):
- self.data_collator.return_position_ids = False
-
-
-class _PaddingFreeCollator:
- def __init__(self):
- self.padding_free = True
- self.return_position_ids = False
- self.calls = 0
-
- def torch_call(self, examples):
- self.calls += 1
- return {
- "input_ids": torch.tensor([[0]], dtype = torch.long),
- "examples_seen": self.calls,
- }
-
-
-def test_enable_sample_packing():
- model = _DummyModel()
- trainer = _DummyTrainer()
-
- enable_sample_packing(model, trainer)
-
- # model hierarchy should now allow packed overlength inputs
- assert getattr(model, "_unsloth_allow_packed_overlength") is True
- assert getattr(model.child, "_unsloth_allow_packed_overlength") is True
-
- collator = trainer.data_collator
- assert collator.return_position_ids is True
- assert getattr(collator, "_unsloth_packing_wrapped") is True
-
- examples = [
- {
- "input_ids": [0, 1, 2],
- "labels": [0, 1, 2],
- "seq_lengths": [2, 1],
- },
- {
- "input_ids": [3, 4, 5],
- "labels": [3, 4, 5],
- "seq_lengths": [3],
- },
- ]
- batch = collator.torch_call(examples)
-
- # packed lengths are aggregated into a single tensor
- assert "packed_seq_lengths" in batch
- assert torch.equal(
- batch["packed_seq_lengths"],
- torch.tensor([2, 1, 3], dtype = torch.int32),
- )
-
- assert batch["input_ids"].shape == (1, 6)
- expected_positions = torch.tensor([0, 1, 0, 0, 1, 2], dtype = torch.long)
- assert torch.equal(batch["position_ids"].view(-1)[:6], expected_positions)
-
-
-def test_enable_sample_packing_trl_collator(tmp_path):
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
- model, _, trainer, _ = _build_packed_training_setup(tmp_path, device)
-
- enable_sample_packing(model, trainer)
-
- examples = [
- {
- "input_ids": [0, 1, 2],
- "labels": [0, 1, 2],
- "seq_lengths": [2, 1],
- },
- {
- "input_ids": [3, 4, 5],
- "labels": [3, 4, 5],
- "seq_lengths": [3],
- },
- ]
-
- batch = trainer.data_collator.torch_call(examples)
-
- assert batch["input_ids"].shape == (1, 6)
- assert torch.equal(
- batch["packed_seq_lengths"],
- torch.tensor([2, 1, 3], dtype = torch.int32),
- )
-
- expected_positions = torch.tensor([0, 1, 0, 0, 1, 2], dtype = torch.long)
- assert torch.equal(batch["position_ids"].view(-1)[:6], expected_positions)
-
- if hasattr(trainer, "accelerator"):
- trainer.accelerator.free_memory()
-
-
-def test_enable_padding_free_metadata():
- model = _DummyModel()
- trainer = SimpleNamespace(
- args = SimpleNamespace(remove_unused_columns = True),
- data_collator = _PaddingFreeCollator(),
- )
-
- enable_padding_free_metadata(model, trainer)
-
- assert getattr(model, "_unsloth_allow_packed_overlength") is True
- assert getattr(model.child, "_unsloth_allow_packed_overlength") is True
-
- collator = trainer.data_collator
- assert collator.return_position_ids is True
- assert getattr(collator, "_unsloth_padding_free_lengths_wrapped") is True
-
- examples = [
- {"input_ids": [0, 1, 2]},
- {"input_ids": [3, 4]},
- ]
- batch = collator.torch_call(examples)
- assert torch.equal(
- batch["packed_seq_lengths"],
- torch.tensor([3, 2], dtype = torch.int32),
- )
- assert trainer.args.remove_unused_columns is False
-
-
-def test_packing_sdpa(tmp_path):
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
- model, batch, trainer, llama_mod = _build_packed_training_setup(tmp_path, device)
-
- assert "packed_seq_lengths" in batch
- assert "attention_mask" not in batch
- assert batch["packed_seq_lengths"].dtype == torch.int32
-
- total_tokens = batch["input_ids"].size(-1)
- assert int(batch["packed_seq_lengths"].sum().item()) == total_tokens
-
- packed_tokens = int(batch["packed_seq_lengths"].sum().item())
- assert "position_ids" in batch
- flat_positions = batch["position_ids"].reshape(-1)[:packed_tokens]
- expected_positions = torch.cat(
- [
- torch.arange(length, dtype = torch.long)
- for length in batch["packed_seq_lengths"].tolist()
- ]
- )
- assert torch.equal(flat_positions.cpu(), expected_positions)
- inputs = _trim_batch_to_total_tokens(batch, packed_tokens)
-
- seq_info = llama_mod.get_packed_info_from_kwargs(
- {"packed_seq_lengths": batch["packed_seq_lengths"]},
- inputs["input_ids"].device,
- )
- assert seq_info is not None
-
- original_mask = attention_dispatch_utils.build_sdpa_packed_attention_mask
- mask_calls = []
- captured_loss_labels = {}
-
- def _capture_mask(seq_info, dtype, device, *, sliding_window = None):
- mask_calls.append(tuple(seq_info[0].tolist()))
- return original_mask(
- seq_info,
- dtype = dtype,
- device = device,
- sliding_window = sliding_window,
- )
-
- def _capture_loss(*, logits, labels, **loss_kwargs):
- captured_loss_labels["labels"] = labels.detach().to("cpu")
- return torch.zeros((), device = logits.device, dtype = logits.dtype)
-
- with ExitStack() as stack:
- stack.enter_context(
- patch.object(attention_dispatch_utils, "HAS_FLASH_ATTENTION", False)
- )
- stack.enter_context(
- patch.object(attention_dispatch_utils, "HAS_XFORMERS", False)
- )
- stack.enter_context(
- patch.object(
- attention_dispatch_utils,
- "build_sdpa_packed_attention_mask",
- side_effect = _capture_mask,
- )
- )
- stack.enter_context(
- patch.object(
- llama_mod,
- "fast_cross_entropy_loss",
- side_effect = _capture_loss,
- )
- )
- with torch.no_grad():
- outputs = model(**inputs)
-
- assert mask_calls, "SDPA packed mask was not constructed"
- assert outputs.loss is not None
- assert "labels" in captured_loss_labels
- flat_loss_labels = captured_loss_labels["labels"].reshape(-1)
- boundaries = (
- torch.cumsum(
- batch["packed_seq_lengths"].to(device = "cpu", dtype = torch.long), dim = 0
- )
- - 1
- )
- for idx in boundaries.tolist():
- assert flat_loss_labels[idx].item() == -100
- assert torch.any(flat_loss_labels != -100)
-
- if hasattr(trainer, "accelerator"):
- trainer.accelerator.free_memory()
diff --git a/tests/utils/test_qat.py b/tests/utils/test_qat.py
deleted file mode 100644
index 1083712d78..0000000000
--- a/tests/utils/test_qat.py
+++ /dev/null
@@ -1,177 +0,0 @@
-from unsloth import FastLanguageModel
-
-from typing import Dict
-
-import pytest
-import torch
-
-try:
- from torchao.quantization.qat import FakeQuantizedLinear
- from torchao.quantization.qat.fake_quantizer import (
- FakeQuantizerBase,
- Float8FakeQuantizer,
- Int4WeightFakeQuantizer,
- IntxFakeQuantizer,
- )
-except ImportError:
- print(
- "Missing torchao import, please install or upgrade torchao with: pip install 'torchao>=0.15.0'"
- )
-
-
-class _CountingFakeQuantizer(torch.nn.Module):
- """
- Dummy fake quantizer that counts the number of times it has been called.
- """
-
- def __init__(self):
- super().__init__()
- self.count = 0
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- self.count += 1
- return x
-
-
-def _get_model(qat_scheme: str, full_finetuning: bool):
- """
- Return a 2-tuple of (model, tokenizer), where the model has been configured
- to use QAT. If `full_finetuning` is False, return the PEFT (LoRA) model.
- """
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/Qwen3-1.7B",
- load_in_4bit = False,
- full_finetuning = full_finetuning,
- qat_scheme = qat_scheme if full_finetuning else None,
- )
- if not full_finetuning:
- model = FastLanguageModel.get_peft_model(
- model,
- qat_scheme = qat_scheme,
- )
- return model, tokenizer
-
-
-def _test_linear_is_fake_quantized(linear: torch.nn.Linear, qat_scheme: str):
- """
- Verify that the given linear contains fake quantizers according to the `qat_scheme`.
- """
- weight_only = False
- if qat_scheme == "fp8-int4":
- act_fq_class = Float8FakeQuantizer
- weight_fq_class = Int4WeightFakeQuantizer
- min_in_features = 128
- elif qat_scheme == "fp8-fp8":
- act_fq_class = Float8FakeQuantizer
- weight_fq_class = Float8FakeQuantizer
- min_in_features = -1
- elif qat_scheme == "int8":
- act_fq_class = None
- weight_fq_class = IntxFakeQuantizer
- min_in_features = 128
- weight_only = True
- else:
- raise ValueError(f"Unknown qat_scheme: {qat_scheme}")
-
- # Check base layer activations and weights
- base_layer = getattr(linear, "base_layer", linear)
- if base_layer.in_features >= min_in_features:
- assert isinstance(base_layer, FakeQuantizedLinear)
- if not weight_only:
- assert isinstance(base_layer.activation_fake_quantizer, act_fq_class)
- assert isinstance(base_layer.weight_fake_quantizer, weight_fq_class)
-
- # Check lora A and B (only for full_finetuning=False)
- if hasattr(linear, "lora_A") and hasattr(linear, "lora_B"):
- lora_A = linear.lora_A.default
- lora_B = linear.lora_B.default
- if lora_A.in_features >= min_in_features:
- assert isinstance(lora_A, FakeQuantizedLinear)
- if not weight_only:
- assert isinstance(lora_A.activation_fake_quantizer, act_fq_class)
- assert isinstance(lora_A.weight_fake_quantizer, weight_fq_class)
- if lora_B.in_features >= min_in_features:
- assert isinstance(lora_B, FakeQuantizedLinear)
- if not weight_only:
- assert isinstance(lora_B.activation_fake_quantizer, act_fq_class)
- assert isinstance(lora_B.weight_fake_quantizer, weight_fq_class)
-
-
-def _test_fake_quantizers_are_called(
- model: torch.nn.Module,
- example_inputs: Dict,
- full_finetuning: bool,
- qat_scheme: str,
-):
- """
- Verify that the fake quantizers are actually called when the model is called.
- """
- weight_only = qat_scheme == "int8"
-
- def _swap_fake_quantizers(model: torch.nn.Module):
- for name, child in model.named_children():
- if isinstance(child, FakeQuantizerBase):
- setattr(model, name, _CountingFakeQuantizer())
-
- def _assert_fake_quantizers_are_called(model: torch.nn.Module):
- for name, child in model.named_children():
- if full_finetuning:
- if isinstance(child, FakeQuantizedLinear):
- if not weight_only:
- assert child.activation_fake_quantizer.count == 1
- assert child.weight_fake_quantizer.count == 1
- else:
- # For LoRA, we only fake quantize the input activations once per block:
- # For self_attn, we only fake quantize the q_proj's input activations
- # For mlp, we only fake quantize the gate_proj's input activations
- if name == "self_attn":
- base_layer = child.q_proj.base_layer
- if not weight_only:
- assert hasattr(base_layer, "activation_fake_quantizer")
- assert base_layer.activation_fake_quantizer.count == 1
- elif name == "mlp":
- base_layer = child.gate_proj.base_layer
- if not weight_only:
- assert hasattr(base_layer, "activation_fake_quantizer")
- assert base_layer.activation_fake_quantizer.count == 1
- elif isinstance(child, FakeQuantizedLinear):
- # Weight fake quantizers should always be called
- assert child.weight_fake_quantizer.count == 1
-
- for k, v in example_inputs.items():
- example_inputs[k] = v.cuda()
- model.apply(_swap_fake_quantizers)
- model(**example_inputs)
- model.apply(_assert_fake_quantizers_are_called)
-
-
-def _test_model_fake_quantize(qat_scheme: str, full_finetuning: bool):
- """
- Test that all linear layers in the model are fake quantized according to the `qat_scheme`.
- """
- model, tokenizer = _get_model(qat_scheme, full_finetuning)
- if full_finetuning:
- model = model.model
- else:
- model = model.base_model.model.model
- for layer in model.layers:
- _test_linear_is_fake_quantized(layer.self_attn.q_proj, qat_scheme)
- _test_linear_is_fake_quantized(layer.self_attn.k_proj, qat_scheme)
- _test_linear_is_fake_quantized(layer.self_attn.v_proj, qat_scheme)
- _test_linear_is_fake_quantized(layer.mlp.gate_proj, qat_scheme)
- _test_linear_is_fake_quantized(layer.mlp.up_proj, qat_scheme)
- _test_linear_is_fake_quantized(layer.mlp.down_proj, qat_scheme)
- inputs = tokenizer("How are you?", return_tensors = "pt")
- _test_fake_quantizers_are_called(model, inputs, full_finetuning, qat_scheme)
-
-
-# TODO: there are bad interactions across tests right now, need to figure out
-# how to disable model caching before re-enabling this test
-@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"])
-def _test_full_model_fake_quantize(qat_scheme: str):
- _test_model_fake_quantize(qat_scheme, full_finetuning = True)
-
-
-@pytest.mark.parametrize("qat_scheme", ["fp8-int4", "fp8-fp8", "int8"])
-def test_lora_model_fake_quantize(qat_scheme: str):
- _test_model_fake_quantize(qat_scheme, full_finetuning = False)
diff --git a/tests/utils/test_trunc_normal_patch.py b/tests/utils/test_trunc_normal_patch.py
deleted file mode 100644
index b84a0772d8..0000000000
--- a/tests/utils/test_trunc_normal_patch.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-"""Tests for trunc_normal low-precision patch compatibility."""
-
-import importlib.util
-import inspect
-from pathlib import Path
-
-import pytest
-import torch
-
-
-_MISSING = object()
-
-
-def _load_import_fixes_module():
- repo_root = Path(__file__).resolve().parents[2]
- import_fixes_path = repo_root / "unsloth" / "import_fixes.py"
- spec = importlib.util.spec_from_file_location(
- "unsloth_import_fixes_local", import_fixes_path
- )
- assert spec is not None and spec.loader is not None
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- return module
-
-
-def _getattr_or_missing(obj, name):
- return getattr(obj, name) if hasattr(obj, name) else _MISSING
-
-
-def _restore_attr(obj, name, value):
- if value is _MISSING:
- if hasattr(obj, name):
- delattr(obj, name)
- return
- setattr(obj, name, value)
-
-
-def test_trunc_normal_patch_accepts_positional_generator():
- import_fixes = _load_import_fixes_module()
- patch_fn = import_fixes.patch_trunc_normal_precision_issue
-
- init_mod = torch.nn.init
- old_fn = init_mod.trunc_normal_
- old_patched = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_patched")
- old_original = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_original")
- try:
- # Normalize to an unpatched baseline before applying the patch.
- if old_original is not _MISSING:
- init_mod.trunc_normal_ = old_original
- if hasattr(init_mod, "_unsloth_trunc_normal_patched"):
- delattr(init_mod, "_unsloth_trunc_normal_patched")
- if hasattr(init_mod, "_unsloth_trunc_normal_original"):
- delattr(init_mod, "_unsloth_trunc_normal_original")
-
- patch_fn()
- sig = inspect.signature(init_mod.trunc_normal_)
- assert "generator" in sig.parameters
- assert sig.parameters["generator"].kind is not inspect.Parameter.KEYWORD_ONLY
-
- tensor = torch.empty(1024, dtype = torch.float32)
- gen = torch.Generator()
- gen.manual_seed(3407)
-
- init_mod.trunc_normal_(tensor, 0.0, 1.0, -2.0, 2.0, gen)
- init_mod.trunc_normal_(tensor, mean = 0.0, std = 1.0, a = -2.0, b = 2.0, generator = gen)
- finally:
- init_mod.trunc_normal_ = old_fn
- _restore_attr(init_mod, "_unsloth_trunc_normal_patched", old_patched)
- _restore_attr(init_mod, "_unsloth_trunc_normal_original", old_original)
-
-
-def test_trunc_normal_patch_rejects_invalid_generator():
- import_fixes = _load_import_fixes_module()
- patch_fn = import_fixes.patch_trunc_normal_precision_issue
-
- init_mod = torch.nn.init
- old_fn = init_mod.trunc_normal_
- old_patched = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_patched")
- old_original = _getattr_or_missing(init_mod, "_unsloth_trunc_normal_original")
- try:
- if old_original is not _MISSING:
- init_mod.trunc_normal_ = old_original
- if hasattr(init_mod, "_unsloth_trunc_normal_patched"):
- delattr(init_mod, "_unsloth_trunc_normal_patched")
- if hasattr(init_mod, "_unsloth_trunc_normal_original"):
- delattr(init_mod, "_unsloth_trunc_normal_original")
-
- patch_fn()
- sig = inspect.signature(init_mod.trunc_normal_)
- if "generator" not in sig.parameters:
- pytest.skip("torch.nn.init.trunc_normal_ lacks a generator parameter")
-
- tensor = torch.empty(16, dtype = torch.float32)
- with pytest.raises(TypeError):
- init_mod.trunc_normal_(tensor, generator = 123)
- finally:
- init_mod.trunc_normal_ = old_fn
- _restore_attr(init_mod, "_unsloth_trunc_normal_patched", old_patched)
- _restore_attr(init_mod, "_unsloth_trunc_normal_original", old_original)
diff --git a/unsloth-cli.py b/unsloth-cli.py
index 612da11eb2..b7613f92df 100644
--- a/unsloth-cli.py
+++ b/unsloth-cli.py
@@ -1,473 +1,229 @@
-#!/usr/bin/env python3
-
-"""
-🦥 Starter Script for Fine-Tuning FastLanguageModel with Unsloth
-
-This script is designed as a starting point for fine-tuning your models using unsloth.
-It includes configurable options for model loading, PEFT parameters, training arguments,
-and model saving/pushing functionalities.
-
-You will likely want to customize this script to suit your specific use case
-and requirements.
-
-Here are a few suggestions for customization:
- - Modify the dataset loading and preprocessing steps to match your data.
- - Customize the model saving and pushing configurations.
-
-Usage: (most of the options have valid default values this is an extended example for demonstration purposes)
- python unsloth-cli.py --model_name "unsloth/llama-3-8b" --max_seq_length 8192 --dtype None --load_in_4bit \
- --r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --use_gradient_checkpointing "unsloth" \
- --random_state 3407 --use_rslora --per_device_train_batch_size 4 --gradient_accumulation_steps 8 \
- --warmup_steps 5 --max_steps 400 --learning_rate 2e-6 --logging_steps 1 --optim "adamw_8bit" \
- --weight_decay 0.005 --lr_scheduler_type "linear" --seed 3407 --output_dir "outputs" \
- --report_to "tensorboard" --save_model --save_path "model" --quantization_method "f16" \
- --push_model --hub_path "hf/model" --hub_token "your_hf_token"
-
-To see a full list of configurable options, use:
- python unsloth-cli.py --help
-
-Happy fine-tuning!
-"""
-
-import argparse
-import os
-
-
-def run(args):
- from unsloth import FastLanguageModel
- from datasets import load_dataset
- from transformers.utils import strtobool
- from trl import SFTTrainer, SFTConfig
- from unsloth import is_bfloat16_supported
- from unsloth.models.loader_utils import prepare_device_map
- import logging
- from unsloth import RawTextDataLoader
-
- logging.getLogger("hf-to-gguf").setLevel(logging.WARNING)
-
- # Load model and tokenizer
- device_map, distributed = prepare_device_map()
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = args.model_name,
- max_seq_length = args.max_seq_length,
- dtype = args.dtype,
- load_in_4bit = args.load_in_4bit,
- device_map = device_map,
- )
-
- # Configure PEFT model
- model = FastLanguageModel.get_peft_model(
- model,
- r = args.r,
- target_modules = [
- "q_proj",
- "k_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "up_proj",
- "down_proj",
- ],
- lora_alpha = args.lora_alpha,
- lora_dropout = args.lora_dropout,
- bias = args.bias,
- use_gradient_checkpointing = args.use_gradient_checkpointing,
- random_state = args.random_state,
- use_rslora = args.use_rslora,
- loftq_config = args.loftq_config,
- )
-
- alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
-
- ### Instruction:
- {}
-
- ### Input:
- {}
-
- ### Response:
- {}"""
-
- EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
-
- def formatting_prompts_func(examples):
- instructions = examples["instruction"]
- inputs = examples["input"]
- outputs = examples["output"]
- texts = []
- for instruction, input, output in zip(instructions, inputs, outputs):
- text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
- texts.append(text)
- return {"text": texts}
-
- def load_dataset_smart(args):
- from transformers.utils import strtobool
-
- if args.raw_text_file:
- # Use raw text loader
- loader = RawTextDataLoader(tokenizer, args.chunk_size, args.stride)
- dataset = loader.load_from_file(args.raw_text_file)
- elif args.dataset.endswith((".txt", ".md", ".json", ".jsonl")):
- # Auto-detect local raw text files
- loader = RawTextDataLoader(tokenizer)
- dataset = loader.load_from_file(args.dataset)
- else:
- # Check for modelscope usage
- use_modelscope = strtobool(
- os.environ.get("UNSLOTH_USE_MODELSCOPE", "False")
- )
- if use_modelscope:
- from modelscope import MsDataset
-
- dataset = MsDataset.load(args.dataset, split = "train")
- else:
- # Existing HuggingFace dataset logic
- dataset = load_dataset(args.dataset, split = "train")
-
- # Apply formatting for structured datasets
- dataset = dataset.map(formatting_prompts_func, batched = True)
- return dataset
-
- # Load dataset using smart loader
- dataset = load_dataset_smart(args)
- print("Data is formatted and ready!")
-
- # Configure training arguments
- training_args = SFTConfig(
- per_device_train_batch_size = args.per_device_train_batch_size,
- per_device_eval_batch_size = args.per_device_eval_batch_size,
- gradient_accumulation_steps = args.gradient_accumulation_steps,
- warmup_steps = args.warmup_steps,
- max_steps = args.max_steps,
- learning_rate = args.learning_rate,
- fp16 = not is_bfloat16_supported(),
- bf16 = is_bfloat16_supported(),
- logging_steps = args.logging_steps,
- optim = args.optim,
- weight_decay = args.weight_decay,
- lr_scheduler_type = args.lr_scheduler_type,
- seed = args.seed,
- output_dir = args.output_dir,
- report_to = args.report_to,
- max_length = args.max_seq_length,
- dataset_num_proc = 2,
- ddp_find_unused_parameters = False if distributed else None,
- packing = args.packing,
- )
-
- # Initialize trainer
- trainer = SFTTrainer(
- model = model,
- processing_class = tokenizer,
- train_dataset = dataset,
- args = training_args,
- )
-
- trainer.train()
-
- # Save model
- if args.save_model:
- # if args.quantization_method is a list, we will save the model for each quantization method
- if args.save_gguf:
- if isinstance(args.quantization, list):
- for quantization_method in args.quantization:
- print(
- f"Saving model with quantization method: {quantization_method}"
- )
- model.save_pretrained_gguf(
- args.save_path,
- tokenizer,
- quantization_method = quantization_method,
- )
- if args.push_model:
- model.push_to_hub_gguf(
- hub_path = args.hub_path,
- hub_token = args.hub_token,
- quantization_method = quantization_method,
- )
- else:
- print(f"Saving model with quantization method: {args.quantization}")
- model.save_pretrained_gguf(
- args.save_path,
- tokenizer,
- quantization_method = args.quantization,
- )
- if args.push_model:
- model.push_to_hub_gguf(
- hub_path = args.hub_path,
- hub_token = args.hub_token,
- quantization_method = args.quantization,
- )
- else:
- model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
- if args.push_model:
- model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)
- else:
- print("Warning: The model is not saved!")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description = "🦥 Fine-tune your llm faster using unsloth!"
- )
-
- model_group = parser.add_argument_group("🤖 Model Options")
- model_group.add_argument(
- "--model_name",
- type = str,
- default = "unsloth/llama-3-8b",
- help = "Model name to load",
- )
- model_group.add_argument(
- "--max_seq_length",
- type = int,
- default = 2048,
- help = "Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!",
- )
- model_group.add_argument(
- "--dtype",
- type = str,
- default = None,
- help = "Data type for model (None for auto detection)",
- )
- model_group.add_argument(
- "--load_in_4bit",
- action = "store_true",
- help = "Use 4bit quantization to reduce memory usage",
- )
- model_group.add_argument(
- "--dataset",
- type = str,
- default = "yahma/alpaca-cleaned",
- help = "Huggingface dataset to use for training",
- )
-
- lora_group = parser.add_argument_group(
- "🧠 LoRA Options",
- "These options are used to configure the LoRA model.",
- )
- lora_group.add_argument(
- "--r",
- type = int,
- default = 16,
- help = "Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)",
- )
- lora_group.add_argument(
- "--lora_alpha",
- type = int,
- default = 16,
- help = "LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)",
- )
- lora_group.add_argument(
- "--lora_dropout",
- type = float,
- default = 0.0,
- help = "LoRA dropout rate, default is 0.0 which is optimized.",
- )
- lora_group.add_argument(
- "--bias",
- type = str,
- default = "none",
- help = "Bias setting for LoRA",
- )
- lora_group.add_argument(
- "--use_gradient_checkpointing",
- type = str,
- default = "unsloth",
- help = "Use gradient checkpointing",
- )
- lora_group.add_argument(
- "--random_state",
- type = int,
- default = 3407,
- help = "Random state for reproducibility, default is 3407.",
- )
- lora_group.add_argument(
- "--use_rslora",
- action = "store_true",
- help = "Use rank stabilized LoRA",
- )
- lora_group.add_argument(
- "--loftq_config",
- type = str,
- default = None,
- help = "Configuration for LoftQ",
- )
-
- training_group = parser.add_argument_group("🎓 Training Options")
- training_group.add_argument(
- "--per_device_train_batch_size",
- type = int,
- default = 2,
- help = "Batch size per device during training, default is 2.",
- )
- training_group.add_argument(
- "--per_device_eval_batch_size",
- type = int,
- default = 4,
- help = "Batch size per device during evaluation, default is 4.",
- )
- training_group.add_argument(
- "--gradient_accumulation_steps",
- type = int,
- default = 4,
- help = "Number of gradient accumulation steps, default is 4.",
- )
- training_group.add_argument(
- "--warmup_steps",
- type = int,
- default = 5,
- help = "Number of warmup steps, default is 5.",
- )
- training_group.add_argument(
- "--max_steps",
- type = int,
- default = 400,
- help = "Maximum number of training steps.",
- )
- training_group.add_argument(
- "--learning_rate",
- type = float,
- default = 2e-4,
- help = "Learning rate, default is 2e-4.",
- )
- training_group.add_argument(
- "--optim",
- type = str,
- default = "adamw_8bit",
- help = "Optimizer type.",
- )
- training_group.add_argument(
- "--weight_decay",
- type = float,
- default = 0.01,
- help = "Weight decay, default is 0.01.",
- )
- training_group.add_argument(
- "--lr_scheduler_type",
- type = str,
- default = "linear",
- help = "Learning rate scheduler type, default is 'linear'.",
- )
- training_group.add_argument(
- "--seed",
- type = int,
- default = 3407,
- help = "Seed for reproducibility, default is 3407.",
- )
- training_group.add_argument(
- "--packing",
- action = "store_true",
- help = "Enable padding-free sample packing via TRL's bin packer.",
- )
-
- report_group = parser.add_argument_group("📊 Report Options")
- report_group.add_argument(
- "--report_to",
- type = str,
- default = "tensorboard",
- choices = [
- "azure_ml",
- "clearml",
- "codecarbon",
- "comet_ml",
- "dagshub",
- "dvclive",
- "flyte",
- "mlflow",
- "neptune",
- "tensorboard",
- "wandb",
- "all",
- "none",
- ],
- help = (
- "The list of integrations to report the results and logs to. Supported platforms are:\n\t\t "
- "'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', "
- "'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations "
- "installed, 'none' for no integrations."
- ),
- )
- report_group.add_argument(
- "--logging_steps",
- type = int,
- default = 1,
- help = "Logging steps, default is 1",
- )
-
- save_group = parser.add_argument_group("💾 Save Model Options")
- save_group.add_argument(
- "--output_dir",
- type = str,
- default = "outputs",
- help = "Output directory",
- )
- save_group.add_argument(
- "--save_model",
- action = "store_true",
- help = "Save the model after training",
- )
- save_group.add_argument(
- "--save_method",
- type = str,
- default = "merged_16bit",
- choices = ["merged_16bit", "merged_4bit", "lora"],
- help = "Save method for the model, default is 'merged_16bit'",
- )
- save_group.add_argument(
- "--save_gguf",
- action = "store_true",
- help = "Convert the model to GGUF after training",
- )
- save_group.add_argument(
- "--save_path",
- type = str,
- default = "model",
- help = "Path to save the model",
- )
- save_group.add_argument(
- "--quantization",
- type = str,
- default = "q8_0",
- nargs = "+",
- help = (
- "Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), "
- "Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf"
- ),
- )
-
- push_group = parser.add_argument_group("🚀 Push Model Options")
- push_group.add_argument(
- "--push_model",
- action = "store_true",
- help = "Push the model to Hugging Face hub after training",
- )
- push_group.add_argument(
- "--push_gguf",
- action = "store_true",
- help = "Push the model as GGUF to Hugging Face hub after training",
- )
- push_group.add_argument(
- "--hub_path",
- type = str,
- default = "hf/model",
- help = "Path on Hugging Face hub to push the model",
- )
- push_group.add_argument(
- "--hub_token",
- type = str,
- help = "Token for pushing the model to Hugging Face hub",
- )
-
- parser.add_argument(
- "--raw_text_file", type = str, help = "Path to raw text file for training"
- )
- parser.add_argument(
- "--chunk_size", type = int, default = 2048, help = "Size of text chunks for training"
- )
- parser.add_argument(
- "--stride", type = int, default = 512, help = "Overlap between chunks"
- )
-
- args = parser.parse_args()
- run(args)
+#!/usr/bin/env python3
+
+"""
+🦥 Starter Script for Fine-Tuning FastLanguageModel with Unsloth
+
+This script is designed as a starting point for fine-tuning your models using unsloth.
+It includes configurable options for model loading, PEFT parameters, training arguments,
+and model saving/pushing functionalities.
+
+You will likely want to customize this script to suit your specific use case
+and requirements.
+
+Here are a few suggestions for customization:
+ - Modify the dataset loading and preprocessing steps to match your data.
+ - Customize the model saving and pushing configurations.
+
+Usage: (most of the options have valid default values this is an extended example for demonstration purposes)
+ python unsloth-cli.py --model_name "unsloth/llama-3-8b" --max_seq_length 8192 --dtype None --load_in_4bit \
+ --r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --use_gradient_checkpointing "unsloth" \
+ --random_state 3407 --use_rslora --per_device_train_batch_size 4 --gradient_accumulation_steps 8 \
+ --warmup_steps 5 --max_steps 400 --learning_rate 2e-6 --logging_steps 1 --optim "adamw_8bit" \
+ --weight_decay 0.005 --lr_scheduler_type "linear" --seed 3407 --output_dir "outputs" \
+ --report_to "tensorboard" --save_model --save_path "model" --quantization_method "f16" \
+ --push_model --hub_path "hf/model" --hub_token "your_hf_token"
+
+To see a full list of configurable options, use:
+ python unsloth-cli.py --help
+
+Happy fine-tuning!
+"""
+
+import argparse
+import os
+
+
+def run(args):
+ import torch
+ from unsloth import FastLanguageModel
+ from datasets import load_dataset
+ from transformers.utils import strtobool
+ from trl import SFTTrainer
+ from transformers import TrainingArguments
+ from unsloth import is_bfloat16_supported
+ import logging
+ logging.getLogger('hf-to-gguf').setLevel(logging.WARNING)
+
+ # Load model and tokenizer
+ model, tokenizer = FastLanguageModel.from_pretrained(
+ model_name=args.model_name,
+ max_seq_length=args.max_seq_length,
+ dtype=args.dtype,
+ load_in_4bit=args.load_in_4bit,
+ )
+
+ # Configure PEFT model
+ model = FastLanguageModel.get_peft_model(
+ model,
+ r=args.r,
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
+ "gate_proj", "up_proj", "down_proj"],
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ bias=args.bias,
+ use_gradient_checkpointing=args.use_gradient_checkpointing,
+ random_state=args.random_state,
+ use_rslora=args.use_rslora,
+ loftq_config=args.loftq_config,
+ )
+
+ alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
+
+ ### Instruction:
+ {}
+
+ ### Input:
+ {}
+
+ ### Response:
+ {}"""
+
+ EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
+ def formatting_prompts_func(examples):
+ instructions = examples["instruction"]
+ inputs = examples["input"]
+ outputs = examples["output"]
+ texts = []
+ for instruction, input, output in zip(instructions, inputs, outputs):
+ text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
+ texts.append(text)
+ return {"text": texts}
+
+ use_modelscope = strtobool(os.environ.get('UNSLOTH_USE_MODELSCOPE', 'False'))
+ if use_modelscope:
+ from modelscope import MsDataset
+ dataset = MsDataset.load(args.dataset, split="train")
+ else:
+ # Load and format dataset
+ dataset = load_dataset(args.dataset, split="train")
+ dataset = dataset.map(formatting_prompts_func, batched=True)
+ print("Data is formatted and ready!")
+
+ # Configure training arguments
+ training_args = TrainingArguments(
+ per_device_train_batch_size=args.per_device_train_batch_size,
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ warmup_steps=args.warmup_steps,
+ max_steps=args.max_steps,
+ learning_rate=args.learning_rate,
+ fp16=not is_bfloat16_supported(),
+ bf16=is_bfloat16_supported(),
+ logging_steps=args.logging_steps,
+ optim=args.optim,
+ weight_decay=args.weight_decay,
+ lr_scheduler_type=args.lr_scheduler_type,
+ seed=args.seed,
+ output_dir=args.output_dir,
+ report_to=args.report_to,
+ )
+
+ # Initialize trainer
+ trainer = SFTTrainer(
+ model=model,
+ tokenizer=tokenizer,
+ train_dataset=dataset,
+ dataset_text_field="text",
+ max_seq_length=args.max_seq_length,
+ dataset_num_proc=2,
+ packing=False,
+ args=training_args,
+ )
+
+ # Train model
+ trainer_stats = trainer.train()
+
+ # Save model
+ if args.save_model:
+ # if args.quantization_method is a list, we will save the model for each quantization method
+ if args.save_gguf:
+ if isinstance(args.quantization, list):
+ for quantization_method in args.quantization:
+ print(f"Saving model with quantization method: {quantization_method}")
+ model.save_pretrained_gguf(
+ args.save_path,
+ tokenizer,
+ quantization_method=quantization_method,
+ )
+ if args.push_model:
+ model.push_to_hub_gguf(
+ hub_path=args.hub_path,
+ hub_token=args.hub_token,
+ quantization_method=quantization_method,
+ )
+ else:
+ print(f"Saving model with quantization method: {args.quantization}")
+ model.save_pretrained_gguf(args.save_path, tokenizer, quantization_method=args.quantization)
+ if args.push_model:
+ model.push_to_hub_gguf(
+ hub_path=args.hub_path,
+ hub_token=args.hub_token,
+ quantization_method=quantization_method,
+ )
+ else:
+ model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
+ if args.push_model:
+ model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)
+ else:
+ print("Warning: The model is not saved!")
+
+
+if __name__ == "__main__":
+
+ # Define argument parser
+ parser = argparse.ArgumentParser(description="🦥 Fine-tune your llm faster using unsloth!")
+
+ model_group = parser.add_argument_group("🤖 Model Options")
+ model_group.add_argument('--model_name', type=str, default="unsloth/llama-3-8b", help="Model name to load")
+ model_group.add_argument('--max_seq_length', type=int, default=2048, help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!")
+ model_group.add_argument('--dtype', type=str, default=None, help="Data type for model (None for auto detection)")
+ model_group.add_argument('--load_in_4bit', action='store_true', help="Use 4bit quantization to reduce memory usage")
+ model_group.add_argument('--dataset', type=str, default="yahma/alpaca-cleaned", help="Huggingface dataset to use for training")
+
+ lora_group = parser.add_argument_group("🧠 LoRA Options", "These options are used to configure the LoRA model.")
+ lora_group.add_argument('--r', type=int, default=16, help="Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)")
+ lora_group.add_argument('--lora_alpha', type=int, default=16, help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)")
+ lora_group.add_argument('--lora_dropout', type=float, default=0, help="LoRA dropout rate, default is 0.0 which is optimized.")
+ lora_group.add_argument('--bias', type=str, default="none", help="Bias setting for LoRA")
+ lora_group.add_argument('--use_gradient_checkpointing', type=str, default="unsloth", help="Use gradient checkpointing")
+ lora_group.add_argument('--random_state', type=int, default=3407, help="Random state for reproducibility, default is 3407.")
+ lora_group.add_argument('--use_rslora', action='store_true', help="Use rank stabilized LoRA")
+ lora_group.add_argument('--loftq_config', type=str, default=None, help="Configuration for LoftQ")
+
+
+ training_group = parser.add_argument_group("🎓 Training Options")
+ training_group.add_argument('--per_device_train_batch_size', type=int, default=2, help="Batch size per device during training, default is 2.")
+ training_group.add_argument('--gradient_accumulation_steps', type=int, default=4, help="Number of gradient accumulation steps, default is 4.")
+ training_group.add_argument('--warmup_steps', type=int, default=5, help="Number of warmup steps, default is 5.")
+ training_group.add_argument('--max_steps', type=int, default=400, help="Maximum number of training steps.")
+ training_group.add_argument('--learning_rate', type=float, default=2e-4, help="Learning rate, default is 2e-4.")
+ training_group.add_argument('--optim', type=str, default="adamw_8bit", help="Optimizer type.")
+ training_group.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay, default is 0.01.")
+ training_group.add_argument('--lr_scheduler_type', type=str, default="linear", help="Learning rate scheduler type, default is 'linear'.")
+ training_group.add_argument('--seed', type=int, default=3407, help="Seed for reproducibility, default is 3407.")
+
+
+ # Report/Logging arguments
+ report_group = parser.add_argument_group("📊 Report Options")
+ report_group.add_argument('--report_to', type=str, default="tensorboard",
+ choices=["azure_ml", "clearml", "codecarbon", "comet_ml", "dagshub", "dvclive", "flyte", "mlflow", "neptune", "tensorboard", "wandb", "all", "none"],
+ help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.")
+ report_group.add_argument('--logging_steps', type=int, default=1, help="Logging steps, default is 1")
+
+ # Saving and pushing arguments
+ save_group = parser.add_argument_group('💾 Save Model Options')
+ save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory")
+ save_group.add_argument('--save_model', action='store_true', help="Save the model after training")
+ save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'")
+ save_group.add_argument('--save_gguf', action='store_true', help="Convert the model to GGUF after training")
+ save_group.add_argument('--save_path', type=str, default="model", help="Path to save the model")
+ save_group.add_argument('--quantization', type=str, default="q8_0", nargs="+",
+ help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ")
+
+ push_group = parser.add_argument_group('🚀 Push Model Options')
+ push_group.add_argument('--push_model', action='store_true', help="Push the model to Hugging Face hub after training")
+ push_group.add_argument('--push_gguf', action='store_true', help="Push the model as GGUF to Hugging Face hub after training")
+ push_group.add_argument('--hub_path', type=str, default="hf/model", help="Path on Hugging Face hub to push the model")
+ push_group.add_argument('--hub_token', type=str, help="Token for pushing the model to Hugging Face hub")
+
+ args = parser.parse_args()
+ run(args)
diff --git a/unsloth/__init__.py b/unsloth/__init__.py
index dbacd551c4..7ffddde9b0 100644
--- a/unsloth/__init__.py
+++ b/unsloth/__init__.py
@@ -14,61 +14,29 @@
import warnings, importlib, sys
from packaging.version import Version
-import os, re, subprocess, inspect, functools
+import os, re, subprocess, inspect
import numpy as np
-# Log Unsloth is being used
-os.environ["UNSLOTH_IS_PRESENT"] = "1"
-
# Check if modules that need patching are already imported
-critical_modules = ["trl", "transformers", "peft"]
+critical_modules = ['trl', 'transformers', 'peft']
already_imported = [mod for mod in critical_modules if mod in sys.modules]
-# Fix some issues before importing other packages
-from .import_fixes import (
- fix_message_factory_issue,
- check_fbgemm_gpu_version,
- disable_broken_causal_conv1d,
- disable_broken_vllm,
- configure_amdgpu_asic_id_table_path,
- torchvision_compatibility_check,
- fix_diffusers_warnings,
- fix_huggingface_hub,
-)
-
-# Configure libdrm ids table path early so ROCm can resolve AMD GPU names.
-configure_amdgpu_asic_id_table_path()
-disable_broken_causal_conv1d()
-disable_broken_vllm()
-fix_message_factory_issue()
-check_fbgemm_gpu_version()
-torchvision_compatibility_check()
-fix_diffusers_warnings()
-fix_huggingface_hub()
-del configure_amdgpu_asic_id_table_path
-del disable_broken_causal_conv1d
-del disable_broken_vllm
-del fix_message_factory_issue
-del check_fbgemm_gpu_version
-del torchvision_compatibility_check
-del fix_diffusers_warnings
-del fix_huggingface_hub
-
# This check is critical because Unsloth optimizes these libraries by modifying
-# their code at import time. If they're imported first, the original (slower,
+# their code at import time. If they're imported first, the original (slower,
# more memory-intensive) implementations will be used instead of Unsloth's
# optimized versions, potentially causing OOM errors or slower training.
+
if already_imported:
# stacklevel=2 makes warning point to user's import line rather than this library code,
# showing them exactly where to fix the import order in their script
warnings.warn(
- f"WARNING: Unsloth should be imported before [{', '.join(already_imported)}] "
+ f"WARNING: Unsloth should be imported before {', '.join(already_imported)} "
f"to ensure all optimizations are applied. Your code may run slower or encounter "
f"memory issues without these optimizations.\n\n"
f"Please restructure your imports with 'import unsloth' at the top of your file.",
stacklevel = 2,
)
-del already_imported, critical_modules
+pass
# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
# enabling it will require much more work, so we have to prioritize. Please understand!
@@ -78,235 +46,175 @@
# Fixes https://github.com/unslothai/unsloth/issues/1266
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+# Reduce VRAM usage by reducing fragmentation
+# And optimize pinning of memory
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \
+ "expandable_segments:True,"\
+ "roundup_power2_divisions:[32:256,64:128,256:64,>:32]"
+
# [TODO] Check why some GPUs don't work
# "pinned_use_cuda_host_register:True,"\
# "pinned_num_register_threads:8"
+# Hugging Face Hub faster downloads
+if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+pass
-from importlib.metadata import version as importlib_version
-from importlib.metadata import PackageNotFoundError
-
-# Check for unsloth_zoo
-try:
- unsloth_zoo_version = importlib_version("unsloth_zoo")
- if Version(unsloth_zoo_version) < Version("2026.3.2"):
- print(
- "Unsloth: Please update Unsloth and Unsloth-Zoo to the latest version!\n"
- "Do this via `pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo`"
- )
- # if os.environ.get("UNSLOTH_DISABLE_AUTO_UPDATES", "0") == "0":
- # try:
- # os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo")
- # except:
- # try:
- # os.system("pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo")
- # except:
- # raise ImportError("Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`")
- import unsloth_zoo
-except PackageNotFoundError:
- raise ImportError(
- f"Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo` then retry!"
- )
-except:
- raise
-del PackageNotFoundError, importlib_version
+# Log Unsloth is being used
+os.environ["UNSLOTH_IS_PRESENT"] = "1"
-# Try importing PyTorch and check version
try:
import torch
except ModuleNotFoundError:
raise ImportError(
- "Unsloth: Pytorch is not installed. Go to https://pytorch.org/.\n"
+ "Unsloth: Pytorch is not installed. Go to https://pytorch.org/.\n"\
"We have some installation instructions on our Github page."
)
+except Exception as exception:
+ raise exception
+pass
+
+# We support Pytorch 2
+# Fixes https://github.com/unslothai/unsloth/issues/38
+torch_version = torch.__version__.split(".")
+major_torch, minor_torch = torch_version[0], torch_version[1]
+major_torch, minor_torch = int(major_torch), int(minor_torch)
+if (major_torch < 2):
+ raise ImportError("Unsloth only supports Pytorch 2 for now. Please update your Pytorch to 2.1.\n"\
+ "We have some installation instructions on our Github page.")
+elif (major_torch == 2) and (minor_torch < 2):
+ # Disable expandable_segments
+ del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
+pass
+
+# First check if CUDA is available ie a NVIDIA GPU is seen
+if not torch.cuda.is_available():
+ raise NotImplementedError("Unsloth: No NVIDIA GPU found? Unsloth currently only supports GPUs!")
+
+# Fix Xformers performance issues since 0.0.25
+import importlib.util
+from pathlib import Path
+from importlib.metadata import version as importlib_version
+from packaging.version import Version
+try:
+ xformers_version = importlib_version("xformers")
+ if Version(xformers_version) < Version("0.0.29"):
+ xformers_location = importlib.util.find_spec("xformers").origin
+ xformers_location = os.path.split(xformers_location)[0]
+ cutlass = Path(xformers_location) / "ops" / "fmha" / "cutlass.py"
+
+ if cutlass.exists():
+ with open(cutlass, "r+") as f:
+ text = f.read()
+ # See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591
+ if "num_splits_key=-1," in text:
+ text = text.replace("num_splits_key=-1,", "num_splits_key=None,")
+ f.seek(0)
+ f.write(text)
+ f.truncate()
+ print("Unsloth: Patching Xformers to fix some performance issues.")
+ pass
+ pass
+ pass
+ pass
except:
- raise
-
-from unsloth_zoo.device_type import (
- is_hip,
- get_device_type,
- DEVICE_TYPE,
- DEVICE_TYPE_TORCH,
- DEVICE_COUNT,
- ALLOW_PREQUANTIZED_MODELS,
-)
-
-# Fix other issues
-from .import_fixes import (
- fix_xformers_performance_issue,
- fix_vllm_aimv2_issue,
- check_vllm_torch_sm100_compatibility,
- fix_vllm_guided_decoding_params,
- fix_vllm_pdl_blackwell,
- fix_triton_compiled_kernel_missing_attrs,
- patch_trunc_normal_precision_issue,
- ignore_logger_messages,
- patch_ipykernel_hf_xet,
- patch_trackio,
- patch_datasets,
- patch_enable_input_require_grads,
- fix_openenv_no_vllm,
- patch_openspiel_env_async,
- fix_executorch,
- patch_vllm_for_notebooks,
- patch_torchcodec_audio_decoder,
- disable_torchcodec_if_broken,
- disable_broken_wandb,
-)
-
-fix_xformers_performance_issue()
-fix_vllm_aimv2_issue()
-# Check vLLM + torch < 2.9.0 + SM100 compatibility BEFORE importing vLLM
-check_vllm_torch_sm100_compatibility()
-fix_vllm_guided_decoding_params()
-fix_vllm_pdl_blackwell()
-fix_triton_compiled_kernel_missing_attrs()
-patch_trunc_normal_precision_issue()
-ignore_logger_messages()
-patch_ipykernel_hf_xet()
-patch_trackio()
-patch_datasets()
-patch_enable_input_require_grads()
-fix_openenv_no_vllm()
-patch_openspiel_env_async()
-fix_executorch()
-patch_vllm_for_notebooks()
-patch_torchcodec_audio_decoder()
-disable_torchcodec_if_broken()
-disable_broken_wandb()
-
-del fix_xformers_performance_issue
-del fix_vllm_aimv2_issue
-del check_vllm_torch_sm100_compatibility
-del fix_vllm_guided_decoding_params
-del fix_vllm_pdl_blackwell
-del fix_triton_compiled_kernel_missing_attrs
-del patch_trunc_normal_precision_issue
-del ignore_logger_messages
-del patch_ipykernel_hf_xet
-del patch_trackio
-del patch_datasets
-del patch_enable_input_require_grads
-del fix_openenv_no_vllm
-del patch_openspiel_env_async
-del fix_executorch
-del patch_vllm_for_notebooks
-del patch_torchcodec_audio_decoder
-del disable_torchcodec_if_broken
-del disable_broken_wandb
+ pass
+pass
# Torch 2.4 has including_emulation
-if DEVICE_TYPE == "cuda":
- major_version, minor_version = torch.cuda.get_device_capability()
- SUPPORTS_BFLOAT16 = major_version >= 8
-
- old_is_bf16_supported = torch.cuda.is_bf16_supported
- if "including_emulation" in str(inspect.signature(old_is_bf16_supported)):
-
- def is_bf16_supported(including_emulation = False):
- return old_is_bf16_supported(including_emulation)
-
- torch.cuda.is_bf16_supported = is_bf16_supported
- else:
-
- def is_bf16_supported():
- return SUPPORTS_BFLOAT16
-
- torch.cuda.is_bf16_supported = is_bf16_supported
- del major_version, minor_version
-elif DEVICE_TYPE == "hip":
- SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()
-elif DEVICE_TYPE == "xpu":
- # torch.xpu.is_bf16_supported() does not have including_emulation
- # set SUPPORTS_BFLOAT16 as torch.xpu.is_bf16_supported()
- SUPPORTS_BFLOAT16 = torch.xpu.is_bf16_supported()
+major_version, minor_version = torch.cuda.get_device_capability()
+SUPPORTS_BFLOAT16 = (major_version >= 8)
+
+old_is_bf16_supported = torch.cuda.is_bf16_supported
+if "including_emulation" in str(inspect.signature(old_is_bf16_supported)):
+ def is_bf16_supported(including_emulation = False):
+ return old_is_bf16_supported(including_emulation)
+ torch.cuda.is_bf16_supported = is_bf16_supported
+else:
+ def is_bf16_supported(): return SUPPORTS_BFLOAT16
+ torch.cuda.is_bf16_supported = is_bf16_supported
+pass
# For Gradio HF Spaces?
# if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
import triton
+libcuda_dirs = lambda: None
+if Version(triton.__version__) >= Version("3.0.0"):
+ try: from triton.backends.nvidia.driver import libcuda_dirs
+ except: pass
+else: from triton.common.build import libcuda_dirs
+
+# Try loading bitsandbytes and triton
+import bitsandbytes as bnb
+try:
+ cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
+ libcuda_dirs()
+except:
+ warnings.warn(
+ "Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
+ )
-if DEVICE_TYPE == "cuda":
- libcuda_dirs = lambda: None
- if Version(triton.__version__) >= Version("3.0.0"):
- try:
- from triton.backends.nvidia.driver import libcuda_dirs
- except:
- pass
- else:
- from triton.common.build import libcuda_dirs
+ if os.path.exists("/usr/lib64-nvidia"):
+ os.system("ldconfig /usr/lib64-nvidia")
+ elif os.path.exists("/usr/local"):
+ # Sometimes bitsandbytes cannot be linked properly in Runpod for example
+ possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n")
+ find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
+ possible_cudas = [find_cuda.search(x) for x in possible_cudas]
+ possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
+
+ # Try linking cuda folder, or everything in local
+ if len(possible_cudas) == 0:
+ os.system("ldconfig /usr/local/")
+ else:
+ find_number = re.compile(r"([\d\.]{2,})")
+ latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
+ latest_cuda = possible_cudas[latest_cuda]
+ os.system(f"ldconfig /usr/local/{latest_cuda}")
+ pass
- # Try loading bitsandbytes and triton
- try:
- import bitsandbytes as bnb
- except:
- print(
- "Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works!"
- )
- bnb = None
+ importlib.reload(bnb)
+ importlib.reload(triton)
try:
+ libcuda_dirs = lambda: None
+ if Version(triton.__version__) >= Version("3.0.0"):
+ try: from triton.backends.nvidia.driver import libcuda_dirs
+ except: pass
+ else: from triton.common.build import libcuda_dirs
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
libcuda_dirs()
except:
- warnings.warn("Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA.")
-
- if os.path.exists("/usr/lib64-nvidia"):
- os.system("ldconfig /usr/lib64-nvidia")
- elif os.path.exists("/usr/local"):
- # Sometimes bitsandbytes cannot be linked properly in Runpod for example
- possible_cudas = (
- subprocess.check_output(["ls", "-al", "/usr/local"])
- .decode("utf-8")
- .split("\n")
- )
- find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
- possible_cudas = [find_cuda.search(x) for x in possible_cudas]
- possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
-
- # Try linking cuda folder, or everything in local
- if len(possible_cudas) == 0:
- os.system("ldconfig /usr/local/")
- else:
- find_number = re.compile(r"([\d\.]{2,})")
- latest_cuda = np.argsort(
- [float(find_number.search(x).group(1)) for x in possible_cudas]
- )[::-1][0]
- latest_cuda = possible_cudas[latest_cuda]
- os.system(f"ldconfig /usr/local/{latest_cuda}")
- del find_number, latest_cuda
- del possible_cudas, find_cuda
+ warnings.warn(
+ "Unsloth: CUDA is not linked properly.\n"\
+ "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
+ "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
+ "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
+ "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
+ "Unsloth will still run for now, but maybe it might crash - let's hope it works!"
+ )
+pass
- if bnb is not None:
- importlib.reload(bnb)
- importlib.reload(triton)
- try:
- libcuda_dirs = lambda: None
- if Version(triton.__version__) >= Version("3.0.0"):
+# Check for unsloth_zoo
+try:
+ unsloth_zoo_version = importlib_version("unsloth_zoo")
+ if Version(unsloth_zoo_version) < Version("2025.3.11"):
+ print(
+ "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\
+ "To disable this, set os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'"
+ )
+ if os.environ.get("UNSLOTH_DISABLE_AUTO_UPDATES", "0") == "0":
+ try:
+ os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo")
+ except:
try:
- from triton.backends.nvidia.driver import libcuda_dirs
+ os.system("pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo")
except:
- pass
- else:
- from triton.common.build import libcuda_dirs
- cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
- libcuda_dirs()
- except:
- warnings.warn(
- "Unsloth: CUDA is not linked properly.\n"
- "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"
- "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"
- "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"
- "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"
- "Unsloth will still run for now, but maybe it might crash - let's hope it works!"
- )
- del libcuda_dirs
-elif DEVICE_TYPE == "hip":
- # NO-OP for rocm device
- pass
-elif DEVICE_TYPE == "xpu":
- import bitsandbytes as bnb
-
- # TODO: check triton for intel installed properly.
- pass
+ raise ImportError("Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`")
+ import unsloth_zoo
+except:
+ raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo`")
+pass
from .models import *
from .models import __version__
@@ -315,16 +223,5 @@ def is_bf16_supported():
from .tokenizer_utils import *
from .trainer import *
-# Export dataprep utilities for CLI and downstream users
-from .dataprep.raw_text import RawTextDataLoader, TextPreprocessor
-from unsloth_zoo.rl_environments import (
- check_python_modules,
- create_locked_down_function,
- execute_with_time_limit,
- Benchmarker,
- is_port_open,
- launch_openenv,
-)
-
# Patch TRL trainers for backwards compatibility
_patch_trl_trainer()
diff --git a/unsloth/_auto_install.py b/unsloth/_auto_install.py
index f6deefeb33..8bb5485192 100644
--- a/unsloth/_auto_install.py
+++ b/unsloth/_auto_install.py
@@ -15,12 +15,10 @@
try: import torch
except: raise ImportError('Install torch via `pip install torch`')
from packaging.version import Version as V
-import re
-v = V(re.match(r"[0-9\.]{3,}", torch.__version__).group(0))
+v = V(torch.__version__)
cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8
-USE_ABI = torch._C._GLIBCXX_USE_CXX11_ABI
-if cuda not in ("11.8", "12.1", "12.4", "12.6", "12.8", "13.0"): raise RuntimeError(f"CUDA = {cuda} not supported!")
+if cuda != "12.1" and cuda != "11.8" and cuda != "12.4" and cuda != "12.6": raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
@@ -30,14 +28,6 @@
elif v < V('2.5.1'): x = 'cu{}{}-torch250'
elif v <= V('2.5.1'): x = 'cu{}{}-torch251'
elif v < V('2.7.0'): x = 'cu{}{}-torch260'
-elif v < V('2.7.9'): x = 'cu{}{}-torch270'
-elif v < V('2.8.0'): x = 'cu{}{}-torch271'
-elif v < V('2.8.9'): x = 'cu{}{}-torch280'
-elif v < V('2.9.1'): x = 'cu{}{}-torch290'
-elif v < V('2.9.2'): x = 'cu{}{}-torch291'
-elif v < V('2.10.1'): x = 'cu{}{}-torch2100'
else: raise RuntimeError(f"Torch = {v} too new!")
-if v > V('2.6.9') and cuda not in ("11.8", "12.6", "12.8", "13.0"): raise RuntimeError(f"CUDA = {cuda} not supported!")
-if v >= V('2.10.0') and cuda not in ("12.6", "12.8", "13.0"): raise RuntimeError(f"Torch 2.10 requires CUDA 12.6, 12.8, or 13.0! Got CUDA = {cuda}")
-x = x.format(cuda.replace(".", ""), "-ampere" if False else "") # is_ampere is broken due to flash-attn
-print(f'pip install --upgrade pip && pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git" --no-build-isolation')
\ No newline at end of file
+x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
+print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
\ No newline at end of file
diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py
index 35eb871529..c10b2641a4 100644
--- a/unsloth/chat_templates.py
+++ b/unsloth/chat_templates.py
@@ -36,7 +36,6 @@
from .tokenizer_utils import *
from .models._utils import patch_tokenizer
import re
-from .ollama_template_mappers import OLLAMA_TEMPLATES
from unsloth_zoo.dataset_utils import (
train_on_responses_only,
standardize_data_formats,
@@ -44,8 +43,6 @@
standardize_sharegpt = standardize_data_formats
CHAT_TEMPLATES = {}
DEFAULT_SYSTEM_MESSAGE = {}
-def _ollama_template(name: str):
- return OLLAMA_TEMPLATES[name]
# =========================================== Unsloth
# Unsloth efficient template leverages from Zephyr
@@ -70,12 +67,25 @@ def _ollama_template(name: str):
"{% if add_generation_prompt %}"\
"{{ '>>> Assistant: ' }}"\
"{% endif %}"
+pass
-unsloth_ollama = _ollama_template("unsloth")
+unsloth_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{ if .System }}{{ .System }}
+{{ end }}{{ if .Prompt }}>>> User: {{ .Prompt }}
+{{ end }}>>> Assistant: {{ .Response }}{__EOS_TOKEN__}
+"""
+PARAMETER stop "{__EOS_TOKEN__}"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+SYSTEM """You are a helpful assistant to the user"""
+'''
unsloth_eos_token = "eos_token"
CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)
DEFAULT_SYSTEM_MESSAGE["unsloth"] = "You are a helpful assistant to the user"
+pass
# =========================================== Zephyr
# Zephyr has no BOS!
@@ -92,12 +102,27 @@ def _ollama_template(name: str):
"{% if add_generation_prompt %}"\
"{{ '<|assistant|>\n' }}"\
"{% endif %}"
+pass
-zephyr_ollama = _ollama_template("zephyr")
+zephyr_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{ if .System }}<|system|>
+{{ .System }}{__EOS_TOKEN__}
+{{ end }}{{ if .Prompt }}<|user|>
+{{ .Prompt }}{__EOS_TOKEN__}
+{{ end }}<|assistant|>
+{{ .Response }}{__EOS_TOKEN__}
+"""
+PARAMETER stop "{__EOS_TOKEN__}"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
zephyr_eos_token = "eos_token"
CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)
DEFAULT_SYSTEM_MESSAGE["zephyr"] = None # No system message in Zephyr
+pass
# =========================================== ChatML
# ChatML has no BOS and not EOS! Rather <|im_start|> and <|im_end|> acts as BOS / EOS.
@@ -114,12 +139,28 @@ def _ollama_template(name: str):
"{% if add_generation_prompt %}"\
"{{ '<|im_start|>assistant\n' }}"\
"{% endif %}"
+pass
-chatml_ollama = _ollama_template("chatml")
+chatml_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{ if .System }}<|im_start|>system
+{{ .System }}<|im_end|>
+{{ end }}{{ if .Prompt }}<|im_start|>user
+{{ .Prompt }}<|im_end|>
+{{ end }}<|im_start|>assistant
+{{ .Response }}<|im_end|>
+"""
+PARAMETER stop "<|im_start|>"
+PARAMETER stop "<|im_end|>"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
chatml_eos_token = "<|im_end|>"
CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["chatml"] = None # No system message in ChatML
+pass
# =========================================== Mistral-1
# Mistral Instruct doesn't allow system prompts, so we append it to the user message.
@@ -145,13 +186,22 @@ def _ollama_template(name: str):
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"
+pass
# Ollama from https://www.ollama.com/library/mistral
-mistral_ollama = _ollama_template("mistral")
+mistral_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]"""
+PARAMETER stop "{__EOS_TOKEN__}"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
mistral_eos_token = "eos_token"
CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)
DEFAULT_SYSTEM_MESSAGE["mistral"] = None # No system message in Mistral
+pass
# =========================================== Llama-2
# Adds BOS to every convo! And weird <> system messages.
@@ -176,13 +226,24 @@ def _ollama_template(name: str):
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
"{% endif %}"\
"{% endfor %}"
+pass
# Ollama from https://www.ollama.com/library/llama3
-llama_ollama = _ollama_template("llama")
+llama_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """[INST] <>{{ .System }}<>
+
+{{ .Prompt }} [/INST]"""
+PARAMETER stop "{__EOS_TOKEN__}"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
llama_eos_token = "eos_token"
CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama"] = None # No system message in Llama
+pass
# =========================================== Vicuna
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
@@ -207,13 +268,22 @@ def _ollama_template(name: str):
"{% if add_generation_prompt %}"\
"{{ 'ASSISTANT:' }}"\
"{% endif %}"
+pass
# Ollama from https://www.ollama.com/library/vicuna
-vicuna_ollama = _ollama_template("vicuna")
+vicuna_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} {__EOS_TOKEN__}"""
+PARAMETER stop "{__EOS_TOKEN__}"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
vicuna_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)
DEFAULT_SYSTEM_MESSAGE["vicuna"] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
+pass
# =========================================== Vicuna Old
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
@@ -238,8 +308,20 @@ def _ollama_template(name: str):
"{% if add_generation_prompt %}"\
"{{ '### Assistant:' }}"\
"{% endif %}"
+pass
-vicuna_old_ollama = _ollama_template("vicuna_old")
+vicuna_old_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{ if .System }}{{ .System }}
+{{ end }}{{ if .Prompt }}### Human: {{ .Prompt }}
+{{ end }}### Assistant: {{ .Response }}{__EOS_TOKEN__}
+"""
+PARAMETER stop "{__EOS_TOKEN__}"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+SYSTEM """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."""
+'''
vicuna_old_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)
@@ -247,6 +329,7 @@ def _ollama_template(name: str):
CHAT_TEMPLATES["vicuna old"] = CHAT_TEMPLATES["vicuna_old"]
DEFAULT_SYSTEM_MESSAGE["vicuna old"] = DEFAULT_SYSTEM_MESSAGE["vicuna_old"]
+pass
# =========================================== Alpaca multi turn
# https://github.com/tatsu-lab/stanford_alpaca Changed for multi-turn convos
@@ -271,12 +354,30 @@ def _ollama_template(name: str):
"{% if add_generation_prompt %}"\
"{{ '### Response:\n' }}"\
"{% endif %}"
+pass
+
+alpaca_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{ if .System }}{{ .System }}
-alpaca_ollama = _ollama_template("alpaca")
+{{ end }}{{ if .Prompt }}### Instruction:
+{{ .Prompt }}{{ end }}
+
+### Response:
+{{ .Response }}{__EOS_TOKEN__}
+
+"""
+PARAMETER stop "{__EOS_TOKEN__}"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+SYSTEM """Below are some instructions that describe some tasks. Write responses that appropriately complete each request."""
+'''
alpaca_eos_token = "eos_token"
CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)
DEFAULT_SYSTEM_MESSAGE["alpaca"] = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
+pass
# =========================================== Gemma
# https://huggingface.co/google/gemma-7b-it
@@ -300,19 +401,52 @@ def _ollama_template(name: str):
"{% if add_generation_prompt %}"\
"{{ 'model\n' }}"\
"{% endif %}"
+pass
# Ollama from https://www.ollama.com/library/gemma
-gemma_ollama = _ollama_template("gemma")
+gemma_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """user
+{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}
+model
+{{ .Response }}
+"""
+PARAMETER repeat_penalty 1
+PARAMETER stop ""
+PARAMETER stop ""
+PARAMETER penalize_newline false
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
gemma_eos_token = ""
CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma"] = None # No system message in Gemma
+pass
# =========================================== Gemma with ChatML instead
# We find using is still more appropriate!
gemma_chatml_template = "{{ bos_token }}" + chatml_template
+pass
-gemma_chatml_ollama = _ollama_template("gemma_chatml")
+gemma_chatml_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{ if .System }}<|im_start|>system
+{{ .System }}<|im_end|>
+{{ end }}{{ if .Prompt }}<|im_start|>user
+{{ .Prompt }}<|im_end|>
+{{ end }}<|im_start|>assistant
+{{ .Response }}<|im_end|>
+"""
+PARAMETER repeat_penalty 1
+PARAMETER stop "<|im_start|>"
+PARAMETER stop "<|im_end|>"
+PARAMETER penalize_newline false
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
gemma_chatml_eos_token = (
{"" : "<|im_start|>", "" : "<|im_end|>"},
@@ -320,22 +454,24 @@ def _ollama_template(name: str):
)
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma_chatml"] = None # No system message in Gemma
+pass
# =========================================== Gemma 2
# Same as Gemma 1, but with sliding window attention!
# https://ollama.com/library/gemma2/blobs/6522ca797f47
gemma2_template = gemma_template
-gemma2_ollama = _ollama_template("gemma2")
+gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
gemma2_eos_token = ""
CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma2"] = None # No system message in Gemma 2
# =========================================== Gemma 2 with ChatML instead
gemma2_chatml_template = gemma_chatml_template
-gemma2_chatml_ollama = _ollama_template("gemma2_chatml")
+gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
gemma2_chatml_eos_token = gemma_chatml_eos_token
CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma2_chatml"] = None # No system message in Gemma 2
+pass
# =========================================== Llama-3
# Weirdly \n\n is needed?
@@ -353,9 +489,25 @@ def _ollama_template(name: str):
"{% if add_generation_prompt %}"\
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"\
"{% endif %}"
+pass
# Ollama from https://www.ollama.com/library/llama3
-llama3_ollama = _ollama_template("llama-3")
+llama3_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
+
+{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
+
+{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
+
+{{ .Response }}<|eot_id|>"""
+PARAMETER stop "<|start_header_id|>"
+PARAMETER stop "<|end_header_id|>"
+PARAMETER stop "<|eot_id|>"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
llama3_template_eos_token = "eos_token"
@@ -364,6 +516,7 @@ def _ollama_template(name: str):
CHAT_TEMPLATES["llama3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama3"] = None # No system message in Llama-3
+pass
# =========================================== Phi-3
@@ -381,9 +534,25 @@ def _ollama_template(name: str):
"{% if add_generation_prompt %}"\
"{{ '<|assistant|>\n' }}"\
"{% endif %}"
+pass
# Ollama from https://www.ollama.com/library/phi3
-phi3_ollama = _ollama_template("phi-3")
+phi3_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{ if .System }}<|system|>
+{{ .System }}<|end|>
+{{ end }}{{ if .Prompt }}<|user|>
+{{ .Prompt }}<|end|>
+{{ end }}<|assistant|>
+{{ .Response }}<|end|>
+"""
+PARAMETER stop "<|end|>"
+PARAMETER stop "<|user|>"
+PARAMETER stop "<|assistant|>"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
phi3_template_eos_token = "<|end|>"
CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
@@ -394,6 +563,7 @@ def _ollama_template(name: str):
CHAT_TEMPLATES["phi-3.5"] = CHAT_TEMPLATES["phi-3"]
DEFAULT_SYSTEM_MESSAGE["phi-3.5"] = None # No system message in Phi-3.5
+pass
# =========================================== Llama-3.1
"""
@@ -523,9 +693,68 @@ def _ollama_template(name: str):
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
"""
+pass
# Ollama from https://ollama.com/library/llama3.1 (needs updating!)
-llama31_ollama = _ollama_template("llama-3.1")
+llama31_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{ if .Messages }}
+{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
+{{- if .System }}
+
+{{ .System }}
+{{- end }}
+{{- if .Tools }}
+
+You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original use question.
+{{- end }}
+{{- end }}<|eot_id|>
+{{- range $i, $_ := .Messages }}
+{{- $last := eq (len (slice $.Messages $i)) 1 }}
+{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
+{{- if and $.Tools $last }}
+
+Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
+
+Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
+
+{{ $.Tools }}
+{{- end }}
+
+{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
+
+{{ end }}
+{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
+{{- if .ToolCalls }}
+
+{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
+{{- else }}
+
+{{ .Content }}{{ if not $last }}<|eot_id|>{{ end }}
+{{- end }}
+{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
+
+{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
+
+{{ end }}
+{{- end }}
+{{- end }}
+{{- else }}
+{{- if .System }}<|start_header_id|>system<|end_header_id|>
+
+{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
+
+{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
+
+{{ end }}{{ .Response }}{{ if .Response }}<|eot_id|>{{ end }}"""
+PARAMETER stop "<|start_header_id|>"
+PARAMETER stop "<|end_header_id|>"
+PARAMETER stop "<|eot_id|>"
+PARAMETER stop "<|eom_id|>"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
llama31_template_eos_token = "eos_token"
CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
@@ -537,6 +766,7 @@ def _ollama_template(name: str):
for version in ("llama-3.2", "llama-3.3", "llama-32", "llama-33"):
CHAT_TEMPLATES[version] = CHAT_TEMPLATES["llama-3.1"]
DEFAULT_SYSTEM_MESSAGE[version] = ""
+pass
# =========================================== Qwen 2.5
@@ -593,10 +823,67 @@ def _ollama_template(name: str):
# Ollama from https://ollama.com/library/qwen2.5/blobs/eb4402837c78
-qwen25_ollama = _ollama_template("qwen-2.5")
+qwen25_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{- if .Messages }}
+{{- if or .System .Tools }}<|im_start|>system
+{{- if .System }}
+{{ .System }}
+{{- end }}
+{{- if .Tools }}
+
+# Tools
+
+You may call one or more functions to assist with the user query.
+
+You are provided with function signatures within XML tags:
+
+{{- range .Tools }}
+{"type": "function", "function": {{ .Function }}}
+{{- end }}
+
+
+For each function call, return a json object with function name and arguments within XML tags:
+
+{"name": , "arguments": }
+
+{{- end }}<|im_end|>
+{{ end }}
+{{- range $i, $_ := .Messages }}
+{{- $last := eq (len (slice $.Messages $i)) 1 -}}
+{{- if eq .Role "user" }}<|im_start|>user
+{{ .Content }}<|im_end|>
+{{ else if eq .Role "assistant" }}<|im_start|>assistant
+{{ if .Content }}{{ .Content }}
+{{- else if .ToolCalls }}
+{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
+{{ end }}
+{{- end }}{{ if not $last }}<|im_end|>
+{{ end }}
+{{- else if eq .Role "tool" }}<|im_start|>user
+
+{{ .Content }}
+<|im_end|>
+{{ end }}
+{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
+{{ end }}
+{{- end }}
+{{- else }}
+{{- if .System }}<|im_start|>system
+{{ .System }}<|im_end|>
+{{ end }}{{ if .Prompt }}<|im_start|>user
+{{ .Prompt }}<|im_end|>
+{{ end }}<|im_start|>assistant
+{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}"""
+PARAMETER stop "<|im_end|>"
+PARAMETER stop "<|endoftext|>"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
qwen25_template_eos_token = "eos_token"
-qwen25_default_system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
+qwen25_default_system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
CHAT_TEMPLATES["qwen-2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen-2.5"] = qwen25_default_system_message # No system message in Qwen 2.5
@@ -608,6 +895,7 @@ def _ollama_template(name: str):
CHAT_TEMPLATES["qwen2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = qwen25_default_system_message # No system message in Qwen 2.5
+pass
# =========================================== Phi-4
# "{{ bos_token }}"\ # Phi-4 removes BOS?
@@ -624,6 +912,7 @@ def _ollama_template(name: str):
"{% if add_generation_prompt %}"\
"{{ '<|im_start|>assistant<|im_sep|>' }}"\
"{% endif %}"
+pass
_phi4_ollama_template = \
"{{ if .System }}<|im_start|><|system|><|im_sep|>{{ .System }}<|im_end|>{{ end }}"\
@@ -631,11 +920,21 @@ def _ollama_template(name: str):
"<|im_start|><|assistant|><|im_sep|>{{ .Response }}<|im_end|>"
# Ollama from https://www.ollama.com/library/phi4 is different
-phi4_ollama = _ollama_template("phi-4")
+phi4_ollama = \
+f'''
+FROM {{__FILE_LOCATION__}}
+TEMPLATE """{_phi4_ollama_template}"""
+PARAMETER stop "<|im_end|>"
+PARAMETER stop "<|im_start|>"
+PARAMETER stop "<|im_sep|>"
+PARAMETER temperature 1.5
+PARAMETER min_p 0.1
+'''
phi4_template_eos_token = "<|im_end|>"
CHAT_TEMPLATES["phi-4"] = (phi4_template, phi4_template_eos_token, False, phi4_ollama,)
DEFAULT_SYSTEM_MESSAGE["phi-4"] = None # No system message in Phi-4
+pass
# =========================================== Gemma-3
@@ -685,7 +984,28 @@ def _ollama_template(name: str):
"""
# Ollama from https://ollama.com/library/gemma3/blobs/e0a42594d802
-gemma3_ollama = _ollama_template("gemma-3")
+gemma3_ollama = \
+'''
+FROM {__FILE_LOCATION__}
+TEMPLATE """{{- range $i, $_ := .Messages }}
+{{- $last := eq (len (slice $.Messages $i)) 1 }}
+{{- if or (eq .Role "user") (eq .Role "system") }}user
+{{ .Content }}
+{{ if $last }}model
+{{ end }}
+{{- else if eq .Role "assistant" }}model
+{{ .Content }}{{ if not $last }}
+{{ end }}
+{{- end }}
+{{- end }}"""
+PARAMETER stop ""
+PARAMETER stop ""
+PARAMETER temperature 0.1
+PARAMETER min_p 0.0
+PARAMETER top_k 64
+PARAMETER top_p 0.95
+PARAMETER num_predict 32768
+'''
gemma3_template_eos_token = ""
CHAT_TEMPLATES["gemma-3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)
@@ -693,962 +1013,11 @@ def _ollama_template(name: str):
CHAT_TEMPLATES["gemma3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma3"] = None # No system message in Gemma-3
-
-# =========================================== Qwen-3
-# Official Qwen-3 chat template (see https://ollama.com/library/qwen3/blobs/eb4402837c78)
-qwen3_template = \
-"""
-{%- if tools %}
- {{- '<|im_start|>system\n' }}
- {%- if messages[0].role == 'system' %}
- {{- messages[0].content + '\n\n' }}
- {%- endif %}
- {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }}
- {%- for tool in tools %}
- {{- "\n" }}
- {{- tool | tojson }}
- {%- endfor %}
- {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\\"name\\": , \\"arguments\\": }\n<|im_end|>\n" }}
-{%- else %}
- {%- if messages[0].role == 'system' %}
- {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
- {%- endif %}
-{%- endif %}
-{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
-{%- for forward_message in messages %}
- {%- set index = (messages|length - 1) - loop.index0 %}
- {%- set message = messages[index] %}
- {%- set current_content = message.content if message.content is not none else '' %}
- {%- set tool_start = '' %}
- {%- set tool_start_length = tool_start|length %}
- {%- set start_of_message = current_content[:tool_start_length] %}
- {%- set tool_end = '' %}
- {%- set tool_end_length = tool_end|length %}
- {%- set start_pos = (current_content|length) - tool_end_length %}
- {%- if start_pos < 0 %}
- {%- set start_pos = 0 %}
- {%- endif %}
- {%- set end_of_message = current_content[start_pos:] %}
- {%- if ns.multi_step_tool and message.role == "user" and not(start_of_message == tool_start and end_of_message == tool_end) %}
- {%- set ns.multi_step_tool = false %}
- {%- set ns.last_query_index = index %}
- {%- endif %}
-{%- endfor %}
-{%- for message in messages %}
- {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
- {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
- {%- elif message.role == "assistant" %}
- {%- set content = message.content %}
- {%- set reasoning_content = '' %}
- {%- if message.reasoning_content is defined and message.reasoning_content is not none %}
- {%- set reasoning_content = message.reasoning_content %}
- {%- else %}
- {%- if '' in message.content %}
- {%- set content = (message.content.split('')|last).lstrip('\n') %}
- {%- set reasoning_content = (message.content.split('')|first).rstrip('\n') %}
- {%- set reasoning_content = (reasoning_content.split('')|last).lstrip('\n') %}
- {%- endif %}
- {%- endif %}
- {%- if loop.index0 > ns.last_query_index %}
- {%- if loop.last or (not loop.last and reasoning_content) %}
- {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }}
- {%- else %}
- {{- '<|im_start|>' + message.role + '\n' + content }}
- {%- endif %}
- {%- else %}
- {{- '<|im_start|>' + message.role + '\n' + content }}
- {%- endif %}
- {%- if message.tool_calls %}
- {%- for tool_call in message.tool_calls %}
- {%- if (loop.first and content) or (not loop.first) %}
- {{- '\n' }}
- {%- endif %}
- {%- if tool_call.function %}
- {%- set tool_call = tool_call.function %}
- {%- endif %}
- {{- '\n{"name": "' }}
- {{- tool_call.name }}
- {{- '", "arguments": ' }}
- {%- if tool_call.arguments is string %}
- {{- tool_call.arguments }}
- {%- else %}
- {{- tool_call.arguments | tojson }}
- {%- endif %}
- {{- '}\n' }}
- {%- endfor %}
- {%- endif %}
- {{- '<|im_end|>\n' }}
- {%- elif message.role == "tool" %}
- {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
- {{- '<|im_start|>user' }}
- {%- endif %}
- {{- '\n\n' }}
- {{- message.content }}
- {{- '\n' }}
- {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
- {{- '<|im_end|>\n' }}
- {%- endif %}
- {%- endif %}
-{%- endfor %}
-{%- if add_generation_prompt %}
- {{- '<|im_start|>assistant\n' }}
- {%- if enable_thinking is defined and enable_thinking is false %}
- {{- '\n\n\n\n' }}
- {%- endif %}
-{%- endif %}
-"""
-
-qwen3_ollama = _ollama_template("qwen-3")
-qwen3_template_eos_token = "<|im_end|>"
-CHAT_TEMPLATES["qwen-3"] = (qwen3_template, qwen3_template_eos_token, False, qwen3_ollama,)
-DEFAULT_SYSTEM_MESSAGE["qwen-3"] = None # No default system message for Qwen-3
-
-CHAT_TEMPLATES["qwen3"] = (qwen3_template, qwen3_template_eos_token, False, qwen3_ollama,)
-DEFAULT_SYSTEM_MESSAGE["qwen3"] = None # No default system message for Qwen-3
-
-# =========================================== Gemma-3n
-# Obtained via
-# print(tokenizer.chat_template.replace("}\n", "####").replace("\n", "\\n").replace("####", "}\n"))
-gemma3n_template = \
-"""{{ bos_token }}
-{%- if messages[0]['role'] == 'system' -%}
- {%- if messages[0]['content'] is string -%}
- {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
- {%- else -%}
- {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}
- {%- endif -%}
- {%- set loop_messages = messages[1:] -%}
-{%- else -%}
- {%- set first_user_prefix = "" -%}
- {%- set loop_messages = messages -%}
-{%- endif -%}
-{%- for message in loop_messages -%}
- {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
- {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
- {%- endif -%}
- {%- if (message['role'] == 'assistant') -%}
- {%- set role = "model" -%}
- {%- else -%}
- {%- set role = message['role'] -%}
- {%- endif -%}
- {{ '' + role + '\n' + (first_user_prefix if loop.first else "") }}
- {%- if message['content'] is string -%}
- {{ message['content'] | trim }}
- {%- elif message['content'] is iterable -%}
- {%- for item in message['content'] -%}
- {%- if item['type'] == 'audio' -%}
- {{ '' }}
- {%- elif item['type'] == 'image' -%}
- {{ '' }}
- {%- elif item['type'] == 'text' -%}
- {{ item['text'] | trim }}
- {%- endif -%}
- {%- endfor -%}
- {%- else -%}
- {{ raise_exception("Invalid content type") }}
- {%- endif -%}
- {{ '\n' }}
-{%- endfor -%}
-{%- if add_generation_prompt -%}
- {{'model\n'}}
-{%- endif -%}
-"""
-
-# Ollama from https://ollama.com/library/gemma3n/blobs/e0a42594d802
-gemma3n_ollama = _ollama_template("gemma-3n")
-gemma3n_template_eos_token = ""
-CHAT_TEMPLATES["gemma-3n"] = (gemma3n_template, gemma3n_template_eos_token, False, gemma3n_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gemma-3n"] = None # No system message in Gemma-3n
-
-CHAT_TEMPLATES["gemma3n"] = (gemma3n_template, gemma3n_template_eos_token, False, gemma3n_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gemma3n"] = None # No system message in Gemma-3n
-
-# =========================================== GPT-OSS
-# Obtained via
-# print(tokenizer.chat_template.replace("}\n", "####").replace("\n", "\\n").replace("####", "}\n"))
-gptoss_template = \
-"""{#-
- In addition to the normal inputs of `messages` and `tools`, this template also accepts the
- following kwargs:
- - "builtin_tools": A list, can contain "browser" and/or "python".
- - "model_identity": A string that optionally describes the model identity.
- - "reasoning_effort": A string that describes the reasoning effort, defaults to "medium".
- #}
-
-{#- Tool Definition Rendering ============================================== #}
-{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%}
- {%- if param_spec.type == "array" -%}
- {%- if param_spec['items'] -%}
- {%- if param_spec['items']['type'] == "string" -%}
- {{- "string[]" }}
- {%- elif param_spec['items']['type'] == "number" -%}
- {{- "number[]" }}
- {%- elif param_spec['items']['type'] == "integer" -%}
- {{- "number[]" }}
- {%- elif param_spec['items']['type'] == "boolean" -%}
- {{- "boolean[]" }}
- {%- else -%}
- {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%}
- {%- if inner_type == "object | object" or inner_type|length > 50 -%}
- {{- "any[]" }}
- {%- else -%}
- {{- inner_type + "[]" }}
- {%- endif -%}
- {%- endif -%}
- {%- if param_spec.nullable -%}
- {{- " | null" }}
- {%- endif -%}
- {%- else -%}
- {{- "any[]" }}
- {%- if param_spec.nullable -%}
- {{- " | null" }}
- {%- endif -%}
- {%- endif -%}
- {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%}
- {#- Handle array of types like ["object", "object"] from Union[dict, list] #}
- {%- if param_spec.type | length > 1 -%}
- {{- param_spec.type | join(" | ") }}
- {%- else -%}
- {{- param_spec.type[0] }}
- {%- endif -%}
- {%- elif param_spec.oneOf -%}
- {#- Handle oneOf schemas - check for complex unions and fallback to any #}
- {%- set has_object_variants = false -%}
- {%- for variant in param_spec.oneOf -%}
- {%- if variant.type == "object" -%}
- {%- set has_object_variants = true -%}
- {%- endif -%}
- {%- endfor -%}
- {%- if has_object_variants and param_spec.oneOf|length > 1 -%}
- {{- "any" }}
- {%- else -%}
- {%- for variant in param_spec.oneOf -%}
- {{- render_typescript_type(variant, required_params) -}}
- {%- if variant.description %}
- {{- "// " + variant.description }}
- {%- endif -%}
- {%- if variant.default is defined %}
- {{ "// default: " + variant.default|tojson }}
- {%- endif -%}
- {%- if not loop.last %}
- {{- " | " }}
- {% endif -%}
- {%- endfor -%}
- {%- endif -%}
- {%- elif param_spec.type == "string" -%}
- {%- if param_spec.enum -%}
- {{- '"' + param_spec.enum|join('" | "') + '"' -}}
- {%- else -%}
- {{- "string" }}
- {%- if param_spec.nullable %}
- {{- " | null" }}
- {%- endif -%}
- {%- endif -%}
- {%- elif param_spec.type == "number" -%}
- {{- "number" }}
- {%- elif param_spec.type == "integer" -%}
- {{- "number" }}
- {%- elif param_spec.type == "boolean" -%}
- {{- "boolean" }}
-
- {%- elif param_spec.type == "object" -%}
- {%- if param_spec.properties -%}
- {{- "{\n" }}
- {%- for prop_name, prop_spec in param_spec.properties.items() -%}
- {{- prop_name -}}
- {%- if prop_name not in (param_spec.required or []) -%}
- {{- "?" }}
- {%- endif -%}
- {{- ": " }}
- {{ render_typescript_type(prop_spec, param_spec.required or []) }}
- {%- if not loop.last -%}
- {{-", " }}
- {%- endif -%}
- {%- endfor -%}
- {{- "}" }}
- {%- else -%}
- {{- "object" }}
- {%- endif -%}
- {%- else -%}
- {{- "any" }}
- {%- endif -%}
-{%- endmacro -%}
-
-{%- macro render_tool_namespace(namespace_name, tools) -%}
- {{- "## " + namespace_name + "\n\n" }}
- {{- "namespace " + namespace_name + " {\n\n" }}
- {%- for tool in tools %}
- {%- set tool = tool.function %}
- {{- "// " + tool.description + "\n" }}
- {{- "type "+ tool.name + " = " }}
- {%- if tool.parameters and tool.parameters.properties %}
- {{- "(_: {\n" }}
- {%- for param_name, param_spec in tool.parameters.properties.items() %}
- {%- if param_spec.description %}
- {{- "// " + param_spec.description + "\n" }}
- {%- endif %}
- {{- param_name }}
- {%- if param_name not in (tool.parameters.required or []) -%}
- {{- "?" }}
- {%- endif -%}
- {{- ": " }}
- {{- render_typescript_type(param_spec, tool.parameters.required or []) }}
- {%- if param_spec.default is defined -%}
- {%- if param_spec.enum %}
- {{- ", // default: " + param_spec.default }}
- {%- elif param_spec.oneOf %}
- {{- "// default: " + param_spec.default }}
- {%- else %}
- {{- ", // default: " + param_spec.default|tojson }}
- {%- endif -%}
- {%- endif -%}
- {%- if not loop.last %}
- {{- ",\n" }}
- {%- else %}
- {{- ",\n" }}
- {%- endif -%}
- {%- endfor %}
- {{- "}) => any;\n\n" }}
- {%- else -%}
- {{- "() => any;\n\n" }}
- {%- endif -%}
- {%- endfor %}
- {{- "} // namespace " + namespace_name }}
-{%- endmacro -%}
-
-{%- macro render_builtin_tools(browser_tool, python_tool) -%}
- {%- if browser_tool %}
- {{- "## browser\n\n" }}
- {{- "// Tool for browsing.\n" }}
- {{- "// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.\n" }}
- {{- "// Cite information from the tool using the following format:\n" }}
- {{- "// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\n" }}
- {{- "// Do not quote more than 10 words directly from the tool output.\n" }}
- {{- "// sources=web (default: web)\n" }}
- {{- "namespace browser {\n\n" }}
- {{- "// Searches for information related to `query` and displays `topn` results.\n" }}
- {{- "type search = (_: {\n" }}
- {{- "query: string,\n" }}
- {{- "topn?: number, // default: 10\n" }}
- {{- "source?: string,\n" }}
- {{- "}) => any;\n\n" }}
- {{- "// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\n" }}
- {{- "// Valid link ids are displayed with the formatting: `【{id}†.*】`.\n" }}
- {{- "// If `cursor` is not provided, the most recent page is implied.\n" }}
- {{- "// If `id` is a string, it is treated as a fully qualified URL associated with `source`.\n" }}
- {{- "// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\n" }}
- {{- "// Use this function without `id` to scroll to a new location of an opened page.\n" }}
- {{- "type open = (_: {\n" }}
- {{- "id?: number | string, // default: -1\n" }}
- {{- "cursor?: number, // default: -1\n" }}
- {{- "loc?: number, // default: -1\n" }}
- {{- "num_lines?: number, // default: -1\n" }}
- {{- "view_source?: boolean, // default: false\n" }}
- {{- "source?: string,\n" }}
- {{- "}) => any;\n\n" }}
- {{- "// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\n" }}
- {{- "type find = (_: {\n" }}
- {{- "pattern: string,\n" }}
- {{- "cursor?: number, // default: -1\n" }}
- {{- "}) => any;\n\n" }}
- {{- "} // namespace browser\n\n" }}
- {%- endif -%}
-
- {%- if python_tool %}
- {{- "## python\n\n" }}
- {{- "Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\n\n" }}
- {{- "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.\n\n" }}
- {%- endif -%}
-{%- endmacro -%}
-
-{#- System Message Construction ============================================ #}
-{%- macro build_system_message() -%}
- {%- if model_identity is not defined %}
- {%- set model_identity = "You are ChatGPT, a large language model trained by OpenAI." %}
- {%- endif %}
- {{- model_identity + "\n" }}
- {{- "Knowledge cutoff: 2024-06\n" }}
- {{- "Current date: " + strftime_now("%Y-%m-%d") + "\n\n" }}
- {%- if reasoning_effort is not defined %}
- {%- set reasoning_effort = "medium" %}
- {%- endif %}
- {{- "Reasoning: " + reasoning_effort + "\n\n" }}
- {%- if builtin_tools is defined and builtin_tools is not none %}
- {{- "# Tools\n\n" }}
- {%- set available_builtin_tools = namespace(browser=false, python=false) %}
- {%- for tool in builtin_tools %}
- {%- if tool == "browser" %}
- {%- set available_builtin_tools.browser = true %}
- {%- elif tool == "python" %}
- {%- set available_builtin_tools.python = true %}
- {%- endif %}
- {%- endfor %}
- {{- render_builtin_tools(available_builtin_tools.browser, available_builtin_tools.python) }}
- {%- endif -%}
- {{- "# Valid channels: analysis, commentary, final. Channel must be included for every message." }}
- {%- if tools -%}
- {{- "\nCalls to these tools must go to the commentary channel: 'functions'." }}
- {%- endif -%}
-{%- endmacro -%}
-
-{#- Main Template Logic ================================================= #}
-{#- Set defaults #}
-
-{#- Render system message #}
-{{- "<|start|>system<|message|>" }}
-{{- build_system_message() }}
-{{- "<|end|>" }}
-
-{#- Extract developer message #}
-{%- if developer_instructions is defined and developer_instructions is not none %}
- {%- set developer_message = developer_instructions %}
- {%- set loop_messages = messages %}
-{%- elif messages[0].role == "developer" or messages[0].role == "system" %}
- {%- set developer_message = messages[0].content %}
- {%- set loop_messages = messages[1:] %}
-{%- else %}
- {%- set developer_message = "" %}
- {%- set loop_messages = messages %}
-{%- endif %}
-
-{#- Render developer message #}
-{%- if developer_message or tools %}
- {{- "<|start|>developer<|message|>" }}
- {%- if developer_message %}
- {{- "# Instructions\n\n" }}
- {{- developer_message }}
- {%- endif %}
- {%- if tools -%}
- {%- if developer_message %}
- {{- "\n\n" }}
- {%- endif %}
- {{- "# Tools\n\n" }}
- {{- render_tool_namespace("functions", tools) }}
- {%- endif -%}
- {{- "<|end|>" }}
-{%- endif %}
-
-{#- Render messages #}
-{%- set last_tool_call = namespace(name=none) %}
-{%- for message in loop_messages -%}
- {#- At this point only assistant/user/tool messages should remain #}
- {%- if message.role == 'assistant' -%}
- {#- Checks to ensure the messages are being passed in the format we expect #}
- {%- if "content" in message %}
- {%- if "<|channel|>analysis<|message|>" in message.content or "<|channel|>final<|message|>" in message.content %}
- {{- raise_exception("You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }}
- {%- endif %}
- {%- endif %}
- {%- if "thinking" in message %}
- {%- if "<|channel|>analysis<|message|>" in message.thinking or "<|channel|>final<|message|>" in message.thinking %}
- {{- raise_exception("You have passed a message containing <|channel|> tags in the thinking field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.") }}
- {%- endif %}
- {%- endif %}
- {%- if "tool_calls" in message %}
- {#- We need very careful handling here - we want to drop the tool call analysis message if the model #}
- {#- has output a later <|final|> message, but otherwise we want to retain it. This is the only case #}
- {#- when we render CoT/analysis messages in inference. #}
- {%- set future_final_message = namespace(found=false) %}
- {%- for future_message in loop_messages[loop.index:] %}
- {%- if future_message.role == 'assistant' and "tool_calls" not in future_message %}
- {%- set future_final_message.found = true %}
- {%- endif %}
- {%- endfor %}
- {#- We assume max 1 tool call per message, and so we infer the tool call name #}
- {#- in "tool" messages from the most recent assistant tool call name #}
- {%- set tool_call = message.tool_calls[0] %}
- {%- if tool_call.function %}
- {%- set tool_call = tool_call.function %}
- {%- endif %}
- {%- if message.content and message.thinking %}
- {{- raise_exception("Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.") }}
- {%- elif message.content and not future_final_message.found %}
- {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
- {%- elif message.thinking and not future_final_message.found %}
- {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
- {%- endif %}
- {{- "<|start|>assistant to=" }}
- {{- "functions." + tool_call.name + "<|channel|>commentary " }}
- {{- (tool_call.content_type if tool_call.content_type is defined else "json") + "<|message|>" }}
- {%- if tool_call.arguments is string %}
- {{- tool_call.arguments }}
- {%- else %}
- {{- tool_call.arguments|tojson }}
- {%- endif %}
- {{- "<|call|>" }}
- {%- set last_tool_call.name = tool_call.name %}
- {%- elif loop.last and not add_generation_prompt %}
- {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}
- {#- This is a situation that should only occur in training, never in inference. #}
- {%- if "thinking" in message %}
- {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.thinking + "<|end|>" }}
- {%- endif %}
- {#- <|return|> indicates the end of generation, but <|end|> does not #}
- {#- <|return|> should never be an input to the model, but we include it as the final token #}
- {#- when training, so the model learns to emit it. #}
- {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
- {%- elif "thinking" in message %}
- {#- CoT is dropped during all previous turns, so we never render it for inference #}
- {{- "<|start|>assistant<|channel|>analysis<|message|>" + message.content + "<|end|>" }}
- {%- set last_tool_call.name = none %}
- {%- else %}
- {#- CoT is dropped during all previous turns, so we never render it for inference #}
- {{- "<|start|>assistant<|channel|>final<|message|>" + message.content + "<|end|>" }}
- {%- set last_tool_call.name = none %}
- {%- endif %}
- {%- elif message.role == 'tool' -%}
- {%- if last_tool_call.name is none %}
- {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
- {%- endif %}
- {{- "<|start|>functions." + last_tool_call.name }}
- {%- if message.content is string %}
- {{- " to=assistant<|channel|>commentary<|message|>" + message.content + "<|end|>" }}
- {%- else %}
- {{- " to=assistant<|channel|>commentary<|message|>" + message.content|tojson + "<|end|>" }}
- {%- endif %}
- {%- elif message.role == 'user' -%}
- {{- "<|start|>user<|message|>" + message.content + "<|end|>" }}
- {%- endif -%}
-{%- endfor -%}
-
-{#- Generation prompt #}
-{%- if add_generation_prompt -%}
-<|start|>assistant
-{%- endif -%}"""
-
-# Ollama from https://ollama.com/library/gpt-oss
-gptoss_ollama = \
-'''
-FROM {__FILE_LOCATION__}
-TEMPLATE """<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
-Knowledge cutoff: 2024-06
-Current date: {{ currentDate }}
-{{- if and .IsThinkSet .Think (ne .ThinkLevel "") }}
-
-Reasoning: {{ .ThinkLevel }}
-{{- else if or (not .IsThinkSet) (and .IsThinkSet .Think) }}
-
-Reasoning: medium
-{{- end }}
-
-{{- $hasNonBuiltinTools := false }}
-{{- if .Tools -}}
-{{- $hasBrowserSearch := false }}
-{{- $hasBrowserOpen := false }}
-{{- $hasBrowserFind := false }}
-{{- $hasPython := false }}
- {{- range .Tools }}
- {{- if eq .Function.Name "browser.search" -}}{{- $hasBrowserSearch = true -}}
- {{- else if eq .Function.Name "browser.open" -}}{{- $hasBrowserOpen = true -}}
- {{- else if eq .Function.Name "browser.find" -}}{{- $hasBrowserFind = true -}}
- {{- else if eq .Function.Name "python" -}}{{- $hasPython = true -}}
- {{- else }}{{ $hasNonBuiltinTools = true -}}
- {{- end }}
- {{- end }}
-{{- if or $hasBrowserSearch $hasBrowserOpen $hasBrowserFind $hasPython }}
-
-# Tools
-{{- if or $hasBrowserSearch $hasBrowserOpen $hasBrowserFind }}
-
-## browser
-
-// Tool for browsing.
-// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.
-// Cite information from the tool using the following format:
-// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.
-// Do not quote more than 10 words directly from the tool output.
-// sources=web (default: web)
-namespace browser {
-{{- if $hasBrowserSearch }}
-
-// Searches for information related to `query` and displays `topn` results.
-type search = (_: {
-query: string,
-topn?: number, // default: 10
-source?: string,
-}) => any;
-{{- end }}
-{{- if $hasBrowserOpen }}
-
-// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.
-// Valid link ids are displayed with the formatting: `【{id}†.*】`.
-// If `cursor` is not provided, the most recent page is implied.
-// If `id` is a string, it is treated as a fully qualified URL associated with `source`.
-// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.
-// Use this function without `id` to scroll to a new location of an opened page.
-type open = (_: {
-id?: number | string, // default: -1
-cursor?: number, // default: -1
-loc?: number, // default: -1
-num_lines?: number, // default: -1
-view_source?: boolean, // default: false
-source?: string,
-}) => any;
-{{- end }}
-{{- if $hasBrowserFind }}
-
-// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.
-type find = (_: {
-pattern: string,
-cursor?: number, // default: -1
-}) => any;
-{{- end }}
-
-} // namespace browser
-{{- end }}{{/* end if has browser tools */}}
-{{- if $hasPython }}
-
-## python
-
-Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).
-
-When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.
-{{- end }}{{/* end if hasPython */}}
-{{- end }}{{/* end if has any built-in tools */}}
-{{- end }}{{/* end if .Tools */}}
-
-# Valid channels: analysis, commentary, final. Channel must be included for every message.{{ if $hasNonBuiltinTools }}
-Calls to these tools must go to the commentary channel: 'functions'.
-{{- end -}}<|end|>{{/* end of system */ -}}
-{{- if or $hasNonBuiltinTools .System -}}
-<|start|>developer<|message|>{{- if $hasNonBuiltinTools }}# Tools
-
-## functions
-
-namespace functions {
-{{- range .Tools }}
-{{- if not (or (eq .Function.Name "browser.search") (eq .Function.Name "browser.open") (eq .Function.Name "browser.find") (eq .Function.Name "python")) }}
-{{if .Function.Description }}
-// {{ .Function.Description }}
-{{- end }}
-{{- if and .Function.Parameters.Properties (gt (len .Function.Parameters.Properties) 0) }}
-type {{ .Function.Name }} = (_: {
-{{- range $name, $prop := .Function.Parameters.Properties }}
-{{- if $prop.Description }}
- // {{ $prop.Description }}
-{{- end }}
- {{ $name }}: {{ if gt (len $prop.Type) 1 }}{{ range $i, $t := $prop.Type }}{{ if $i }} | {{ end }}{{ $t }}{{ end }}{{ else }}{{ index $prop.Type 0 }}{{ end }},
-{{- end }}
-}) => any;
-{{- else }}
-type {{ .Function.Name }} = () => any;
-{{- end }}
-{{- end }}{{/* end if not browser tool */}}
-{{- end }}{{/* end of range .Tools */}}
-
-} // namespace functions
-{{- end }}{{/* end if hasNonBuiltinTools */}}
-{{- if .System}}
-
-# Instructions
-
-{{ .System }}
-{{- end -}}
-<|end|>
-{{- end -}}
-{{- /* Find the index of the last user message */ -}}
-{{- $lastUserIdx := -1 }}
-{{- $prefillingContent := false }}
-{{- $prefillingThinkingOnly := false }}
-{{- range $i, $msg := .Messages }}
- {{- $last := eq (len (slice $.Messages $i)) 1 -}}
- {{- if eq $msg.Role "user" }}
- {{- $lastUserIdx = $i }}
- {{- end -}}
- {{- if and $last (eq $msg.Role "assistant") (gt (len $msg.Content) 0) }}
- {{- $prefillingContent = true }}
- {{- else if and $last (eq $msg.Role "assistant") (gt (len $msg.Thinking) 0) }}
- {{- $prefillingThinkingOnly = true }}
- {{- end }}
-{{- end -}}
-{{- /* Now render messages */ -}}
-{{- range $i, $msg := .Messages }}
- {{- $last := eq (len (slice $.Messages $i)) 1 -}}
- {{- if (ne $msg.Role "system") -}}
- {{- if eq $msg.Role "tool" -}}
- {{- if or (eq $msg.ToolName "python") (eq $msg.ToolName "browser.search") (eq $msg.ToolName "browser.open") (eq $msg.ToolName "browser.find") -}}
- <|start|>{{ $msg.ToolName }} to=assistant<|message|>{{ $msg.Content }}<|end|>
- {{- else -}}
- <|start|>functions.{{ $msg.ToolName }} to=assistant<|message|>{{ $msg.Content }}<|end|>
- {{- end -}}
- {{- else if eq $msg.Role "assistant" -}}
- {{- if and $msg.Thinking (gt $i $lastUserIdx) -}}{{- /* Show thinking only after last user message */ -}}
- <|start|>assistant<|channel|>analysis<|message|>{{ $msg.Thinking }}{{- if not $prefillingThinkingOnly -}}<|end|>{{- end -}}
- {{- end -}}
- {{- if gt (len $msg.Content) 0 -}}
- <|start|>assistant<|channel|>final<|message|>{{ $msg.Content }}{{- if not $prefillingContent -}}<|end|>{{- end -}}
- {{- end -}}
- {{- if gt (len $msg.ToolCalls) 0 -}}
- {{- range $j, $toolCall := $msg.ToolCalls -}}
- {{- $isBuiltin := or (eq $toolCall.Function.Name "python") (eq $toolCall.Function.Name "browser.search") (eq $toolCall.Function.Name "browser.open") (eq $toolCall.Function.Name "browser.find") -}}
- <|start|>assistant<|channel|>{{ if $isBuiltin }}analysis{{ else }}commentary{{ end }} to={{ if not $isBuiltin}}functions.{{end}}{{ $toolCall.Function.Name }} <|constrain|>json<|message|>{{ $toolCall.Function.Arguments }}<|call|>
- {{- end -}}
- {{- end -}}
- {{- else if eq $msg.Role "user" -}}
- <|start|>{{ $msg.Role }}<|message|>{{ $msg.Content }}<|end|>
- {{- end }}
- {{- else }}
- {{- end }}
-{{- end -}}
-{{- if not (or $prefillingContent $prefillingThinkingOnly) -}}
-<|start|>assistant
-{{- end -}}"""
-PARAMETER temperature 1.0
-PARAMETER top_k 0
-PARAMETER top_p 1.0
-'''
-
-gptoss_template_template_eos_token = "<|return|>"
-CHAT_TEMPLATES["gpt-oss"] = (gptoss_template, gptoss_template_template_eos_token, False, gptoss_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gpt-oss"] = None # No system message in GPT-oss
-
-CHAT_TEMPLATES["gptoss"] = (gptoss_template, gptoss_template_template_eos_token, False, gptoss_ollama,)
-DEFAULT_SYSTEM_MESSAGE["gptoss"] = None # No system message in GPT-oss
-
-# =========================================== Qwen3-Instruct
-qwen3_instruct_template = \
-'''{%- if tools %}
- {{- '<|im_start|>system\\n' }}
- {%- if messages[0].role == 'system' %}
- {{- messages[0].content + '\\n\\n' }}
- {%- endif %}
- {{- "# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n" }}
- {%- for tool in tools %}
- {{- "\\n" }}
- {{- tool | tojson }}
- {%- endfor %}
- {{- "\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\"name\\": , \\"arguments\\": }\\n<|im_end|>\\n" }}
-{%- else %}
- {%- if messages[0].role == 'system' %}
- {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}
- {%- endif %}
-{%- endif %}
-{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
-{%- for message in messages[::-1] %}
- {%- set index = (messages|length - 1) - loop.index0 %}
- {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}
- {%- set ns.multi_step_tool = false %}
- {%- set ns.last_query_index = index %}
- {%- endif %}
-{%- endfor %}
-{%- for message in messages %}
- {%- if message.content is string %}
- {%- set content = message.content %}
- {%- else %}
- {%- set content = '' %}
- {%- endif %}
- {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
- {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}
- {%- elif message.role == "assistant" %}
- {%- set reasoning_content = '' %}
- {%- if message.reasoning_content is string %}
- {%- set reasoning_content = message.reasoning_content %}
- {%- else %}
- {%- if '' in content %}
- {%- set reasoning_content = content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}
- {%- set content = content.split('')[-1].lstrip('\\n') %}
- {%- endif %}
- {%- endif %}
- {%- if loop.index0 > ns.last_query_index %}
- {%- if reasoning_content %}
- {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}
- {%- else %}
- {{- '<|im_start|>' + message.role + '\\n' + content }}
- {%- endif %}
- {%- else %}
- {{- '<|im_start|>' + message.role + '\\n' + content }}
- {%- endif %}
- {%- if message.tool_calls %}
- {%- for tool_call in message.tool_calls %}
- {%- if (loop.first and content) or (not loop.first) %}
- {{- '\\n' }}
- {%- endif %}
- {%- if tool_call.function %}
- {%- set tool_call = tool_call.function %}
- {%- endif %}
- {{- '\\n{"name": "' }}
- {{- tool_call.name }}
- {{- '", "arguments": ' }}
- {%- if tool_call.arguments is string %}
- {{- tool_call.arguments }}
- {%- else %}
- {{- tool_call.arguments | tojson }}
- {%- endif %}
- {{- '}\\n' }}
- {%- endfor %}
- {%- endif %}
- {{- '<|im_end|>\\n' }}
- {%- elif message.role == "tool" %}
- {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
- {{- '<|im_start|>user' }}
- {%- endif %}
- {{- '\\n\\n' }}
- {{- content }}
- {{- '\\n' }}
- {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
- {{- '<|im_end|>\\n' }}
- {%- endif %}
- {%- endif %}
-{%- endfor %}
-{%- if add_generation_prompt %}
- {{- '<|im_start|>assistant\\n' }}
-{%- endif %}'''
-
-qwen3_template_eos_token = "<|im_end|>"
-CHAT_TEMPLATES["qwen3-instruct"] = (qwen3_instruct_template, qwen3_template_eos_token, False, _ollama_template("qwen3-instruct"),)
-DEFAULT_SYSTEM_MESSAGE["qwen3-instruct"] = None # No system message in Qwen3
-
-
-# =========================================== Qwen3-Thinking
-qwen3_thinking_template = \
-'''{%- if tools %}
- {{- '<|im_start|>system\\n' }}
- {%- if messages[0].role == 'system' %}
- {{- messages[0].content + '\\n\\n' }}
- {%- endif %}
- {{- "# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n" }}
- {%- for tool in tools %}
- {{- "\\n" }}
- {{- tool | tojson }}
- {%- endfor %}
- {{- "\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\"name\\": , \\"arguments\\": }\\n<|im_end|>\\n" }}
-{%- else %}
- {%- if messages[0].role == 'system' %}
- {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}
- {%- endif %}
-{%- endif %}
-{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
-{%- for message in messages[::-1] %}
- {%- set index = (messages|length - 1) - loop.index0 %}
- {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}
- {%- set ns.multi_step_tool = false %}
- {%- set ns.last_query_index = index %}
- {%- endif %}
-{%- endfor %}
-{%- for message in messages %}
- {%- if message.content is string %}
- {%- set content = message.content %}
- {%- else %}
- {%- set content = '' %}
- {%- endif %}
- {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
- {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}
- {%- elif message.role == "assistant" %}
- {%- set reasoning_content = '' %}
- {%- if message.reasoning_content is string %}
- {%- set reasoning_content = message.reasoning_content %}
- {%- else %}
- {%- if '' in content %}
- {%- set reasoning_content = content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}
- {%- set content = content.split('')[-1].lstrip('\\n') %}
- {%- endif %}
- {%- endif %}
- {%- if loop.index0 > ns.last_query_index %}
- {%- if loop.last or (not loop.last and reasoning_content) %}
- {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}
- {%- else %}
- {{- '<|im_start|>' + message.role + '\\n' + content }}
- {%- endif %}
- {%- else %}
- {{- '<|im_start|>' + message.role + '\\n' + content }}
- {%- endif %}
- {%- if message.tool_calls %}
- {%- for tool_call in message.tool_calls %}
- {%- if (loop.first and content) or (not loop.first) %}
- {{- '\\n' }}
- {%- endif %}
- {%- if tool_call.function %}
- {%- set tool_call = tool_call.function %}
- {%- endif %}
- {{- '\\n{"name": "' }}
- {{- tool_call.name }}
- {{- '", "arguments": ' }}
- {%- if tool_call.arguments is string %}
- {{- tool_call.arguments }}
- {%- else %}
- {{- tool_call.arguments | tojson }}
- {%- endif %}
- {{- '}\\n' }}
- {%- endfor %}
- {%- endif %}
- {{- '<|im_end|>\\n' }}
- {%- elif message.role == "tool" %}
- {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
- {{- '<|im_start|>user' }}
- {%- endif %}
- {{- '\\n\\n' }}
- {{- content }}
- {{- '\\n' }}
- {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
- {{- '<|im_end|>\\n' }}
- {%- endif %}
- {%- endif %}
-{%- endfor %}
-{%- if add_generation_prompt %}
- {{- '<|im_start|>assistant\n\n' }}
-{%- endif %}'''
-
-CHAT_TEMPLATES["qwen3-thinking"] = (
- qwen3_thinking_template,
- qwen3_template_eos_token,
- False,
- _ollama_template("qwen3-thinking"),
-)
-DEFAULT_SYSTEM_MESSAGE["qwen3-thinking"] = None # No system message in Qwen3
-
-
-# =========================================== Liquid-LFM2
-liquid_lfm2_template = \
-'''
-{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + '
-' + message['content'] + '<|im_end|>' + '
-'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
-' }}{% endif %}'''
-
-liquid_lfm2_template_eos_token = "<|im_end|>"
-CHAT_TEMPLATES["lfm-2"] = (liquid_lfm2_template, liquid_lfm2_template_eos_token, False, None)
-DEFAULT_SYSTEM_MESSAGE["lfm-2"] = None # No system message in Phi-3
-
-
-# =========================================== Starling-LM
-
-starling_template = \
-"""{{ bos_token }}
-{%- for message in messages %}
- {{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>' }}
-{%- endfor %}
-{%- if add_generation_prompt %}
- {{ 'GPT4 Correct Assistant:' }}
-{%- endif %}"""
-
-# Ollama from https://ollama.com/library/starling-lm:7b/blobs/4b21bfc435b4
-starling_ollama = _ollama_template("starling")
-
-starling_template_eos_token = "<|end_of_turn|>"
-CHAT_TEMPLATES["starling"] = (starling_template, starling_template_eos_token, False, starling_ollama)
-DEFAULT_SYSTEM_MESSAGE["starling"] = None
-
-
-# =========================================== Yi-chat
-
-yi_chat_template = \
-"""
-{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '
-' + message['content'] + '<|im_end|>' + '
-'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
-' }}{% endif %}
-"""
-
-# Ollama from https://ollama.com/library/yi:34b-chat/blobs/62fbfd9ed093
-yi_chat_ollama = _ollama_template("yi-chat")
-
-yi_chat_template_eos_token = "<|endoftext|>"
-CHAT_TEMPLATES["yi-chat"] = (yi_chat_template, yi_chat_template_eos_token, False, yi_chat_ollama)
-DEFAULT_SYSTEM_MESSAGE["yi-chat"] = None
+pass
def _change_system_message(template: str, type_chat_template: str, system_message: str = None):
system_message_pattern = r"\{system_message\}"
-
+
# For predefined templates, check if default system message exists
default_system_message = DEFAULT_SYSTEM_MESSAGE.get(f"{type_chat_template}", None)
if default_system_message is None:
@@ -1659,24 +1028,27 @@ def _change_system_message(template: str, type_chat_template: str, system_messag
"You need to manually add the system message in your data."
)
return template, system_message
-
+ pass
+
# For custom templates
if type_chat_template is None:
has_placeholder = re.search(system_message_pattern, template) is not None
-
+
if has_placeholder:
if system_message is None:
raise ValueError("Unsloth: You need to provide a system message for custom templates.")
new_template = re.sub(system_message_pattern, system_message, template)
return new_template, system_message
-
+
return template, system_message
-
+ pass
+
# For predefined templates with default system message
message_to_use = system_message if system_message is not None else default_system_message
new_template = re.sub(system_message_pattern, message_to_use, template)
-
+
return new_template, message_to_use
+pass
def get_chat_template(
@@ -1693,6 +1065,7 @@ def get_chat_template(
if tokenizer.__class__.__name__.startswith("Gemma"):
if chat_template == "chatml": chat_template = "gemma_chatml"
IS_GEMMA = True
+ pass
# We add a check for Llama-3
# if chat_template == "llama-3":
@@ -1711,7 +1084,7 @@ def get_chat_template(
same_padding_token = False
type_chat_template = None
-
+
if type(chat_template) in (list, tuple,):
# For changing system message later
# Since it's not supported yet, we will raise an error first!
@@ -1854,11 +1227,13 @@ def get_chat_template(
f"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\n"\
f"{CHAT_TEMPLATES.keys()}"
)
+ pass
# Careful on Gemma
# bos_token is a must or else losses become too high
if IS_GEMMA and not chat_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
chat_template = "{{ bos_token }}" + chat_template
+ pass
# For ShareGPT role -> from and content -> value
new_chat_template = chat_template\
@@ -1880,6 +1255,7 @@ def get_chat_template(
"{% endif %}"
else:
chat_template = new_chat_template
+ pass
chat_template, system_message = _change_system_message(chat_template, type_chat_template, system_message)
@@ -1896,6 +1272,7 @@ def get_chat_template(
if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token
if not same_padding_token:
if old_pad_token != new_pad_token: tokenizer.pad_token = old_pad_token
+ pass
# stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
@@ -1906,13 +1283,16 @@ def get_chat_template(
tokenizer._ollama_modelfile = ollama_modelfile
tokenizer._system_message = system_message
return tokenizer#, stopping_criteria
+pass
def remove_special_tokens(tokenizer, prompt):
# Removes double BOS token
if prompt.startswith(tokenizer.bos_token):
prompt = prompt[len(tokenizer.bos_token):]
+ pass
return prompt
+pass
def _parse_combined_prompt(combined_prompt, dataset):
@@ -1925,6 +1305,8 @@ def _parse_combined_prompt(combined_prompt, dataset):
f"Unsloth: Your prompt includes '{column}' but this does not exist in the dataset. "\
f"Only allowed columns are {list(dataset_columns)}"
)
+ pass
+ pass
# Find [[...]]
optional_prompts = list(re.finditer(r"\[\[.+?\]\]", combined_prompt, flags = re.DOTALL | re.MULTILINE))
@@ -1942,6 +1324,7 @@ def _parse_combined_prompt(combined_prompt, dataset):
l, r = left[0][-1], right[0][0]
final_optional_prompts.append(left)
if l != r: final_optional_prompts.append(combined_prompt[l : r])
+ pass
final_optional_prompts.append(optional_prompts[-1])
# Add right
@@ -1951,69 +1334,54 @@ def _parse_combined_prompt(combined_prompt, dataset):
else:
# Just add in the entire string
final_optional_prompts.append(combined_prompt)
+ pass
check_combined = "".join(x if type(x) is str else x[1] for x in final_optional_prompts)
assert(combined_prompt == check_combined)
return possible_columns, final_optional_prompts
+pass
def _create_formatter(possible_columns, final_optional_prompts, user_column_name):
- columns = list(dict.fromkeys(possible_columns))
- merged_prompt_parts = []
- formatter_templates = []
+ # Start final prompt!
+ function = ["def __combined_prompt_processor__(examples):"]
+ columns = list(set(possible_columns))
+ for column in columns:
+ function.append(f"{' '*4}{column}__ = examples['{column}']")
+ function.append(f"{' '*4}texts = []")
+ function.append(f"{' '*4}for ({', '.join(columns)}) in zip({', '.join(f'{x}__' for x in columns)}):")
+
+ # Add optional tags as well!
+ final_prompt = ""
+ formatter = []
for j, optional_prompt in enumerate(final_optional_prompts):
if type(optional_prompt) is str:
- needed_columns = re.findall(r"\{(.+?)\}", optional_prompt)
- formatter_templates.append(("required", optional_prompt, needed_columns))
- merged_prompt_parts.append(optional_prompt)
- continue
-
- _, prompt = optional_prompt
- prompt = prompt[2:-2]
- needed_columns = re.findall(r"\{(.+?)\}", prompt)
- if len(needed_columns) == 0:
- raise IndexError("Unsloth: Optional [[...]] blocks must contain at least 1 {column}.")
- optional_name = f"__optional_{j}__"
- formatter_templates.append(("optional", optional_name, prompt, needed_columns))
- merged_prompt_parts.append("{" + optional_name + "}")
-
- merged_prompt = "".join(merged_prompt_parts)
-
- def __combined_prompt_processor__(examples):
- if len(examples) == 0:
- return {user_column_name: []}
-
- first_key = next(iter(examples.keys()), None)
- if first_key is None:
- return {user_column_name: []}
- n_rows = len(examples[first_key])
-
- texts = []
- for row_idx in range(n_rows):
- row_values = {column: examples[column][row_idx] for column in columns}
- formatter_values = {}
-
- for formatter_template in formatter_templates:
- if formatter_template[0] == "required":
- _, _, needed_columns = formatter_template
- for column in needed_columns:
- formatter_values[column] = row_values[column]
- continue
-
- _, optional_name, prompt, needed_columns = formatter_template
- if row_values[needed_columns[0]] not in (None, ""):
- prompt_values = {column: row_values[column] for column in needed_columns}
- formatter_values[optional_name] = prompt.format(**prompt_values)
- else:
- formatter_values[optional_name] = ""
-
- texts.append(merged_prompt.format(**formatter_values))
-
- return {user_column_name: texts}
+ columns = re.findall(r"\{(.+?)\}", optional_prompt)
+ formatter += columns
+ # Must escape \n \r
+ final_prompt += optional_prompt.encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
+ else:
+ where, prompt = optional_prompt
+ # Strip [[...]]
+ # Must escape \n \r
+ prompt = prompt[2:-2].encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
+ columns = re.findall(r"\{(.+?)\}", prompt)
+ x = f"__optional_{j}__"
+ prompt = f"{' '*8}{x} = '{prompt}'.format({', '.join(f'{x} = {x}' for x in columns)}) if {columns[0]} else ''"
+ function.append(prompt)
+ formatter.append(x)
+ final_prompt += "{" + x + "}"
+ pass
+ pass
- return __combined_prompt_processor__
+ function.insert(1, f"{' '*4}__combined_prompt__ = '{final_prompt}'")
+ function.append(f"{' '*8}texts.append("\
+ f"__combined_prompt__.format({', '.join(f'{x} = {x}' for x in formatter)}))")
+ function.append(f"{' '*4}return " + "{ " + f"'{user_column_name}' : texts" + " }")
+ return "\n".join(function)
+pass
def to_sharegpt(
@@ -2043,19 +1411,17 @@ def to_sharegpt(
convo = dataset[0]["conversations"]
if type(convo) is list:
raise TypeError("Unsloth: Your dataset is probably already in ShareGPT format!")
+ pass
+ pass
possible_columns, final_optional_prompts = _parse_combined_prompt(merged_prompt, dataset)
- formatter = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)
- dataset = dataset.map(formatter, batched = True, desc = "Merging columns")
+ function = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)
+ exec(function, globals())
+ dataset = dataset.map(__combined_prompt_processor__, batched = True, desc = "Merging columns")
def __convert_to_sharegpt__(examples):
users = examples[merged_column_name]
assistants = examples[output_column_name]
- if len(users) != len(assistants):
- raise ValueError(
- "Unsloth: Input and output columns must have matching batch lengths. "
- f"Got {len(users)} {merged_column_name} rows and {len(assistants)} {output_column_name} rows."
- )
texts = [
[
{"from" : "human", "value" : str(user) },
@@ -2064,6 +1430,7 @@ def __convert_to_sharegpt__(examples):
for user, assistant in zip(users, assistants)
]
return { "conversations" : texts, }
+ pass
dataset = dataset.map(
__convert_to_sharegpt__,
@@ -2083,21 +1450,23 @@ def __convert_to_sharegpt__(examples):
for j in range(1, n_extensions+1):
shuffled = dataset.shuffle(seed = random_state+j).rename_columns({"conversations0" : f"conversations{j}"})
all_shuffled.append(shuffled)
+ pass
dataset = concatenate_datasets(all_shuffled, axis = 1)
# Combine them into 1
+ function = "def __combine_conversations__(examples):\n"
n_extensions += 1
- conversation_columns = [f"conversations{j}" for j in range(n_extensions)]
- def __combine_conversations__(examples):
- columns = [examples[column] for column in conversation_columns]
- convos = []
- for conversations in zip(*columns):
- merged_conversation = []
- for conversation in conversations:
- merged_conversation.extend(conversation)
- convos.append(merged_conversation)
- return {"conversations" : convos}
-
+ for j in range(n_extensions):
+ function += f"{' '*4}conversations{j}__ = examples['conversations{j}']\n"
+ function += f"{' '*4}convos = []\n"
+ function += f"{' '*4}for ({', '.join(f'conversations{j}' for j in range(n_extensions))}) "\
+ f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n"
+ function += f"{' '*8}convos.append("\
+ f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n"
+ function += f"{' '*4}return " + "{ " + "'conversations' : convos" + " }"
+
+ # Map function
+ exec(function, globals())
dataset = dataset.map(
__combine_conversations__,
batched = True,
@@ -2106,6 +1475,7 @@ def __combine_conversations__(examples):
remove_columns = dataset.column_names if remove_unused_columns else None,
)
return dataset
+pass
def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
@@ -2118,6 +1488,7 @@ def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
# Remove BOS
if getattr(tokenizer, "bos_token", None) is not None:
added_tokens_decoder = [x for x in added_tokens_decoder if x != tokenizer.bos_token]
+ pass
repeatted_tokens = []
# Join all vocab
@@ -2135,6 +1506,9 @@ def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
joined_text = joined_text.replace(token[:j], "")
repeatted_tokens.append(token[:j])
break
+ pass
+ pass
+ pass
# Remove duplicates
splitted = joined_text.split("\x01\x00")
@@ -2150,7 +1524,9 @@ def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
elif token.startswith("<") and len(token) <= 2: continue
elif token.startswith("") and len(token) == 3: continue
filtered_eos_tokens.append(token)
+ pass
return filtered_eos_tokens
+pass
def construct_chat_template( \
@@ -2168,14 +1544,14 @@ def construct_chat_template( \
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|>""",
-
+
default_system_message = \
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
extra_eos_tokens = None,
):
"""
- Creates an Ollama modelfile and a HF Jinja template from a custom
+ Creates a Ollama modelfile and a HF Jinja template from a custom
template. You must provide 2x examples of an input & output.
There is an optional system message as well.
@@ -2194,6 +1570,8 @@ def construct_chat_template( \
assert(type(extra_eos) is str)
if extra_eos not in vocab:
raise ValueError(f"Unsloth: `{extra_eos}` is not a singular token in the tokenizer.")
+ pass
+ pass
error_msg = \
"Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} "\
@@ -2211,6 +1589,7 @@ def construct_chat_template( \
raise RuntimeError(
"Unsloth: Your tokenizer does not have an EOS token? Please provide one via extra_eos_tokens!"
)
+ pass
# Check tokenizer types
tokenizer_name = tokenizer.name_or_path.lower()
@@ -2224,11 +1603,13 @@ def construct_chat_template( \
"Unsloth: Base llama-3 models did not train <|eot_id|>.\n"\
"Please use the instruct version or use <|end_of_text|>"
)
+ pass
extra_eos_tokens = list(set(extra_eos_tokens))
count_eos = 0
for eos in extra_eos_tokens:
count_eos += len(re.findall(r"{OUTPUT}" + re.escape(eos), chat_template))
+ pass
# This forces you to provide 2 input and outputs
final_combined_check = False
@@ -2242,6 +1623,7 @@ def construct_chat_template( \
if found == -1: break
j -= 1
at_least_one = True
+ pass
if j > 0: j += 1
else: raise RuntimeError(error_msg)
@@ -2254,6 +1636,7 @@ def construct_chat_template( \
instruction_response = chat_template[j:]
if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1:
raise RuntimeError(error_msg)
+ pass
# 1st System, Instruction, Output pair
left = chat_template[:j]
@@ -2293,6 +1676,7 @@ def construct_chat_template( \
"But we require the following:\n"\
f"{left_changed}"
)
+ pass
except:
ending = chat_template[chat_template.find("{OUTPUT}") + len("{OUTPUT}"):]
@@ -2305,6 +1689,7 @@ def construct_chat_template( \
try_find = re.escape(response_part[:j])
try: found = next(re.finditer("(" + try_find + ").+?\\{INPUT\\}", chat_template, flags = re.DOTALL | re.MULTILINE))
except: break
+ pass
separator = found.group(1)
response_start = chat_template.find(response_part)
@@ -2317,11 +1702,13 @@ def construct_chat_template( \
system_part = chat_template[:where]
system_part, input_part, output_part = system_part, instruction_part, response_part
+ pass
if count_eos == 0:
logger.warning("Unsloth: We automatically added an EOS token to stop endless generations.")
eos = extra_eos_tokens[0]
output_part = output_part + eos
+ pass
# Ollama modelfile parts
@@ -2334,11 +1721,14 @@ def construct_chat_template( \
if ollama_system.startswith(tokenizer.bos_token):
has_bos_token = True
ollama_system = ollama_system[len(tokenizer.bos_token):]
+ pass
+ pass
# Check system
if "{SYSTEM}" in ollama_system:
system_modelfile = "{{ if .System }}" + ollama_system.replace("{SYSTEM}", "{{ .System }}") + "{{ end }}"
else:
system_modelfile = ollama_system
+ pass
input_modelfile = "{{ if .Prompt }}" + input_part .replace("{INPUT}", "{{ .Prompt }}") + "{{ end }}"
output_modelfile = output_part.replace("{OUTPUT}", "{{ .Response }}")
@@ -2365,8 +1755,10 @@ def process(part, which, content = "message['content']"):
part = "'" + part.replace(which, f"' + {content} + '") + "'"
if part.startswith("'' + "): part = part[5:]
return part
+ pass
input_jinja = process(input_part, "{INPUT}")
output_jinja = process(output_part, "{OUTPUT}")
+ pass
jinja_template = \
"{% for message in loop_messages %}"\
@@ -2381,6 +1773,7 @@ def process(part, which, content = "message['content']"):
"{% if add_generation_prompt %}"\
"{{ '" + output_part[:output_part.find("{OUTPUT}")] + "' }}"\
"{% endif %}"
+ pass
# Now add system prompt to jinja
if len(system_part) != 0:
@@ -2390,12 +1783,14 @@ def process(part, which, content = "message['content']"):
if "{SYSTEM}" in partial_system:
if default_system_message is None:
raise RuntimeError("Unsloth: Please specify a default system message!")
+ pass
# Separate the BOS
if has_bos_token:
partial_system = partial_system.replace(tokenizer.bos_token, "", 1)
system_part = system_part .replace(tokenizer.bos_token, "", 1)
-
+ pass
+
partial_system = \
"{% if messages[0]['role'] == 'system' %}"\
"{{ " + partial_system + " }}"\
@@ -2404,17 +1799,20 @@ def process(part, which, content = "message['content']"):
full_system = system_part.replace("{SYSTEM}", default_system_message)
if "{SYSTEM}" in system_part:
modelfile += '\nSYSTEM "' + default_system_message + '"'
+ pass
partial_system += "{% else %}"\
"{{ '" + full_system + "' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"
else:
partial_system += "{% endif %}"
+ pass
jinja_template = partial_system + jinja_template
if has_bos_token:
jinja_template = "{{ bos_token }}" + jinja_template
+ pass
# Fix missing loop_messages
if "{% set loop_messages = messages %}" not in jinja_template:
@@ -2423,6 +1821,7 @@ def process(part, which, content = "message['content']"):
"{% for message in messages %}",
1, # Only replace the first one
)
+ pass
# Check if system part is the same!
jinja_template = re.sub(
@@ -2434,15 +1833,17 @@ def process(part, which, content = "message['content']"):
jinja_template, flags = re.MULTILINE | re.DOTALL,
)
- # Check jinja template for bos
+ # Check jinja tempate for bos
if always_bos_token:
if not jinja_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
jinja_template = "{{ bos_token }}" + jinja_template
+ pass
# Get instruction and output parts for train_on_inputs = False
input_part = input_part [:input_part .find("{INPUT}")]
output_part = output_part[:output_part.find("{OUTPUT}")]
return modelfile, jinja_template, input_part, output_part
+pass
def test_construct_chat_template():
@@ -2461,10 +1862,10 @@ def test_construct_chat_template():
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|>"""
-
+
default_system_message = \
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
-
+
extra_eos_tokens = None
modelfile, jinja_template, _, _ = construct_chat_template(
@@ -2486,6 +1887,8 @@ def test_construct_chat_template():
tokenizer.chat_template = jinja_template
new_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_output == new_output)
+ pass
+pass
def apply_chat_template( \
@@ -2504,15 +1907,15 @@ def apply_chat_template( \
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{OUTPUT}<|eot_id|>""",
-
+
default_system_message = \
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
-
+
extra_eos_tokens = None,
-
+
):
"""
- Creates an Ollama modelfile and a HF Jinja template from a custom
+ Creates a Ollama modelfile and a HF Jinja template from a custom
template. You must provide 2x examples of an input & output.
There is an optional system message as well.
@@ -2528,6 +1931,7 @@ def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
return { "text" : texts, }
+ pass
tokenizer.chat_template = jinja_template
tokenizer._ollama_modelfile = modelfile
@@ -2540,6 +1944,7 @@ def formatting_prompts_func(examples):
tokenizer.tokenizer._unsloth_output_part = output_part
return dataset.map(formatting_prompts_func, batched = True,)
+pass
def create_stopping_criteria(tokenizer, stop_word = "eos_token"):
@@ -2555,7 +1960,9 @@ def __init__(self, stops = "eos_token", device = "cuda", encounters = 1):
self.stop_token = tokenizer(["\n" + stops], add_special_tokens = False, return_tensors = "pt")
self.stop_token = self.stop_token.input_ids.ravel()[1:].to("cuda")
self.length = self.stop_token.shape[0]
+ pass
self.single_match = self.length == 1
+ pass
def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> bool:
input_ids = input_ids.ravel()
@@ -2565,8 +1972,11 @@ def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> bool:
if input_ids.shape[0] >= self.length and \
(input_ids[-self.length:] == self.stop_token).all(): return True
return False
+ pass
+ pass
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = stop_word)])
return stopping_criteria
+pass
def test_chat_templates():
@@ -2670,6 +2080,7 @@ def test_chat_templates():
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
+pass
def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf"):
@@ -2705,25 +2116,20 @@ def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf")
if tokenizer.chat_template is not None:
prompt = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
+ prompt = prompt.replace("'", "") # Subprocess does not like ''
prompt = remove_special_tokens(tokenizer, prompt)
prompts.append(prompt)
-
+ pass
+
for prompt in prompts:
- # Use a list of args with shell=False so prompt content is passed literally.
- command = [
- "./llama.cpp/llama-cli",
- "-m", gguf_model,
- "-n", "0",
- "--temp", "0.0",
- "--verbose-prompt",
- "--check-tensors",
- "-p", prompt,
- ]
+ command = f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
+ f"--check-tensors -p '{prompt}'"
datas = []
- with subprocess.Popen(command, shell = False, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
+ with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
for line in sp.stdout:
datas.append(line.decode("utf-8", errors = "replace"))
+ pass
gguf_tokens = "".join(datas)
# Now extract GGUF tokenization attempt
@@ -2745,4 +2151,7 @@ def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf")
print(gguf_tokenized)
print()
raise RuntimeError("Failed comparing GGUF to HF.")
+ pass
+ pass
return True
+pass
diff --git a/unsloth/dataprep/__init__.py b/unsloth/dataprep/__init__.py
deleted file mode 100644
index 048f9b8010..0000000000
--- a/unsloth/dataprep/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .synthetic import *
-from .raw_text import *
diff --git a/unsloth/dataprep/raw_text.py b/unsloth/dataprep/raw_text.py
deleted file mode 100644
index ba010edabb..0000000000
--- a/unsloth/dataprep/raw_text.py
+++ /dev/null
@@ -1,348 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import re
-import json
-import csv
-from typing import List, Dict, Any, Union, Optional
-from datasets import Dataset
-from pathlib import Path
-
-__all__ = [
- "RawTextDataLoader",
- "TextPreprocessor",
-]
-
-SUPPORTED_FORMATS = {
- ".txt": "plain_text",
- ".md": "markdown",
- ".json": "json_lines",
- ".jsonl": "json_lines",
- ".csv": "csv_text_column",
-}
-
-
-class RawTextDataLoader:
- def __init__(self, tokenizer, chunk_size = 2048, stride = 512, return_tokenized = True):
- if chunk_size <= 0:
- raise ValueError(f"chunk_size must be positive, got {chunk_size}")
- if stride >= chunk_size:
- raise ValueError(
- f"stride ({stride}) must be smaller than chunk_size ({chunk_size})"
- )
- self.tokenizer = tokenizer
- self.chunk_size = chunk_size
- self.stride = stride
- self.return_tokenized = return_tokenized
-
- def detect_format(self, file_path):
- """Auto-detect file format and parse accordingly"""
- extension = Path(file_path).suffix.lower()
- return SUPPORTED_FORMATS.get(extension, "plain_text")
-
- def load_from_file(self, file_path, return_tokenized = None):
- """Load raw text and convert to dataset"""
- if return_tokenized is None:
- return_tokenized = self.return_tokenized
- file_format = self.detect_format(file_path)
- text_content = self._read_file_by_format(file_path, file_format)
- if not text_content or not text_content.strip():
- raise ValueError(f"File '{file_path}' is empty or contains only whitespace")
- chunks = self.smart_chunk_text(
- text_content, self.chunk_size, self.stride, return_tokenized
- )
- return self.create_causal_dataset(chunks)
-
- def load_from_files(self, file_paths, return_tokenized = None):
- """Load multiple text files"""
- if return_tokenized is None:
- return_tokenized = self.return_tokenized
- all_chunks = []
- for file_path in file_paths:
- file_format = self.detect_format(file_path)
- text_content = self._read_file_by_format(file_path, file_format)
- chunks = self.smart_chunk_text(
- text_content, self.chunk_size, self.stride, return_tokenized
- )
- all_chunks.extend(chunks)
- return self.create_causal_dataset(all_chunks)
-
- def chunk_text(self, text, return_tokenized = None):
- """Split text into overlapping chunks"""
- if return_tokenized is None:
- return_tokenized = self.return_tokenized
- return self.smart_chunk_text(
- text, self.chunk_size, self.stride, return_tokenized
- )
-
- def create_causal_dataset(self, chunks):
- """Create dataset for causal language modeling"""
- if chunks and isinstance(chunks[0], dict):
- # If chunks are already tokenized (dict with input_ids, attention_mask)
- # Reorganize the data structure for Dataset.from_dict
- input_ids = [chunk["input_ids"] for chunk in chunks]
- attention_mask = [chunk["attention_mask"] for chunk in chunks]
- # Labels are same as input_ids for causal LM training
- labels = [list(ids) for ids in input_ids]
- return Dataset.from_dict(
- {
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "labels": labels,
- }
- )
- else:
- # If chunks are text strings (backward compatibility)
- return Dataset.from_dict({"text": chunks})
-
- def smart_chunk_text(self, text, chunk_size, stride, return_tokenized = True):
- """
- Intelligent chunking that:
- 1. Respects sentence/paragraph boundaries
- 2. Handles various text formats (.txt, .md, .json, etc.)
- 3. Maintains context with stride overlap
- 4. Returns tokenized chunks directly (more efficient) or text chunks
- """
- # First pass: tokenize the entire text to get accurate token counts
- tokenized = self.tokenizer(text, return_tensors = "pt", add_special_tokens = False)
- tokens = tokenized["input_ids"]
-
- # Handle different tokenizer return formats
- if hasattr(tokens, "__len__") and len(tokens) > 0:
- # If it's a nested structure, get the first element
- if hasattr(tokens[0], "__len__"):
- tokens = tokens[0]
- elif isinstance(tokens, int):
- # If tokenizer returns just a count, create a simple range
- tokens = list(range(tokens))
-
- if len(tokens) <= chunk_size:
- # Text is small enough to fit in one chunk
- if return_tokenized:
- # Add EOS token to the tokens if available
- eos_token_id = getattr(self.tokenizer, "eos_token_id", None)
- if eos_token_id is not None:
- tokens = (
- tokens.tolist() if hasattr(tokens, "tolist") else list(tokens)
- )
- tokens.append(eos_token_id)
-
- # Create attention mask
- attention_mask = [1] * len(tokens)
- return [{"input_ids": tokens, "attention_mask": attention_mask}]
- else:
- eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else ""
- return [text + eos_token]
-
- chunks = []
- start_idx = 0
-
- while start_idx < len(tokens):
- # Calculate end index for this chunk
- end_idx = min(start_idx + chunk_size, len(tokens))
-
- # Extract tokens for this chunk
- chunk_tokens = tokens[start_idx:end_idx]
-
- if return_tokenized:
- # Convert to list if it's a tensor
- chunk_tokens_list = (
- chunk_tokens.tolist()
- if hasattr(chunk_tokens, "tolist")
- else list(chunk_tokens)
- )
-
- # Add EOS token if it's the last chunk or chunk is complete
- if end_idx == len(tokens) or len(chunk_tokens_list) == chunk_size:
- eos_token_id = getattr(self.tokenizer, "eos_token_id", None)
- if eos_token_id is not None:
- chunk_tokens_list.append(eos_token_id)
-
- # Create attention mask (all tokens are attended to)
- attention_mask = [1] * len(chunk_tokens_list)
-
- chunks.append(
- {"input_ids": chunk_tokens_list, "attention_mask": attention_mask}
- )
- else:
- # Decode back to text (backward compatibility)
- chunk_text = self.tokenizer.decode(
- chunk_tokens, skip_special_tokens = True
- )
-
- # Add EOS token if it's the last chunk or chunk is complete
- if end_idx == len(tokens) or len(chunk_tokens) == chunk_size:
- eos_token = (
- self.tokenizer.eos_token if self.tokenizer.eos_token else ""
- )
- chunk_text += eos_token
-
- chunks.append(chunk_text)
-
- # Move to next chunk with stride overlap
- if end_idx == len(tokens):
- break
- start_idx += chunk_size - stride
-
- return chunks
-
- def _read_file_by_format(self, file_path, file_format):
- """Read file content based on detected format."""
- with open(file_path, "r", encoding = "utf-8") as f:
- if file_format == "plain_text" or file_format == "markdown":
- return f.read()
- elif file_format == "json_lines":
- lines = []
- for line in f:
- try:
- data = json.loads(line.strip())
- text = self._extract_text_from_json(data)
- if text:
- lines.append(text)
- except json.JSONDecodeError:
- continue
- return "\n\n".join(lines)
- elif file_format == "csv_text_column":
- reader = csv.DictReader(f)
- texts = []
- for row in reader:
- text = self._extract_text_from_csv_row(row)
- if text:
- texts.append(text)
- return "\n\n".join(texts)
- return ""
-
- def _extract_text_from_json(self, data):
- """Extract text from JSON object using common field names."""
- text_fields = ["text", "content", "message", "body", "description", "prompt"]
- for field in text_fields:
- if field in data and isinstance(data[field], str):
- return data[field]
- return ""
-
- def _extract_text_from_csv_row(self, row):
- """Extract text from CSV row using common column names."""
- text_columns = ["text", "content", "message", "body", "description", "prompt"]
- for column in text_columns:
- if column in row and row[column]:
- return row[column]
- return ""
-
-
-class TextPreprocessor:
- def clean_text(self, text):
- """Remove unwanted characters, normalize whitespace"""
- text = re.sub(r"\s+", " ", text)
- text = re.sub(r"[^\x20-\x7E\n\t]", "", text)
- text = text.replace("\r\n", "\n").replace("\r", "\n")
- text = re.sub(r"\n{3,}", "\n\n", text)
- return text.strip()
-
- def extract_sections(self, text, patterns):
- """Extract specific sections (e.g., code blocks, quotes)"""
- sections = []
- for pattern in patterns:
- matches = re.findall(pattern, text, re.MULTILINE | re.DOTALL)
- sections.extend(matches)
- return sections
-
- def add_structure_tokens(self, text):
- """Add special tokens for structure (chapters, sections)"""
- text = re.sub(
- r"^# (.+)$", r"<|chapter|>\1<|/chapter|>", text, flags = re.MULTILINE
- )
- text = re.sub(
- r"^## (.+)$", r"<|section|>\1<|/section|>", text, flags = re.MULTILINE
- )
- text = re.sub(
- r"^### (.+)$", r"<|subsection|>\1<|/subsection|>", text, flags = re.MULTILINE
- )
- text = re.sub(
- r"```(\w*)\n(.*?)\n```", r"<|code|\1|>\2<|/code|>", text, flags = re.DOTALL
- )
- return text
-
- def validate_dataset(self, dataset):
- """
- Check for:
- - Minimum/maximum sequence lengths
- - Character encoding issues
- - Repeated content
- - Empty chunks
- """
- stats = {
- "total_samples": len(dataset),
- "empty_samples": 0,
- "min_length": float("inf"),
- "max_length": 0,
- "avg_length": 0,
- "repeated_content": 0,
- "encoding_issues": 0,
- "warnings": [],
- }
-
- texts = dataset["text"]
- text_lengths = []
- seen_texts = set()
-
- for i, text in enumerate(texts):
- if not text or len(text.strip()) == 0:
- stats["empty_samples"] += 1
- continue
-
- # Check for encoding issues
- try:
- text.encode("utf-8")
- except UnicodeEncodeError:
- stats["encoding_issues"] += 1
-
- # Calculate lengths
- length = len(text)
- text_lengths.append(length)
- stats["min_length"] = min(stats["min_length"], length)
- stats["max_length"] = max(stats["max_length"], length)
-
- # Check for repeated content
- text_hash = hash(text.strip())
- if text_hash in seen_texts:
- stats["repeated_content"] += 1
- else:
- seen_texts.add(text_hash)
-
- # Calculate average length
- if text_lengths:
- stats["avg_length"] = sum(text_lengths) / len(text_lengths)
- stats["min_length"] = (
- stats["min_length"] if stats["min_length"] != float("inf") else 0
- )
-
- # Generate warnings
- if stats["empty_samples"] > 0:
- stats["warnings"].append(f"Found {stats['empty_samples']} empty samples")
-
- if stats["repeated_content"] > 0:
- stats["warnings"].append(
- f"Found {stats['repeated_content']} repeated samples"
- )
-
- if stats["encoding_issues"] > 0:
- stats["warnings"].append(
- f"Found {stats['encoding_issues']} encoding issues"
- )
-
- if stats["min_length"] < 10:
- stats["warnings"].append("Some samples are very short (< 10 characters)")
-
- return stats
diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py
deleted file mode 100644
index 612c531f47..0000000000
--- a/unsloth/dataprep/synthetic.py
+++ /dev/null
@@ -1,473 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-__all__ = [
- "SyntheticDataKit",
-]
-import subprocess
-import threading
-from collections import deque
-import time
-import os
-
-os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
-import requests
-import torch
-import gc
-import time
-import re
-from unsloth_zoo.log import logger
-import numpy as np
-
-from .synthetic_configs import (
- synthetic_qa_config,
-)
-
-
-def _load_vllm_utils():
- from unsloth_zoo.vllm_utils import (
- load_vllm,
- patch_vllm,
- delete_vllm,
- )
-
- return load_vllm, patch_vllm, delete_vllm
-
-
-def terminate_tree(proc: subprocess.Popen, timeout = 15):
- if proc is None or proc.poll() is not None:
- return
-
- try:
- import psutil
-
- parent = psutil.Process(proc.pid)
- for child in parent.children(recursive = True):
- child.terminate()
- parent.terminate()
- parent.wait(timeout = timeout / 2)
- return
- except:
- pass
-
- if os.name == "nt":
- try:
- subprocess.run(
- ["taskkill", "/T", "/F", "/PID", str(proc.pid)],
- capture_output = True,
- timeout = 5,
- )
- proc.wait(timeout = 1)
- return
- except:
- pass
-
- proc.kill()
- try:
- proc.wait(timeout = 5)
- except:
- pass
-
-
-class PipeCapture:
- """Non blocking pipe capture"""
-
- def __init__(
- self,
- pipe,
- keep_lines = 2000,
- echo = False,
- name = "",
- text = True,
- encoding = "utf-8",
- errors = "replace",
- ready_regex = None,
- ):
- self.pipe = pipe
- self.buf = deque(maxlen = keep_lines)
- self.lock = threading.Lock()
- self.echo = echo
- self.name = name
- self.text = text
- self.encoding = encoding
- self.errors = errors
-
- self.ready_event = threading.Event()
- self.closed_event = threading.Event()
-
- self.ready_regex = None
- if ready_regex is not None:
- if not hasattr(ready_regex, "search"):
- ready_regex = re.compile(ready_regex)
- self.ready_regex = ready_regex
-
- self.t = threading.Thread(target = self._reader, daemon = True)
- self.t.start()
-
- def _reader(self):
- try:
- sentinel = "" if self.text else b""
- for raw_line in iter(self.pipe.readline, sentinel):
- if not self.text:
- line = raw_line.decode(self.encoding, self.errors)
- else:
- line = raw_line
- line = line.rstrip("\r\n")
- if self.echo:
- if "platform is" not in line:
- print(f"{self.name}: {line}")
-
- with self.lock:
- self.buf.append(line)
-
- if self.ready_regex is not None and self.ready_regex.search(line):
- self.ready_event.set()
-
- finally:
- try:
- self.pipe.close()
- except Exception:
- pass
- self.closed_event.set()
-
- def wait_for_ready(self, timeout = None):
- return self.ready_event.wait(timeout)
-
- def has_closed(self):
- return self.closed_event.is_set()
-
- def wait_until_closed(self, timeout = None):
- return self.closed_event.wait(timeout)
-
- def tail(self, n = 200):
- with self.lock:
- return "\n".join(list(self.buf)[-n:])
-
-
-class SyntheticDataKit:
- def __init__(
- self,
- model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
- max_seq_length = 2048,
- gpu_memory_utilization = 0.98,
- float8_kv_cache = False,
- conservativeness = 1.0,
- token = None,
- timeout = 1200, # maybe this is not enough for large models if we need to download
- **kwargs,
- ):
- assert type(model_name) is str
- assert type(max_seq_length) is int
- assert type(gpu_memory_utilization) is float
- assert type(float8_kv_cache) is bool
- assert type(conservativeness) is float
- assert token is None or type(token) is str
-
- self.model_name = model_name
- self.max_seq_length = max_seq_length
-
- from transformers import AutoConfig, AutoTokenizer
-
- self.config = AutoConfig.from_pretrained(
- model_name,
- token = token,
- )
- self.tokenizer = AutoTokenizer.from_pretrained(
- model_name,
- token = token,
- )
- load_vllm, patch_vllm, delete_vllm = _load_vllm_utils()
- self._delete_vllm = delete_vllm
- patch_vllm(debug = False)
- engine_args = load_vllm(
- model_name = model_name,
- config = self.config,
- gpu_memory_utilization = gpu_memory_utilization,
- max_seq_length = max_seq_length,
- disable_log_stats = True,
- float8_kv_cache = float8_kv_cache,
- conservativeness = conservativeness,
- return_args = True,
- enable_lora = False,
- use_bitsandbytes = False,
- compilation_config = 3,
- **kwargs,
- )
- if "dtype" in engine_args:
- dtype_val = engine_args["dtype"]
- if dtype_val == torch.float16:
- dtype_val = "float16"
- elif dtype_val == torch.bfloat16:
- dtype_val = "bfloat16"
- elif dtype_val == torch.float32:
- dtype_val = "float32"
- engine_args["dtype"] = dtype_val
- # Convert torch.bfloat16, torch.float16, etc. to valid CLI string
- if hasattr(dtype_val, "name"):
- engine_args["dtype"] = dtype_val.name
- elif isinstance(dtype_val, str) and dtype_val.startswith("torch."):
- engine_args["dtype"] = dtype_val.split(".")[-1]
- # Only allow valid vLLM choices
- valid_dtypes = {"auto", "bfloat16", "float", "float16", "float32", "half"}
- if engine_args["dtype"] not in valid_dtypes:
- engine_args["dtype"] = "auto"
- if "device" in engine_args:
- del engine_args["device"]
- if "model" in engine_args:
- del engine_args["model"]
-
- subprocess_commands = [
- "vllm",
- "serve",
- str(model_name),
- ]
- for key, value in engine_args.items():
- flag = key.replace("_", "-")
- if key == "compilation_config":
- # [TODO] Unsure why subprocess doesn't process json properly
- # Also -O3 breaks on T4!
- # subprocess_commands += ["-O3",]
- continue
- which = str(value).replace("torch.", "")
- if which == "True":
- # Ignore --enforce-eager True
- subprocess_commands += [
- "--" + flag,
- ]
- elif which == "False":
- # Ignore flag
- pass
- elif which == "None":
- # Ignore flag
- pass
- else:
- subprocess_commands += [
- "--" + flag,
- which,
- ]
- logger.info(subprocess_commands)
- vllm_process = subprocess.Popen(
- subprocess_commands,
- stdout = subprocess.PIPE,
- stderr = subprocess.PIPE,
- start_new_session = True,
- )
- ready_re = re.compile(r"Starting vLLM API server(?:\s+\d+)?\s+on\b")
- self.vllm_process = vllm_process
- self.stdout_capture = PipeCapture(
- vllm_process.stdout,
- keep_lines = 1000,
- echo = True,
- name = "vLLM STDOUT",
- ready_regex = ready_re,
- text = False,
- )
- self.stderr_capture = PipeCapture(
- vllm_process.stderr,
- keep_lines = 2000,
- echo = False,
- name = "vLLM STDERR",
- ready_regex = None,
- text = False,
- )
- # we don't print stderr to console but self.stderr_capture.tail(200) will print the last 200 lines
-
- ready = self.stdout_capture.wait_for_ready(timeout = timeout)
- if not ready:
- if self.stdout_capture.has_closed() or self.vllm_process.poll() is not None:
- print("Stdout stream ended before readiness message detected.")
- print("\n--- stdout tail ---\n", self.stdout_capture.tail(50))
- print("\n--- stderr tail ---\n", self.stderr_capture.tail(50))
- else:
- print(f"Unsloth: vllm_process failed to load! (timeout={timeout})")
- print("\n--- stdout tail ---\n", self.stdout_capture.tail(50))
- print("\n--- stderr tail ---\n", self.stderr_capture.tail(50))
- terminate_tree(self.vllm_process)
- return
- else:
- print("vLLM Server Ready Detected")
-
- trial = 0
- while not self.check_vllm_status():
- if trial >= 100:
- print("Unsloth: vllm_process failed to load!")
- print("\n--- stdout tail ---\n", self.stdout_capture.tail(50))
- print("\n--- stderr tail ---\n", self.stderr_capture.tail(50))
- terminate_tree(self.vllm_process)
- return
- trial += 1
- time.sleep(1)
- return
-
- @staticmethod
- def from_pretrained(
- model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
- max_seq_length = 2048,
- gpu_memory_utilization = 0.9,
- float8_kv_cache = False,
- conservativeness = 1.0,
- token = None,
- **kwargs,
- ):
- return SyntheticDataKit(
- model_name = model_name,
- max_seq_length = max_seq_length,
- gpu_memory_utilization = gpu_memory_utilization,
- float8_kv_cache = float8_kv_cache,
- conservativeness = conservativeness,
- token = token,
- **kwargs,
- )
-
- @staticmethod
- def check_vllm_status():
- try:
- response = requests.get("http://localhost:8000/metrics")
- if response.status_code == 200:
- return True
- except requests.exceptions.ConnectionError:
- return False
-
- def cleanup(self):
- if not hasattr(self, "vllm_process"):
- return
-
- vllm_process = self.vllm_process
- print("Attempting to terminate the VLLM server gracefully...")
- try:
- vllm_process.terminate()
- vllm_process.wait(timeout = 10)
- print("Server terminated gracefully.")
- except subprocess.TimeoutExpired:
- print(
- "Server did not terminate gracefully after 10 seconds. Forcing kill..."
- )
- vllm_process.kill()
- vllm_process.wait()
- print("Server killed forcefully.")
- except Exception as e:
- print(f"An error occurred while trying to stop the process: {e}")
- try:
- if vllm_process.poll() is None:
- print("Attempting forceful kill due to error...")
- vllm_process.kill()
- vllm_process.wait()
- print("Server killed forcefully after error.")
- except Exception as kill_e:
- print(f"Error during forceful kill: {kill_e}")
- for _ in range(10):
- torch.cuda.empty_cache()
- gc.collect()
-
- # Delete vLLM module as well
- if hasattr(self, "_delete_vllm"):
- self._delete_vllm(llm = None)
-
- def __enter__(self):
- return self
-
- def __exit__(self, *exc):
- self.cleanup()
-
- def __del__(self):
- self.cleanup()
-
- def chunk_data(self, filename = None):
- # Chunks data by max tokens and generation length
- assert filename is not None
- assert os.path.exists(filename)
- assert hasattr(self, "tokenizer")
- if not hasattr(self, "max_seq_length"):
- raise RuntimeError(
- "Please use SynthetidDataKit.from_pretrained(...) first!"
- )
- if not hasattr(self, "overlap") or not hasattr(self, "max_generation_tokens"):
- raise RuntimeError("Please use prepare_qa_generation first!")
-
- with open(filename, "r", encoding = "utf-8") as f:
- text = f.read()
-
- max_tokens = (
- self.max_seq_length - self.max_generation_tokens * 2 - 128
- ) # -128 to reduce errors
- if max_tokens <= 5:
- raise RuntimeError("Generation length is way too long!")
- input_ids = self.tokenizer(text, add_special_tokens = False).input_ids
-
- # Get left and right boundaries
- length = len(input_ids)
- n_chunks = int(np.ceil(length / (max_tokens - self.overlap)))
- boundaries = np.ceil(np.linspace(0, length - self.overlap, n_chunks)).astype(
- int
- )
- boundaries = np.stack((boundaries[:-1], (boundaries + self.overlap)[1:])).T
- boundaries = np.minimum(boundaries, length).tolist()
-
- # Get extension of filename like .txt
- filename, extension = os.path.splitext(filename)
- if filename.endswith("/"):
- filename = filename[:-1]
-
- all_filenames = []
- for i, (left, right) in enumerate(boundaries):
- chunked_text = self.tokenizer.decode(input_ids[left:right])
- new_filename = f"{filename}_{i}{extension}"
- all_filenames.append(new_filename)
- with open(new_filename, "w", encoding = "utf-8") as f:
- f.write(chunked_text)
- return all_filenames
-
- def prepare_qa_generation(
- self,
- output_folder = "data",
- max_generation_tokens = 512,
- temperature = 0.7,
- top_p = 0.95,
- overlap = 64,
- default_num_pairs = 25,
- cleanup_threshold = 1.0,
- cleanup_batch_size = 4,
- cleanup_temperature = 0.3,
- ):
- assert hasattr(self, "model_name")
- assert hasattr(self, "max_seq_length")
- assert max_generation_tokens < self.max_seq_length
-
- locations = "pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final"
- locations = locations.split(",")
- for path in locations:
- os.makedirs(os.path.join(output_folder, path), exist_ok = True)
-
- self.max_generation_tokens = max_generation_tokens
-
- config = (
- synthetic_qa_config.replace("{data_output_location}", str(output_folder))
- .replace("{model_name}", str(self.model_name))
- .replace("{temperature}", str(temperature))
- .replace("{top_p}", str(top_p))
- .replace(
- "{chunk_size}", str(self.max_seq_length - max_generation_tokens * 2 - 2)
- )
- .replace("{overlap}", str(overlap))
- .replace("{max_tokens}", str(max_generation_tokens))
- .replace("{default_num_pairs}", str(default_num_pairs))
- .replace("{cleanup_threshold}", str(cleanup_threshold))
- .replace("{cleanup_batch_size}", str(cleanup_batch_size))
- .replace("{cleanup_temperature}", str(cleanup_temperature))
- )
-
- with open("synthetic_data_kit_config.yaml", "w", encoding = "utf-8") as f:
- f.write(config)
-
- self.overlap = overlap
diff --git a/unsloth/dataprep/synthetic_configs.py b/unsloth/dataprep/synthetic_configs.py
deleted file mode 100644
index 2e536467e2..0000000000
--- a/unsloth/dataprep/synthetic_configs.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-synthetic_qa_config = """\
-# Master configuration file for Synthetic Data Kit
-
-# Global paths configuration
-paths:
- # Input data locations
- input:
- pdf: "{data_output_location}/pdf"
- html: "{data_output_location}/html"
- youtube: "{data_output_location}/youtube"
- docx: "{data_output_location}/docx"
- ppt: "{data_output_location}/ppt"
- txt: "{data_output_location}/txt"
-
- # Output locations
- output:
- parsed: "{data_output_location}/output" # Where parsed text files are saved
- generated: "{data_output_location}/generated" # Where generated content is saved
- cleaned: "{data_output_location}/cleaned" # Where cleaned content is saved
- final: "{data_output_location}/final" # Where final formatted content is saved
-
-# VLLM server configuration
-vllm:
- api_base: "http://localhost:8000/v1" # Base URL for VLLM API
- port: 8000 # Port for VLLM server
- model: "{model_name}" # Default model to use
- max_retries: 3 # Number of retries for API calls
- retry_delay: 1.0 # Initial delay between retries (seconds)
-
-# Ingest configuration
-ingest:
- default_format: "txt" # Default output format for parsed files
- youtube_captions: "auto" # Options: "auto", "manual" - caption preference
-
-# LLM generation parameters
-generation:
- temperature: {temperature} # Higher = more creative, lower = more deterministic
- top_p: {top_p} # Nucleus sampling parameter
- chunk_size: {chunk_size} # Size of text chunks for processing
- overlap: {overlap} # Overlap between chunks to maintain context
- max_tokens: {max_tokens} # Maximum tokens in LLM responses
- num_pairs: {default_num_pairs} # Default number of QA pairs to generate
-
-# Content cleanup parameters
-cleanup:
- threshold: {cleanup_threshold} # Default quality threshold (1-10)
- batch_size: {cleanup_batch_size} # Number of items per batch for rating
- temperature: {cleanup_temperature} # Temperature for rating (lower = more consistent)
-
-# Format conversion parameters
-format:
- default: "jsonl" # Default output format
- include_metadata: true # Include metadata in output files
- pretty_json: true # Use indentation in JSON output
-
-# Prompts for different tasks
-prompts:
- # Summary generation prompt
- summary: |
- Summarize this document in 3-5 sentences, focusing on the main topic and key concepts.
-
- # QA pair generation prompt
- qa_generation: |
- Create {num_pairs} question-answer pairs from this text for LLM training.
-
- Rules:
- 1. Questions must be about important facts in the text
- 2. Answers must be directly supported by the text
- 3. Return JSON format only:
-
- [
- {{
- "question": "Question 1?",
- "answer": "Answer 1."
- }},
- {{
- "question": "Question 2?",
- "answer": "Answer 2."
- }}
- ]
-
- Text:
- {text}
-
- # QA pair rating prompt
- qa_rating: |
- Rate each of these question-answer pairs for quality and return exactly this JSON format:
-
- [
- {{"question": "same question text", "answer": "same answer text", "rating": n}}
- ]
-
- Where n is a number from 1-10.
-
- DO NOT include any text outside of the JSON array, just return valid JSON:
-
- {pairs}"""
diff --git a/unsloth/device_type.py b/unsloth/device_type.py
deleted file mode 100644
index a42d2b9fab..0000000000
--- a/unsloth/device_type.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-__all__ = [
- "is_hip",
- "get_device_type",
- "DEVICE_TYPE",
- "DEVICE_TYPE_TORCH",
- "DEVICE_COUNT",
- "ALLOW_PREQUANTIZED_MODELS",
- "ALLOW_BITSANDBYTES",
-]
-
-import torch
-import functools
-import inspect
-from unsloth_zoo.utils import Version
-
-
-@functools.cache
-def is_hip():
- return bool(getattr(getattr(torch, "version", None), "hip", None))
-
-
-@functools.cache
-def get_device_type():
- if hasattr(torch, "cuda") and torch.cuda.is_available():
- if is_hip():
- return "hip"
- return "cuda"
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- return "xpu"
- # Check torch.accelerator
- if hasattr(torch, "accelerator"):
- if not torch.accelerator.is_available():
- raise NotImplementedError(
- "Unsloth cannot find any torch accelerator? You need a GPU."
- )
- accelerator = str(torch.accelerator.current_accelerator())
- if accelerator in ("cuda", "xpu", "hip"):
- raise RuntimeError(
- f"Unsloth: Weirdly `torch.cuda.is_available()`, `torch.xpu.is_available()` and `is_hip` all failed.\n"
- f"But `torch.accelerator.current_accelerator()` works with it being = `{accelerator}`\n"
- f"Please reinstall torch - it's most likely broken :("
- )
- raise NotImplementedError(
- "Unsloth currently only works on NVIDIA, AMD and Intel GPUs."
- )
-
-
-DEVICE_TYPE: str = get_device_type()
-# HIP fails for autocast and other torch functions. Use CUDA instead
-DEVICE_TYPE_TORCH = DEVICE_TYPE
-if DEVICE_TYPE_TORCH == "hip":
- DEVICE_TYPE_TORCH = "cuda"
-
-
-@functools.cache
-def get_device_count():
- if DEVICE_TYPE in ("cuda", "hip"):
- return torch.cuda.device_count()
- elif DEVICE_TYPE == "xpu":
- return torch.xpu.device_count()
- else:
- return 1
-
-
-DEVICE_COUNT: int = get_device_count()
-
-# 4-bit quantization requires a block size of 64
-# | Device Type | Warp Size | Block Size |
-# |-----------------|-----------|------------|
-# | CUDA | 32 | 32 |
-# | Radeon (Navi) | 32 | 32 |
-# | Instinct (MI) | 64 | 32 |
-#
-# Since bitsandbytes 0.49.0, pre-quantized models with 64 blockwise now works
-# on Radeon GPUs, but not Instinct MI300x for eg
-# See https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1748
-#
-# Since bitsandbytes 0.49.2, blocksize=64 4-bit quantization is supported on
-# CDNA (MI Instinct / gfx9xx) GPUs as well
-# See https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1856
-
-ALLOW_PREQUANTIZED_MODELS: bool = True
-# HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB
-ALLOW_BITSANDBYTES: bool = True
-if DEVICE_TYPE == "hip":
- try:
- import bitsandbytes
- except:
- print(
- "Unsloth: `bitsandbytes` is not installed - 4bit QLoRA unallowed, but 16bit and full finetuning works."
- )
- ALLOW_PREQUANTIZED_MODELS = False
- ALLOW_BITSANDBYTES = False
- if ALLOW_BITSANDBYTES:
- ALLOW_BITSANDBYTES = Version(bitsandbytes.__version__) > Version("0.48.2.dev0")
- if Version(bitsandbytes.__version__) >= Version("0.49.2"):
- pass
- elif Version(bitsandbytes.__version__) >= Version("0.49.0"):
- try:
- # Pre-quantized bitsandbytes models use blocksize 64, so we need to check the GPU
- from bitsandbytes.cextension import ROCM_WARP_SIZE_64
-
- ALLOW_PREQUANTIZED_MODELS = not ROCM_WARP_SIZE_64
- except Exception as e:
- print(
- "Unsloth: Checking `from bitsandbytes.cextension import ROCM_WARP_SIZE_64` had error = \n"
- f"{str(e)}\n"
- "4bit QLoRA disabled for now, but 16bit and full finetuning works."
- )
- ALLOW_PREQUANTIZED_MODELS = False
- ALLOW_BITSANDBYTES = False
- elif ALLOW_BITSANDBYTES:
- from bitsandbytes.nn.modules import Params4bit
-
- if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource(
- Params4bit
- ):
- ALLOW_PREQUANTIZED_MODELS = False
diff --git a/unsloth/import_fixes.py b/unsloth/import_fixes.py
deleted file mode 100644
index ca44a0ce7e..0000000000
--- a/unsloth/import_fixes.py
+++ /dev/null
@@ -1,1823 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import importlib.abc
-import importlib.machinery
-import importlib.util
-from pathlib import Path
-from importlib.metadata import version as importlib_version
-from packaging.version import Version as TrueVersion
-import re
-import logging
-import textwrap
-import warnings
-import sys
-import functools
-
-# We cannot do from unsloth_zoo.log import logger since FBGEMM might cause seg faults.
-UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") in (
- "1",
- "True",
- "true",
-)
-logger = logging.getLogger(__name__)
-if UNSLOTH_ENABLE_LOGGING:
- logging.basicConfig(
- level = logging.INFO, format = "[%(name)s|%(levelname)s]%(message)s"
- )
- logger.setLevel(logging.INFO)
-else:
- logging.basicConfig(
- level = logging.WARNING, format = "[%(name)s|%(levelname)s]%(message)s"
- )
- logger.setLevel(logging.WARNING)
-
-_AMDGPU_IDS_MISSING_TEXT = "amdgpu.ids: No such file or directory"
-
-
-def Version(version):
- try:
- new_version = str(version)
- new_version = re.match(r"[0-9\.]{1,}", new_version)
- if new_version is None:
- raise Exception(str(e))
- new_version = new_version.group(0).rstrip(".")
- if new_version != version:
- new_version += ".1" # Add .1 for dev / alpha / beta / rc
- return TrueVersion(new_version)
- except:
- from inspect import getframeinfo, stack
-
- caller = getframeinfo(stack()[1][0])
- raise RuntimeError(
- f"Unsloth: Could not get version for `{version}`\n"
- f"File name = [{caller.filename}] Line number = [{caller.lineno}]"
- )
-
-
-# Ignore logging messages
-class HideLoggingMessage(logging.Filter):
- __slots__ = ("text",)
-
- def __init__(self, text):
- self.text = text
-
- def filter(self, x):
- return not (self.text in x.getMessage())
-
-
-class HidePrintMessage:
- def __init__(self, original_stream):
- self._original_stream = original_stream
- self._hidden_texts = []
-
- def add_filter(self, text):
- self._hidden_texts.append(text)
-
- def write(self, message):
- if not any(text in message for text in self._hidden_texts):
- self._original_stream.write(message)
-
- def flush(self):
- self._original_stream.flush()
-
- def __getattr__(self, name):
- return getattr(self._original_stream, name)
-
-
-import contextlib
-import ctypes
-
-try:
- _libc = ctypes.CDLL(None)
-except Exception:
- _libc = None
-
-
-@contextlib.contextmanager
-def suppress_cuda_printf():
- """Suppress CUDA device-side printf by redirecting stdout/stderr fds to /dev/null.
-
- CUDA device printf (eg CUTLASS "Arch conditional MMA" errors on Blackwell)
- writes to stdout fd 1 at the C level, bypassing Python sys.stdout entirely.
- The existing HidePrintMessage filter on sys.stderr cannot catch these since
- they go to a different fd at a different layer. This context manager redirects
- both fd 1 and fd 2 at the OS level, syncs CUDA, then restores them.
- """
- sys.stdout.flush()
- sys.stderr.flush()
- saved_fds = {}
- try:
- for fd in (1, 2):
- saved_fds[fd] = os.dup(fd)
- devnull = os.open(os.devnull, os.O_WRONLY)
- os.dup2(devnull, fd)
- os.close(devnull)
- yield
- finally:
- try:
- import torch
-
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- except Exception:
- pass
- if _libc is not None:
- try:
- _libc.fflush(None)
- except Exception:
- pass
- for fd, saved in saved_fds.items():
- os.dup2(saved, fd)
- os.close(saved)
-
-
-if not UNSLOTH_ENABLE_LOGGING:
- import sys
-
- # Apply to stderr for FBGEMM and CUTLASS errors
- sys.stderr = HidePrintMessage(sys.stderr)
- # https://github.com/pytorch/FBGEMM/blob/d99cd96490ec4aabac2ee95b1e76ea4dcfcfa628/fbgemm_gpu/experimental/gemm/triton_gemm/utils.py#L43-L52
- sys.stderr.add_filter("TMA benchmarks will be running")
- # CUTLASS/FBGEMM MMA instruction error on SM90 vs SM100 (Blackwell) GPUs
- # https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp
- sys.stderr.add_filter("Arch conditional MMA instruction used without targeting")
- # CUTLASS arch conditional errors for various architectures
- sys.stderr.add_filter("CUTE_INVALID_CONTROL_PATH")
- # CUTLASS TMA-related errors when not targeting correct architecture
- sys.stderr.add_filter("Trying to use tma without CUTE_ARCH_TMA")
- # Skipping import of cpp extensions due to incompatible torch version 2.9.0+cu128 for torchao version 0.15.0
- logging.getLogger("torchao").setLevel(logging.ERROR)
- # Also filter torchao print to stderr about cpp extensions
- sys.stderr.add_filter("Skipping import of cpp extensions")
- # SyntaxWarning: invalid escape sequence '\.'
- warnings.filterwarnings(
- "ignore", message = "invalid escape sequence", category = SyntaxWarning
- )
- # PYTORCH_CUDA_ALLOC_CONF is deprecated warning from torch
- warnings.filterwarnings("ignore", message = "PYTORCH_CUDA_ALLOC_CONF is deprecated")
- # TF32 precision deprecation warning from torch
- warnings.filterwarnings(
- "ignore", message = "Please use the new API settings to control TF32"
- )
- # Deprecation warnings from torchao
- warnings.filterwarnings("ignore", message = "`int4_weight_only` is deprecated")
- warnings.filterwarnings("ignore", message = "`int8_weight_only` is deprecated")
-
- # TorchAO deprecated import paths (https://github.com/pytorch/ao/issues/2752)
- warnings.filterwarnings(
- "ignore",
- message = r"Importing.*from torchao\.dtypes.*is deprecated",
- category = DeprecationWarning,
- )
- warnings.filterwarnings(
- "ignore",
- message = r"Importing BlockSparseLayout from torchao\.dtypes is deprecated",
- category = DeprecationWarning,
- )
-
- # SWIG builtin type warnings (from bitsandbytes/triton SWIG bindings)
- warnings.filterwarnings(
- "ignore",
- message = r"builtin type Swig.*has no __module__ attribute",
- category = DeprecationWarning,
- )
-
- # Triton autotuner deprecation (https://github.com/triton-lang/triton/pull/4496)
- warnings.filterwarnings(
- "ignore",
- message = r"warmup, rep, and use_cuda_graph parameters are deprecated",
- category = DeprecationWarning,
- )
-
- # Python 3.12+ multiprocessing fork warning in multi-threaded processes
- warnings.filterwarnings(
- "ignore",
- message = r".*multi-threaded.*use of fork\(\) may lead to deadlocks",
- category = DeprecationWarning,
- )
-
- # Resource warnings from internal socket/file operations
- warnings.filterwarnings(
- "ignore", message = r"unclosed.*socket", category = ResourceWarning
- )
- warnings.filterwarnings(
- "ignore", message = r"unclosed file.*dev/null", category = ResourceWarning
- )
-
- # torch 2.9+ pin_memory/is_pinned device arg deprecation
- warnings.filterwarnings(
- "ignore",
- message = r"The `device` argument is deprecated",
- category = DeprecationWarning,
- )
- warnings.filterwarnings(
- "ignore",
- message = r".*pin_memory.*device.*deprecated",
- category = DeprecationWarning,
- )
- warnings.filterwarnings(
- "ignore",
- message = r".*is_pinned.*device.*deprecated",
- category = DeprecationWarning,
- )
-
- # vllm "Level is deprecated" stderr noise
- sys.stderr.add_filter("Level is deprecated")
-
- # PydanticSerializationUnexpectedValue warning
- warnings.filterwarnings(
- "ignore",
- message = r".*PydanticSerializationUnexpectedValue",
- )
- warnings.filterwarnings(
- "ignore",
- message = r"Expected.*but got.*with value.*is not.*subclass",
- )
-
- # Triton "df: No such file or directory" stderr noise
- sys.stderr.add_filter("df: No such file")
- # ROCm/libdrm missing ids table stderr noise on some AMD setups
- sys.stderr.add_filter(_AMDGPU_IDS_MISSING_TEXT)
- # Apex ROCm fused RoPE backend selection warning when Aiter is enabled.
- warnings.filterwarnings(
- "ignore",
- message = r"^Aiter backend is selected for fused RoPE\.?",
- category = UserWarning,
- module = r"^apex\.transformer\.functional\.fused_rope$",
- )
-
-
-# Fix up AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
-# MUST do this at the start primarily due to tensorflow causing issues
-def fix_message_factory_issue():
- try:
- import google.protobuf.message_factory
-
- class MessageFactory:
- def CreatePrototype(self, *args, **kwargs):
- return
-
- def GetMessages(self, *args, **kwargs):
- return
-
- def GetPrototype(self, *args, **kwargs):
- return
-
- if not hasattr(google.protobuf.message_factory, "MessageFactory"):
- logger.info("Unsloth: Patching protobuf.MessageFactory as it doesn't exist")
- google.protobuf.message_factory.MessageFactory = MessageFactory
- elif (
- hasattr(google.protobuf.message_factory, "MessageFactory")
- and not hasattr(
- google.protobuf.message_factory.MessageFactory, "GetPrototype"
- )
- and not hasattr(google.protobuf.message_factory, "GetMessageClass")
- ):
- google.protobuf.message_factory.MessageFactory = MessageFactory
- logger.info("Unsloth: Patching protobuf.MessageFactory as it doesn't exist")
- elif (
- hasattr(google.protobuf.message_factory, "MessageFactory")
- and not hasattr(
- google.protobuf.message_factory.MessageFactory, "GetPrototype"
- )
- and hasattr(google.protobuf.message_factory, "GetMessageClass")
- ):
- GetMessageClass = google.protobuf.message_factory.GetMessageClass
-
- def GetPrototype(self, descriptor):
- return GetMessageClass(descriptor)
-
- google.protobuf.message_factory.MessageFactory.GetPrototype = GetPrototype
- logger.info("Unsloth: Patching protobuf.MessageFactory.GetPrototype")
- pass
- except:
- pass
-
-
-# Fix Xformers performance issues since 0.0.25
-def fix_xformers_performance_issue():
- spec = importlib.util.find_spec("xformers")
- if spec is None:
- return
- xformers_version = importlib_version("xformers")
- if Version(xformers_version) < Version("0.0.29"):
- xformers_location = spec.origin
- if xformers_location is None:
- xformers_location = spec.submodule_search_locations[0]
- else:
- xformers_location = os.path.split(xformers_location)[0]
- cutlass = Path(xformers_location) / "ops" / "fmha" / "cutlass.py"
- try:
- if cutlass.exists():
- with open(cutlass, "r+", encoding = "utf-8") as f:
- text = f.read()
- # See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591
- if "num_splits_key=-1," in text:
- text = text.replace(
- "num_splits_key=-1,",
- "num_splits_key=None,",
- )
- f.seek(0)
- f.write(text)
- f.truncate()
- logger.info(
- "Unsloth: Patching Xformers to fix some performance issues."
- )
- except Exception as e:
- logger.info(f"Unsloth: Failed patching Xformers with error = {str(e)}")
-
-
-def patch_vllm_for_notebooks():
- import sys
-
- ipython = None
- try:
- from IPython import get_ipython as _get_ipython
- except Exception:
- _get_ipython = None
-
- if _get_ipython is not None:
- try:
- ipython = _get_ipython()
- except Exception:
- ipython = None
-
- if ipython is None:
- try:
- import builtins
-
- _get_ipython = getattr(builtins, "get_ipython", None)
- if callable(_get_ipython):
- ipython = _get_ipython()
- except Exception:
- ipython = None
-
- if ipython is None:
- return
-
- try:
- shell = ipython.__class__.__name__
- is_notebook = shell == "ZMQInteractiveShell" or "google.colab" in str(
- type(ipython)
- )
- except Exception:
- return
-
- if not is_notebook:
- return
-
- if not hasattr(sys.stdout, "fileno"):
- return
-
- needs_patch = False
- try:
- fd = sys.stdout.fileno()
- if not isinstance(fd, int) or fd < 0:
- needs_patch = True
- except Exception:
- needs_patch = True
-
- if not needs_patch:
- return
-
- logger.info(
- "Unsloth: Notebook detected - Patching sys.stdout.fileno for newer `vllm>=0.12.0` versions"
- )
- sys.stdout.fileno = lambda: 1
-
-
-# ValueError: 'aimv2' is already used by a Transformers config, pick another name.
-def fix_vllm_aimv2_issue():
- spec = importlib.util.find_spec("vllm")
- if spec is None:
- return
- vllm_version = importlib_version("vllm")
- if Version(vllm_version) < Version("0.10.1"):
- vllm_location = spec.origin
- if vllm_location is None:
- vllm_location = spec.submodule_search_locations[0]
- else:
- vllm_location = os.path.split(vllm_location)[0]
- ovis_config = Path(vllm_location) / "transformers_utils" / "configs" / "ovis.py"
- try:
- if ovis_config.exists():
- with open(ovis_config, "r+", encoding = "utf-8") as f:
- text = f.read()
- # See https://github.com/vllm-project/vllm-ascend/issues/2046
- if 'AutoConfig.register("aimv2", AIMv2Config)' in text:
- text = text.replace(
- 'AutoConfig.register("aimv2", AIMv2Config)',
- "",
- )
- text = text.replace(
- """backbone_config.pop('model_type')
- backbone_config = AutoConfig.for_model(model_type,
- **backbone_config)""",
- """if model_type != "aimv2":
- backbone_config.pop('model_type')
- backbone_config = AutoConfig.for_model(model_type, **backbone_config)
- else:
- backbone_config = AIMv2Config(**backbone_config)""",
- )
- f.seek(0)
- f.write(text)
- f.truncate()
- logger.info(
- "Unsloth: Patching vLLM to fix `'aimv2' is already used by a Transformers config, pick another name.`"
- )
- except Exception as e:
- logger.info(f"Unsloth: Failed patching vLLM with error = {str(e)}")
-
-
-def fix_vllm_guided_decoding_params():
- def _maybe_raise_vllm_transformers_mismatch(error):
- error_text = str(error)
- if (
- "ALLOWED_LAYER_TYPES" in error_text
- or "transformers.configuration_utils" in error_text
- ):
- try:
- vllm_version = importlib_version("vllm")
- except Exception:
- vllm_version = "unknown"
- raise RuntimeError(
- "Unsloth: vLLM with version "
- f"{vllm_version} does not yet support transformers>=5.0.0. "
- "Please downgrade to transformers==4.57.3 via "
- 'pip install --force-reinstall "transformers==4.57.3". '
- f"Original error: {error}"
- ) from error
-
- if importlib.util.find_spec("vllm") is None:
- return
- # GuidedDecodingParmas is renamed to StructuredOutputsParams in vLLM
- # https://github.com/vllm-project/vllm/pull/22772/files
- # trl still wants to use GuidedDecodingParams. This is a temporary patch till trl updates
- try:
- import vllm
- except (ImportError, OSError) as e:
- _maybe_raise_vllm_transformers_mismatch(e)
- if disable_broken_vllm(e):
- return
- raise
-
- try:
- from vllm.sampling_params import GuidedDecodingParams
- except (ImportError, OSError) as e:
- _maybe_raise_vllm_transformers_mismatch(e)
- if disable_broken_vllm(e):
- return
- if not hasattr(vllm, "sampling_params") or not hasattr(
- vllm.sampling_params, "StructuredOutputsParams"
- ):
- raise
- vllm.sampling_params.GuidedDecodingParams = (
- vllm.sampling_params.StructuredOutputsParams
- )
-
-
-def ignore_logger_messages():
- # Ignore Environment variable `HF_TOKEN` is set
- try:
- from huggingface_hub._login import logger as huggingface_hub_logger
-
- huggingface_hub_logger.addFilter(HideLoggingMessage("`HF_TOKEN`"))
- del huggingface_hub_logger
- except:
- pass
-
-
-def patch_ipykernel_hf_xet():
- # HF-XET == 1.1.10 and ipykernel == 7.0.0 / 7.0.1 causes issues
- # See https://github.com/huggingface/xet-core/issues/526
- # 2025-10-13T20:37:33.028737Z ERROR Python exception updating progress:, error: PyErr { type: , value: LookupError(), traceback: Some() }, caller: "src/progress_update.rs:313"
- # at /home/runner/work/xet-core/xet-core/error_printer/src/lib.rs:28
- if importlib.util.find_spec("hf_xet") is None:
- return
- if importlib.util.find_spec("ipykernel") is None:
- return
- if importlib.util.find_spec("huggingface_hub") is None:
- return
-
- ipykernel_version = Version(importlib_version("ipykernel"))
- if (
- (Version(importlib_version("hf_xet")) == Version("1.1.10"))
- and (
- (ipykernel_version == Version("7.0.0"))
- or (
- ipykernel_version == Version("7.0.1")
- ) # 7.0.1 seems to also break with LookupError:
- )
- ):
- print(
- "#### Unsloth: `hf_xet==1.1.10` and `ipykernel==7.0.0` or `ipykernel==7.0.1` breaks progress bars. Using ASCII progress bars.\n"
- "#### Unsloth: To re-enable progress bars, please upgrade to `ipykernel>=7.1.0` or wait for a fix to\n"
- "https://github.com/huggingface/xet-core/issues/526"
- )
- from huggingface_hub.utils import disable_progress_bars
-
- disable_progress_bars()
-
-
-def patch_trackio():
- # Set some environment variables to customize the Trackio dashboard for experiment tracking
- # See https://github.com/unslothai/notebooks/pull/110
- os.environ["TRACKIO_LOGO_LIGHT_URL"] = (
- "https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png"
- )
- os.environ["TRACKIO_LOGO_DARK_URL"] = (
- "https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png"
- )
- os.environ["TRACKIO_PLOT_ORDER"] = "train/reward"
-
-
-def patch_datasets():
- # Datasets 4.4.0 and 4.4.1 weirdly have some weird `_thread.RLock_recursion_count` issues
- if importlib.util.find_spec("datasets") is None:
- return
-
- datasets_version = Version(importlib_version("datasets"))
- if (datasets_version <= Version("4.5.0")) and (
- datasets_version >= Version("4.4.0")
- ):
- raise NotImplementedError(
- f"#### Unsloth: Using `datasets = {str(datasets_version)}` will cause recursion errors.\n"
- "Please downgrade datasets to `datasets==4.3.0"
- )
-
-
-def check_fbgemm_gpu_version():
- if importlib.util.find_spec("fbgemm_gpu") is None:
- return
- try:
- fbgemm_gpu_version = importlib_version("fbgemm_gpu_genai")
- except:
- return
- # We noticed some SegFault or bad alloc errors on lower versions of fbgemm_gpu.
- # Instead of raising an error, disable FBGEMM and fall back to Triton kernels.
- if Version(fbgemm_gpu_version) < Version("1.4.0"):
- os.environ["UNSLOTH_HAS_FBGEMM"] = "0"
- logger.info(
- f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} is old and may cause issues. "
- f"Disabling FBGEMM - using Triton kernels instead."
- )
- return
-
- logger.info(f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} detected.")
-
-
-def patch_enable_input_require_grads():
- """
- Patch transformers PreTrainedModel.enable_input_require_grads to handle vision models
- that raise NotImplementedError from get_input_embeddings().
-
- """
- import inspect
- from transformers import PreTrainedModel
-
- # Check if the original function iterates over self.modules() instead of just returning the enable_input_require_grads
- # Ref: https://github.com/huggingface/transformers/pull/41993/files#diff-6b72b98c4c2dcfc6cc606843917733f5d858374fbc22a735ff483bbc0c1e63eaL1979-R1996
- try:
- original_source = inspect.getsource(PreTrainedModel.enable_input_require_grads)
- except:
- return
-
- # Only patch if the new pattern exists (iterating over self.modules())
- if "for module in self.modules()" not in original_source:
- return
-
- def _patched_enable_input_require_grads(self):
- def make_inputs_require_grads(module, input, output):
- output.requires_grad_(True)
-
- hooks = []
- seen_modules = set()
-
- for module in self.modules():
- if not (
- isinstance(module, PreTrainedModel)
- and hasattr(module, "get_input_embeddings")
- ):
- continue
-
- try:
- input_embeddings = module.get_input_embeddings()
- except NotImplementedError:
- # Vision models may not implement get_input_embeddings - skip them
- # For GLM V4.6 for example, this skips only `self.visual`
- continue
-
- if input_embeddings is None:
- continue
-
- embedding_id = id(input_embeddings)
- if embedding_id in seen_modules:
- continue
-
- seen_modules.add(embedding_id)
- hooks.append(
- input_embeddings.register_forward_hook(make_inputs_require_grads)
- )
-
- self._require_grads_hooks = hooks
- if hooks:
- self._require_grads_hook = hooks[0]
-
- PreTrainedModel.enable_input_require_grads = _patched_enable_input_require_grads
-
- logger.info(
- "Unsloth: Patched enable_input_require_grads for vision model compatibility"
- )
-
-
-def _is_custom_torch_build(raw_version_str):
- """Check if a raw version string indicates a custom or source build.
- Must operate on the raw string from importlib_version(), not the parsed
- Version object, since our custom Version() strips local identifiers.
-
- Standard PyTorch releases use: +cu124, +rocm6.3, +cpu, +xpu
- Source/custom builds use: +gitXXXXXXX, +HEXHASH, or other suffixes.
- """
- if "+" not in raw_version_str:
- return False
- local = raw_version_str.split("+", 1)[1]
- if not local:
- return False
- # Use fullmatch so the entire local identifier must match, not just a prefix.
- # cu/rocm require a trailing digit (e.g. cu124, rocm6.3). cpu/xpu are exact.
- # Case-insensitive since some builds may use uppercase.
- return not re.fullmatch(r"cu\d[\d.]*|rocm\d[\d.]*|cpu|xpu", local, re.IGNORECASE)
-
-
-def _infer_required_torchvision(torch_major, torch_minor):
- """Infer the minimum required torchvision minor version from torch version.
-
- The torch -> torchvision minor version mapping follows a consistent formula:
- torch 1.x -> torchvision 0.(x + 1) (verified: torch 1.7 through 1.13)
- torch 2.x -> torchvision 0.(x + 15) (verified: torch 2.0 through 2.9)
-
- Returns (tv_major, tv_minor) or None if the major version is unrecognized.
- """
- if torch_major == 1 and torch_minor >= 7:
- return (0, torch_minor + 1)
- if torch_major == 2:
- return (0, torch_minor + 15)
- return None
-
-
-def torchvision_compatibility_check():
- # Allow skipping via environment variable for custom environments
- if os.environ.get("UNSLOTH_SKIP_TORCHVISION_CHECK", "0").lower() in ("1", "true"):
- return
-
- if importlib.util.find_spec("torch") is None:
- raise ImportError("Unsloth: torch not found. Please install torch first.")
- if importlib.util.find_spec("torchvision") is None:
- return
-
- try:
- torch_version_raw = importlib_version("torch")
- torchvision_version_raw = importlib_version("torchvision")
- except Exception:
- return
-
- try:
- torch_v = Version(torch_version_raw)
- tv_v = Version(torchvision_version_raw)
- except Exception:
- return
-
- # Known compatibility table (ground truth, takes precedence over formula).
- # See https://pytorch.org/get-started/previous-versions/
- TORCH_TORCHVISION_COMPAT = {
- (2, 9): (0, 24),
- (2, 8): (0, 23),
- (2, 7): (0, 22),
- (2, 6): (0, 21),
- (2, 5): (0, 20),
- (2, 4): (0, 19),
- }
-
- # Extract major.minor from the parsed version
- torch_release = torch_v.release
- if len(torch_release) < 2:
- return
- torch_major, torch_minor = torch_release[0], torch_release[1]
-
- # Try known table first, then fall back to formula for forward compatibility
- required = TORCH_TORCHVISION_COMPAT.get((torch_major, torch_minor))
-
- if required is None:
- required = _infer_required_torchvision(torch_major, torch_minor)
-
- if required is None:
- return
-
- required_tv_str = f"{required[0]}.{required[1]}.0"
-
- if tv_v >= Version(required_tv_str):
- logger.info(
- f"Unsloth: torch=={torch_version_raw} and "
- f"torchvision=={torchvision_version_raw} are compatible."
- )
- return
-
- # Version mismatch detected
- message = (
- f"Unsloth: torch=={torch_version_raw} requires "
- f"torchvision>={required_tv_str}, "
- f"but found torchvision=={torchvision_version_raw}. "
- f'Try updating torchvision via `pip install --upgrade "torchvision>={required_tv_str}"`. '
- f"Please refer to https://pytorch.org/get-started/previous-versions/ "
- f"for more information."
- )
-
- is_custom = _is_custom_torch_build(torch_version_raw) or _is_custom_torch_build(
- torchvision_version_raw
- )
-
- # Detect nightly/dev/alpha/beta/rc builds from the raw version string.
- # These often have version mismatches that are expected.
- _pre_tags = (".dev", "a0", "b0", "rc", "alpha", "beta", "nightly")
- is_prerelease = any(t in torch_version_raw for t in _pre_tags) or any(
- t in torchvision_version_raw for t in _pre_tags
- )
-
- # Only downgrade to warning for custom/source or prerelease builds.
- # Stable mismatches should fail fast to prevent runtime operator errors.
- if is_custom or is_prerelease:
- reason = "custom/source build" if is_custom else "pre-release build"
- logger.warning(
- f"{message}\n"
- f"Detected a {reason}. "
- f"Continuing with a warning. "
- f"Set UNSLOTH_SKIP_TORCHVISION_CHECK=1 to silence this."
- )
- return
-
- raise ImportError(message)
-
-
-# Fix TRL OpenEnv 0.26 NameError: name 'SamplingParams' is not defined
-def fix_openenv_no_vllm():
- spec = importlib.util.find_spec("trl")
- if spec is None:
- return
- trl_location = spec.origin
- if trl_location is None:
- trl_location = spec.submodule_search_locations[0]
- else:
- trl_location = os.path.split(trl_location)[0]
- openenv = Path(trl_location) / "experimental" / "openenv" / "utils.py"
- if not openenv.exists():
- return
-
- try:
- with open(openenv, "r+", encoding = "utf-8") as f:
- text = f.read()
- bad = (
- "if is_vllm_available():\n"
- " from vllm import SamplingParams\n"
- " from vllm.sampling_params import GuidedDecodingParams\n"
- )
- replace_with = bad + (
- "else:\n"
- " from typing import Any\n"
- " SamplingParams = Any\n"
- " GuidedDecodingParams = Any\n"
- "\n"
- )
- if bad + "\n" + "\n" in text and replace_with not in text:
- text = text.replace(bad + "\n" + "\n", replace_with)
- f.seek(0)
- f.write(text)
- f.truncate()
- logger.info(
- "Unsloth: Patching TRL OpenEnv to fix SamplingParams not defined"
- )
- except Exception as e:
- logger.info(f"Unsloth: Failed patching TRL OpenEnv with error = {str(e)}")
-
-
-# Fix Exeuctorch needing get_mapped_key
-def fix_executorch():
- spec = importlib.util.find_spec("executorch")
- if spec is None:
- return
- executorch_location = spec.origin
- if executorch_location is None:
- executorch_location = spec.submodule_search_locations[0]
- else:
- executorch_location = os.path.split(executorch_location)[0]
- executorch = Path(executorch_location) / "examples" / "models" / "__init__.py"
- if not executorch.exists():
- return
-
- try:
- what = r"""
- import sys
- import types
- import re
- from typing import Any, Optional
- def get_mapped_key(key: str, mapping_dict: dict[str, str]) -> str:
- try:
- # Checks if there is a layer # in the key
- if any(k.isdigit() for k in key.split(".")):
- # Replace layer number with "{}" to create key for lookup
- abstract_key = re.sub(r"(\.\d+)", ".{}", key)
- layer_num = re.search(r"\d+", key).group(0)
- new_key = mapping_dict[abstract_key]
- new_key = new_key.format(layer_num)
- else:
- new_key = mapping_dict[key]
- except KeyError as e:
- raise Exception(
- f'Error converting the state dict. Found unexpected key: "{key}". '
- "Please make sure you're loading a checkpoint with the right format. "
- ) from e
-
- return new_key
-
- torchtune = types.ModuleType("torchtune")
- torchtune.__path__ = []
- models = types.ModuleType("torchtune.models")
- models.__path__ = []
- convert_weights = types.ModuleType("torchtune.models.convert_weights")
- convert_weights.get_mapped_key = get_mapped_key
- torchtune.models = models
- models.convert_weights = convert_weights
- sys.modules["torchtune"] = torchtune
- sys.modules["torchtune.models"] = models
- sys.modules["torchtune.models.convert_weights"] = convert_weights
- """
- what = textwrap.dedent(what)
-
- with open(executorch, "r+", encoding = "utf-8") as f:
- text = f.read()
- bad = "from enum import Enum\n"
- if bad in text and what not in text:
- text = text.replace(bad + "\n", bad + "\n" + what)
- f.seek(0)
- f.write(text)
- f.truncate()
- logger.info("Unsloth: Patching Executorch to fix get_mapped_key")
- except Exception as e:
- logger.info(f"Unsloth: Failed Executorch with error = {str(e)}")
-
-
-def fix_diffusers_warnings():
- # Silence Flax classes are deprecated and will be removed in Diffusers v1.0.0.
- os.environ["DIFFUSERS_VERBOSITY"] = "error"
-
-
-def fix_huggingface_hub():
- # huggingface_hub.is_offline_mode got removed, so add it back
- import huggingface_hub
-
- if not hasattr(huggingface_hub, "is_offline_mode"):
- huggingface_hub.is_offline_mode = (
- lambda: huggingface_hub.constants.HF_HUB_OFFLINE
- )
-
-
-def fix_triton_compiled_kernel_missing_attrs():
- """
- Triton 3.6.0+ removed direct `num_ctas` and `cluster_dims` attributes from
- CompiledKernel, but torch 2.9.x Inductor still expects them in
- torch/_inductor/runtime/triton_heuristics.py make_launcher() (line ~1757).
-
- The scope dict eagerly evaluates:
- binary.metadata.num_ctas, *binary.metadata.cluster_dims
- when hasattr(binary, "metadata") is True, but metadata lacks cluster_dims.
- This crashes before reaching the new launch path that doesn't need cta_args.
-
- Upstream fix: pytorch/pytorch@97bd4db added hasattr guards.
- We monkey-patch CompiledKernel.__init__ to inject the missing attributes
- so the older hasattr(binary, "num_ctas") branch succeeds instead.
- """
- try:
- import torch
- except (ImportError, ModuleNotFoundError):
- return
-
- try:
- import triton
- import triton.compiler.compiler as triton_compiler
- except (ImportError, ModuleNotFoundError):
- return
-
- # Only needed when the CompiledKernel class lacks num_ctas as a direct attr
- # but has metadata (triton >= 3.6.0 with torch < 2.10)
- _ck_cls = triton_compiler.CompiledKernel
- if hasattr(_ck_cls, "num_ctas"):
- return # Old triton with direct attrs -- no patch needed
-
- _orig_init = _ck_cls.__init__
-
- def _patched_init(self, *args, **kwargs):
- _orig_init(self, *args, **kwargs)
- if not hasattr(self, "num_ctas"):
- self.num_ctas = getattr(self.metadata, "num_ctas", 1)
- if not hasattr(self, "cluster_dims") and not hasattr(self, "clusterDims"):
- self.cluster_dims = (1, 1, 1)
-
- _ck_cls.__init__ = _patched_init
- logger.info(
- "Unsloth: Patched triton CompiledKernel with num_ctas/cluster_dims "
- "for torch.compile compatibility."
- )
-
-
-def patch_trunc_normal_precision_issue():
- """
- Patch torch.nn.init.trunc_normal_ for low precision tensors to run init in fp32.
-
- torch.nn.init.trunc_normal_ can saturate at truncation bounds in fp16/bf16 on
- some versions/backends. This was observed in TorchTitan investigations where
- low-precision truncation produced boundary-heavy initialization behavior:
- https://github.com/pytorch/torchtitan/pull/2342
-
- To avoid that failure mode, initialize into a temporary fp32 tensor, then copy
- back to the original dtype.
- """
- try:
- import torch
- except (ImportError, ModuleNotFoundError):
- return
-
- if getattr(torch.nn.init, "_unsloth_trunc_normal_patched", False):
- return
-
- original_trunc_normal = torch.nn.init.trunc_normal_
- if getattr(original_trunc_normal, "__unsloth_trunc_normal_patched__", False):
- torch.nn.init._unsloth_trunc_normal_patched = True
- return
-
- low_precision_dtypes = {torch.float16, torch.bfloat16}
-
- def _call_original(target, mean, std, a, b, generator):
- if generator is None:
- return original_trunc_normal(target, mean = mean, std = std, a = a, b = b)
- try:
- return original_trunc_normal(
- target, mean = mean, std = std, a = a, b = b, generator = generator
- )
- except TypeError as exc:
- # Older torch versions may not accept a generator keyword argument.
- msg = str(exc).lower()
- if "unexpected keyword argument" in msg and "generator" in msg:
- return original_trunc_normal(target, mean = mean, std = std, a = a, b = b)
- raise
-
- try:
- from torch.distributed._tensor import DTensor
- except Exception:
- DTensor = None
-
- @torch.no_grad()
- def _patched_trunc_normal_(
- tensor,
- mean: float = 0.0,
- std: float = 1.0,
- a: float = -2.0,
- b: float = 2.0,
- generator = None,
- ):
- if DTensor is not None and isinstance(tensor, DTensor):
- local_tensor = getattr(tensor, "_local_tensor", None)
- if local_tensor is None:
- return _call_original(tensor, mean, std, a, b, generator)
- if local_tensor.dtype in low_precision_dtypes:
- local_fp32 = local_tensor.float()
- _call_original(local_fp32, mean, std, a, b, generator)
- local_tensor.copy_(local_fp32.to(dtype = local_tensor.dtype))
- return tensor
- return _call_original(tensor, mean, std, a, b, generator)
-
- if tensor.dtype in low_precision_dtypes:
- tensor_fp32 = tensor.float()
- _call_original(tensor_fp32, mean, std, a, b, generator)
- tensor.copy_(tensor_fp32.to(dtype = tensor.dtype))
- return tensor
-
- return _call_original(tensor, mean, std, a, b, generator)
-
- _patched_trunc_normal_.__unsloth_trunc_normal_patched__ = True
- _patched_trunc_normal_._unsloth_original = original_trunc_normal
- torch.nn.init._unsloth_trunc_normal_original = original_trunc_normal
- torch.nn.init.trunc_normal_ = _patched_trunc_normal_
- torch.nn.init._unsloth_trunc_normal_patched = True
- logger.info("Unsloth: Patched torch.nn.init.trunc_normal_ for fp16/bf16 stability.")
-
-
-def check_vllm_torch_sm100_compatibility():
- """
- Check for incompatible vLLM + torch < 2.9.0 + SM100 (Blackwell) combination.
-
- vLLM's distributed module (device_communicators) crashes with std::bad_alloc
- when imported on SM100 GPUs (B200/B100) with torch < 2.9.0. This is due to
- C++ code in vLLM's NCCL/distributed layer being incompatible with older
- torch versions on the newer Blackwell architecture.
-
- This check runs early (before vLLM import) to provide a helpful error message
- instead of a cryptic std::bad_alloc crash.
- """
- # Check if vLLM is installed (without importing it)
- if importlib.util.find_spec("vllm") is None:
- return
-
- # Check torch version
- try:
- torch_version = Version(importlib_version("torch"))
- if torch_version >= Version("2.9.0"):
- return # torch >= 2.9.0 is compatible
- except Exception:
- return # Can't determine torch version, skip check
-
- # Check if any CUDA GPU is SM100 (Blackwell)
- try:
- import torch
-
- if not torch.cuda.is_available():
- return
-
- has_sm100 = False
- sm100_gpu_name = None
- for i in range(torch.cuda.device_count()):
- major, minor = torch.cuda.get_device_capability(i)
- if major == 10:
- has_sm100 = True
- sm100_gpu_name = torch.cuda.get_device_name(i)
- break
-
- if not has_sm100:
- return
- except Exception:
- return
-
- # Get vLLM version for the error message
- try:
- vllm_version = importlib_version("vllm")
- except Exception:
- vllm_version = "unknown"
-
- # Incompatible combination detected - raise helpful error
- raise RuntimeError(
- f"Unsloth: Incompatible configuration detected.\n\n"
- f" GPU: {sm100_gpu_name} (SM100 / Blackwell architecture)\n"
- f" torch version: {torch_version}\n"
- f" vLLM version: {vllm_version}\n\n"
- f"vLLM's distributed module crashes with std::bad_alloc on SM100 GPUs "
- f"(B200/B100/Blackwell) when using torch < 2.9.0.\n\n"
- f"To fix this, please upgrade torch:\n"
- f" pip install --upgrade torch>=2.9.0\n\n"
- f"Alternatively, if you don't need vLLM:\n"
- f" pip uninstall vllm"
- )
-
-
-def fix_vllm_pdl_blackwell():
- """
- Fix vLLM PDL (Programmatic Dependent Launch) bug on Blackwell GPUs (SM100).
-
- The issue: vLLM's LoRA Triton kernels use tl.extra.cuda.gdc_wait() for PDL
- optimization on SM90+ GPUs. This fails on SM100 (B200/B100) during CUDA graph
- capture because Triton's pipeliner can't handle gdc_wait in complex kernels.
-
- See: https://github.com/vllm-project/vllm/issues/30872
- """
- if importlib.util.find_spec("vllm") is None:
- return
-
- # Check if any CUDA GPU is SM100 (Blackwell)
- try:
- import torch
-
- if not torch.cuda.is_available():
- return
-
- # Scan all GPUs for SM100 - fix applies globally via env var and monkey-patch
- has_sm100 = False
- sm100_gpu_name = None
- for i in range(torch.cuda.device_count()):
- major, minor = torch.cuda.get_device_capability(i)
- if major == 10:
- has_sm100 = True
- sm100_gpu_name = torch.cuda.get_device_name(i)
- break
-
- if not has_sm100:
- return
- except Exception:
- return
-
- # Helper to check if module spec exists
- def _spec_exists(name):
- try:
- return importlib.util.find_spec(name) is not None
- except (ImportError, OSError, ModuleNotFoundError, ValueError):
- return False
-
- # Check if vLLM has the PDL-related modules before doing internet check
- has_utils = _spec_exists("vllm.lora.ops.triton_ops.utils")
- has_expand_op = _spec_exists("vllm.lora.ops.triton_ops.lora_expand_op")
- has_shrink_op = _spec_exists("vllm.lora.ops.triton_ops.lora_shrink_op")
-
- if not has_utils and not has_expand_op and not has_shrink_op:
- # Old vLLM version without PDL support - nothing to patch
- return
-
- # Check if vLLM version includes the fix
- VLLM_PDL_FIX_VERSION = "0.15.0"
- try:
- vllm_version = Version(importlib_version("vllm"))
- if vllm_version >= Version(VLLM_PDL_FIX_VERSION):
- logger.info(
- f"Unsloth: SM100 ({sm100_gpu_name}) detected but vLLM {vllm_version} "
- f"should include PDL fix - skipping workaround"
- )
- return
- except Exception as e:
- logger.debug(
- f"Unsloth: vLLM version check failed ({e}), applying PDL workaround."
- )
-
- # Apply the PDL fix
- os.environ["TRITON_DISABLE_PDL"] = "1"
-
- def fake_supports_pdl(*args, **kwargs):
- return False
-
- patched = []
- patched_names = set()
-
- def _record_patch(name):
- if name not in patched_names:
- patched.append(name)
- patched_names.add(name)
-
- # First, patch the source module (utils.py) where supports_pdl is defined.
- # This is critical because supports_pdl uses @lru_cache - we must clear the
- # cache to prevent stale cached results from the original function.
- try:
- utils_module = importlib.import_module("vllm.lora.ops.triton_ops.utils")
- if hasattr(utils_module, "supports_pdl"):
- original_fn = utils_module.supports_pdl
- if hasattr(original_fn, "cache_clear"):
- original_fn.cache_clear()
- utils_module.supports_pdl = fake_supports_pdl
- _record_patch("utils")
- except (ImportError, ModuleNotFoundError, AttributeError):
- pass
-
- # Also patch the consumer modules that import supports_pdl from utils.
- # This ensures the patched function is used even if the module was already
- # imported before this fix runs.
- consumer_modules = {
- "lora_expand_op": "vllm.lora.ops.triton_ops.lora_expand_op",
- "lora_shrink_op": "vllm.lora.ops.triton_ops.lora_shrink_op",
- "fused_moe_lora_op": "vllm.lora.ops.triton_ops.fused_moe_lora_op",
- }
- for name, path in consumer_modules.items():
- try:
- module = importlib.import_module(path)
- if hasattr(module, "supports_pdl"):
- module.supports_pdl = fake_supports_pdl
- _record_patch(name)
- except (ImportError, ModuleNotFoundError, AttributeError):
- pass
-
- # Patch any additional already-loaded triton ops consumers that expose supports_pdl.
- for module_name, module in tuple(sys.modules.items()):
- if not module_name.startswith("vllm.lora.ops.triton_ops."):
- continue
- if module is None or not hasattr(module, "supports_pdl"):
- continue
- module.supports_pdl = fake_supports_pdl
- _record_patch(module_name.rsplit(".", 1)[-1])
-
- if patched:
- logger.info(
- f"Unsloth: Applied PDL fix for SM100 ({sm100_gpu_name}) - "
- f"patched: {', '.join(patched)}"
- )
- else:
- # Just set the env var - vLLM might be an older version without supports_pdl
- logger.info(f"Unsloth: Set TRITON_DISABLE_PDL=1 for SM100 ({sm100_gpu_name})")
-
-
-def patch_openspiel_env_async():
- """Apply nest_asyncio for OpenEnv EnvClient async compatibility.
-
- OpenEnv's EnvClient uses async methods (reset/step). In Jupyter notebooks
- these work via top-level await, but converted scripts need
- asyncio.get_event_loop().run_until_complete() wrappers. Applying nest_asyncio
- ensures nested event loop calls work in all contexts without replacing the
- original async methods (which would break scripts that already have their own
- sync wrappers).
- """
- try:
- import inspect
- from openenv.core.env_client import EnvClient
-
- if not inspect.iscoroutinefunction(EnvClient.reset):
- return # Already sync, nothing to do
-
- try:
- import nest_asyncio
-
- nest_asyncio.apply()
- logger.info(
- "Unsloth: Applied nest_asyncio for OpenEnv EnvClient async compatibility"
- )
- except ImportError:
- logger.info(
- "Unsloth: nest_asyncio not installed, OpenEnv async methods may need manual wrapping"
- )
- except (ImportError, AttributeError):
- pass # openenv not installed
-
-
-def patch_torchcodec_audio_decoder():
- """Call unsloth_zoo's AudioDecoder patch."""
- try:
- from unsloth_zoo.dataset_utils import patch_torchcodec_audio_decoder as _patch
-
- _patch()
- except (ImportError, AttributeError, RuntimeError):
- pass
-
-
-def disable_torchcodec_if_broken():
- """Disable torchcodec in transformers if it cannot actually load.
-
- transformers checks if torchcodec is installed via importlib.util.find_spec(),
- but this returns True even when torchcodec cannot load its native libraries
- (e.g., when FFmpeg is missing). This causes runtime errors when transformers
- tries to use torchcodec for audio loading.
-
- This function tests if torchcodec can actually load and if not, patches
- transformers to think torchcodec is unavailable so it falls back to librosa.
- """
- try:
- import importlib.util
-
- if importlib.util.find_spec("torchcodec") is None:
- return # torchcodec not installed, nothing to do
-
- # Test if torchcodec can actually load
- from torchcodec.decoders import AudioDecoder
- except (ImportError, RuntimeError, OSError):
- # torchcodec cannot load - disable it in transformers
- try:
- import transformers.utils.import_utils as tf_import_utils
-
- tf_import_utils._torchcodec_available = False
- except (ImportError, AttributeError):
- pass
-
-
-def disable_broken_wandb():
- """Disable wandb if it's installed but cannot actually import.
-
- wandb can fail to import when there's a protobuf version mismatch
- (e.g., wandb < 0.19.11 with protobuf >= 6.0). This causes cascading
- import failures through trl -> transformers/accelerate -> wandb that
- crash unsloth's import chain.
-
- There are two separate is_wandb_available() functions used by trl:
- - transformers.integrations.integration_utils.is_wandb_available
- (used by most trl trainers)
- - accelerate.utils.imports.is_wandb_available
- (used by trl/trainer/callbacks.py)
-
- Both must be patched to fully prevent broken wandb imports.
- """
- if importlib.util.find_spec("wandb") is None:
- return # wandb not installed, nothing to do
-
- try:
- import wandb
- except Exception:
- # wandb is installed but broken - patch all checkers to skip it
- logger.info(
- "Unsloth: wandb is installed but broken (likely a protobuf version mismatch). "
- "Disabling wandb to prevent import errors. To fix, run: pip install --upgrade wandb"
- )
- _wandb_false = lambda: False
- # Patch transformers' is_wandb_available (used by most trl trainers)
- try:
- import transformers.integrations.integration_utils as tf_integration
-
- tf_integration.is_wandb_available = _wandb_false
- except (ImportError, AttributeError):
- pass
- # Patch accelerate's is_wandb_available (used by trl/trainer/callbacks.py).
- # Must patch both the source module AND the re-export namespace since
- # `from accelerate.utils import is_wandb_available` reads from
- # accelerate.utils, not accelerate.utils.imports.
- try:
- import accelerate.utils.imports as acc_imports
-
- acc_imports.is_wandb_available = _wandb_false
- except (ImportError, AttributeError):
- pass
- try:
- import accelerate.utils as acc_utils
-
- acc_utils.is_wandb_available = _wandb_false
- except (ImportError, AttributeError):
- pass
- # Set env var as additional fallback
- os.environ["WANDB_DISABLED"] = "true"
-
-
-CAUSAL_CONV1D_BROKEN = False
-_CAUSAL_CONV1D_PREFIX = "causal_conv1d"
-_CAUSAL_CONV1D_BLOCKER_SENTINEL = "_unsloth_causal_conv1d_blocker"
-VLLM_BROKEN = False
-_VLLM_PREFIX = "vllm"
-_VLLM_BLOCKER_SENTINEL = "_unsloth_vllm_blocker"
-_ROCM_ENV_HINT_KEYS = (
- "ROCM_PATH",
- "ROCM_HOME",
- "HIP_PATH",
- "HSA_PATH",
- "HIP_VISIBLE_DEVICES",
- "ROCR_VISIBLE_DEVICES",
-)
-_ROCM_PATH_HINTS = (
- Path("/opt/rocm"),
- Path("/dev/kfd"),
- Path("/sys/module/amdgpu"),
-)
-_AMDGPU_ASIC_ID_TABLE_PATH_ENV = "AMDGPU_ASIC_ID_TABLE_PATH"
-_AMDGPU_ASIC_ID_CANDIDATE_PATHS = (
- Path("/usr/share/libdrm/amdgpu.ids"),
- Path("/usr/local/share/libdrm/amdgpu.ids"),
- Path("/opt/rocm/share/libdrm/amdgpu.ids"),
- Path("/opt/amdgpu/share/libdrm/amdgpu.ids"),
-)
-
-
-def _log_rocm_detection(message):
- if UNSLOTH_ENABLE_LOGGING:
- logger.info(message)
-
-
-@functools.lru_cache(1)
-def _is_rocm_torch_build() -> bool:
- # Most official ROCm wheels include a local version suffix like +rocmX.Y.
- # Some custom/source builds do not, so we fall back to runtime hints.
- try:
- torch_version_raw = str(importlib_version("torch")).lower()
- if "rocm" in torch_version_raw:
- _log_rocm_detection(
- "Unsloth: ROCm detection matched torch version tag (+rocm)."
- )
- return True
- except Exception:
- pass
-
- # Environment hints commonly present on ROCm runtimes.
- for key in _ROCM_ENV_HINT_KEYS:
- value = os.environ.get(key, "")
- if isinstance(value, str) and value.strip():
- _log_rocm_detection(
- f"Unsloth: ROCm detection matched environment key `{key}`."
- )
- return True
-
- # Filesystem / driver hints for ROCm stacks.
- for path in _ROCM_PATH_HINTS:
- try:
- if path.exists():
- _log_rocm_detection(
- f"Unsloth: ROCm detection matched filesystem hint `{path}`."
- )
- return True
- except Exception:
- continue
-
- _log_rocm_detection("Unsloth: ROCm detection did not match any known hints.")
- return False
-
-
-def _iter_amdgpu_asic_id_table_candidates():
- # Try torch-adjacent ids table paths first without importing torch.
- try:
- torch_spec = importlib.util.find_spec("torch")
- except Exception:
- torch_spec = None
-
- roots = []
- if torch_spec is not None:
- if torch_spec.origin:
- roots.append(Path(torch_spec.origin).resolve().parent)
- if torch_spec.submodule_search_locations:
- for location in torch_spec.submodule_search_locations:
- roots.append(Path(location).resolve())
-
- seen = set()
- for root in roots:
- for candidate in (
- root / "share" / "libdrm" / "amdgpu.ids",
- root.parent / "share" / "libdrm" / "amdgpu.ids",
- root.parent.parent / "share" / "libdrm" / "amdgpu.ids",
- ):
- candidate_str = str(candidate)
- if candidate_str in seen:
- continue
- seen.add(candidate_str)
- yield candidate
-
- for candidate in _AMDGPU_ASIC_ID_CANDIDATE_PATHS:
- candidate_str = str(candidate)
- if candidate_str in seen:
- continue
- seen.add(candidate_str)
- yield candidate
-
-
-def configure_amdgpu_asic_id_table_path():
- # Honor an existing valid user-provided path.
- configured = os.environ.get(_AMDGPU_ASIC_ID_TABLE_PATH_ENV, "").strip()
- if configured:
- configured_path = Path(configured)
- try:
- if configured_path.is_file():
- return str(configured_path)
- except Exception:
- pass
-
- # Only attempt this on ROCm-like environments.
- if not _is_rocm_torch_build():
- return None
-
- for candidate in _iter_amdgpu_asic_id_table_candidates():
- try:
- if candidate.is_file():
- os.environ[_AMDGPU_ASIC_ID_TABLE_PATH_ENV] = str(candidate)
- if UNSLOTH_ENABLE_LOGGING:
- logger.info(
- f"Unsloth: Set {_AMDGPU_ASIC_ID_TABLE_PATH_ENV}={candidate}"
- )
- return str(candidate)
- except Exception:
- continue
-
- return None
-
-
-def _is_causal_conv1d_name(module_name: str) -> bool:
- return module_name == _CAUSAL_CONV1D_PREFIX or module_name.startswith(
- _CAUSAL_CONV1D_PREFIX + "."
- )
-
-
-def _is_vllm_name(module_name: str) -> bool:
- return module_name == _VLLM_PREFIX or module_name.startswith(_VLLM_PREFIX + ".")
-
-
-def _resolve_module_name(module_name, package):
- if not isinstance(module_name, str):
- return module_name
- if module_name.startswith("."):
- try:
- return importlib.util.resolve_name(module_name, package)
- except Exception:
- return module_name
- return module_name
-
-
-def _is_broken_causal_conv1d_error(error) -> bool:
- checked = set()
- current = error
- while current is not None and id(current) not in checked:
- checked.add(id(current))
- message = str(current).lower()
- if (
- ("causal_conv1d_cuda" in message and "undefined symbol" in message)
- or ("_zn3c103hip28c10_hip_check_implementation" in message)
- or ("causal_conv1d" in message and "undefined symbol" in message)
- ):
- return True
- current = getattr(current, "__cause__", None) or getattr(
- current, "__context__", None
- )
- return False
-
-
-def _is_broken_vllm_error(error) -> bool:
- checked = set()
- current = error
- while current is not None and id(current) not in checked:
- checked.add(id(current))
- message = str(current).lower()
- if (
- ("vllm/_c" in message or "vllm._c" in message)
- and (
- "undefined symbol" in message
- or "cannot open shared object file" in message
- or ".so:" in message
- )
- ) or ("vllm" in message and "undefined symbol" in message):
- return True
- # Also catch CUDA shared library mismatches during vllm import
- # e.g. "libcudart.so.12: cannot open shared object file"
- if (
- "libcudart" in message or "libcublas" in message or "libnvrtc" in message
- ) and "cannot open shared object file" in message:
- return True
- current = getattr(current, "__cause__", None) or getattr(
- current, "__context__", None
- )
- return False
-
-
-def _get_vllm_cuda_mismatch_message(error):
- """If the error is a CUDA version mismatch, return a helpful install message."""
- import re as _re
-
- checked = set()
- current = error
- wanted_cuda = None
- while current is not None and id(current) not in checked:
- checked.add(id(current))
- message = str(current)
- # Extract the CUDA version vllm was built for, e.g. "libcudart.so.12"
- match = _re.search(r"libcudart\.so\.(\d+)", message)
- if match:
- wanted_cuda = match.group(1)
- break
- current = getattr(current, "__cause__", None) or getattr(
- current, "__context__", None
- )
- if wanted_cuda is None:
- return None
-
- # Detect what CUDA version is actually available on the system
- system_cuda_display = None # Human-readable, e.g. "13.0"
- system_cuda_tag = None # For wheel URL, e.g. "130"
- try:
- import torch
-
- cuda_version = torch.version.cuda # e.g. "13.0" or "12.8"
- if cuda_version:
- system_cuda_display = cuda_version
- system_cuda_tag = cuda_version.replace(".", "")[:3] # "130" or "128"
- except Exception:
- pass
-
- if system_cuda_tag is None or system_cuda_tag.startswith(wanted_cuda):
- return None # Not a mismatch or can't determine
-
- try:
- vllm_version = importlib_version("vllm").split("+")[0]
- except Exception:
- vllm_version = "VLLM_VERSION"
-
- cpu_arch = "x86_64"
- try:
- import platform
-
- cpu_arch = platform.machine()
- except Exception:
- pass
-
- return (
- f"Unsloth: vLLM was built for CUDA {wanted_cuda} but this system has "
- f"CUDA {system_cuda_display}. Please reinstall vLLM with the correct CUDA version:\n"
- f"\n"
- f" uv pip install https://github.com/vllm-project/vllm/releases/download/"
- f"v{vllm_version}/vllm-{vllm_version}+cu{system_cuda_tag}-cp38-abi3-"
- f"manylinux_2_35_{cpu_arch}.whl"
- )
-
-
-class _CausalConv1dImportBlockerLoader(importlib.abc.Loader):
- __slots__ = ("module_name",)
-
- def __init__(self, module_name):
- self.module_name = module_name
-
- def create_module(self, spec):
- return None
-
- def exec_module(self, module):
- raise ModuleNotFoundError(f"No module named '{self.module_name}'")
-
-
-class _CausalConv1dImportBlockerFinder(importlib.abc.MetaPathFinder):
- __slots__ = (_CAUSAL_CONV1D_BLOCKER_SENTINEL,)
-
- def __init__(self):
- setattr(self, _CAUSAL_CONV1D_BLOCKER_SENTINEL, True)
-
- def find_spec(self, fullname, path = None, target = None):
- if not CAUSAL_CONV1D_BROKEN or not _is_causal_conv1d_name(fullname):
- return None
- return importlib.machinery.ModuleSpec(
- name = fullname,
- loader = _CausalConv1dImportBlockerLoader(fullname),
- is_package = fullname == _CAUSAL_CONV1D_PREFIX,
- )
-
-
-class _VllmImportBlockerLoader(importlib.abc.Loader):
- __slots__ = ("module_name",)
-
- def __init__(self, module_name):
- self.module_name = module_name
-
- def create_module(self, spec):
- return None
-
- def exec_module(self, module):
- raise ModuleNotFoundError(f"No module named '{self.module_name}'")
-
-
-class _VllmImportBlockerFinder(importlib.abc.MetaPathFinder):
- __slots__ = (_VLLM_BLOCKER_SENTINEL,)
-
- def __init__(self):
- setattr(self, _VLLM_BLOCKER_SENTINEL, True)
-
- def find_spec(self, fullname, path = None, target = None):
- if not VLLM_BROKEN or not _is_vllm_name(fullname):
- return None
- return importlib.machinery.ModuleSpec(
- name = fullname,
- loader = _VllmImportBlockerLoader(fullname),
- is_package = fullname == _VLLM_PREFIX,
- )
-
-
-def _patch_find_spec_for_causal_conv1d():
- current_find_spec = importlib.util.find_spec
- if getattr(current_find_spec, "_unsloth_causal_conv1d_find_spec_patch", False):
- return
-
- def _blocked_find_spec(name, package = None):
- resolved_name = _resolve_module_name(name, package)
- if CAUSAL_CONV1D_BROKEN and isinstance(resolved_name, str):
- if _is_causal_conv1d_name(resolved_name):
- return None
- return current_find_spec(name, package)
-
- _blocked_find_spec._unsloth_causal_conv1d_find_spec_patch = True
- _blocked_find_spec._unsloth_original_find_spec = current_find_spec
- importlib.util.find_spec = _blocked_find_spec
-
-
-def _patch_find_spec_for_vllm():
- current_find_spec = importlib.util.find_spec
- if getattr(current_find_spec, "_unsloth_vllm_find_spec_patch", False):
- return
-
- def _blocked_find_spec(name, package = None):
- resolved_name = _resolve_module_name(name, package)
- if VLLM_BROKEN and isinstance(resolved_name, str):
- if _is_vllm_name(resolved_name):
- return None
- return current_find_spec(name, package)
-
- _blocked_find_spec._unsloth_vllm_find_spec_patch = True
- _blocked_find_spec._unsloth_original_find_spec = current_find_spec
- importlib.util.find_spec = _blocked_find_spec
-
-
-def _install_causal_conv1d_blocker():
- _patch_find_spec_for_causal_conv1d()
- for finder in sys.meta_path:
- if getattr(finder, _CAUSAL_CONV1D_BLOCKER_SENTINEL, False):
- return
- sys.meta_path.insert(0, _CausalConv1dImportBlockerFinder())
-
-
-def _install_vllm_blocker():
- _patch_find_spec_for_vllm()
- for finder in sys.meta_path:
- if getattr(finder, _VLLM_BLOCKER_SENTINEL, False):
- return
- sys.meta_path.insert(0, _VllmImportBlockerFinder())
-
-
-def _clear_causal_conv1d_modules():
- for module_name in list(sys.modules):
- if _is_causal_conv1d_name(module_name):
- sys.modules.pop(module_name, None)
-
-
-def _clear_vllm_modules():
- for module_name in list(sys.modules):
- if _is_vllm_name(module_name):
- sys.modules.pop(module_name, None)
-
-
-def disable_broken_vllm(error = None):
- """Disable vLLM dynamically when its shared library is ABI-broken."""
- global VLLM_BROKEN
- if VLLM_BROKEN:
- _install_vllm_blocker()
- return True
-
- failure = error
- if failure is None:
- try:
- if importlib.util.find_spec("vllm") is None:
- return False
- except Exception:
- return False
-
- try:
- import vllm # noqa: F401
-
- return False
- except Exception as import_error:
- failure = import_error
-
- if not _is_broken_vllm_error(failure):
- return False
-
- VLLM_BROKEN = True
- _clear_vllm_modules()
- _install_vllm_blocker()
- cuda_msg = _get_vllm_cuda_mismatch_message(failure)
- if cuda_msg:
- logger.warning(cuda_msg)
- else:
- logger.warning(
- "Unsloth: Detected broken vLLM binary extension; "
- "disabling vLLM imports and continuing import.\n"
- "Please reinstall via `uv pip install unsloth vllm torchvision torchaudio "
- "--torch-backend=auto`."
- )
- return True
-
-
-def _disable_transformers_causal_conv1d():
- try:
- import transformers.utils.import_utils as tf_import_utils
- except Exception:
- return
-
- if hasattr(tf_import_utils, "is_causal_conv1d_available"):
- tf_import_utils.is_causal_conv1d_available = lambda: False
-
- for attr_name in (
- "_causal_conv1d_available",
- "_is_causal_conv1d_available",
- ):
- if hasattr(tf_import_utils, attr_name):
- setattr(tf_import_utils, attr_name, False)
-
-
-def disable_broken_causal_conv1d():
- """Disable causal_conv1d dynamically when its shared library is ABI-broken.
-
- This mirrors Unsloth's FlashAttention fallback behavior: if importing causal_conv1d
- fails with a known binary symbol error, we disable it at startup so model imports do
- not hard-fail.
- """
- global CAUSAL_CONV1D_BROKEN
- if CAUSAL_CONV1D_BROKEN:
- _install_causal_conv1d_blocker()
- _disable_transformers_causal_conv1d()
- return
-
- try:
- if importlib.util.find_spec("causal_conv1d") is None:
- return
- except Exception:
- return
-
- try:
- import causal_conv1d # noqa: F401
-
- return
- except Exception as error:
- if not _is_broken_causal_conv1d_error(error):
- return
-
- CAUSAL_CONV1D_BROKEN = True
- _clear_causal_conv1d_modules()
- _install_causal_conv1d_blocker()
- _disable_transformers_causal_conv1d()
- print(
- "Unsloth: Detected broken causal_conv1d binary; "
- "disabling causal_conv1d fast path and continuing import."
- )
diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py
index 15913413d9..ef5fa5da70 100644
--- a/unsloth/kernels/__init__.py
+++ b/unsloth/kernels/__init__.py
@@ -44,14 +44,7 @@
apply_lora_o,
fast_lora_forward,
)
-from .fp8 import * # This step is to ensure that we patch the FbgmemFP8Linear and FP8Linear's forward functions before the execution of model creation so that this applies to compiled non fast inference models as well
-from .utils import (
- fast_dequantize,
- fast_gemv,
- QUANT_STATE,
- fast_linear_forward,
- matmul_lora,
-)
+from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora
from .flex_attention import (
HAS_FLEX_ATTENTION,
@@ -62,12 +55,11 @@
)
import os
-
if "UNSLOTH_ZOO_IS_PRESENT" not in os.environ:
try:
- print(
- "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning."
- )
+ print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.")
except:
print("Unsloth: Will patch your computer to enable 2x faster free finetuning.")
+ pass
+pass
del os
diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py
index d92229314f..834a74c66d 100644
--- a/unsloth/kernels/cross_entropy_loss.py
+++ b/unsloth/kernels/cross_entropy_loss.py
@@ -20,11 +20,10 @@
MAX_FUSED_SIZE,
triton_tanh,
triton_cast,
- torch_gpu_device,
- is_cdna,
+ torch_cuda_device,
)
from transformers.models.llama.modeling_llama import logger
-from unsloth_zoo.utils import Version
+from packaging.version import Version
from unsloth_zoo.loss_utils import (
patch_loss_functions as _patch_loss_functions,
@@ -33,145 +32,134 @@
def _cross_entropy_forward(
- logits_ptr,
- logits_row_stride,
- loss_ptr,
- logsumexp_ptr,
- labels_ptr,
- VOCAB_SIZE: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
- DO_SOFTCAPPING: tl.constexpr,
- SOFTCAP: tl.constexpr,
- DO_LOGIT_SCALING: tl.constexpr,
- LOGIT_SCALE: tl.constexpr,
+ logits_ptr ,
+ logits_row_stride ,
+ loss_ptr ,
+ logsumexp_ptr ,
+ labels_ptr ,
+ VOCAB_SIZE : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
+ DO_SOFTCAPPING : tl.constexpr,
+ SOFTCAP : tl.constexpr,
+ DO_LOGIT_SCALING : tl.constexpr,
+ LOGIT_SCALE : tl.constexpr,
):
"""
- Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
- Pi = exp(xi) / sum(exp(xi))
- CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
- = -y [ x - log[sum(exp(x))] ]
- = y * (log[sum(exp(x))] - x)
- If y == 0: CE_i = 0
- If y == 1: CE_i = logsumexp - x
-
- logsumexp is also stable
- Take y = log[sum(exp(x))]
- exp(y) = sum(exp(x))
- exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
- exp(y) = exp(c)*sum(exp(x - c))
- y = log(exp(c)*sum(exp(x - c)))
- y = c + log[sum(exp(x - c))]
- This means we can set c = max(x) to make sure
- exp(x - c) always is exp(x - max(x)).
- This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
+ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
+ Pi = exp(xi) / sum(exp(xi))
+ CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
+ = -y [ x - log[sum(exp(x))] ]
+ = y * (log[sum(exp(x))] - x)
+ If y == 0: CE_i = 0
+ If y == 1: CE_i = logsumexp - x
+
+ logsumexp is also stable
+ Take y = log[sum(exp(x))]
+ exp(y) = sum(exp(x))
+ exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
+ exp(y) = exp(c)*sum(exp(x - c))
+ y = log(exp(c)*sum(exp(x - c)))
+ y = c + log[sum(exp(x - c))]
+ This means we can set c = max(x) to make sure
+ exp(x - c) always is exp(x - max(x)).
+ This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
"""
row_idx = tl.program_id(0)
- logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
- loss_ptr += row_idx
+ logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
+ loss_ptr += row_idx
logsumexp_ptr += row_idx
- labels_ptr += row_idx
+ labels_ptr += row_idx
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
- logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(
- tl.float32
- )
+ logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
# Go logit scaling for Cohere: t * x
- if DO_LOGIT_SCALING:
- logits = LOGIT_SCALE * logits
+ if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
- if DO_SOFTCAPPING:
- logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
-
+ if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
+
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
# Go logit scaling for Cohere: t * x
- if DO_LOGIT_SCALING:
- x = LOGIT_SCALE * x
+ if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
- if DO_SOFTCAPPING:
- x = SOFTCAP * triton_tanh(x / SOFTCAP)
+ if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
loss = logsumexp - x
else:
loss = 0.0
tl.store(logsumexp_ptr, logsumexp)
tl.store(loss_ptr, loss)
-
-
+pass
_cross_entropy_forward = triton.jit(_cross_entropy_forward)
_cross_entropy_forward = triton.heuristics(
{
- "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING"]),
+ "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
}
)(_cross_entropy_forward)
def _chunked_cross_entropy_forward(
- logits_ptr,
- logits_row_stride: tl.constexpr,
- loss_ptr,
- logsumexp_ptr,
- labels_ptr,
- VOCAB_SIZE: tl.constexpr,
- N_CHUNKS: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
- DO_SOFTCAPPING: tl.constexpr,
- SOFTCAP: tl.constexpr,
- DO_LOGIT_SCALING: tl.constexpr,
- LOGIT_SCALE: tl.constexpr,
+ logits_ptr ,
+ logits_row_stride ,
+ loss_ptr ,
+ logsumexp_ptr ,
+ labels_ptr ,
+ VOCAB_SIZE : tl.constexpr,
+ N_CHUNKS : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
+ DO_SOFTCAPPING : tl.constexpr,
+ SOFTCAP : tl.constexpr,
+ DO_LOGIT_SCALING : tl.constexpr,
+ LOGIT_SCALE : tl.constexpr,
):
"""
- 256K vocab divided in 4 chunks
+ 256K vocab divided in 4 chunks
- |-65536-| |-65536-| |-65536-| |-65536-|
- |-------| |-------| |-------| |-------|
- |-------| |-------| |-------| |-------|
+ |-65536-| |-65536-| |-65536-| |-65536-|
+ |-------| |-------| |-------| |-------|
+ |-------| |-------| |-------| |-------|
- If y == 0: CE_i = 0
- If y == 1: CE_i = logsumexp - x
+ If y == 0: CE_i = 0
+ If y == 1: CE_i = logsumexp - x
- Notice we can do logsumexp for each chunk and then
- logsumexp[chunk_sum(logsumexp)] == logsumexp
+ Notice we can do logsumexp for each chunk and then
+ logsumexp[chunk_sum(logsumexp)] == logsumexp
- chunk_sum = log[chunk_sum(logsumexp)]
- = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
- = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
- = log[sum(exp(a)) + ... + sum(exp(z))]
- = logsumexp(x)
+ chunk_sum = log[chunk_sum(logsumexp)]
+ = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
+ = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
+ = log[sum(exp(a)) + ... + sum(exp(z))]
+ = logsumexp(x)
- This means we can perform a logsumexp for each chunk, then do a
- final logsumexp reduction!
+ This means we can perform a logsumexp for each chunk, then do a
+ final logsumexp reduction!
- Ie do: logsumexp(chunked_logsumexp) - x
+ Ie do: logsumexp(chunked_logsumexp) - x
"""
- row_idx = tl.program_id(0)
+ row_idx = tl.program_id(0)
chunk_idx = tl.program_id(1)
- logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
- loss_ptr += row_idx
+ logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
+ loss_ptr += row_idx
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
- labels_ptr += row_idx
+ labels_ptr += row_idx
- col_offsets = chunk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
- logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(
- tl.float32
- )
+ logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
# Go logit scaling for Cohere: t * x
- if DO_LOGIT_SCALING:
- logits = LOGIT_SCALE * logits
+ if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
- if DO_SOFTCAPPING:
- logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
+ if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
@@ -182,62 +170,60 @@ def _chunked_cross_entropy_forward(
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
# Go logit scaling for Cohere: t * x
- if DO_LOGIT_SCALING:
- x = LOGIT_SCALE * x
+ if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
- if DO_SOFTCAPPING:
- x = SOFTCAP * triton_tanh(x / SOFTCAP)
+ if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
loss = -1.0 * x
else:
loss = 0.0
tl.store(loss_ptr, loss)
+ pass
tl.store(logsumexp_ptr, logsumexp)
-
-
+pass
_chunked_cross_entropy_forward = triton.jit(_chunked_cross_entropy_forward)
_chunked_cross_entropy_forward = triton.heuristics(
{
- "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING"]),
+ "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
}
)(_chunked_cross_entropy_forward)
def _cross_entropy_backward(
- logits_ptr,
- logits_row_stride: tl.constexpr,
- dloss_ptr,
- dloss_row_stride: tl.constexpr,
- logsumexp_ptr,
- labels_ptr,
- VOCAB_SIZE: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
- DO_SOFTCAPPING: tl.constexpr,
- SOFTCAP: tl.constexpr,
- DO_LOGIT_SCALING: tl.constexpr,
- LOGIT_SCALE: tl.constexpr,
+ logits_ptr ,
+ logits_row_stride ,
+ dloss_ptr ,
+ dloss_row_stride ,
+ logsumexp_ptr ,
+ labels_ptr ,
+ VOCAB_SIZE : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
+ DO_SOFTCAPPING : tl.constexpr,
+ SOFTCAP : tl.constexpr,
+ DO_LOGIT_SCALING : tl.constexpr,
+ LOGIT_SCALE : tl.constexpr,
):
"""
- CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
- dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
+ CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
+ dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
- From https://en.wikipedia.org/wiki/LogSumExp
- d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
+ From https://en.wikipedia.org/wiki/LogSumExp
+ d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
- dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
- dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
- dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
+ dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
+ dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
+ dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
- If y == 0: dC/dx = 0
- If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
- If y == 1 and x != label: dC/dx = exp[x - logsumexp]
+ If y == 0: dC/dx = 0
+ If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
+ If y == 1 and x != label: dC/dx = exp[x - logsumexp]
"""
- row_idx = tl.program_id(0)
+ row_idx = tl.program_id(0)
block_idx = tl.program_id(1)
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
- dloss_ptr += row_idx * dloss_row_stride
- col_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ dloss_ptr += row_idx * dloss_row_stride
+ col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
@@ -252,6 +238,7 @@ def _cross_entropy_backward(
if DO_LOGIT_SCALING:
# d/dx [s * x] = s
x = x * LOGIT_SCALE
+ pass
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
partial = x
@@ -259,166 +246,140 @@ def _cross_entropy_backward(
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
partial = triton_tanh(x / SOFTCAP)
x = SOFTCAP * partial
+ pass
logsumexp = tl.load(logsumexp_ptr + row_idx)
y = tl.exp(x - logsumexp)
y = tl.where(
col_offsets == label_idx,
- y - 1.0, # exp(x - logsumexp) - 1
- y, # exp(x - logsumexp)
+ y - 1.0, # exp(x - logsumexp) - 1
+ y, # exp(x - logsumexp)
)
if DO_LOGIT_SCALING:
# d/dx [s * x] = s
y = y * LOGIT_SCALE
+ pass
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
- y = y * (1.0 - partial * partial)
+ y = y * (1.0 - partial*partial)
+ pass
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
-
-
+pass
_cross_entropy_backward = triton.jit(_cross_entropy_backward)
_cross_entropy_backward = triton.heuristics(
{
- "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING"]),
+ "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
}
)(_cross_entropy_backward)
-MAX_FUSED_SIZE = 65536 # 2**16
-
-
+MAX_FUSED_SIZE = 65536 # 2**16
class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
- def forward(
- ctx, logits, labels, logit_softcapping: float = 0, logit_scaling: float = 0
- ):
- n_rows: int
- vocab_size: int
+ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0):
+ n_rows : int
+ vocab_size : int
n_rows, vocab_size = logits.shape
device = logits.device
- labels = labels.to(device)
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
- n_chunks: int = div + (mod != 0)
+ n_chunks : int = div + (mod != 0)
losses = torch.empty(n_rows, dtype = torch.float32, device = device)
- DO_SOFTCAPPING: bool = bool(logit_softcapping != 0)
- DO_LOGIT_SCALING: bool = bool(logit_scaling != 0)
+ DO_SOFTCAPPING : bool = bool(logit_softcapping != 0)
+ DO_LOGIT_SCALING : bool = bool(logit_scaling != 0)
- BLOCK_SIZE: int
- num_warps: int
+ BLOCK_SIZE : int
+ num_warps : int
if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
- if is_cdna():
- num_warps = num_warps // 2
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device)
- with torch_gpu_device(device):
+ with torch_cuda_device(device):
_cross_entropy_forward[(n_rows,)](
- logits,
- logits.stride(0),
+ logits, logits.stride(0),
losses,
logsumexp,
labels,
- VOCAB_SIZE = vocab_size,
- BLOCK_SIZE = BLOCK_SIZE,
- DO_SOFTCAPPING = DO_SOFTCAPPING,
- SOFTCAP = logit_softcapping,
+ VOCAB_SIZE = vocab_size,
+ BLOCK_SIZE = BLOCK_SIZE,
+ DO_SOFTCAPPING = DO_SOFTCAPPING,
+ SOFTCAP = logit_softcapping,
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
- LOGIT_SCALE = logit_scaling,
- num_warps = num_warps,
+ LOGIT_SCALE = logit_scaling,
+ num_warps = num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
- logsumexp = torch.empty(
- (
- n_rows,
- n_chunks,
- ),
- dtype = torch.float32,
- device = device,
- )
+ logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = device)
- with torch_gpu_device(device):
- _chunked_cross_entropy_forward[
- (
- n_rows,
- n_chunks,
- )
- ](
- logits,
- logits.stride(0),
+ with torch_cuda_device(device):
+ _chunked_cross_entropy_forward[(n_rows, n_chunks,)](
+ logits, logits.stride(0),
losses,
logsumexp,
labels,
- VOCAB_SIZE = vocab_size,
- N_CHUNKS = n_chunks,
- BLOCK_SIZE = MAX_FUSED_SIZE,
- DO_SOFTCAPPING = DO_SOFTCAPPING,
- SOFTCAP = logit_softcapping,
+ VOCAB_SIZE = vocab_size,
+ N_CHUNKS = n_chunks,
+ BLOCK_SIZE = MAX_FUSED_SIZE,
+ DO_SOFTCAPPING = DO_SOFTCAPPING,
+ SOFTCAP = logit_softcapping,
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
- LOGIT_SCALE = logit_scaling,
- num_warps = 32 if not is_cdna() else 16,
+ LOGIT_SCALE = logit_scaling,
+ num_warps = 32,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
- logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
+ logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
losses += logsumexp
- losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
+ losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
+ pass
ctx.save_for_backward(logits, logsumexp, labels)
- ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
+ ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
ctx.logit_softcapping = logit_softcapping
- ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
- ctx.logit_scaling = logit_scaling
+ ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
+ ctx.logit_scaling = logit_scaling
return losses
+ pass
+
@staticmethod
def backward(ctx, dlosses):
logits, logsumexp, labels = ctx.saved_tensors
- n_rows: int
- vocab_size: int
+ n_rows : int
+ vocab_size : int
n_rows, vocab_size = logits.shape
- BLOCK_SIZE: int = 4096
- div: int
- mod: int
+ BLOCK_SIZE : int = 4096
+ div : int
+ mod : int
div, mod = divmod(vocab_size, BLOCK_SIZE)
- n_blocks: int = div + (mod != 0)
+ n_blocks : int = div + (mod != 0)
- with torch_gpu_device(dlosses.device):
- _cross_entropy_backward[
- (
- n_rows,
- n_blocks,
- )
- ](
- logits,
- logits.stride(0),
- dlosses,
- dlosses.stride(0),
+ with torch_cuda_device(dlosses.device):
+ _cross_entropy_backward[(n_rows, n_blocks,)](
+ logits, logits.stride(0),
+ dlosses, dlosses.stride(0),
logsumexp,
labels,
- VOCAB_SIZE = vocab_size,
- BLOCK_SIZE = BLOCK_SIZE,
- DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
- SOFTCAP = ctx.logit_softcapping,
+ VOCAB_SIZE = vocab_size,
+ BLOCK_SIZE = BLOCK_SIZE,
+ DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
+ SOFTCAP = ctx.logit_softcapping,
DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
- LOGIT_SCALE = ctx.logit_scaling,
- num_warps = 8,
+ LOGIT_SCALE = ctx.logit_scaling,
+ num_warps = 8,
)
- return (
- logits,
- None,
- None,
- None,
- )
+ return logits, None, None, None,
+ pass
+pass
def fast_cross_entropy_loss(
@@ -436,28 +397,24 @@ def fast_cross_entropy_loss(
losses: float
"""
batch, seq_len, d = logits.shape
- assert labels.shape == (batch, seq_len)
+ assert(labels.shape == (batch, seq_len))
- device = logits.device
loss = Fast_CrossEntropyLoss.apply(
- logits.view(batch * seq_len, d),
+ logits.view(batch*seq_len, d),
labels.view(-1),
logit_softcapping,
logit_scaling,
)
if n_items is None:
n_items = torch.count_nonzero(labels != -100)
- if torch.is_tensor(n_items):
- n_items = n_items.to(device)
return loss.sum() / n_items
-
-
-if (Version(torch.__version__) < Version("2.4.0")) and not hasattr(
- fast_cross_entropy_loss, "__wrapped__"
-):
+pass
+if (Version(torch.__version__) < Version("2.4.0")) and \
+ not hasattr(fast_cross_entropy_loss, "__wrapped__"):
fast_cross_entropy_loss = torch._disable_dynamo(fast_cross_entropy_loss)
-
+pass
# Patch CE Losses in transformers
def patch_loss_functions(torch_compile = True):
_patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)
+pass
diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py
index f1c0e298d9..a4fb2a89b6 100644
--- a/unsloth/kernels/fast_lora.py
+++ b/unsloth/kernels/fast_lora.py
@@ -14,7 +14,6 @@
import torch
from .utils import (
- _maybe_fake_quantize_activations,
fast_dequantize,
QUANT_STATE,
get_lora_parameters,
@@ -63,95 +62,54 @@ class LoRA_MLP(torch.autograd.Function):
Don't forget to see our blog post for more details!
"""
-
@staticmethod
@torch_amp_custom_fwd
- def forward(
- ctx,
- X: torch.Tensor,
- gateW,
- gateW_quant,
- gateA,
- gateB,
- gateS,
- upW,
- upW_quant,
- upA,
- upB,
- upS,
- downW,
- downW_quant,
- downA,
- downB,
- downS,
- _forward_function,
- _backward_function,
- inplace = True,
- ):
+ def forward(ctx, X : torch.Tensor,
+ gateW, gateW_quant, gateA, gateB, gateS,
+ upW, upW_quant, upA, upB, upS,
+ downW, downW_quant, downA, downB, downS,
+ _forward_function, _backward_function,
+ inplace = True,):
dtype = X.dtype
e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
- g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
+ g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
h = _forward_function(e, g)
i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
ctx.custom_saved_tensors = (
- gateW,
- gateW_quant,
- gateS,
- upW,
- upW_quant,
- upS,
- downW,
- downW_quant,
- downS,
+ gateW, gateW_quant, gateS,
+ upW, upW_quant, upS,
+ downW, downW_quant, downS,
_backward_function,
)
- ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB, X, e, g)
+ ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
+ X, e, g)
ctx.inplace = inplace
return i
+ pass
+
@staticmethod
@torch_amp_custom_bwd
- def backward(ctx, dY: torch.Tensor):
- (
- gateW,
- gateW_quant,
- gateS,
- upW,
- upW_quant,
- upS,
- downW,
- downW_quant,
- downS,
- _backward_function,
- ) = ctx.custom_saved_tensors
- gateA, gateB, upA, upB, downA, downB, X, e, g = ctx.saved_tensors
+ def backward(ctx, dY : torch.Tensor):
+ gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
+ _backward_function = ctx.custom_saved_tensors
+ gateA, gateB, upA, upB, downA, downB, \
+ X, e, g = ctx.saved_tensors
batch, seq_len, hd = X.shape
dY = dY.view(-1, dY.shape[-1])
- X = X.view(-1, X.shape[-1])
- e = e.view(-1, e.shape[-1])
- g = g.view(-1, g.shape[-1])
+ X = X .view(-1, X .shape[-1])
+ e = e .view(-1, e .shape[-1])
+ g = g .view(-1, g .shape[-1])
dtype = X.dtype
- gateA, gateB, upA, upB, downA, downB = (
- gateA.to(dtype),
- gateB.to(dtype),
- upA.to(dtype),
- upB.to(dtype),
- downA.to(dtype),
- downB.to(dtype),
- )
+ gateA, gateB, upA, upB, downA, downB = \
+ gateA.to(dtype), gateB.to(dtype), upA.to(dtype), upB.to(dtype), downA.to(dtype), downB.to(dtype)
- gateA, gateB, upA, upB, downA, downB = (
- gateA.t(),
- gateB.t(),
- upA.t(),
- upB.t(),
- downA.t(),
- downB.t(),
- )
+ gateA, gateB, upA, upB, downA, downB = \
+ gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
DW, e, g = _backward_function(DW, e, g)
@@ -161,8 +119,8 @@ def backward(ctx, dY: torch.Tensor):
d_downB = torch.empty_like(downB)
d_gateA = torch.empty_like(gateA)
d_gateB = torch.empty_like(gateB)
- d_upA = torch.empty_like(upA)
- d_upB = torch.empty_like(upB)
+ d_upA = torch.empty_like(upA)
+ d_upB = torch.empty_like(upB)
# Down projection LoRA weights
# d_downA = h.t() @ (dY @ downB.t())
@@ -206,122 +164,57 @@ def backward(ctx, dY: torch.Tensor):
# gateW, gateW_quant, gateA, gateB, gateS,
# upW, upW_quant, upA, upB, upS,
# downW, downW_quant, downA, downB, downS,
- return (
- dX.view(batch, seq_len, hd),
- None,
- None,
- d_gateA.t(),
- d_gateB.t(),
- None,
- None,
- None,
- d_upA.t(),
- d_upB.t(),
- None,
- None,
- None,
- d_downA.t(),
- d_downB.t(),
- None,
- None,
- None,
- None,
- ) # _backward and _forward and inplace
+ return dX.view(batch, seq_len, hd), \
+ None, None, d_gateA.t(), d_gateB.t(), None, \
+ None, None, d_upA.t(), d_upB.t(), None, \
+ None, None, d_downA.t(), d_downB.t(), None, \
+ None, None, None, # _backward and _forward and inplace
+ pass
+pass
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
-
-
def apply_lora_mlp_swiglu(self, X, inplace = True):
- X = _maybe_fake_quantize_activations(X, self.gate_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
- upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
- out = LoRA_MLP.apply(
- X,
- gateW,
- gateW_quant,
- gateA,
- gateB,
- gateS,
- upW,
- upW_quant,
- upA,
- upB,
- upS,
- downW,
- downW_quant,
- downA,
- downB,
- downS,
- swiglu_fg_kernel,
- swiglu_DWf_DW_dfg_kernel,
- inplace,
- )
+ out = LoRA_MLP.apply(X,
+ gateW, gateW_quant, gateA, gateB, gateS,
+ upW, upW_quant, upA, upB, upS,
+ downW, downW_quant, downA, downB, downS,
+ swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,
+ inplace,)
return out
+pass
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
-
-
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
- X = _maybe_fake_quantize_activations(X, self.gate_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
- upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
- out = LoRA_MLP.apply(
- X,
- gateW,
- gateW_quant,
- gateA,
- gateB,
- gateS,
- upW,
- upW_quant,
- upA,
- upB,
- upS,
- downW,
- downW_quant,
- downA,
- downB,
- downS,
- geglu_exact_forward_kernel,
- geglu_exact_backward_kernel,
- inplace,
- )
+ out = LoRA_MLP.apply(X,
+ gateW, gateW_quant, gateA, gateB, gateS,
+ upW, upW_quant, upA, upB, upS,
+ downW, downW_quant, downA, downB, downS,
+ geglu_exact_forward_kernel, geglu_exact_backward_kernel,
+ inplace,)
return out
+pass
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
-
-
def apply_lora_mlp_geglu_approx(self, X):
- X = _maybe_fake_quantize_activations(X, self.gate_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
- upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
- out = LoRA_MLP.apply(
- X,
- gateW,
- gateW_quant,
- gateA,
- gateB,
- gateS,
- upW,
- upW_quant,
- upA,
- upB,
- upS,
- downW,
- downW_quant,
- downA,
- downB,
- downS,
- geglu_approx_forward_kernel,
- geglu_approx_backward_kernel,
- )
+ out = LoRA_MLP.apply(X,
+ gateW, gateW_quant, gateA, gateB, gateS,
+ upW, upW_quant, upA, upB, upS,
+ downW, downW_quant, downA, downB, downS,
+ geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
return out
+pass
class LoRA_QKV(torch.autograd.Function):
@@ -354,102 +247,48 @@ class LoRA_QKV(torch.autograd.Function):
dC/dAv = X.T @ D(Wv) @ B.T
dC/dBv = A.T @ X.T @ D(Wv)
"""
-
@staticmethod
@torch_amp_custom_fwd
- def forward(
- ctx,
- X: torch.Tensor,
- QW,
- QW_quant,
- QA,
- QB,
- QS,
- KW,
- KW_quant,
- KA,
- KB,
- KS,
- VW,
- VW_quant,
- VA,
- VB,
- VS,
- inplace = True,
- ):
+ def forward(ctx, X : torch.Tensor,
+ QW, QW_quant, QA, QB, QS,
+ KW, KW_quant, KA, KB, KS,
+ VW, VW_quant, VA, VB, VS,
+ inplace = True):
dtype = X.dtype
- # bitsandbytes 8-bit matmul expects 2D inputs.
- # TorchInductor/AOTAutograd fails on 3D tensors during backward,
- # so we explicitly flatten the sequence dimension.
- orig_shape = X.shape
- X_for_matmul = X
- if X.dim() == 3:
- X_for_matmul = X.view(-1, X.shape[-1])
- Q = matmul_lora(X_for_matmul, QW, QW_quant, QA, QB, QS)
- K = matmul_lora(X_for_matmul, KW, KW_quant, KA, KB, KS)
- V = matmul_lora(X_for_matmul, VW, VW_quant, VA, VB, VS)
-
- # Restore original shape after matmul
- if len(orig_shape) == 3:
- Q = Q.view(orig_shape[0], orig_shape[1], -1)
- K = K.view(orig_shape[0], orig_shape[1], -1)
- V = V.view(orig_shape[0], orig_shape[1], -1)
+ Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
+ K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
+ V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
ctx.custom_saved_tensors = (
- QW,
- QW_quant,
- QS,
- KW,
- KW_quant,
- KS,
- VW,
- VW_quant,
- VS,
- )
- ctx.save_for_backward(
- X,
- QA,
- QB,
- KA,
- KB,
- VA,
- VB,
+ QW, QW_quant, QS,
+ KW, KW_quant, KS,
+ VW, VW_quant, VS,
)
+ ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
ctx.inplace = inplace
return Q, K, V
+ pass
@staticmethod
@torch_amp_custom_bwd
def backward(ctx, dQ, dK, dV):
- QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = ctx.custom_saved_tensors
- (
- X,
- QA,
- QB,
- KA,
- KB,
- VA,
- VB,
- ) = ctx.saved_tensors
+ QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
+ ctx.custom_saved_tensors
+ X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
batch, seq_len, hd = X.shape
dQ = dQ.view(-1, dQ.shape[-1])
- dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
+ dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
dV = dV.view(-1, dV.shape[-1])
- X = X.view(-1, X.shape[-1])
+ X = X .view(-1, X .shape[-1])
dtype = X.dtype
- QA, QB, KA, KB, VA, VB = (
- QA.to(dtype),
- QB.to(dtype),
- KA.to(dtype),
- KB.to(dtype),
- VA.to(dtype),
- VB.to(dtype),
- )
+ QA, QB, KA, KB, VA, VB = \
+ QA.to(dtype), QB.to(dtype), KA.to(dtype), KB.to(dtype), VA.to(dtype), VB.to(dtype)
- QA, QB, KA, KB, VA, VB = QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
+ QA, QB, KA, KB, VA, VB = \
+ QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
### Weight projection LoRA weights
# See our blogpost for more details.
@@ -511,52 +350,27 @@ def backward(ctx, dQ, dK, dV):
# QW, QW_quant, QA, QB, QS,
# KW, KW_quant, KA, KB, KS,
# VW, VW_quant, VA, VB, VS,
- return (
- dX.view(batch, seq_len, hd),
- None,
- None,
- d_QA.t(),
- d_QB.t(),
- None,
- None,
- None,
- d_KA.t(),
- d_KB.t(),
+ return dX.view(batch, seq_len, hd), \
+ None, None, d_QA.t(), d_QB.t(), None, \
+ None, None, d_KA.t(), d_KB.t(), None, \
+ None, None, d_VA.t(), d_VB.t(), None, \
None,
- None,
- None,
- d_VA.t(),
- d_VB.t(),
- None,
- None,
- )
+ pass
+pass
def apply_lora_qkv(self, X, inplace = True):
- X = _maybe_fake_quantize_activations(X, self.q_proj)
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
- Q, K, V = LoRA_QKV.apply(
- X,
- QW,
- QW_quant,
- QA,
- QB,
- QS,
- KW,
- KW_quant,
- KA,
- KB,
- KS,
- VW,
- VW_quant,
- VA,
- VB,
- VS,
+ Q, K, V = LoRA_QKV.apply(X,
+ QW, QW_quant, QA, QB, QS,
+ KW, KW_quant, KA, KB, KS,
+ VW, VW_quant, VA, VB, VS,
inplace,
)
return Q, K, V
+pass
class LoRA_W(torch.autograd.Function):
@@ -586,29 +400,26 @@ class LoRA_W(torch.autograd.Function):
dC/dAv = X.T @ D(Wv) @ B.T
dC/dBv = A.T @ X.T @ D(Wv)
"""
-
@staticmethod
@torch_amp_custom_fwd
- def forward(ctx, X: torch.Tensor, W, W_quant, A, B, S):
+ def forward(ctx, X : torch.Tensor,
+ W, W_quant, A, B, S):
dtype = X.dtype
XW = matmul_lora(X, W, W_quant, A, B, S)
- ctx.custom_saved_tensors = (
- W,
- W_quant,
- S,
- )
+ ctx.custom_saved_tensors = (W, W_quant, S,)
ctx.save_for_backward(A, B, X)
return XW
+ pass
@staticmethod
@torch_amp_custom_bwd
- def backward(ctx, dY: torch.Tensor):
+ def backward(ctx, dY : torch.Tensor):
W, W_quant, S = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors
batch, seq_len, hd = X.shape
- dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
- X = X.reshape(-1, X.shape[-1]) # Must be reshape
+ dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
+ X = X .reshape(-1, X .shape[-1]) # Must be reshape
dtype = X.dtype
A, B = A.to(dtype), B.to(dtype)
@@ -635,19 +446,20 @@ def backward(ctx, dY: torch.Tensor):
dX.addmm_(dY @ B.t(), A.t(), alpha = S)
# W, W_quant, A, B, S
- return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
+ return dX.view(batch, seq_len, hd), \
+ None, None, d_A.t(), d_B.t(), None
+ pass
+pass
def apply_lora_o(self, X):
- X = _maybe_fake_quantize_activations(X, self.o_proj)
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
+pass
IDENTITY_DROPOUT = torch.nn.Identity
-
-
@torch._disable_dynamo
def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError(
@@ -661,23 +473,17 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
- result = self._mixed_batch_forward(
- x, *args, adapter_names = adapter_names, **kwargs
- )
+ result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
# Fastpath
if len(self.active_adapters) == 1:
active_adapter = self.active_adapters[0]
- if active_adapter not in self.lora_A.keys():
- return self.base_layer(x, *args, **kwargs)
+ if active_adapter not in self.lora_A.keys(): return self.base_layer(x, *args, **kwargs)
dropout = self.lora_dropout[active_adapter]
- if (
- isinstance(dropout, IDENTITY_DROPOUT)
- and not self.use_dora[active_adapter]
- ):
+ if isinstance(dropout, IDENTITY_DROPOUT) and not self.use_dora[active_adapter]:
lora_A = self.lora_A[active_adapter].weight
lora_B = self.lora_B[active_adapter].weight
scaling = self.scaling[active_adapter]
@@ -718,13 +524,14 @@ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
result = result + self.lora_magnitude_vector[active_adapter](
x,
- lora_A = lora_A,
- lora_B = lora_B,
- scaling = scaling,
- base_layer = self.get_base_layer(),
- base_result = base_result,
+ lora_A=lora_A,
+ lora_B=lora_B,
+ scaling=scaling,
+ base_layer=self.get_base_layer(),
+ base_result=base_result,
)
if requires_conversion:
result = result.to(expected_dtype)
return result
+pass
diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py
index b94ff56dec..6f82394228 100644
--- a/unsloth/kernels/flex_attention.py
+++ b/unsloth/kernels/flex_attention.py
@@ -18,11 +18,11 @@
import os
torch_compile_options = {
- "epilogue_fusion": True,
- "max_autotune": True,
- "shape_padding": True,
- "trace.enabled": os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1",
- "triton.cudagraphs": False,
+ "epilogue_fusion" : True,
+ "max_autotune" : True,
+ "shape_padding" : True,
+ "trace.enabled" : os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1",
+ "triton.cudagraphs" : False,
}
# Flex Attention supported from torch 2.5 onwards only
@@ -31,24 +31,23 @@
flex_attention as _flex_attention,
create_block_mask as _create_block_mask,
)
-
- _flex_attention = torch.compile(
- _flex_attention, dynamic = True, options = torch_compile_options
- )
+ _flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
HAS_FLEX_ATTENTION = False
except:
HAS_FLEX_ATTENTION = False
+pass
if not HAS_FLEX_ATTENTION:
+
# Logit softcapping
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
- n_heads = self.config.num_attention_heads
- head_dim = self.head_dim
+ n_heads = self.config.num_attention_heads
+ head_dim = self.head_dim
n_kv_heads = self.config.num_key_value_heads
- n_groups = self.num_key_value_groups
-
+ n_groups = self.num_key_value_groups
+
# Grouped query attention
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
@@ -62,17 +61,18 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
- Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
+ Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch.matmul(Q, K.transpose(2, 3))
- A = t * torch.tanh(A / t) # Logit softcapping
+ A = t * torch.tanh(A / t) # Logit softcapping
A += causal_mask[:q_len, :q_len]
# Much slower in torch compile!
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch.matmul(A, V)
A = A.transpose(1, 2).contiguous()
- A = A.reshape(bsz, q_len, n_heads * head_dim)
+ A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
+ pass
create_flex_attention_causal_mask = None
create_flex_attention_sliding_window_mask = None
@@ -85,78 +85,73 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
def generate_tanh_softcap(t):
def tanh_softcap(x, b, h, q_idx, kv_idx):
return t * torch.tanh(x / t)
-
return tanh_softcap
-
+ pass
def causal_masker(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
+ pass
@functools.lru_cache
def sliding_window_masker(size = 4096):
def sliding_window(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
- window_mask = q_idx - kv_idx <= size
+ window_mask = q_idx - kv_idx <= size
return causal_mask & window_mask
-
return sliding_window
+ pass
@functools.lru_cache
def create_block_mask(mask, n = 128):
return _create_block_mask(
- mask,
- 1,
- 1,
- n,
- n,
+ mask, 1, 1, n, n,
BLOCK_SIZE = 128,
_compile = True,
)
+ pass
def create_flex_attention_causal_mask(max_seq_length = 8192):
causal_mask = create_block_mask(causal_masker, max_seq_length)
return causal_mask
+ pass
- def create_flex_attention_sliding_window_mask(
- max_seq_length = 8192, sliding_window = 4096
- ):
+ def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096):
sliding_masker = sliding_window_masker(sliding_window)
causal_mask = create_block_mask(sliding_masker, max_seq_length)
return causal_mask
+ pass
@functools.lru_cache
def flex_attention(s, t):
scale = 1.0 / math.sqrt(s)
score_mod = generate_tanh_softcap(t)
return functools.partial(
- _flex_attention,
- score_mod = score_mod,
- scale = scale,
- enable_gqa = True,
+ _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True,
)
-
+ pass
+
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
- n_heads = self.config.num_attention_heads
- head_dim = self.head_dim
+ n_heads = self.config.num_attention_heads
+ head_dim = self.head_dim
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
fx = flex_attention(s, t)
A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
A = A.transpose(1, 2).contiguous()
- A = A.reshape(bsz, q_len, n_heads * head_dim)
+ A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
+ pass
+pass
torch_matmul = torch.matmul
-torch_tanh = torch.tanh
+torch_tanh = torch.tanh
torch_nn_functional_softmax = torch.nn.functional.softmax
-
-
def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
- n_heads = self.config.num_attention_heads
- head_dim = self.head_dim
+ n_heads = self.config.num_attention_heads
+ head_dim = self.head_dim
n_kv_heads = self.config.num_key_value_heads
- n_groups = self.num_key_value_groups
-
+ n_groups = self.num_key_value_groups
+
# Grouped query attention
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
@@ -170,18 +165,17 @@ def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len)
s = self.config.query_pre_attn_scalar
t = self.config.attn_logit_softcapping
- Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
+ Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
A = torch_matmul(Q, K.transpose(2, 3))
# Logit softcapping
- A /= t
- torch_tanh(A, out = A)
- A *= t
+ A /= t; torch_tanh(A, out = A); A *= t;
A += causal_mask[:q_len, :q_len]
# Much slower in torch compile!
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
A = torch_matmul(A, V)
A = A.transpose(1, 2).contiguous()
- A = A.reshape(bsz, q_len, n_heads * head_dim)
+ A = A.reshape(bsz, q_len, n_heads*head_dim)
return A
+pass
diff --git a/unsloth/kernels/fp8.py b/unsloth/kernels/fp8.py
deleted file mode 100644
index a57f4ffb64..0000000000
--- a/unsloth/kernels/fp8.py
+++ /dev/null
@@ -1,624 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import os
-import torch
-import torch.nn as nn
-import triton
-import triton.language as tl
-from torch.nn import functional as F
-import math
-from unsloth_zoo.utils import Version
-from unsloth_zoo.log import logger
-from unsloth_zoo.temporary_patches.common import torch_compile
-
-torch_matmul = torch.matmul
-
-try:
- from transformers.integrations.finegrained_fp8 import FP8Linear
-except:
- FP8Linear = None
- logger.info(
- "Unsloth: FP8 models need importing FP8Linear from `transformers.integrations.finegrained_fp8` but we don't see it."
- )
-
-try:
- from transformers.integrations.fbgemm_fp8 import FbgemmFp8Linear
-except:
- FbgemmFp8Linear = None
- logger.info(
- "Unsloth: FP8 models need importing FbgemmFP8Linear from `transformers.integrations.fbgemm_fp8` but we don't see it."
- )
-
-try:
- from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
- triton_quantize_fp8_block,
- )
-except:
- triton_quantize_fp8_block = None
- logger.info(
- "Unsloth: Could not find fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm.triton_quantize_fp8_block"
- )
-
-try:
- from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
- blockwise_fp8_gemm as torchao_blockwise_gemm,
- )
-except:
- torchao_blockwise_gemm = None
- logger.info(
- "Unsloth: Could not find torchao.prototype.blockwise_fp8_inference.blockwise_quantization.blockwise_fp8_gemm"
- )
-
-
-@triton.jit
-def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
- pid_m = tl.program_id(axis = 0)
- pid_n = tl.program_id(axis = 1)
- n = tl.cdiv(N, BLOCK_SIZE)
- offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
- offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
- offs = offs_m[:, None] * N + offs_n[None, :]
- mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
- x = tl.load(x_ptr + offs, mask = mask).to(tl.float32)
- s = tl.load(s_ptr + pid_m * n + pid_n)
- y = x * s
- tl.store(y_ptr + offs, y, mask = mask)
-
-
-def weight_dequant_block(
- x: torch.Tensor, s: torch.Tensor, block_size: int = 128, dtype = torch.bfloat16
-) -> torch.Tensor:
- if not x.is_contiguous():
- x = x.contiguous()
- if not s.is_contiguous():
- s = s.contiguous()
- assert x.dim() == 2 and s.dim() == 2
- M, N = x.size()
- y = torch.empty_like(x, dtype = dtype)
- grid = lambda meta: (
- triton.cdiv(M, meta["BLOCK_SIZE"]),
- triton.cdiv(N, meta["BLOCK_SIZE"]),
- )
- weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE = block_size)
- return y
-
-
-def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype = torch.bfloat16):
- # Per-tensor scale: single value for entire weight matrix
- if s.numel() == 1:
- return x.to(dtype) * s.view(1, 1).to(dtype)
- # Row quantized weight: scale shape is (m, 1) or (n, 1)
- elif s.ndim == 2 and s.shape[1] == 1:
- if x.shape[0] == s.shape[0]:
- y = x.to(dtype) * s.to(dtype)
- elif x.shape[1] == s.shape[0]:
- # sometimes, this is called with the transpose of the weight. Adjust for that.
- y = x.t().to(dtype) * s.to(dtype)
- y = y.t()
- else:
- raise ValueError(f"Incompatible shapes {x.shape = }, {s.shape = }")
- return y
- # Block quantized weight: scale shape is (ceil(m/block_m), ceil(n/block_n))
- else:
- return weight_dequant_block(x, s, dtype = dtype)
-
-
-# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
-@triton.jit
-def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
- pid = tl.program_id(axis = 0)
- offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
- x = tl.load(x_ptr + offs).to(tl.float32)
- s = tl.max(tl.abs(x)) / 448.0
- # For a row of all zeros, lets return zeros as is
- # for LoRA, there are cases where dY has 0 in it and we should not let it be NaN
- # this is a deviation from the original implementation.
- s = 1.0 if s == 0 else s
- y = x / s
- y = y.to(y_ptr.dtype.element_ty)
- tl.store(y_ptr + offs, y)
- tl.store(s_ptr + pid, s)
-
-
-def act_quant(
- x: torch.Tensor, block_size: int = 128
-) -> tuple[torch.Tensor, torch.Tensor]:
- if not x.is_contiguous():
- x = x.contiguous()
- assert x.shape[-1] % block_size == 0
- y = torch.empty_like(x, dtype = torch.float8_e4m3fn)
- s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype = torch.float32)
-
- def grid(meta):
- return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
-
- act_quant_kernel[grid](x, y, s, BLOCK_SIZE = block_size)
- return y, s
-
-
-# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
-@triton.jit
-def _w8a8_block_fp8_matmul(
- # Pointers to inputs and output
- A,
- B,
- C,
- As,
- Bs,
- # Shape for matmul
- M,
- N,
- K,
- # Block size for block-wise quantization
- group_n,
- group_k,
- # Stride for inputs and output
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- stride_As_m,
- stride_As_k,
- stride_Bs_k,
- stride_Bs_n,
- # Meta-parameters
- BLOCK_SIZE_M: tl.constexpr,
- BLOCK_SIZE_N: tl.constexpr,
- BLOCK_SIZE_K: tl.constexpr,
- GROUP_SIZE_M: tl.constexpr,
-):
- """Triton-accelerated function used to perform linear operations (dot
- product) on input tensors `A` and `B` with block-wise quantization, and
- store the result in output tensor `C`.
- """
-
- pid = tl.program_id(axis = 0)
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
- num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * GROUP_SIZE_M
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
- pid_m = first_pid_m + (pid % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
-
- offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
- b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
-
- As_ptrs = As + offs_am * stride_As_m
- offs_bsn = offs_bn // group_n
- Bs_ptrs = Bs + offs_bsn * stride_Bs_n
-
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype = tl.float32)
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
- a = tl.load(a_ptrs, mask = offs_k[None, :] < K - k * BLOCK_SIZE_K, other = 0.0)
- b = tl.load(b_ptrs, mask = offs_k[:, None] < K - k * BLOCK_SIZE_K, other = 0.0)
-
- k_start = k * BLOCK_SIZE_K
- offs_ks = k_start // group_k
- a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
- b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
-
- accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
- a_ptrs += BLOCK_SIZE_K * stride_ak
- b_ptrs += BLOCK_SIZE_K * stride_bk
-
- if C.dtype.element_ty == tl.bfloat16:
- c = accumulator.to(tl.bfloat16)
- elif C.dtype.element_ty == tl.float16:
- c = accumulator.to(tl.float16)
- else:
- c = accumulator.to(tl.float32)
-
- offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
- c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
- tl.store(c_ptrs, c, mask = c_mask)
-
-
-def w8a8_block_fp8_matmul_triton(
- A: torch.Tensor,
- B: torch.Tensor,
- As: torch.Tensor,
- Bs: torch.Tensor,
- block_size: list[int],
- output_dtype: torch.dtype = torch.float32,
-) -> torch.Tensor:
- """Block-wise FP8 matmul."""
- if block_size is None:
- block_n, block_k = 128, 128
- else:
- assert len(block_size) == 2
- block_n, block_k = block_size[0], block_size[1]
-
- N, K = B.shape
- assert A.shape[-1] == B.shape[-1]
- assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
- assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
- assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
- assert triton.cdiv(N, block_n) == Bs.shape[0]
- assert triton.cdiv(K, block_k) == Bs.shape[1]
-
- M = A.numel() // A.shape[-1]
- C_shape = A.shape[:-1] + (N,)
- C = A.new_empty(C_shape, dtype = output_dtype)
-
- BLOCK_SIZE_M = 128
- if M < BLOCK_SIZE_M:
- BLOCK_SIZE_M = max(triton.next_power_of_2(M), 16)
- BLOCK_SIZE_K, BLOCK_SIZE_N = block_k, block_n
-
- def grid(META):
- return (
- triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
- )
-
- _w8a8_block_fp8_matmul[grid](
- A,
- B,
- C,
- As,
- Bs,
- M,
- N,
- K,
- block_n,
- block_k,
- A.stride(-2),
- A.stride(-1),
- B.stride(1),
- B.stride(0),
- C.stride(-2),
- C.stride(-1),
- As.stride(-2),
- As.stride(-1),
- Bs.stride(1),
- Bs.stride(0),
- BLOCK_SIZE_M = BLOCK_SIZE_M,
- BLOCK_SIZE_N = BLOCK_SIZE_N,
- BLOCK_SIZE_K = BLOCK_SIZE_K,
- GROUP_SIZE_M = 8,
- )
- return C
-
-
-def torchao_block_matmul(
- act_q: torch.Tensor,
- weight_q: torch.Tensor,
- act_scale: torch.Tensor,
- weight_scale: torch.Tensor,
- block_size: tuple[int, int],
- output_dtype: torch.dtype = torch.bfloat16,
-):
- out = torchao_blockwise_gemm(
- act_q.contiguous(),
- act_scale.contiguous(),
- weight_q.contiguous(),
- weight_scale.contiguous(),
- block_size = block_size[1],
- )
- return out.to(output_dtype)
-
-
-# Note that older versions of fbgemm (<=1.3.0) cause numerical imprecisions resulting in NaNs especially when X has high values in it.
-# So our preference order is fbgemm (>=1.4.0) > torchao > triton. All of these have similar outputs/losses. Never use fbgemm (<=1.3.0) for block quantized FP8 matmul.
-# This torchao FP8 matmul seems to be ~3x faster than the w8a8_block_fp8_matmul_triton. Though torchao is 15-30% slower than fbgemm implementation (on H100 GPUs).
-fp8_block_matmul = (
- torchao_block_matmul
- if torchao_blockwise_gemm is not None
- else w8a8_block_fp8_matmul_triton
-)
-
-
-class FP8BlockQuantLinear(torch.autograd.Function):
- @staticmethod
- def forward(ctx, X, weight, weight_scale):
- m, n = weight.shape
-
- # Save original scale for backward (before any transformation)
- original_weight_scale = weight_scale
-
- # Handle per-tensor quantization: expand scalar to block scale shape
- if weight_scale.numel() == 1:
- block_size = [128, 128]
- # Expand scalar to (ceil(m/128), ceil(n/128)) - same value for all blocks
- num_blocks_m = triton.cdiv(m, block_size[0])
- num_blocks_n = triton.cdiv(n, block_size[1])
- weight_scale = weight_scale.expand(num_blocks_m, num_blocks_n).contiguous()
- else:
- # Block quantization path
- p, q = weight_scale.shape
- block_size = getattr(weight, "block_size", None) or getattr(
- weight_scale, "block_size", [128, 128]
- )
- assert block_size is not None, "block_size is not set"
- if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q:
- if (
- triton.cdiv(m, block_size[0]) == q
- and triton.cdiv(n, block_size[1]) == p
- ):
- weight_scale = weight_scale.T
- original_weight_scale = weight_scale # Update for transposed case
- else:
- raise ValueError(
- f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}"
- )
-
- if not weight.is_contiguous():
- weight = weight.contiguous()
-
- # Quantize input and run FP8 matmul
- qinput, scale = act_quant(X, block_size[1])
- output = fp8_block_matmul(
- qinput,
- weight,
- scale,
- weight_scale,
- block_size,
- output_dtype = X.dtype,
- )
- ctx.weight = weight
- ctx.weight_scale = original_weight_scale # Save original for backward
- return output.to(X.dtype)
-
- @staticmethod
- def backward(ctx, grad_output):
- W_deq = weight_dequant(ctx.weight, ctx.weight_scale)
- grad_X = torch_matmul(grad_output, W_deq)
- del W_deq
- return grad_X, None, None
-
-
-@torch_compile
-def fp8_torch_block_quant_forward(X, weight, weight_scale):
- return FP8BlockQuantLinear.apply(X, weight, weight_scale)
-
-
-class FbgemmFp8Linear_matmul(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x, weight, weight_scale, bias = None):
- if weight.shape[0] == weight_scale.shape[0] and (
- weight.shape[0] % 8 == 0 and weight.shape[1] % 8 == 0
- ):
- # Edit: The kernel seems to expect that the weight has dimensions divisible by 8. Otherwise it throws `RuntimeError: cutlass cannot implement`
- # One thing we can do is to pad the weight and weight scale to multiple of 8 and perform a F8F8BF16 operation.
- # I tried benchmarking that for speed but observed that dequantize+bf16 matmul is significantly faster than padding+f8f8bf16 matmul. So we'll go that route.
- # So essentially, f8f8bf16_rowise only happens when shapes are proper (no transposes) and divisible by 8.
-
- # quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
- output_shape = (*x.shape[:-1], -1)
- # x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
- # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
- x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
- x.view(-1, x.shape[-1]).contiguous(),
- scale_ub = getattr(weight, "input_scale_ub", None),
- )
- # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
- # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
-
- # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
- weight_scale_float32 = weight_scale.to(torch.float32)
-
- if not weight.is_contiguous():
- weight = weight.contiguous()
- if not weight_scale.is_contiguous():
- weight_scale = weight_scale.contiguous()
-
- output = torch.ops.fbgemm.f8f8bf16_rowwise(
- x_quantized, weight, x_scale, weight_scale_float32, use_fast_accum = True
- )
- output = output + bias if bias is not None else output
- # Hacky for now, we have the output to the device of x
- output = output.to(x.device, x.dtype)
- output = output.reshape(output_shape)
- del x_quantized, x_scale
- elif (
- weight.shape[0] != weight_scale.shape[0]
- and weight.shape[1] == weight_scale.shape[0]
- ) or (weight.shape[0] // 8 != 0 or weight.shape[1] // 8 != 0):
- # Either the weight/scale is transposed or its shape is not divisible by 8. Both cases, dequantizing is the preferred way.
- # The transpose case is generally noticed in backward pass when we do dY@W instead of @W.T as we do for forward.
- # The shape case, I noticed to happen in MLP of Qwen 2.5 VL 7B where the gate proj is of shape (3420, 1280) and 3420/8=427.5
-
- W_deq = weight_dequant(weight, weight_scale).T
- output = torch_matmul(x, W_deq)
- del W_deq
- else:
- raise ValueError(
- f"Shapes are incompatible {weight.shape = }, {weight_scale.shape = }, {x.shape = }"
- )
-
- ctx.weight = weight
- ctx.weight_scale = weight_scale
- return output
-
- @staticmethod
- def backward(ctx, grad_output):
- W_deq = weight_dequant(ctx.weight, ctx.weight_scale)
- grad_X = torch_matmul(grad_output, W_deq)
- del W_deq
- return grad_X, None, None, None, None
-
-
-@torch_compile
-def fbgemm_fp8_linear(X, weight, weight_scale, bias = None):
- return FbgemmFp8Linear_matmul.apply(X, weight, weight_scale, bias)
-
-
-class FP8_fbgemm_block_linear(torch.autograd.Function):
- @staticmethod
- def forward(ctx, X, weight, weight_scale, bias = None):
- orig_shape = X.shape
- X = X.view(-1, X.shape[-1])
-
- bs_n, bs_k = getattr(weight, "block_size", None) or getattr(
- weight_scale, "block_size", [128, 128]
- )
- bs_m = bs_n
-
- m, n = weight.shape
- p, q = weight_scale.shape
-
- if triton.cdiv(m, bs_n) != p or triton.cdiv(n, bs_k) != q:
- if triton.cdiv(m, bs_n) == q and triton.cdiv(n, bs_k) == p:
- # weights are transposed during backward pass for training :)
- # We transpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X
- weight_scale = weight_scale.T
- else:
- raise ValueError(
- f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {bs_n, bs_k}"
- )
-
- xq, xs = triton_quantize_fp8_block(X, bs_m, bs_n, None)
- ## TODO: Investigate and resolve the high divergence of this output from baseline
- # WARNING: This causes the outputs to diverge from expected when X has high values in it.
- # That results in the model producing gibberish, especially on longer sequences and training loss starting at high values like 8 instead of <1 ideally
- # Please refrain from using this till this issue is resolved. This exists here just for a future headstart.
- output = torch.ops.fbgemm.f8f8bf16_blockwise(
- xq, weight.contiguous(), xs, weight_scale.contiguous(), bs_m, bs_n, bs_k
- )
- output = output + bias if bias is not None else output
-
- output = output.view(*orig_shape[:-1], -1)
-
- del xq
- del xs
-
- ctx.weight = weight
- ctx.weight_scale = weight_scale
- ctx.block_size = [bs_m, bs_n, bs_k]
- return output
-
- @staticmethod
- def backward(ctx, grad_output):
- W_deq = weight_dequant(ctx.weight, ctx.weight_scale)
- grad_X = torch_matmul(grad_output, W_deq)
- del W_deq
- return grad_X, None, None, None, None
-
-
-@torch_compile
-def fp8_fbgemm_block_linear(X, weight, weight_scale, bias = None):
- return FP8_fbgemm_block_linear.apply(X, weight, weight_scale, bias)
-
-
-def test_has_fbgemm():
- # We must manually check if the faster FBGEMM works on the specific GPU
- # For example RTX 5090 and RTX 4090 does not work
- # Also SM100 (Blackwell B200/B100) GPUs fail with CUTLASS SM90 kernels
- # [TODO] Investigate with TorchAO why FBGEMM fails on consumer GPUs
- M, N, K = 128, 128, 128
- xq = torch.ones(M, K, dtype = torch.float8_e4m3fn, device = "cuda")
- wq = xq
- M, K = xq.shape
- N, _ = wq.shape
- block_scale = torch.ones(M // 128, K // 128, dtype = torch.float32, device = "cuda")
- has_fbgemm = False
- try:
- out = torch.ops.fbgemm.f8f8bf16_blockwise(xq, wq, block_scale, block_scale)
- assert torch.unique(out).item() == 128
- has_fbgemm = True
- del out
- except Exception as e:
- error_str = str(e).lower()
- # Catch any CUTLASS/CUDA errors and disable FBGEMM
- # This includes MMA instruction errors, architecture mismatches, kernel launch failures, etc.
- cutlass_cuda_errors = (
- "cutlass",
- "cuda error",
- "cuda runtime error",
- "no kernel image",
- "arch conditional",
- "mma instruction",
- "compute capability",
- "cute_invalid_control_path",
- "tma",
- )
- is_cutlass_cuda_error = any(err in error_str for err in cutlass_cuda_errors)
-
- if is_cutlass_cuda_error:
- print(
- "Unsloth: FBGEMM on the current GPU cannot load - will switch to Triton kernels"
- )
- else:
- print(
- f"Unsloth: FBGEMM on the current GPU cannot load with error = {e} - will switch to Triton kernels"
- )
- has_fbgemm = False
- del block_scale, xq
- torch.cuda.empty_cache()
- return has_fbgemm
-
-
-fp8_block_quant_linear = fp8_torch_block_quant_forward
-if "UNSLOTH_HAS_FBGEMM" not in os.environ:
- os.environ["UNSLOTH_HAS_FBGEMM"] = "0"
-try:
- import fbgemm_gpu
-
- # Older versions cause numerical imprecisions resulting in NaNs especially when X has high values in it.
- # This is both fast and accurate hence preferred.
- # This makes it 15% faster than the torchao implementation.
- if Version(fbgemm_gpu.__version__) >= Version("1.4.0"):
- # We must manually confirm if blockwise FBGEMM works!
- # This check is a must for consumer grade GPUs which fail
- # Suppress CUDA device printf during probe -- on Blackwell (SM100) GPUs,
- # FBGEMM's CUTLASS blockwise kernel (hardcoded SM90) fires thousands of
- # "Arch conditional MMA" lines to stdout fd 1 before aborting.
- from unsloth.import_fixes import suppress_cuda_printf
-
- with suppress_cuda_printf():
- _has_fbgemm = test_has_fbgemm()
- if _has_fbgemm:
- os.environ["UNSLOTH_HAS_FBGEMM"] = "1"
- logger.info(f"Using fbgemm_gpu block quantized FP8 matmul")
- fp8_block_quant_linear = fp8_fbgemm_block_linear
- else:
- os.environ["UNSLOTH_HAS_FBGEMM"] = "0"
-except:
- pass
-
-
-@torch_compile
-def fp8_linear(X, weight, weight_scale, bias = None):
- # Per-tensor quantization: single scalar scale for entire weight
- # Block quantized FP8: 2D scale tensor with multiple columns
- if weight_scale.numel() == 1 or (
- weight_scale.ndim == 2 and weight_scale.shape[1] > 1
- ):
- out = fp8_block_quant_linear(X, weight, weight_scale)
- # Row/channel quantized FP8: 2D scale with shape (n, 1)
- else:
- out = fbgemm_fp8_linear(X, weight, weight_scale, bias)
- return out
-
-
-def module_forward_patch(forward_function, scale_attr = "weight_scale"):
- def patched_forward(self, X):
- return forward_function(X, self.weight, getattr(self, scale_attr))
-
- return patched_forward
-
-
-# Patch the forward functions of the layers (for compiled models)
-if FbgemmFp8Linear is not None:
- FbgemmFp8Linear.forward = module_forward_patch(fbgemm_fp8_linear, "weight_scale")
-if FP8Linear is not None:
- FP8Linear.forward = module_forward_patch(fp8_block_quant_linear, "weight_scale_inv")
diff --git a/unsloth/kernels/geglu.py b/unsloth/kernels/geglu.py
index 50b4e521d3..1ece87c080 100644
--- a/unsloth/kernels/geglu.py
+++ b/unsloth/kernels/geglu.py
@@ -18,46 +18,28 @@
from .utils import (
calculate_settings,
triton_tanh,
- torch_gpu_device,
+ torch_cuda_device,
)
-# signed int32 max is 2**31-1 so num_elements cannot exceed 2**31
-NUM_INT32_ELEMENTS = 2**31
-SAFE_INT32_BUFFER_MULTIPLIER = 4
-BLOCK_SIZE = 1024
-INT32_SAFETY_BUFFER = NUM_INT32_ELEMENTS - BLOCK_SIZE * SAFE_INT32_BUFFER_MULTIPLIER
-
@triton.jit
-def _exact_forward_kernel(
- e,
- g,
- h,
- n_elements,
- BLOCK_SIZE: tl.constexpr,
- LONG_INDEXING: tl.constexpr,
-):
+def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
- if LONG_INDEXING:
- offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
- tl.int64
- )
- n_elements = tl.cast(n_elements, tl.int64)
- else:
- offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
# h = f * up
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
- g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
- f_row = f_row.to(g_row.dtype) # Exact copy from HF
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask = mask)
+pass
def geglu_exact_forward_kernel(gate, up):
@@ -65,28 +47,15 @@ def geglu_exact_forward_kernel(gate, up):
n_elements = gate.numel()
device = gate.device
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- with torch_gpu_device(device):
- _exact_forward_kernel[grid](
- gate,
- up,
- out,
- n_elements,
- BLOCK_SIZE = BLOCK_SIZE,
- LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
- )
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ with torch_cuda_device(device):
+ _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
+pass
@triton.jit
-def _exact_backward_kernel(
- DW,
- e,
- g,
- n_elements,
- BLOCK_SIZE: tl.constexpr,
- LONG_INDEXING: tl.constexpr,
-):
+def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
"""
f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
h = f * up
@@ -98,96 +67,74 @@ def _exact_backward_kernel(
f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
"""
block_idx = tl.program_id(0)
- if LONG_INDEXING:
- offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
- tl.int64
- )
- n_elements = tl.cast(n_elements, tl.int64)
- else:
- offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
- DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
- e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
- g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# Break e_row away for re-use
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
f_row = f_partial_row * e_row
-
+
f_row = f_row.to(DW_row.dtype)
# h = f * g
- h_row = f_row * g_row
+ h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
dg_row = DW_row * g_row
# df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
- t = 0.3989422804014327 # 1/sqrt(2*pi)
+ t = 0.3989422804014327 # 1/sqrt(2*pi)
df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
de_row = dg_row.to(tl.float32) * df_de
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
- tl.store(DW + offsets, h_row, mask = mask) # h = f * g
- tl.store(e + offsets, df_row, mask = mask) # df = DW * f
- tl.store(g + offsets, de_row, mask = mask) # de
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
+ tl.store(g + offsets, de_row, mask = mask) # de
+pass
def geglu_exact_backward_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- with torch_gpu_device(e.device):
- _exact_backward_kernel[grid](
- DW,
- e,
- g,
- n_elements,
- BLOCK_SIZE = BLOCK_SIZE,
- LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
- )
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ with torch_cuda_device(e.device):
+ _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
+pass
@triton.jit
-def _approx_forward_kernel(
- e,
- g,
- h,
- n_elements,
- BLOCK_SIZE: tl.constexpr,
- LONG_INDEXING: tl.constexpr,
-):
+def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
- if LONG_INDEXING:
- offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
- tl.int64
- )
- n_elements = tl.cast(n_elements, tl.int64)
- else:
- offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
# h = f * up
- s = 0.7978845608028654 # math.sqrt(2 / math.pi)
-
+ s = 0.7978845608028654 # math.sqrt(2 / math.pi)
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
- g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
- f_row = (
- 0.5 * e_row * (triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) + 1.0)
+ f_row = 0.5 * e_row * (
+ triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
+ + 1.0
)
- f_row = f_row.to(g_row.dtype) # Exact copy from HF
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask = mask)
+pass
def geglu_approx_forward_kernel(gate, up):
@@ -195,28 +142,15 @@ def geglu_approx_forward_kernel(gate, up):
n_elements = gate.numel()
device = gate.device
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- with torch_gpu_device(device):
- _approx_forward_kernel[grid](
- gate,
- up,
- out,
- n_elements,
- BLOCK_SIZE = BLOCK_SIZE,
- LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
- )
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ with torch_cuda_device(device):
+ _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
return out
+pass
@triton.jit
-def _approx_backward_kernel(
- DW,
- e,
- g,
- n_elements,
- BLOCK_SIZE: tl.constexpr,
- LONG_INDEXING: tl.constexpr,
-):
+def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
"""
f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
h = f * up
@@ -232,34 +166,28 @@ def _approx_backward_kernel(
See https://www.desmos.com/calculator/nqprfoni6x
"""
block_idx = tl.program_id(0)
- if LONG_INDEXING:
- offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
- tl.int64
- )
- n_elements = tl.cast(n_elements, tl.int64)
- else:
- offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
- DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
- e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
- g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# See https://www.desmos.com/calculator/nqprfoni6x
- s = 0.7978845608028654 # math.sqrt(2 / math.pi)
- a = s * e_row # a = sqrt(2 / pi) * x
- b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
+ s = 0.7978845608028654 # math.sqrt(2 / math.pi)
+ a = s * e_row # a = sqrt(2 / pi) * x
+ b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
T = 1.0 + triton_tanh(a + b)
T2 = 0.5 * T
# Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
- Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
- df_de = T2 + Q2 # 1/2 * (T + Q)
+ Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
+ df_de = T2 + Q2 # 1/2 * (T + Q)
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
f_row = T2 * e_row
f_row = f_row.to(DW_row.dtype)
# h = f * g
- h_row = f_row * g_row
+ h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
@@ -269,22 +197,17 @@ def _approx_backward_kernel(
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
- tl.store(DW + offsets, h_row, mask = mask) # h = f * g
- tl.store(e + offsets, df_row, mask = mask) # df = DW * f
- tl.store(g + offsets, de_row, mask = mask) # de
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
+ tl.store(g + offsets, de_row, mask = mask) # de
+pass
def geglu_approx_backward_kernel(DW, e, g):
batch_seq_len, hd = e.shape
n_elements = e.numel()
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- with torch_gpu_device(e.device):
- _approx_backward_kernel[grid](
- DW,
- e,
- g,
- n_elements,
- BLOCK_SIZE = BLOCK_SIZE,
- LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
- )
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ with torch_cuda_device(e.device):
+ _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
+pass
diff --git a/unsloth/kernels/layernorm.py b/unsloth/kernels/layernorm.py
index 9e64c3d341..ed8182014e 100644
--- a/unsloth/kernels/layernorm.py
+++ b/unsloth/kernels/layernorm.py
@@ -16,7 +16,7 @@
import triton
import triton.language as tl
import torch
-from .utils import calculate_settings, torch_gpu_device
+from .utils import calculate_settings, torch_cuda_device
from unsloth_zoo.patching_utils import (
patch_layernorm,
)
@@ -24,25 +24,23 @@
@triton.jit
def layernorm_forward(
- Y,
- Y_row_stride,
- X,
- X_row_stride,
+ Y, Y_row_stride,
+ X, X_row_stride,
W,
b,
r,
mu,
- n_cols: tl.constexpr,
- eps: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
+ n_cols : tl.constexpr,
+ eps : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr
):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
- Y += row_idx * Y_row_stride
- X += row_idx * X_row_stride
- r += row_idx
+ Y += row_idx * Y_row_stride
+ X += row_idx * X_row_stride
+ r += row_idx
mu += row_idx
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
@@ -51,32 +49,29 @@ def layernorm_forward(
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
- mean_X = tl.sum(X_row, axis = 0) / n_cols
+ mean_X = tl.sum(X_row, axis = 0) / n_cols
# (X[0] - mean) == -mean so we need to mask it out
XX = tl.where(mask, X_row - mean_X, 0)
row_var = tl.sum(XX * XX, axis = 0) / n_cols
- # Explicit float32 scalar to ensure correct type promotion on HIP/ROCm
- eps_f32 = tl.full((), eps, tl.float32)
- inv_var = tl.math.rsqrt(row_var + eps_f32)
- tl.store(r, inv_var)
- tl.store(mu, mean_X)
+ inv_var = tl.math.rsqrt(row_var + eps)
+ tl.store (r, inv_var)
+ tl.store (mu, mean_X)
output = (XX * inv_var) * W_row + b_row
tl.store(Y + col_offsets, output, mask = mask)
+pass
@triton.jit
def layernorm_backward(
- dY,
- dY_row_stride,
- X,
- X_row_stride,
+ dY, dY_row_stride,
+ X, X_row_stride,
W,
b,
r,
mu,
- n_cols: tl.constexpr,
- eps: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
+ n_cols : tl.constexpr,
+ eps : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr
):
# Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
row_idx = tl.program_id(0)
@@ -84,28 +79,25 @@ def layernorm_backward(
mask = col_offsets < n_cols
dY += row_idx * dY_row_stride
- X += row_idx * X_row_stride
- r += row_idx
+ X += row_idx * X_row_stride
+ r += row_idx
mu += row_idx
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
# are in float32!
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
- X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
- W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
- b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
+ b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
- inv_var = tl.load(r).to(tl.float32)
- mean = tl.load(mu).to(tl.float32)
- normed = (X_row - mean) * inv_var
+ inv_var = tl.load(r) .to(tl.float32)
+ mean = tl.load(mu).to(tl.float32)
+ normed = (X_row - mean) * inv_var
dY_W = dY_row * W_row
- dX_row = (
- dY_W
- - tl.sum(dY_W, axis = 0) / n_cols
- - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
- )
+ dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
dX_row = dX_row * inv_var
tl.store(dY + col_offsets, dX_row, mask = mask)
+pass
class Fast_Layernorm(torch.autograd.Function):
@@ -117,30 +109,28 @@ def forward(ctx, X, W, b, eps):
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
device = X.device
- Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
- r = torch.empty(n_rows, dtype = torch.float32, device = device)
+ Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
+ r = torch.empty(n_rows, dtype = torch.float32, device = device)
mu = torch.empty(n_rows, dtype = torch.float32, device = device)
- with torch_gpu_device(device):
+ with torch_cuda_device(device):
layernorm_forward[(n_rows,)](
- Y,
- Y.stride(0),
- X,
- X.stride(0),
+ Y, Y.stride(0),
+ X, X.stride(0),
W,
b,
r,
mu,
- n_cols,
- eps,
+ n_cols, eps,
BLOCK_SIZE = BLOCK_SIZE,
- num_warps = num_warps,
+ num_warps = num_warps,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
- ctx.num_warps = num_warps
+ ctx.num_warps = num_warps
ctx.save_for_backward(X, W, b, r, mu)
return Y.view(*shape)
+ pass
@staticmethod
def backward(ctx, dY):
@@ -150,48 +140,42 @@ def backward(ctx, dY):
X, W, b, r, mu = ctx.saved_tensors
n_rows, n_cols = dY.shape
- with torch_gpu_device(dY.device):
+ with torch_cuda_device(dY.device):
layernorm_backward[(n_rows,)](
- dY,
- dY.stride(0),
- X,
- X.stride(0),
+ dY, dY.stride(0),
+ X, X .stride(0),
W,
b,
r,
mu,
- n_cols,
- ctx.eps,
+ n_cols, ctx.eps,
BLOCK_SIZE = ctx.BLOCK_SIZE,
- num_warps = ctx.num_warps,
+ num_warps = ctx.num_warps,
)
dX = dY.view(*shape)
return dX, None, None, None, None
+ pass
+pass
def fast_layernorm(layernorm, X):
- assert layernorm.elementwise_affine is True
- W = layernorm.weight
+ assert(layernorm.elementwise_affine is True)
+ W = layernorm.weight
bias = layernorm.bias
- eps = (
- layernorm.variance_epsilon
- if hasattr(layernorm, "variance_epsilon")
+ eps = layernorm.variance_epsilon if \
+ hasattr(layernorm, "variance_epsilon") \
else layernorm.eps
- )
out = Fast_Layernorm.apply(X, W, bias, eps)
return out
+pass
+
def test_layernorm(
- dim = 1024,
- eps = 1e-5,
- dtype = torch.float16,
- bsz = 21,
- random_state = 3407,
- seqlen = 3341,
+ dim = 1024, eps = 1e-5, dtype = torch.float16,
+ bsz = 21, random_state = 3407, seqlen = 3341,
):
from torch.nn import LayerNorm
-
layernorm = LayerNorm((dim,), eps = eps, device = "cuda", dtype = dtype)
torch.cuda.manual_seed(random_state)
torch.manual_seed(random_state)
@@ -199,7 +183,7 @@ def test_layernorm(
torch.nn.init.uniform_(layernorm.bias)
X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
XX = X.clone()
- X.requires_grad_(True)
+ X .requires_grad_(True)
XX.requires_grad_(True)
Y = layernorm(X)
YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
@@ -208,7 +192,8 @@ def test_layernorm(
# from unsloth.kernels import fast_layernorm
Y = fast_layernorm(layernorm, XX)
Y.backward(YY)
- assert torch.dist(correct_grad, XX.grad).item() <= 0.1
+ assert(torch.dist(correct_grad, XX.grad).item() <= 0.1)
+pass
def testing_suite_layernorm():
@@ -225,3 +210,9 @@ def testing_suite_layernorm():
random_state = random_state,
seqlen = seqlen,
)
+ pass
+ pass
+ pass
+ pass
+ pass
+pass
diff --git a/unsloth/kernels/moe/LICENSE b/unsloth/kernels/moe/LICENSE
deleted file mode 100644
index 29ebfa545f..0000000000
--- a/unsloth/kernels/moe/LICENSE
+++ /dev/null
@@ -1,661 +0,0 @@
- GNU AFFERO GENERAL PUBLIC LICENSE
- Version 3, 19 November 2007
-
- Copyright (C) 2007 Free Software Foundation, Inc.
- Everyone is permitted to copy and distribute verbatim copies
- of this license document, but changing it is not allowed.
-
- Preamble
-
- The GNU Affero General Public License is a free, copyleft license for
-software and other kinds of works, specifically designed to ensure
-cooperation with the community in the case of network server software.
-
- The licenses for most software and other practical works are designed
-to take away your freedom to share and change the works. By contrast,
-our General Public Licenses are intended to guarantee your freedom to
-share and change all versions of a program--to make sure it remains free
-software for all its users.
-
- When we speak of free software, we are referring to freedom, not
-price. Our General Public Licenses are designed to make sure that you
-have the freedom to distribute copies of free software (and charge for
-them if you wish), that you receive source code or can get it if you
-want it, that you can change the software or use pieces of it in new
-free programs, and that you know you can do these things.
-
- Developers that use our General Public Licenses protect your rights
-with two steps: (1) assert copyright on the software, and (2) offer
-you this License which gives you legal permission to copy, distribute
-and/or modify the software.
-
- A secondary benefit of defending all users' freedom is that
-improvements made in alternate versions of the program, if they
-receive widespread use, become available for other developers to
-incorporate. Many developers of free software are heartened and
-encouraged by the resulting cooperation. However, in the case of
-software used on network servers, this result may fail to come about.
-The GNU General Public License permits making a modified version and
-letting the public access it on a server without ever releasing its
-source code to the public.
-
- The GNU Affero General Public License is designed specifically to
-ensure that, in such cases, the modified source code becomes available
-to the community. It requires the operator of a network server to
-provide the source code of the modified version running there to the
-users of that server. Therefore, public use of a modified version, on
-a publicly accessible server, gives the public access to the source
-code of the modified version.
-
- An older license, called the Affero General Public License and
-published by Affero, was designed to accomplish similar goals. This is
-a different license, not a version of the Affero GPL, but Affero has
-released a new version of the Affero GPL which permits relicensing under
-this license.
-
- The precise terms and conditions for copying, distribution and
-modification follow.
-
- TERMS AND CONDITIONS
-
- 0. Definitions.
-
- "This License" refers to version 3 of the GNU Affero General Public License.
-
- "Copyright" also means copyright-like laws that apply to other kinds of
-works, such as semiconductor masks.
-
- "The Program" refers to any copyrightable work licensed under this
-License. Each licensee is addressed as "you". "Licensees" and
-"recipients" may be individuals or organizations.
-
- To "modify" a work means to copy from or adapt all or part of the work
-in a fashion requiring copyright permission, other than the making of an
-exact copy. The resulting work is called a "modified version" of the
-earlier work or a work "based on" the earlier work.
-
- A "covered work" means either the unmodified Program or a work based
-on the Program.
-
- To "propagate" a work means to do anything with it that, without
-permission, would make you directly or secondarily liable for
-infringement under applicable copyright law, except executing it on a
-computer or modifying a private copy. Propagation includes copying,
-distribution (with or without modification), making available to the
-public, and in some countries other activities as well.
-
- To "convey" a work means any kind of propagation that enables other
-parties to make or receive copies. Mere interaction with a user through
-a computer network, with no transfer of a copy, is not conveying.
-
- An interactive user interface displays "Appropriate Legal Notices"
-to the extent that it includes a convenient and prominently visible
-feature that (1) displays an appropriate copyright notice, and (2)
-tells the user that there is no warranty for the work (except to the
-extent that warranties are provided), that licensees may convey the
-work under this License, and how to view a copy of this License. If
-the interface presents a list of user commands or options, such as a
-menu, a prominent item in the list meets this criterion.
-
- 1. Source Code.
-
- The "source code" for a work means the preferred form of the work
-for making modifications to it. "Object code" means any non-source
-form of a work.
-
- A "Standard Interface" means an interface that either is an official
-standard defined by a recognized standards body, or, in the case of
-interfaces specified for a particular programming language, one that
-is widely used among developers working in that language.
-
- The "System Libraries" of an executable work include anything, other
-than the work as a whole, that (a) is included in the normal form of
-packaging a Major Component, but which is not part of that Major
-Component, and (b) serves only to enable use of the work with that
-Major Component, or to implement a Standard Interface for which an
-implementation is available to the public in source code form. A
-"Major Component", in this context, means a major essential component
-(kernel, window system, and so on) of the specific operating system
-(if any) on which the executable work runs, or a compiler used to
-produce the work, or an object code interpreter used to run it.
-
- The "Corresponding Source" for a work in object code form means all
-the source code needed to generate, install, and (for an executable
-work) run the object code and to modify the work, including scripts to
-control those activities. However, it does not include the work's
-System Libraries, or general-purpose tools or generally available free
-programs which are used unmodified in performing those activities but
-which are not part of the work. For example, Corresponding Source
-includes interface definition files associated with source files for
-the work, and the source code for shared libraries and dynamically
-linked subprograms that the work is specifically designed to require,
-such as by intimate data communication or control flow between those
-subprograms and other parts of the work.
-
- The Corresponding Source need not include anything that users
-can regenerate automatically from other parts of the Corresponding
-Source.
-
- The Corresponding Source for a work in source code form is that
-same work.
-
- 2. Basic Permissions.
-
- All rights granted under this License are granted for the term of
-copyright on the Program, and are irrevocable provided the stated
-conditions are met. This License explicitly affirms your unlimited
-permission to run the unmodified Program. The output from running a
-covered work is covered by this License only if the output, given its
-content, constitutes a covered work. This License acknowledges your
-rights of fair use or other equivalent, as provided by copyright law.
-
- You may make, run and propagate covered works that you do not
-convey, without conditions so long as your license otherwise remains
-in force. You may convey covered works to others for the sole purpose
-of having them make modifications exclusively for you, or provide you
-with facilities for running those works, provided that you comply with
-the terms of this License in conveying all material for which you do
-not control copyright. Those thus making or running the covered works
-for you must do so exclusively on your behalf, under your direction
-and control, on terms that prohibit them from making any copies of
-your copyrighted material outside their relationship with you.
-
- Conveying under any other circumstances is permitted solely under
-the conditions stated below. Sublicensing is not allowed; section 10
-makes it unnecessary.
-
- 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
-
- No covered work shall be deemed part of an effective technological
-measure under any applicable law fulfilling obligations under article
-11 of the WIPO copyright treaty adopted on 20 December 1996, or
-similar laws prohibiting or restricting circumvention of such
-measures.
-
- When you convey a covered work, you waive any legal power to forbid
-circumvention of technological measures to the extent such circumvention
-is effected by exercising rights under this License with respect to
-the covered work, and you disclaim any intention to limit operation or
-modification of the work as a means of enforcing, against the work's
-users, your or third parties' legal rights to forbid circumvention of
-technological measures.
-
- 4. Conveying Verbatim Copies.
-
- You may convey verbatim copies of the Program's source code as you
-receive it, in any medium, provided that you conspicuously and
-appropriately publish on each copy an appropriate copyright notice;
-keep intact all notices stating that this License and any
-non-permissive terms added in accord with section 7 apply to the code;
-keep intact all notices of the absence of any warranty; and give all
-recipients a copy of this License along with the Program.
-
- You may charge any price or no price for each copy that you convey,
-and you may offer support or warranty protection for a fee.
-
- 5. Conveying Modified Source Versions.
-
- You may convey a work based on the Program, or the modifications to
-produce it from the Program, in the form of source code under the
-terms of section 4, provided that you also meet all of these conditions:
-
- a) The work must carry prominent notices stating that you modified
- it, and giving a relevant date.
-
- b) The work must carry prominent notices stating that it is
- released under this License and any conditions added under section
- 7. This requirement modifies the requirement in section 4 to
- "keep intact all notices".
-
- c) You must license the entire work, as a whole, under this
- License to anyone who comes into possession of a copy. This
- License will therefore apply, along with any applicable section 7
- additional terms, to the whole of the work, and all its parts,
- regardless of how they are packaged. This License gives no
- permission to license the work in any other way, but it does not
- invalidate such permission if you have separately received it.
-
- d) If the work has interactive user interfaces, each must display
- Appropriate Legal Notices; however, if the Program has interactive
- interfaces that do not display Appropriate Legal Notices, your
- work need not make them do so.
-
- A compilation of a covered work with other separate and independent
-works, which are not by their nature extensions of the covered work,
-and which are not combined with it such as to form a larger program,
-in or on a volume of a storage or distribution medium, is called an
-"aggregate" if the compilation and its resulting copyright are not
-used to limit the access or legal rights of the compilation's users
-beyond what the individual works permit. Inclusion of a covered work
-in an aggregate does not cause this License to apply to the other
-parts of the aggregate.
-
- 6. Conveying Non-Source Forms.
-
- You may convey a covered work in object code form under the terms
-of sections 4 and 5, provided that you also convey the
-machine-readable Corresponding Source under the terms of this License,
-in one of these ways:
-
- a) Convey the object code in, or embodied in, a physical product
- (including a physical distribution medium), accompanied by the
- Corresponding Source fixed on a durable physical medium
- customarily used for software interchange.
-
- b) Convey the object code in, or embodied in, a physical product
- (including a physical distribution medium), accompanied by a
- written offer, valid for at least three years and valid for as
- long as you offer spare parts or customer support for that product
- model, to give anyone who possesses the object code either (1) a
- copy of the Corresponding Source for all the software in the
- product that is covered by this License, on a durable physical
- medium customarily used for software interchange, for a price no
- more than your reasonable cost of physically performing this
- conveying of source, or (2) access to copy the
- Corresponding Source from a network server at no charge.
-
- c) Convey individual copies of the object code with a copy of the
- written offer to provide the Corresponding Source. This
- alternative is allowed only occasionally and noncommercially, and
- only if you received the object code with such an offer, in accord
- with subsection 6b.
-
- d) Convey the object code by offering access from a designated
- place (gratis or for a charge), and offer equivalent access to the
- Corresponding Source in the same way through the same place at no
- further charge. You need not require recipients to copy the
- Corresponding Source along with the object code. If the place to
- copy the object code is a network server, the Corresponding Source
- may be on a different server (operated by you or a third party)
- that supports equivalent copying facilities, provided you maintain
- clear directions next to the object code saying where to find the
- Corresponding Source. Regardless of what server hosts the
- Corresponding Source, you remain obligated to ensure that it is
- available for as long as needed to satisfy these requirements.
-
- e) Convey the object code using peer-to-peer transmission, provided
- you inform other peers where the object code and Corresponding
- Source of the work are being offered to the general public at no
- charge under subsection 6d.
-
- A separable portion of the object code, whose source code is excluded
-from the Corresponding Source as a System Library, need not be
-included in conveying the object code work.
-
- A "User Product" is either (1) a "consumer product", which means any
-tangible personal property which is normally used for personal, family,
-or household purposes, or (2) anything designed or sold for incorporation
-into a dwelling. In determining whether a product is a consumer product,
-doubtful cases shall be resolved in favor of coverage. For a particular
-product received by a particular user, "normally used" refers to a
-typical or common use of that class of product, regardless of the status
-of the particular user or of the way in which the particular user
-actually uses, or expects or is expected to use, the product. A product
-is a consumer product regardless of whether the product has substantial
-commercial, industrial or non-consumer uses, unless such uses represent
-the only significant mode of use of the product.
-
- "Installation Information" for a User Product means any methods,
-procedures, authorization keys, or other information required to install
-and execute modified versions of a covered work in that User Product from
-a modified version of its Corresponding Source. The information must
-suffice to ensure that the continued functioning of the modified object
-code is in no case prevented or interfered with solely because
-modification has been made.
-
- If you convey an object code work under this section in, or with, or
-specifically for use in, a User Product, and the conveying occurs as
-part of a transaction in which the right of possession and use of the
-User Product is transferred to the recipient in perpetuity or for a
-fixed term (regardless of how the transaction is characterized), the
-Corresponding Source conveyed under this section must be accompanied
-by the Installation Information. But this requirement does not apply
-if neither you nor any third party retains the ability to install
-modified object code on the User Product (for example, the work has
-been installed in ROM).
-
- The requirement to provide Installation Information does not include a
-requirement to continue to provide support service, warranty, or updates
-for a work that has been modified or installed by the recipient, or for
-the User Product in which it has been modified or installed. Access to a
-network may be denied when the modification itself materially and
-adversely affects the operation of the network or violates the rules and
-protocols for communication across the network.
-
- Corresponding Source conveyed, and Installation Information provided,
-in accord with this section must be in a format that is publicly
-documented (and with an implementation available to the public in
-source code form), and must require no special password or key for
-unpacking, reading or copying.
-
- 7. Additional Terms.
-
- "Additional permissions" are terms that supplement the terms of this
-License by making exceptions from one or more of its conditions.
-Additional permissions that are applicable to the entire Program shall
-be treated as though they were included in this License, to the extent
-that they are valid under applicable law. If additional permissions
-apply only to part of the Program, that part may be used separately
-under those permissions, but the entire Program remains governed by
-this License without regard to the additional permissions.
-
- When you convey a copy of a covered work, you may at your option
-remove any additional permissions from that copy, or from any part of
-it. (Additional permissions may be written to require their own
-removal in certain cases when you modify the work.) You may place
-additional permissions on material, added by you to a covered work,
-for which you have or can give appropriate copyright permission.
-
- Notwithstanding any other provision of this License, for material you
-add to a covered work, you may (if authorized by the copyright holders of
-that material) supplement the terms of this License with terms:
-
- a) Disclaiming warranty or limiting liability differently from the
- terms of sections 15 and 16 of this License; or
-
- b) Requiring preservation of specified reasonable legal notices or
- author attributions in that material or in the Appropriate Legal
- Notices displayed by works containing it; or
-
- c) Prohibiting misrepresentation of the origin of that material, or
- requiring that modified versions of such material be marked in
- reasonable ways as different from the original version; or
-
- d) Limiting the use for publicity purposes of names of licensors or
- authors of the material; or
-
- e) Declining to grant rights under trademark law for use of some
- trade names, trademarks, or service marks; or
-
- f) Requiring indemnification of licensors and authors of that
- material by anyone who conveys the material (or modified versions of
- it) with contractual assumptions of liability to the recipient, for
- any liability that these contractual assumptions directly impose on
- those licensors and authors.
-
- All other non-permissive additional terms are considered "further
-restrictions" within the meaning of section 10. If the Program as you
-received it, or any part of it, contains a notice stating that it is
-governed by this License along with a term that is a further
-restriction, you may remove that term. If a license document contains
-a further restriction but permits relicensing or conveying under this
-License, you may add to a covered work material governed by the terms
-of that license document, provided that the further restriction does
-not survive such relicensing or conveying.
-
- If you add terms to a covered work in accord with this section, you
-must place, in the relevant source files, a statement of the
-additional terms that apply to those files, or a notice indicating
-where to find the applicable terms.
-
- Additional terms, permissive or non-permissive, may be stated in the
-form of a separately written license, or stated as exceptions;
-the above requirements apply either way.
-
- 8. Termination.
-
- You may not propagate or modify a covered work except as expressly
-provided under this License. Any attempt otherwise to propagate or
-modify it is void, and will automatically terminate your rights under
-this License (including any patent licenses granted under the third
-paragraph of section 11).
-
- However, if you cease all violation of this License, then your
-license from a particular copyright holder is reinstated (a)
-provisionally, unless and until the copyright holder explicitly and
-finally terminates your license, and (b) permanently, if the copyright
-holder fails to notify you of the violation by some reasonable means
-prior to 60 days after the cessation.
-
- Moreover, your license from a particular copyright holder is
-reinstated permanently if the copyright holder notifies you of the
-violation by some reasonable means, this is the first time you have
-received notice of violation of this License (for any work) from that
-copyright holder, and you cure the violation prior to 30 days after
-your receipt of the notice.
-
- Termination of your rights under this section does not terminate the
-licenses of parties who have received copies or rights from you under
-this License. If your rights have been terminated and not permanently
-reinstated, you do not qualify to receive new licenses for the same
-material under section 10.
-
- 9. Acceptance Not Required for Having Copies.
-
- You are not required to accept this License in order to receive or
-run a copy of the Program. Ancillary propagation of a covered work
-occurring solely as a consequence of using peer-to-peer transmission
-to receive a copy likewise does not require acceptance. However,
-nothing other than this License grants you permission to propagate or
-modify any covered work. These actions infringe copyright if you do
-not accept this License. Therefore, by modifying or propagating a
-covered work, you indicate your acceptance of this License to do so.
-
- 10. Automatic Licensing of Downstream Recipients.
-
- Each time you convey a covered work, the recipient automatically
-receives a license from the original licensors, to run, modify and
-propagate that work, subject to this License. You are not responsible
-for enforcing compliance by third parties with this License.
-
- An "entity transaction" is a transaction transferring control of an
-organization, or substantially all assets of one, or subdividing an
-organization, or merging organizations. If propagation of a covered
-work results from an entity transaction, each party to that
-transaction who receives a copy of the work also receives whatever
-licenses to the work the party's predecessor in interest had or could
-give under the previous paragraph, plus a right to possession of the
-Corresponding Source of the work from the predecessor in interest, if
-the predecessor has it or can get it with reasonable efforts.
-
- You may not impose any further restrictions on the exercise of the
-rights granted or affirmed under this License. For example, you may
-not impose a license fee, royalty, or other charge for exercise of
-rights granted under this License, and you may not initiate litigation
-(including a cross-claim or counterclaim in a lawsuit) alleging that
-any patent claim is infringed by making, using, selling, offering for
-sale, or importing the Program or any portion of it.
-
- 11. Patents.
-
- A "contributor" is a copyright holder who authorizes use under this
-License of the Program or a work on which the Program is based. The
-work thus licensed is called the contributor's "contributor version".
-
- A contributor's "essential patent claims" are all patent claims
-owned or controlled by the contributor, whether already acquired or
-hereafter acquired, that would be infringed by some manner, permitted
-by this License, of making, using, or selling its contributor version,
-but do not include claims that would be infringed only as a
-consequence of further modification of the contributor version. For
-purposes of this definition, "control" includes the right to grant
-patent sublicenses in a manner consistent with the requirements of
-this License.
-
- Each contributor grants you a non-exclusive, worldwide, royalty-free
-patent license under the contributor's essential patent claims, to
-make, use, sell, offer for sale, import and otherwise run, modify and
-propagate the contents of its contributor version.
-
- In the following three paragraphs, a "patent license" is any express
-agreement or commitment, however denominated, not to enforce a patent
-(such as an express permission to practice a patent or covenant not to
-sue for patent infringement). To "grant" such a patent license to a
-party means to make such an agreement or commitment not to enforce a
-patent against the party.
-
- If you convey a covered work, knowingly relying on a patent license,
-and the Corresponding Source of the work is not available for anyone
-to copy, free of charge and under the terms of this License, through a
-publicly available network server or other readily accessible means,
-then you must either (1) cause the Corresponding Source to be so
-available, or (2) arrange to deprive yourself of the benefit of the
-patent license for this particular work, or (3) arrange, in a manner
-consistent with the requirements of this License, to extend the patent
-license to downstream recipients. "Knowingly relying" means you have
-actual knowledge that, but for the patent license, your conveying the
-covered work in a country, or your recipient's use of the covered work
-in a country, would infringe one or more identifiable patents in that
-country that you have reason to believe are valid.
-
- If, pursuant to or in connection with a single transaction or
-arrangement, you convey, or propagate by procuring conveyance of, a
-covered work, and grant a patent license to some of the parties
-receiving the covered work authorizing them to use, propagate, modify
-or convey a specific copy of the covered work, then the patent license
-you grant is automatically extended to all recipients of the covered
-work and works based on it.
-
- A patent license is "discriminatory" if it does not include within
-the scope of its coverage, prohibits the exercise of, or is
-conditioned on the non-exercise of one or more of the rights that are
-specifically granted under this License. You may not convey a covered
-work if you are a party to an arrangement with a third party that is
-in the business of distributing software, under which you make payment
-to the third party based on the extent of your activity of conveying
-the work, and under which the third party grants, to any of the
-parties who would receive the covered work from you, a discriminatory
-patent license (a) in connection with copies of the covered work
-conveyed by you (or copies made from those copies), or (b) primarily
-for and in connection with specific products or compilations that
-contain the covered work, unless you entered into that arrangement,
-or that patent license was granted, prior to 28 March 2007.
-
- Nothing in this License shall be construed as excluding or limiting
-any implied license or other defenses to infringement that may
-otherwise be available to you under applicable patent law.
-
- 12. No Surrender of Others' Freedom.
-
- If conditions are imposed on you (whether by court order, agreement or
-otherwise) that contradict the conditions of this License, they do not
-excuse you from the conditions of this License. If you cannot convey a
-covered work so as to satisfy simultaneously your obligations under this
-License and any other pertinent obligations, then as a consequence you may
-not convey it at all. For example, if you agree to terms that obligate you
-to collect a royalty for further conveying from those to whom you convey
-the Program, the only way you could satisfy both those terms and this
-License would be to refrain entirely from conveying the Program.
-
- 13. Remote Network Interaction; Use with the GNU General Public License.
-
- Notwithstanding any other provision of this License, if you modify the
-Program, your modified version must prominently offer all users
-interacting with it remotely through a computer network (if your version
-supports such interaction) an opportunity to receive the Corresponding
-Source of your version by providing access to the Corresponding Source
-from a network server at no charge, through some standard or customary
-means of facilitating copying of software. This Corresponding Source
-shall include the Corresponding Source for any work covered by version 3
-of the GNU General Public License that is incorporated pursuant to the
-following paragraph.
-
- Notwithstanding any other provision of this License, you have
-permission to link or combine any covered work with a work licensed
-under version 3 of the GNU General Public License into a single
-combined work, and to convey the resulting work. The terms of this
-License will continue to apply to the part which is the covered work,
-but the work with which it is combined will remain governed by version
-3 of the GNU General Public License.
-
- 14. Revised Versions of this License.
-
- The Free Software Foundation may publish revised and/or new versions of
-the GNU Affero General Public License from time to time. Such new versions
-will be similar in spirit to the present version, but may differ in detail to
-address new problems or concerns.
-
- Each version is given a distinguishing version number. If the
-Program specifies that a certain numbered version of the GNU Affero General
-Public License "or any later version" applies to it, you have the
-option of following the terms and conditions either of that numbered
-version or of any later version published by the Free Software
-Foundation. If the Program does not specify a version number of the
-GNU Affero General Public License, you may choose any version ever published
-by the Free Software Foundation.
-
- If the Program specifies that a proxy can decide which future
-versions of the GNU Affero General Public License can be used, that proxy's
-public statement of acceptance of a version permanently authorizes you
-to choose that version for the Program.
-
- Later license versions may give you additional or different
-permissions. However, no additional obligations are imposed on any
-author or copyright holder as a result of your choosing to follow a
-later version.
-
- 15. Disclaimer of Warranty.
-
- THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
-APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
-HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
-OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
-THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
-IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
-ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
-
- 16. Limitation of Liability.
-
- IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
-WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
-THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
-GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
-USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
-DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
-PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
-EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
-SUCH DAMAGES.
-
- 17. Interpretation of Sections 15 and 16.
-
- If the disclaimer of warranty and limitation of liability provided
-above cannot be given local legal effect according to their terms,
-reviewing courts shall apply local law that most closely approximates
-an absolute waiver of all civil liability in connection with the
-Program, unless a warranty or assumption of liability accompanies a
-copy of the Program in return for a fee.
-
- END OF TERMS AND CONDITIONS
-
- How to Apply These Terms to Your New Programs
-
- If you develop a new program, and you want it to be of the greatest
-possible use to the public, the best way to achieve this is to make it
-free software which everyone can redistribute and change under these terms.
-
- To do so, attach the following notices to the program. It is safest
-to attach them to the start of each source file to most effectively
-state the exclusion of warranty; and each file should have at least
-the "copyright" line and a pointer to where the full notice is found.
-
-
- Copyright (C)
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published
- by the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-
-Also add information on how to contact you by electronic and paper mail.
-
- If your software can interact with users remotely through a computer
-network, you should also make sure that it provides a way for users to
-get its source. For example, if your program is a web application, its
-interface could display a "Source" link that leads users to an archive
-of the code. There are many ways you could offer source, and different
-solutions will be better for different programs; see section 13 for the
-specific requirements.
-
- You should also get your employer (if you work as a programmer) or school,
-if any, to sign a "copyright disclaimer" for the program, if necessary.
-For more information on this, and how to apply and follow the GNU AGPL, see
-.
\ No newline at end of file
diff --git a/unsloth/kernels/moe/README.md b/unsloth/kernels/moe/README.md
deleted file mode 100644
index 0375c8d507..0000000000
--- a/unsloth/kernels/moe/README.md
+++ /dev/null
@@ -1,86 +0,0 @@
-## MoE Grouped GEMM
-
-Optimized implementation of `MoE MLP Block`.
-Licensed under AGPLv3.
-
-### Background
-
-`MoE MLP` requires the following steps:
-- Calculate `topk_weights` and `topk_indices`
-- If using a grouped gemm implementation, calculate permutation indices needed to rearrange tokens grouped by expert
-- For each expert:
- - `expert_tokens`: gather the tokens assigned to the expert
- - `first_gemm`: `gate / up proj` @ `expert_tokens`
- - `silu_and_mul`: `silu` and `mul` of `first_gemm`
- - `second_gemm`: `silu_and_mul` @ `down proj`
- - `scatter_second_gemm`: scatter the `second_gemm` to the original token order
- - `topk_weight_mul`: `second_gemm` @ `topk_weights`
- - `final_output`: if `topk > 1`, `topk_weight_mul.view(num_tokens, topk, -1).sum(dim=1)` else `topk_weight_mul`
-
-One way to eliminate the loop is to use a grouped GEMM, where all expert GEMMs are computed within a single kernel, which iterates over tiles of the expert GEMMs as individual GEMMs, where each GEMM, the `A` matrix is `M' x K` and the `B` matrix is `K x N`, where `M'` is the number of tokens assigned to the expert and `B` is the weight matrix for that expert.
-
-This requires an additional permute (and subsequent copy) of the hidden states such that the tokens assigned to each expert are contiguous in memory before running the first grouped GEMM within the Expert MLP.
-Additionally, after the second grouped GEMM, the hidden states must be permuted back to the original token order and multiplied by `topk_weights` to get the final output.
-
-### Optimizations
-This repo implements a grouped GEMM-based MoE MLP with the following optimizations:
-- Eliminates the loop over experts by performing gemms as a grouped GEMM, computing the expert gemms within a single fused triton kernel
-- Fuses the permutation of hidden states from token order (original input order) to expert order (tokens grouped by expert) within the prologue of first the first grouped GEMM
-- Fuses the (un)permutation of hidden states from expert order back to token order in second GEMM
-- Fuses the mul of hidden states by expert weights within epilogue of second GEMM (only implemented for inference, not for training)
-
-### Structure
-- `grouped_gemm/interface.py`: wrappers for the individual forward / backward kernels as well as the `torch.autograd.Function`
-- `grouped_gemm/kernels/forward.py`: forward kernel
-- `grouped_gemm/kernels/backward.py`: backward dX and dW kernels
-- `grouped_gemm/kernels/tuning.py`: manual tuning utils
-- `grouped_gemm/kernels/autotuning.py`: autotuning utils
-- `grouped_gemm/reference/moe_block.py`: contains `Qwen3MoeFusedGroupedGEMMBlock`, a reference implementation of Huggingface `Qwen3SparseMOEBlock` with fused triton kernel in-place of original HF expert computation
-- `grouped_gemm/reference/moe_ops.py`: supporting ops (routing, token sorting, etc.) and reference MoE block using a torch-native grouped gemm approach.
-
-### Tests
-- `grouped_gemm/tests/test_grouped_gemm.py`: unit tests for forward, backward grouped gemm kernels as well as the wrapped grouped gemm autograd.Function. Best not to run this entire test suite at once due to the large number of parametrized unit tests. Rather, use filters to run specific
-sets of tests. E.g., to run forward tests with autotune turned on: `pytest -sv -k "forward and autotune" --tb=short tests/test_grouped_gemm.py`. Use the test function names and parameter ids for words to filter on.
-- `grouped_gemm/tests/test_qwen3_moe.py`: end to end test for Qwen3 MoE block. IMPORTANT: read `tests/run_qwen3_moe_tests.sh` as well as notes in the test itself for complications when running parametrized pytest test suites and triton / autotune. TLDR: use the test script and NOT pytest to run the tests.
-
-### Benchmarks
-- `grouped_gemm/benchmark/benchmark_fused_moe.py`: benchmarks HF `Qwen3SpareMOEBlock` or `Llama4TextMoe` against the fused implementation
-
-
-Running with these flags on an `H100` to bench forward pass (run with `--help` to see all available flags):
-
-For `Qwen3-30B-A3B`:
-```
-python benchmark/benchmark_fused_moe.py --model qwen3 --mode forward --seqlen 1024 --permute_x --permute_y --autotune
-```
-
-For the backward bench:
-```
-python benchmark/benchmark_fused_moe.py --model qwen3 --mode backward --seqlen 1024 --permute_x --permute_y --autotune
-```
-
-For `Llama-4-Scout-17B-16E`:
-```
-python benchmark/benchmark_fused_moe.py --model llama4 --autotune --mode=forward --permute_y
-```
-Ditto for backwards.
-
-### Notes
-- Tested and benched on `H100`, though should run on Ampere and possibly even earlier gpu generations though the autotuning configs will need to be adjusted.
-- The env I used to develop the kernel was `pytorch 2.7/2.8` and `pytorch-triton 3.3`.
-- The kernels can be run either as autotuned (see `autotuning.py`) or with manually specified config (see `tuning.py`). Recommended to run using autotuner since the MoE block requires 2 configs for the forward (2 grouped gemms) and 4 for the backwards (dX and dW per grouped gemm, 2 grouped gemms).
-- Running with autotuning turned off with the default manual kernel config will result is **highly** sub-optimal performance as it is only meant for testing / debugging purposes.
-- I've tried to strike a balance between compilation time and autotuning search space -- can probably squeeze even more performance for specific workloads.
-- The Llama4 reference layer is still highly under-optimized as there are many low-hanging opportunities for further speedups around routing and shared expert calculation.
-
-TODO:
-- TMA store: implemented but not enabled currently due to non-determinism arising from triton pipelining bug.
-- Warp specialization: Hopper support for WS not yet enabled on triton 3.3x branch which ships with latest pytorch 2.7.
-- Additional optimizations:
- - Fused / optimized implementations of routing, token sorting, etc.
- - Better software pipelining within grouped gemm
- - Threadblock swizzling for better L2 caching
- - Llama4
- - Fused gather / topk weight merging
- - Custom topk, gather indices kernel
- - Shared expert fusion with experts calculation
\ No newline at end of file
diff --git a/unsloth/kernels/moe/__init__.py b/unsloth/kernels/moe/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/unsloth/kernels/moe/autotune_cache.py b/unsloth/kernels/moe/autotune_cache.py
deleted file mode 100644
index f23d9688ea..0000000000
--- a/unsloth/kernels/moe/autotune_cache.py
+++ /dev/null
@@ -1,500 +0,0 @@
-# Unsloth
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published
-# by the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-
-"""
-Auto-tuning cache system for MoE kernels to ensure tuning runs only once at training start.
-"""
-
-import hashlib
-import json
-import logging
-import os
-import time
-from typing import Dict, List, Optional, Tuple, Any
-import torch
-import triton
-
-logger = logging.getLogger(__name__)
-
-# Global cache for kernel configurations
-_kernel_config_cache: Dict[str, Any] = {}
-_autotune_completed: Dict[str, bool] = {}
-
-
-def _get_cache_key(
- num_experts: int,
- hidden_dim: int,
- intermediate_dim: int,
- top_k: int,
- dtype: torch.dtype,
- device_capability: Tuple[int, int],
- seq_len: int = 8192, # Default sequence length for tuning
-) -> str:
- """Generate a unique cache key based on model configuration."""
- key_data = {
- "num_experts": num_experts,
- "hidden_dim": hidden_dim,
- "intermediate_dim": intermediate_dim,
- "top_k": top_k,
- "dtype": str(dtype),
- "device_capability": device_capability,
- "seq_len": seq_len,
- }
- key_str = json.dumps(key_data, sort_keys = True)
- return hashlib.md5(key_str.encode()).hexdigest()
-
-
-def _get_cache_file_path(cache_key: str) -> str:
- """Get the file path for the cache file."""
- cache_dir = os.path.expanduser("~/.cache/unsloth/moe_autotune")
- os.makedirs(cache_dir, exist_ok = True)
- return os.path.join(cache_dir, f"{cache_key}.json")
-
-
-def load_cached_config(cache_key: str) -> Optional[Dict[str, Any]]:
- """Load cached kernel configuration from disk."""
- cache_file = _get_cache_file_path(cache_key)
- if not os.path.exists(cache_file):
- return None
-
- try:
- with open(cache_file, "r") as f:
- cached_data = json.load(f)
-
- # Verify cache is still valid (same device, etc.)
- current_device_capability = torch.cuda.get_device_capability()
- if cached_data.get("device_capability") != current_device_capability:
- logger.info("Device capability changed, invalidating cache")
- os.remove(cache_file)
- return None
-
- logger.info(f"Loaded cached MoE kernel config: {cache_key}")
- return cached_data
- except Exception as e:
- logger.warning(f"Failed to load cache file {cache_file}: {e}")
- try:
- os.remove(cache_file)
- except:
- pass
- return None
-
-
-def save_cached_config(
- cache_key: str,
- config_fwd: Any,
- config_bwd_dx: Any,
- config_bwd_dw: Any,
- metadata: Dict[str, Any] = None,
-) -> None:
- """Save kernel configuration to disk cache."""
- cache_file = _get_cache_file_path(cache_key)
-
- cache_data = {
- "timestamp": time.time(),
- "device_capability": torch.cuda.get_device_capability(),
- "config_fwd": config_fwd.__dict__
- if hasattr(config_fwd, "__dict__")
- else str(config_fwd),
- "config_bwd_dx": config_bwd_dx.__dict__
- if hasattr(config_bwd_dx, "__dict__")
- else str(config_bwd_dx),
- "config_bwd_dw": config_bwd_dw.__dict__
- if hasattr(config_bwd_dw, "__dict__")
- else str(config_bwd_dw),
- "metadata": metadata or {},
- }
-
- try:
- with open(cache_file, "w") as f:
- json.dump(cache_data, f, indent = 2)
- logger.info(f"Saved MoE kernel config cache: {cache_key}")
- except Exception as e:
- logger.warning(f"Failed to save cache file {cache_file}: {e}")
-
-
-def get_or_autotune_moe_kernels(
- num_experts: int,
- hidden_dim: int,
- intermediate_dim: int,
- top_k: int,
- dtype: torch.dtype,
- force_autotune: bool = False,
- seq_len: int = 8192,
-) -> Tuple[Any, Any, Any]:
- """
- Get cached kernel configurations or run auto-tuning.
-
- Args:
- num_experts: Number of experts in the MoE layer
- hidden_dim: Hidden dimension of the model
- intermediate_dim: Intermediate dimension for MoE MLP
- top_k: Number of experts to route to
- dtype: Data type for computation
- force_autotune: Force re-running autotuning even if cache exists
- seq_len: Sequence length to use for tuning benchmarks
-
- Returns:
- Tuple of (config_fwd, config_bwd_dx, config_bwd_dw)
- """
- device_capability = torch.cuda.get_device_capability()
- cache_key = _get_cache_key(
- num_experts,
- hidden_dim,
- intermediate_dim,
- top_k,
- dtype,
- device_capability,
- seq_len,
- )
-
- # 0. Check for environment variable override to DISABLE autotuning
- if os.environ.get("UNSLOTH_MOE_DISABLE_AUTOTUNE", "0") == "1":
- logger.info(
- f"UNSLOTH_MOE_DISABLE_AUTOTUNE=1: Using Heuristic (Safe) MoE kernel configs for SM{device_capability[0]}{device_capability[1]}"
- )
- return _get_heuristic_configs()
- if not force_autotune and cache_key in _kernel_config_cache:
- logger.info(f"Using in-memory cached MoE kernel configs: {cache_key}")
- return _kernel_config_cache[cache_key]
-
- # Try to load from disk
- if not force_autotune:
- cached_data = load_cached_config(cache_key)
- if cached_data is not None:
- # Reconstruct config objects from cached data
- try:
- from .grouped_gemm.kernels.tuning import (
- KernelConfigForward,
- KernelConfigBackward_dX,
- KernelConfigBackward_dW,
- )
-
- config_fwd = KernelConfigForward(**cached_data["config_fwd"])
- config_bwd_dx = KernelConfigBackward_dX(**cached_data["config_bwd_dx"])
- config_bwd_dw = KernelConfigBackward_dW(**cached_data["config_bwd_dw"])
-
- configs = (config_fwd, config_bwd_dx, config_bwd_dw)
- _kernel_config_cache[cache_key] = configs
- return configs
- except Exception as e:
- logger.warning(f"Failed to reconstruct cached configs: {e}")
-
- # Run autotuning
- if cache_key in _autotune_completed and not force_autotune:
- logger.info(f"Autotuning already completed for: {cache_key}")
- return _kernel_config_cache[cache_key]
-
- logger.info(f"Running MoE kernel auto-tuning for: {cache_key}")
- logger.info(
- f"Configuration: {num_experts} experts, {hidden_dim} hidden, {intermediate_dim} intermediate, top_k={top_k}"
- )
-
- try:
- configs = _run_moe_autotuning(
- num_experts, hidden_dim, intermediate_dim, top_k, dtype, seq_len
- )
-
- # Cache the results
- _kernel_config_cache[cache_key] = configs
- _autotune_completed[cache_key] = True
-
- # Save to disk
- config_fwd, config_bwd_dx, config_bwd_dw = configs
- save_cached_config(
- cache_key,
- config_fwd,
- config_bwd_dx,
- config_bwd_dw,
- {
- "num_experts": num_experts,
- "hidden_dim": hidden_dim,
- "intermediate_dim": intermediate_dim,
- },
- )
-
- logger.info(f"MoE kernel auto-tuning completed: {cache_key}")
- return configs
-
- except Exception as e:
- logger.error(f"MoE kernel auto-tuning failed: {e}")
- if "AttributeError" in str(e) and "_experimental_make_tensor_descriptor" in str(
- e
- ):
- logger.warning(
- "Unsloth: Your Triton version might be incompatible with TMA features. Falling back to default configs."
- )
- logger.info("Falling back to default kernel configurations")
- return _get_default_configs()
-
-
-def _run_moe_autotuning(
- num_experts: int,
- hidden_dim: int,
- intermediate_dim: int,
- top_k: int,
- dtype: torch.dtype,
- seq_len: int,
-) -> Tuple[Any, Any, Any]:
- """Run the actual auto-tuning for MoE kernels."""
-
- # Create dummy inputs for tuning
- device = "cuda"
- # Use a fixed, safe number of tokens for autotuning to avoid OOMs and dependency on seq_len
- # 4096 is standard for finding good kernels without consuming 10GB+ VRAM
- # We ignore the passed seq_len for the actual allocation to satisfy user request
- num_tokens = 4096
- total_tokens = num_tokens * top_k
-
- # Create dummy tensors
- hidden_states = torch.randn(num_tokens, hidden_dim, device = device, dtype = dtype)
-
- # Create dummy weights
- gate_up_weights = torch.randn(
- num_experts, 2 * intermediate_dim, hidden_dim, device = device, dtype = dtype
- )
- down_weights = torch.randn(
- num_experts, hidden_dim, intermediate_dim, device = device, dtype = dtype
- )
-
- # Create dummy routing data
- m_sizes = torch.randint(
- 1, total_tokens // num_experts + 1, (num_experts,), device = device
- )
- m_sizes = m_sizes * (total_tokens // m_sizes.sum().item())
- # Adjust to ensure exact total
- diff = total_tokens - m_sizes.sum().item()
- if diff != 0:
- m_sizes[0] += diff
-
- gather_indices = torch.arange(total_tokens, device = device)
- torch.randperm(total_tokens, out = gather_indices)
-
- # Autotune forward kernel - use the interface function with autotune=True
- # This properly invokes the kernel and lets triton handle the autotuning
- from .grouped_gemm.interface import (
- grouped_gemm_forward,
- grouped_gemm_dX,
- grouped_gemm_dW,
- )
- from .grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel
- from .grouped_gemm.kernels.backward import (
- _autotuned_grouped_gemm_dX_kernel,
- _autotuned_grouped_gemm_dW_kernel,
- )
- from .grouped_gemm.kernels.tuning import (
- KernelConfigForward,
- KernelConfigBackward_dX,
- KernelConfigBackward_dW,
- )
-
- logger.info("Autotuning forward kernel (first GEMM)...")
- # Run with autotune=True to trigger autotuning
- _ = grouped_gemm_forward(
- X = hidden_states,
- W = gate_up_weights,
- topk = top_k,
- m_sizes = m_sizes,
- gather_indices = gather_indices,
- permute_x = True,
- permute_y = False,
- autotune = True,
- )
- triton_config_fwd = _autotuned_grouped_gemm_forward_kernel.best_config
-
- # Convert triton.Config to KernelConfigForward
- config_fwd = KernelConfigForward(
- BLOCK_SIZE_M = triton_config_fwd.kwargs["BLOCK_SIZE_M"],
- BLOCK_SIZE_N = triton_config_fwd.kwargs["BLOCK_SIZE_N"],
- BLOCK_SIZE_K = triton_config_fwd.kwargs["BLOCK_SIZE_K"],
- num_warps = triton_config_fwd.num_warps,
- num_stages = triton_config_fwd.num_stages,
- use_tma_load_x = triton_config_fwd.kwargs.get("USE_TMA_LOAD_X", False),
- use_tma_load_w = triton_config_fwd.kwargs.get("USE_TMA_LOAD_W", False),
- use_tma_store = triton_config_fwd.kwargs.get("USE_TMA_STORE", False),
- )
-
- # Autotune backward dX kernel
- logger.info("Autotuning backward dX kernel...")
- dummy_grad = torch.randn(
- total_tokens, 2 * intermediate_dim, device = device, dtype = dtype
- )
- _ = grouped_gemm_dX(
- dY = dummy_grad,
- W = gate_up_weights,
- gather_indices = gather_indices,
- m_sizes = m_sizes,
- topk = top_k,
- permute_x = True,
- permute_y = False,
- autotune = True,
- )
- triton_config_bwd_dx = _autotuned_grouped_gemm_dX_kernel.best_config
-
- # Convert triton.Config to KernelConfigBackward_dX
- config_bwd_dx = KernelConfigBackward_dX(
- BLOCK_SIZE_M = triton_config_bwd_dx.kwargs["BLOCK_SIZE_M"],
- BLOCK_SIZE_N = triton_config_bwd_dx.kwargs["BLOCK_SIZE_N"],
- BLOCK_SIZE_K = triton_config_bwd_dx.kwargs["BLOCK_SIZE_K"],
- num_warps = triton_config_bwd_dx.num_warps,
- num_stages = triton_config_bwd_dx.num_stages,
- use_tma_load_dy = triton_config_bwd_dx.kwargs.get("USE_TMA_LOAD_dY", False),
- use_tma_load_w = triton_config_bwd_dx.kwargs.get("USE_TMA_LOAD_W", False),
- use_tma_store = triton_config_bwd_dx.kwargs.get("USE_TMA_STORE", False),
- )
-
- # Autotune backward dW kernel
- logger.info("Autotuning backward dW kernel...")
- _ = grouped_gemm_dW(
- X = hidden_states,
- dY = dummy_grad,
- m_sizes = m_sizes,
- gather_indices = gather_indices,
- topk = top_k,
- permute_x = True,
- permute_y = False,
- autotune = True,
- )
- triton_config_bwd_dw = _autotuned_grouped_gemm_dW_kernel.best_config
-
- # Convert triton.Config to KernelConfigBackward_dW
- config_bwd_dw = KernelConfigBackward_dW(
- BLOCK_SIZE_M = triton_config_bwd_dw.kwargs["BLOCK_SIZE_M"],
- BLOCK_SIZE_N = triton_config_bwd_dw.kwargs["BLOCK_SIZE_N"],
- BLOCK_SIZE_K = triton_config_bwd_dw.kwargs["BLOCK_SIZE_K"],
- num_warps = triton_config_bwd_dw.num_warps,
- num_stages = triton_config_bwd_dw.num_stages,
- use_tma_load_dy = triton_config_bwd_dw.kwargs.get("USE_TMA_LOAD_dY", False),
- use_tma_load_x = triton_config_bwd_dw.kwargs.get("USE_TMA_LOAD_X", False),
- use_tma_store = triton_config_bwd_dw.kwargs.get("USE_TMA_STORE", False),
- )
-
- return config_fwd, config_bwd_dx, config_bwd_dw
-
- return config_fwd, config_bwd_dx, config_bwd_dw
-
-
-def _get_heuristic_configs() -> Tuple[Any, Any, Any]:
- """
- Get 'Safe Heuristic' kernel configurations.
- These are verified to be safe on A100 (SM80) and provide ~9x speedup on H100/B200.
- """
- from .grouped_gemm.kernels.tuning import (
- KernelConfigForward,
- KernelConfigBackward_dX,
- KernelConfigBackward_dW,
- )
-
- # Safe Forward Config: 64x128x128 (Fits A100 SMEM)
- config_fwd = KernelConfigForward(
- BLOCK_SIZE_M = 64,
- BLOCK_SIZE_N = 128,
- BLOCK_SIZE_K = 128,
- num_warps = 8,
- num_stages = 3,
- permute_x = True,
- permute_y = True,
- use_tma_load_x = False,
- use_tma_load_w = False, # TMA loads might need alignment checks, safer to disable for heuristic
- use_tma_store = False,
- )
-
- # Safe Backward Configs: 64x64x256
- config_bwd_dx = KernelConfigBackward_dX(
- BLOCK_SIZE_M = 64,
- BLOCK_SIZE_N = 64,
- BLOCK_SIZE_K = 256,
- num_warps = 8,
- num_stages = 4,
- permute_x = True,
- permute_y = True,
- use_tma_load_dy = False,
- use_tma_load_w = False,
- use_tma_store = False,
- )
-
- config_bwd_dw = KernelConfigBackward_dW(
- BLOCK_SIZE_M = 64,
- BLOCK_SIZE_N = 64,
- BLOCK_SIZE_K = 256,
- num_warps = 8,
- num_stages = 4,
- permute_x = True,
- permute_y = True,
- use_tma_load_dy = False,
- use_tma_load_x = False,
- use_tma_store = False,
- )
-
- return config_fwd, config_bwd_dx, config_bwd_dw
-
-
-def _get_default_configs() -> Tuple[Any, Any, Any]:
- """Get default kernel configurations as fallback."""
- from .grouped_gemm.kernels.tuning import (
- KernelConfigForward,
- KernelConfigBackward_dX,
- KernelConfigBackward_dW,
- )
-
- logger.warning("Using default MoE kernel configurations (not optimal)")
-
- config_fwd = KernelConfigForward(
- BLOCK_SIZE_M = 128,
- BLOCK_SIZE_N = 128,
- BLOCK_SIZE_K = 64,
- num_warps = 8,
- num_stages = 3,
- use_tma_load_x = False,
- use_tma_load_w = False,
- use_tma_store = False,
- )
-
- config_bwd_dx = KernelConfigBackward_dX(
- BLOCK_SIZE_M = 128,
- BLOCK_SIZE_N = 128,
- BLOCK_SIZE_K = 64,
- num_warps = 8,
- num_stages = 3,
- use_tma_load_dy = False,
- use_tma_load_w = False,
- use_tma_store = False,
- )
-
- config_bwd_dw = KernelConfigBackward_dW(
- BLOCK_SIZE_M = 128,
- BLOCK_SIZE_N = 128,
- BLOCK_SIZE_K = 64,
- num_warps = 8,
- num_stages = 3,
- use_tma_load_dy = False,
- use_tma_load_x = False,
- use_tma_store = False,
- )
-
- return config_fwd, config_bwd_dx, config_bwd_dw
-
-
-def clear_cache() -> None:
- """Clear all cached kernel configurations."""
- global _kernel_config_cache, _autotune_completed
- _kernel_config_cache.clear()
- _autotune_completed.clear()
- logger.info("Cleared MoE kernel cache")
-
-
-def is_autotuning_completed(cache_key: str) -> bool:
- """Check if autotuning has been completed for a given cache key."""
- return cache_key in _autotune_completed
diff --git a/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py b/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py
deleted file mode 100644
index 074cc5a566..0000000000
--- a/unsloth/kernels/moe/benchmark/benchmark_fused_moe.py
+++ /dev/null
@@ -1,399 +0,0 @@
-import argparse
-import time
-from contextlib import nullcontext
-
-import torch
-from transformers import AutoConfig
-from transformers.models.llama4 import Llama4TextConfig
-from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
-from transformers.models.qwen3_moe import Qwen3MoeConfig
-from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
-from triton.testing import do_bench
-from utils import (
- create_kernel_configs,
- get_autotuner,
- post_process_results,
- postprocess_autotune_results,
- save_results,
-)
-
-from grouped_gemm.kernels.autotuning import (
- DEFAULT_K_BLOCK_SIZES,
- DEFAULT_M_BLOCK_SIZES,
- DEFAULT_N_BLOCK_SIZES,
- DEFAULT_NUM_STAGES,
- DEFAULT_NUM_WARPS,
-)
-from grouped_gemm.kernels.tuning import (
- KernelConfigBackward_dW,
- KernelConfigBackward_dX,
- KernelConfigForward,
- KernelResult,
- TritonTuningContext,
-)
-from grouped_gemm.reference.layers.llama4_moe import Llama4TritonTextMoe
-from grouped_gemm.reference.layers.qwen3_moe import Qwen3MoeFusedGroupedGEMMBlock
-
-SEED = 42
-LLAMA4_ID = "meta-llama/Llama-4-Scout-17B-16E"
-QWEN3_MODEL_ID = "Qwen/Qwen3-30B-A3B"
-
-
-def run_benchmark_forward(
- ref_model: torch.nn.Module,
- tt_model: torch.nn.Module,
- config: AutoConfig,
- seqlen: int,
- dtype: torch.dtype,
- autotune: bool,
- kernel_config_fwd: KernelConfigForward = None,
- bs: int = 1,
-):
- torch.manual_seed(
- SEED
- ) # Should not be needed when running using pytest -- autouse fixture in conftest.py
- device = "cuda"
- hidden_size = config.hidden_size
-
- X = torch.randn(
- bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
- )
-
- # Forward
- bench_forward_ref = lambda: ref_model(X) # noqa: E731
- bench_forward_fused = lambda: tt_model(X) # noqa: E731
-
- ref_forward_time = do_bench(bench_forward_ref)
-
- if not autotune:
- assert kernel_config_fwd is not None
- tuning_context = TritonTuningContext(kernel_config_fwd)
- else:
- tuning_context = nullcontext()
-
- with tuning_context:
- fused_forward_time = do_bench(bench_forward_fused)
-
- if (not autotune) and (not tuning_context.success):
- return 0, 1
-
- print(
- f"Forward: ref {ref_forward_time:.4f}, fused {fused_forward_time:.4f}, speedup {ref_forward_time / fused_forward_time:.1f}x"
- )
- return ref_forward_time, fused_forward_time
-
-
-def run_benchmark_backward(
- ref_model: torch.nn.Module,
- tt_model: torch.nn.Module,
- config: AutoConfig,
- seqlen: int,
- dtype: torch.dtype,
- bs = 1,
-):
- torch.manual_seed(
- SEED
- ) # Should not be needed when running using pytest -- autouse fixture in conftest.py
- device = "cuda"
- hidden_size = config.hidden_size
-
- X = torch.randn(
- bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
- )
- X_test = X.detach().clone().requires_grad_(True)
-
- output, _ = ref_model(X)
-
- # Prevent autotuning forward pass
- from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel
-
- _autotuned_grouped_gemm_forward_kernel.configs = (
- _autotuned_grouped_gemm_forward_kernel.configs[:20]
- )
- test_output, _ = tt_model(X_test)
-
- # Bench
- grad_output = torch.randn_like(output)
- bench_backward_ref = lambda: output.backward(grad_output, retain_graph = True) # noqa: E731
- bench_backward_fused = lambda: test_output.backward(grad_output, retain_graph = True) # noqa: E731
-
- ref_backward_time = do_bench(
- bench_backward_ref, grad_to_none = [X, *ref_model.parameters()]
- )
- fused_backward_time = do_bench(
- bench_backward_fused, grad_to_none = [X_test, *tt_model.parameters()]
- )
- print(
- f"Backward: ref {ref_backward_time:.4f}, fused {fused_backward_time:.4f}, speedup {ref_backward_time / fused_backward_time:.1f}x"
- )
- return ref_backward_time, fused_backward_time
-
-
-def setup_model(
- config: Qwen3MoeConfig | Llama4TextConfig,
- dtype,
- permute_x,
- permute_y,
- autotune,
- kernel_config_fwd,
- kernel_config_bwd_dW,
- kernel_config_bwd_dX,
- dX_only = False,
- dW_only = False,
- overlap_router_shared = False,
- device = "cuda",
-):
- if isinstance(config, Qwen3MoeConfig):
- ref_model = Qwen3MoeSparseMoeBlock(config).to(device, dtype)
-
- # Triton kernel grouped gemm version of MoE Block -- this is what we're testing
- tt_model = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
- ref_model,
- permute_x = permute_x,
- permute_y = permute_y,
- autotune = autotune,
- kernel_config_fwd = kernel_config_fwd,
- kernel_config_bwd_dW = kernel_config_bwd_dW,
- kernel_config_bwd_dX = kernel_config_bwd_dX,
- dX_only = dX_only,
- dW_only = dW_only,
- ).to(device, dtype)
-
- elif isinstance(config, Llama4TextConfig):
- ref_model = Llama4TextMoe(config).to(device, dtype)
- tt_model = Llama4TritonTextMoe(
- config,
- overlap_router_shared = overlap_router_shared,
- permute_x = permute_x,
- permute_y = permute_y,
- autotune = autotune,
- kernel_config_fwd = kernel_config_fwd,
- kernel_config_bwd_dW = kernel_config_bwd_dW,
- kernel_config_bwd_dX = kernel_config_bwd_dX,
- dX_only = dX_only,
- dW_only = dW_only,
- ).to(device, dtype)
-
- else:
- raise ValueError(f"Unrecognized config {type(config).__name__}")
-
- return ref_model, tt_model
-
-
-def run_benchmark(
- mode: str,
- model_config: Qwen3MoeConfig | Llama4TextConfig,
- seqlen: int,
- dtype: torch.dtype,
- permute_x: bool,
- permute_y: bool,
- autotune: bool,
- kernel_config_fwd: KernelConfigForward = None,
- kernel_config_bwd_dW: KernelConfigBackward_dW = None,
- kernel_config_bwd_dX: KernelConfigBackward_dX = None,
- overlap_router_shared: bool = False,
- results_dir: str = None,
-):
- if autotune:
- autotuner = get_autotuner(mode)
- if mode == "dW":
- dW_only = True
- elif mode == "dX":
- dX_only = True
- else:
- dW_only = dX_only = False
-
- ref_model, tt_model = setup_model(
- model_config,
- dtype = dtype,
- permute_x = permute_x,
- permute_y = permute_y,
- autotune = autotune,
- kernel_config_fwd = kernel_config_fwd,
- kernel_config_bwd_dW = kernel_config_bwd_dW,
- kernel_config_bwd_dX = kernel_config_bwd_dX,
- dX_only = dX_only,
- dW_only = dW_only,
- overlap_router_shared = overlap_router_shared,
- )
-
- if mode == "forward":
- ref_time, fused_time = run_benchmark_forward(
- ref_model,
- tt_model,
- config = model_config,
- seqlen = seqlen,
- dtype = dtype,
- autotune = autotune,
- kernel_config_fwd = kernel_config_fwd,
- )
- else:
- ref_time, fused_time = run_benchmark_backward(
- ref_model, tt_model, config = model_config, seqlen = seqlen, dtype = dtype
- )
-
- if autotune:
- if mode == "backward":
- autotuner_dW, autotuner_dX = autotuner
- postprocess_autotune_results(
- autotuner_dW, "dW", ref_time, fused_time, results_dir
- )
- postprocess_autotune_results(
- autotuner_dX, "dX", ref_time, fused_time, results_dir
- )
- else:
- postprocess_autotune_results(
- autotuner, mode, ref_time, fused_time, results_dir
- )
-
- return ref_time, fused_time
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--results_dir", type = str, default = "benchmark_results")
- parser.add_argument("--model", type = str, choices = ["llama4", "qwen3"], required = True)
- parser.add_argument("--seqlen", type = int, default = 1024)
- parser.add_argument(
- "--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
- )
- parser.add_argument("--permute_x", action = "store_true")
- parser.add_argument("--permute_y", action = "store_true")
- parser.add_argument("--autotune", action = "store_true")
- parser.add_argument("--overlap_router_shared", action = "store_true")
- parser.add_argument(
- "--BLOCK_SIZE_M",
- nargs = 2,
- type = int,
- default = [DEFAULT_M_BLOCK_SIZES[0], DEFAULT_M_BLOCK_SIZES[-1]],
- )
- parser.add_argument(
- "--BLOCK_SIZE_N",
- nargs = 2,
- type = int,
- default = [DEFAULT_N_BLOCK_SIZES[0], DEFAULT_N_BLOCK_SIZES[-1]],
- )
- parser.add_argument(
- "--BLOCK_SIZE_K",
- nargs = 2,
- type = int,
- default = [DEFAULT_K_BLOCK_SIZES[0], DEFAULT_K_BLOCK_SIZES[-1]],
- )
- parser.add_argument(
- "--num_warps",
- nargs = 2,
- type = int,
- default = [DEFAULT_NUM_WARPS[0], DEFAULT_NUM_WARPS[-1]],
- )
- parser.add_argument(
- "--num_stages",
- nargs = 2,
- type = int,
- default = [DEFAULT_NUM_STAGES[0], DEFAULT_NUM_STAGES[-1]],
- )
- parser.add_argument(
- "--use_tma_load_w", action = "store_true"
- ) # No need to specify, will automatically parametrize these for each kernel config
- parser.add_argument(
- "--use_tma_load_x", action = "store_true"
- ) # No need to specify, will automatically parametrize these for each kernel config
- parser.add_argument(
- "--use_tma_load_dy", action = "store_true"
- ) # No need to specify, will automatically parametrize these for each kernel config
- parser.add_argument(
- "--mode",
- type = str,
- choices = ["forward", "backward", "dW", "dX"],
- default = "forward",
- )
- args = parser.parse_args()
- args.dtype = getattr(torch, args.dtype)
-
- model_id = QWEN3_MODEL_ID if args.model == "qwen3" else LLAMA4_ID
- model_config = AutoConfig.from_pretrained(model_id)
- model_config = model_config.text_config if args.model == "llama4" else model_config
-
- mode = args.mode
-
- if args.autotune:
- # logging.basicConfig(level=logging.INFO)
- print(
- f"Benchmarking {model_id} {mode}: seqlen={args.seqlen}, dtype={args.dtype}, permute_x={args.permute_x}, permute_y={args.permute_y}, autotune"
- )
- start_time = time.time()
- ref_time, fused_time = run_benchmark(
- args.mode,
- model_config,
- seqlen = args.seqlen,
- dtype = args.dtype,
- permute_x = args.permute_x,
- permute_y = args.permute_y,
- autotune = args.autotune,
- overlap_router_shared = args.overlap_router_shared,
- results_dir = args.results_dir,
- )
- end_time = time.time()
- print(f"Total time: {end_time - start_time:.4f} seconds")
-
- # NOTE: better to use autotuner for now, since the MoE block needs 2 different kernel configs for forward (2 grouped gemms, gate_up_proj and down_proj)
- # and the backward pass needs 4 different kernel configs (2 grouped gemms each for dW and dX)
- # The benchmark only supports 1 kernel config at a time so the same config will be used for both grouped gemms, which is suboptimal.
- else:
- assert False, "Use autotune for now"
- kernel_configs = create_kernel_configs(args, args.permute_x, args.permute_y)
- print(f"Running {len(kernel_configs)} kernel configs")
- default_kernel_config_fwd = KernelConfigForward(
- permute_x = args.permute_x, permute_y = args.permute_y
- )
- default_kernel_config_bwd_dW = KernelConfigBackward_dW(
- permute_x = args.permute_x, permute_y = args.permute_y
- )
- default_kernel_config_bwd_dX = KernelConfigBackward_dX(
- permute_x = args.permute_x, permute_y = args.permute_y
- )
- results = []
- for kernel_config in kernel_configs:
- if args.mode == "forward":
- kernel_config_fwd = kernel_config
- kernel_config_bwd_dW = default_kernel_config_bwd_dW
- kernel_config_bwd_dX = default_kernel_config_bwd_dX
- elif args.mode == "dW":
- kernel_config_fwd = default_kernel_config_fwd
- kernel_config_bwd_dW = kernel_config
- kernel_config_bwd_dX = default_kernel_config_bwd_dX
- elif args.mode == "dX":
- kernel_config_fwd = default_kernel_config_fwd
- kernel_config_bwd_dW = default_kernel_config_bwd_dW
- kernel_config_bwd_dX = kernel_config
- else:
- raise ValueError(f"Invalid mode: {args.mode}")
- print(
- f"Benchmarking {model_id} {args.mode} with seqlen={args.seqlen}, dtype={args.dtype}, permute_x={args.permute_x}, permute_y={args.permute_y}, kernel_config_fwd={kernel_config_fwd}, kernel_config_bwd_dW={kernel_config_bwd_dW}, kernel_config_bwd_dX={kernel_config_bwd_dX}"
- )
-
- ref_time, fused_time = run_benchmark(
- args.mode,
- model_config,
- seqlen = args.seqlen,
- dtype = args.dtype,
- permute_x = kernel_config.permute_x,
- permute_y = kernel_config.permute_y,
- autotune = False,
- kernel_config_fwd = kernel_config_fwd,
- kernel_config_bwd_dW = kernel_config_bwd_dW,
- kernel_config_bwd_dX = kernel_config_bwd_dX,
- )
- results.append(
- KernelResult(
- torch_time = ref_time,
- triton_time = fused_time,
- speedup = ref_time / fused_time,
- kernel_config = kernel_config,
- )
- )
- df = post_process_results(
- results, args.mode, args.seqlen, args.dtype, args.autotune
- )
- save_results(
- df, args.results_dir, args.mode, args.seqlen, args.dtype, args.autotune
- )
diff --git a/unsloth/kernels/moe/benchmark/utils.py b/unsloth/kernels/moe/benchmark/utils.py
deleted file mode 100644
index 21905d8df1..0000000000
--- a/unsloth/kernels/moe/benchmark/utils.py
+++ /dev/null
@@ -1,228 +0,0 @@
-import argparse
-import datetime
-import json
-import logging
-import math
-import os
-from itertools import product
-
-import pandas as pd
-import torch
-
-from grouped_gemm.kernels.tuning import (
- KernelConfigBackward_dW,
- KernelConfigBackward_dX,
- KernelConfigForward,
- KernelResult,
-)
-
-SEED = 42
-
-
-def create_merged_results(
- df: pd.DataFrame, mode: str, seqlen: int, dtype: torch.dtype, autotune: bool
-):
- kernel_result_cols = df.columns.to_list()
- test_config_dict = {
- "mode": mode,
- "seqlen": seqlen,
- "dtype": dtype,
- "autotune": autotune,
- }
- test_config_cols = list(test_config_dict.keys())
- for col in test_config_cols:
- df[col] = test_config_dict[col]
- # Reorder columns so that test config cols are first
- df = df[test_config_cols + kernel_result_cols]
- return df
-
-
-def post_process_results(
- results: list[KernelResult],
- mode: str,
- seqlen: int,
- dtype: torch.dtype,
- autotune: bool,
-):
- df = KernelResult.to_dataframe(results, sort_by = "speedup")
- df = create_merged_results(df, mode, seqlen, dtype, autotune)
- return df
-
-
-def save_results(
- df: pd.DataFrame,
- results_dir: str,
- mode: str,
- seqlen: int,
- dtype: torch.dtype,
- autotune: bool,
-):
- dt = datetime.datetime.now().strftime("%Y%m%d_%H%M")
- save_dir = f"{results_dir}/{mode}"
- save_path = f"{save_dir}/{dt}_{seqlen}_{str(dtype).split('.')[-1]}.csv"
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
- print(f"Saving results to {save_path}")
- df.to_csv(save_path, index = False)
-
-
-def create_kernel_configs(args: argparse.Namespace, permute_x: bool, permute_y: bool):
- block_m_range = power_of_two_range(args.BLOCK_SIZE_M[0], args.BLOCK_SIZE_M[1])
- block_n_range = power_of_two_range(args.BLOCK_SIZE_N[0], args.BLOCK_SIZE_N[1])
- block_k_range = power_of_two_range(args.BLOCK_SIZE_K[0], args.BLOCK_SIZE_K[1])
- num_warps_range = multiples_of_range(args.num_warps[0], args.num_warps[1], step = 2)
- num_stages_range = multiples_of_range(
- args.num_stages[0], args.num_stages[1], step = 1
- )
-
- mode = args.mode
- kernel_configs = []
- for (
- block_m,
- block_n,
- block_k,
- num_warps,
- num_stages,
- tma_load_a,
- tma_load_b,
- ) in product(
- block_m_range,
- block_n_range,
- block_k_range,
- num_warps_range,
- num_stages_range,
- [True, False],
- [True, False],
- ):
- if mode == "forward":
- kernel_config = KernelConfigForward(
- BLOCK_SIZE_M = block_m,
- BLOCK_SIZE_N = block_n,
- BLOCK_SIZE_K = block_k,
- num_warps = num_warps,
- num_stages = num_stages,
- use_tma_load_w = tma_load_a,
- use_tma_load_x = tma_load_b,
- permute_x = permute_x,
- permute_y = permute_y,
- )
- elif mode == "dW":
- kernel_config = KernelConfigBackward_dW(
- BLOCK_SIZE_M = block_m,
- BLOCK_SIZE_N = block_n,
- BLOCK_SIZE_K = block_k,
- num_warps = num_warps,
- num_stages = num_stages,
- use_tma_load_dy = tma_load_a,
- use_tma_load_x = tma_load_b,
- permute_x = permute_x,
- permute_y = permute_y,
- )
- elif mode == "dX":
- kernel_config = KernelConfigBackward_dX(
- BLOCK_SIZE_M = block_m,
- BLOCK_SIZE_N = block_n,
- BLOCK_SIZE_K = block_k,
- num_warps = num_warps,
- num_stages = num_stages,
- use_tma_load_dy = tma_load_a,
- use_tma_load_w = tma_load_b,
- permute_x = permute_x,
- permute_y = permute_y,
- )
- else:
- raise ValueError(f"Invalid mode: {mode}")
- kernel_configs.append(kernel_config)
-
- logging.info(f"Pruning {len(kernel_configs)} kernel configs")
-
- pruned_configs = []
- for config in kernel_configs:
- if mode == "forward":
- if permute_x and config.use_tma_load_x:
- continue
- elif mode == "dW":
- if permute_x and config.use_tma_load_x:
- continue
- if permute_y and config.use_tma_load_dy:
- continue
- elif mode == "dX":
- if permute_y and config.use_tma_load_dy:
- continue
- pruned_configs.append(config)
- logging.info(f"After pruning, {len(pruned_configs)} kernel configs")
-
- return pruned_configs
-
-
-def power_of_two_range(start, end):
- start = math.log2(start)
- end = math.log2(end)
- return [2**i for i in range(int(start), int(end) + 1)]
-
-
-def multiples_of_range(start, end, step = 1):
- return list(range(start, end + step, step))
-
-
-def map_key_to_args(key, mode):
- pass
-
-
-def save_autotune_results(autotune_cache, mode, ref_time, fused_time, results_dir):
- device_name = torch.cuda.get_device_name().replace(" ", "_")
- dt = datetime.datetime.now().strftime("%Y%m%d_%H%M")
- save_dir = f"{results_dir}/{mode}/autotune/{dt}/{device_name}"
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
-
- for key, config in autotune_cache.items():
- key = [
- str(k) if not "torch" in str(k) else str(k.split("torch.")[-1]) for k in key
- ]
- filename = "_".join(key)
- save_path = f"{save_dir}/{filename}.json"
- print(f"Saving autotune results to {save_path}")
- with open(save_path, "w") as f:
- result = {
- **config.all_kwargs(),
- "ref_time": ref_time,
- "fused_time": fused_time,
- }
- json.dump(result, f)
-
-
-def get_autotuner(mode):
- if mode == "forward":
- from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel
-
- return _autotuned_grouped_gemm_forward_kernel
- elif mode == "dW":
- from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dW_kernel
-
- return _autotuned_grouped_gemm_dW_kernel
- elif mode == "dX":
- from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dX_kernel
-
- return _autotuned_grouped_gemm_dX_kernel
- elif mode == "backward":
- from grouped_gemm.kernels.backward import (
- _autotuned_grouped_gemm_dW_kernel,
- _autotuned_grouped_gemm_dX_kernel,
- )
-
- return _autotuned_grouped_gemm_dW_kernel, _autotuned_grouped_gemm_dX_kernel
- else:
- raise ValueError(f"Invalid mode: {mode}")
-
-
-def postprocess_autotune_results(autotuner, mode, ref_time, fused_time, results_dir):
- for key, value in autotuner.cache.items():
- print(f"{mode} {key}: {value.all_kwargs()}")
- save_autotune_results(
- autotuner.cache,
- mode = mode,
- ref_time = ref_time,
- fused_time = fused_time,
- results_dir = results_dir,
- )
diff --git a/unsloth/kernels/moe/grouped_gemm/LICENSE b/unsloth/kernels/moe/grouped_gemm/LICENSE
deleted file mode 100644
index 29ebfa545f..0000000000
--- a/unsloth/kernels/moe/grouped_gemm/LICENSE
+++ /dev/null
@@ -1,661 +0,0 @@
- GNU AFFERO GENERAL PUBLIC LICENSE
- Version 3, 19 November 2007
-
- Copyright (C) 2007 Free Software Foundation, Inc.
- Everyone is permitted to copy and distribute verbatim copies
- of this license document, but changing it is not allowed.
-
- Preamble
-
- The GNU Affero General Public License is a free, copyleft license for
-software and other kinds of works, specifically designed to ensure
-cooperation with the community in the case of network server software.
-
- The licenses for most software and other practical works are designed
-to take away your freedom to share and change the works. By contrast,
-our General Public Licenses are intended to guarantee your freedom to
-share and change all versions of a program--to make sure it remains free
-software for all its users.
-
- When we speak of free software, we are referring to freedom, not
-price. Our General Public Licenses are designed to make sure that you
-have the freedom to distribute copies of free software (and charge for
-them if you wish), that you receive source code or can get it if you
-want it, that you can change the software or use pieces of it in new
-free programs, and that you know you can do these things.
-
- Developers that use our General Public Licenses protect your rights
-with two steps: (1) assert copyright on the software, and (2) offer
-you this License which gives you legal permission to copy, distribute
-and/or modify the software.
-
- A secondary benefit of defending all users' freedom is that
-improvements made in alternate versions of the program, if they
-receive widespread use, become available for other developers to
-incorporate. Many developers of free software are heartened and
-encouraged by the resulting cooperation. However, in the case of
-software used on network servers, this result may fail to come about.
-The GNU General Public License permits making a modified version and
-letting the public access it on a server without ever releasing its
-source code to the public.
-
- The GNU Affero General Public License is designed specifically to
-ensure that, in such cases, the modified source code becomes available
-to the community. It requires the operator of a network server to
-provide the source code of the modified version running there to the
-users of that server. Therefore, public use of a modified version, on
-a publicly accessible server, gives the public access to the source
-code of the modified version.
-
- An older license, called the Affero General Public License and
-published by Affero, was designed to accomplish similar goals. This is
-a different license, not a version of the Affero GPL, but Affero has
-released a new version of the Affero GPL which permits relicensing under
-this license.
-
- The precise terms and conditions for copying, distribution and
-modification follow.
-
- TERMS AND CONDITIONS
-
- 0. Definitions.
-
- "This License" refers to version 3 of the GNU Affero General Public License.
-
- "Copyright" also means copyright-like laws that apply to other kinds of
-works, such as semiconductor masks.
-
- "The Program" refers to any copyrightable work licensed under this
-License. Each licensee is addressed as "you". "Licensees" and
-"recipients" may be individuals or organizations.
-
- To "modify" a work means to copy from or adapt all or part of the work
-in a fashion requiring copyright permission, other than the making of an
-exact copy. The resulting work is called a "modified version" of the
-earlier work or a work "based on" the earlier work.
-
- A "covered work" means either the unmodified Program or a work based
-on the Program.
-
- To "propagate" a work means to do anything with it that, without
-permission, would make you directly or secondarily liable for
-infringement under applicable copyright law, except executing it on a
-computer or modifying a private copy. Propagation includes copying,
-distribution (with or without modification), making available to the
-public, and in some countries other activities as well.
-
- To "convey" a work means any kind of propagation that enables other
-parties to make or receive copies. Mere interaction with a user through
-a computer network, with no transfer of a copy, is not conveying.
-
- An interactive user interface displays "Appropriate Legal Notices"
-to the extent that it includes a convenient and prominently visible
-feature that (1) displays an appropriate copyright notice, and (2)
-tells the user that there is no warranty for the work (except to the
-extent that warranties are provided), that licensees may convey the
-work under this License, and how to view a copy of this License. If
-the interface presents a list of user commands or options, such as a
-menu, a prominent item in the list meets this criterion.
-
- 1. Source Code.
-
- The "source code" for a work means the preferred form of the work
-for making modifications to it. "Object code" means any non-source
-form of a work.
-
- A "Standard Interface" means an interface that either is an official
-standard defined by a recognized standards body, or, in the case of
-interfaces specified for a particular programming language, one that
-is widely used among developers working in that language.
-
- The "System Libraries" of an executable work include anything, other
-than the work as a whole, that (a) is included in the normal form of
-packaging a Major Component, but which is not part of that Major
-Component, and (b) serves only to enable use of the work with that
-Major Component, or to implement a Standard Interface for which an
-implementation is available to the public in source code form. A
-"Major Component", in this context, means a major essential component
-(kernel, window system, and so on) of the specific operating system
-(if any) on which the executable work runs, or a compiler used to
-produce the work, or an object code interpreter used to run it.
-
- The "Corresponding Source" for a work in object code form means all
-the source code needed to generate, install, and (for an executable
-work) run the object code and to modify the work, including scripts to
-control those activities. However, it does not include the work's
-System Libraries, or general-purpose tools or generally available free
-programs which are used unmodified in performing those activities but
-which are not part of the work. For example, Corresponding Source
-includes interface definition files associated with source files for
-the work, and the source code for shared libraries and dynamically
-linked subprograms that the work is specifically designed to require,
-such as by intimate data communication or control flow between those
-subprograms and other parts of the work.
-
- The Corresponding Source need not include anything that users
-can regenerate automatically from other parts of the Corresponding
-Source.
-
- The Corresponding Source for a work in source code form is that
-same work.
-
- 2. Basic Permissions.
-
- All rights granted under this License are granted for the term of
-copyright on the Program, and are irrevocable provided the stated
-conditions are met. This License explicitly affirms your unlimited
-permission to run the unmodified Program. The output from running a
-covered work is covered by this License only if the output, given its
-content, constitutes a covered work. This License acknowledges your
-rights of fair use or other equivalent, as provided by copyright law.
-
- You may make, run and propagate covered works that you do not
-convey, without conditions so long as your license otherwise remains
-in force. You may convey covered works to others for the sole purpose
-of having them make modifications exclusively for you, or provide you
-with facilities for running those works, provided that you comply with
-the terms of this License in conveying all material for which you do
-not control copyright. Those thus making or running the covered works
-for you must do so exclusively on your behalf, under your direction
-and control, on terms that prohibit them from making any copies of
-your copyrighted material outside their relationship with you.
-
- Conveying under any other circumstances is permitted solely under
-the conditions stated below. Sublicensing is not allowed; section 10
-makes it unnecessary.
-
- 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
-
- No covered work shall be deemed part of an effective technological
-measure under any applicable law fulfilling obligations under article
-11 of the WIPO copyright treaty adopted on 20 December 1996, or
-similar laws prohibiting or restricting circumvention of such
-measures.
-
- When you convey a covered work, you waive any legal power to forbid
-circumvention of technological measures to the extent such circumvention
-is effected by exercising rights under this License with respect to
-the covered work, and you disclaim any intention to limit operation or
-modification of the work as a means of enforcing, against the work's
-users, your or third parties' legal rights to forbid circumvention of
-technological measures.
-
- 4. Conveying Verbatim Copies.
-
- You may convey verbatim copies of the Program's source code as you
-receive it, in any medium, provided that you conspicuously and
-appropriately publish on each copy an appropriate copyright notice;
-keep intact all notices stating that this License and any
-non-permissive terms added in accord with section 7 apply to the code;
-keep intact all notices of the absence of any warranty; and give all
-recipients a copy of this License along with the Program.
-
- You may charge any price or no price for each copy that you convey,
-and you may offer support or warranty protection for a fee.
-
- 5. Conveying Modified Source Versions.
-
- You may convey a work based on the Program, or the modifications to
-produce it from the Program, in the form of source code under the
-terms of section 4, provided that you also meet all of these conditions:
-
- a) The work must carry prominent notices stating that you modified
- it, and giving a relevant date.
-
- b) The work must carry prominent notices stating that it is
- released under this License and any conditions added under section
- 7. This requirement modifies the requirement in section 4 to
- "keep intact all notices".
-
- c) You must license the entire work, as a whole, under this
- License to anyone who comes into possession of a copy. This
- License will therefore apply, along with any applicable section 7
- additional terms, to the whole of the work, and all its parts,
- regardless of how they are packaged. This License gives no
- permission to license the work in any other way, but it does not
- invalidate such permission if you have separately received it.
-
- d) If the work has interactive user interfaces, each must display
- Appropriate Legal Notices; however, if the Program has interactive
- interfaces that do not display Appropriate Legal Notices, your
- work need not make them do so.
-
- A compilation of a covered work with other separate and independent
-works, which are not by their nature extensions of the covered work,
-and which are not combined with it such as to form a larger program,
-in or on a volume of a storage or distribution medium, is called an
-"aggregate" if the compilation and its resulting copyright are not
-used to limit the access or legal rights of the compilation's users
-beyond what the individual works permit. Inclusion of a covered work
-in an aggregate does not cause this License to apply to the other
-parts of the aggregate.
-
- 6. Conveying Non-Source Forms.
-
- You may convey a covered work in object code form under the terms
-of sections 4 and 5, provided that you also convey the
-machine-readable Corresponding Source under the terms of this License,
-in one of these ways:
-
- a) Convey the object code in, or embodied in, a physical product
- (including a physical distribution medium), accompanied by the
- Corresponding Source fixed on a durable physical medium
- customarily used for software interchange.
-
- b) Convey the object code in, or embodied in, a physical product
- (including a physical distribution medium), accompanied by a
- written offer, valid for at least three years and valid for as
- long as you offer spare parts or customer support for that product
- model, to give anyone who possesses the object code either (1) a
- copy of the Corresponding Source for all the software in the
- product that is covered by this License, on a durable physical
- medium customarily used for software interchange, for a price no
- more than your reasonable cost of physically performing this
- conveying of source, or (2) access to copy the
- Corresponding Source from a network server at no charge.
-
- c) Convey individual copies of the object code with a copy of the
- written offer to provide the Corresponding Source. This
- alternative is allowed only occasionally and noncommercially, and
- only if you received the object code with such an offer, in accord
- with subsection 6b.
-
- d) Convey the object code by offering access from a designated
- place (gratis or for a charge), and offer equivalent access to the
- Corresponding Source in the same way through the same place at no
- further charge. You need not require recipients to copy the
- Corresponding Source along with the object code. If the place to
- copy the object code is a network server, the Corresponding Source
- may be on a different server (operated by you or a third party)
- that supports equivalent copying facilities, provided you maintain
- clear directions next to the object code saying where to find the
- Corresponding Source. Regardless of what server hosts the
- Corresponding Source, you remain obligated to ensure that it is
- available for as long as needed to satisfy these requirements.
-
- e) Convey the object code using peer-to-peer transmission, provided
- you inform other peers where the object code and Corresponding
- Source of the work are being offered to the general public at no
- charge under subsection 6d.
-
- A separable portion of the object code, whose source code is excluded
-from the Corresponding Source as a System Library, need not be
-included in conveying the object code work.
-
- A "User Product" is either (1) a "consumer product", which means any
-tangible personal property which is normally used for personal, family,
-or household purposes, or (2) anything designed or sold for incorporation
-into a dwelling. In determining whether a product is a consumer product,
-doubtful cases shall be resolved in favor of coverage. For a particular
-product received by a particular user, "normally used" refers to a
-typical or common use of that class of product, regardless of the status
-of the particular user or of the way in which the particular user
-actually uses, or expects or is expected to use, the product. A product
-is a consumer product regardless of whether the product has substantial
-commercial, industrial or non-consumer uses, unless such uses represent
-the only significant mode of use of the product.
-
- "Installation Information" for a User Product means any methods,
-procedures, authorization keys, or other information required to install
-and execute modified versions of a covered work in that User Product from
-a modified version of its Corresponding Source. The information must
-suffice to ensure that the continued functioning of the modified object
-code is in no case prevented or interfered with solely because
-modification has been made.
-
- If you convey an object code work under this section in, or with, or
-specifically for use in, a User Product, and the conveying occurs as
-part of a transaction in which the right of possession and use of the
-User Product is transferred to the recipient in perpetuity or for a
-fixed term (regardless of how the transaction is characterized), the
-Corresponding Source conveyed under this section must be accompanied
-by the Installation Information. But this requirement does not apply
-if neither you nor any third party retains the ability to install
-modified object code on the User Product (for example, the work has
-been installed in ROM).
-
- The requirement to provide Installation Information does not include a
-requirement to continue to provide support service, warranty, or updates
-for a work that has been modified or installed by the recipient, or for
-the User Product in which it has been modified or installed. Access to a
-network may be denied when the modification itself materially and
-adversely affects the operation of the network or violates the rules and
-protocols for communication across the network.
-
- Corresponding Source conveyed, and Installation Information provided,
-in accord with this section must be in a format that is publicly
-documented (and with an implementation available to the public in
-source code form), and must require no special password or key for
-unpacking, reading or copying.
-
- 7. Additional Terms.
-
- "Additional permissions" are terms that supplement the terms of this
-License by making exceptions from one or more of its conditions.
-Additional permissions that are applicable to the entire Program shall
-be treated as though they were included in this License, to the extent
-that they are valid under applicable law. If additional permissions
-apply only to part of the Program, that part may be used separately
-under those permissions, but the entire Program remains governed by
-this License without regard to the additional permissions.
-
- When you convey a copy of a covered work, you may at your option
-remove any additional permissions from that copy, or from any part of
-it. (Additional permissions may be written to require their own
-removal in certain cases when you modify the work.) You may place
-additional permissions on material, added by you to a covered work,
-for which you have or can give appropriate copyright permission.
-
- Notwithstanding any other provision of this License, for material you
-add to a covered work, you may (if authorized by the copyright holders of
-that material) supplement the terms of this License with terms:
-
- a) Disclaiming warranty or limiting liability differently from the
- terms of sections 15 and 16 of this License; or
-
- b) Requiring preservation of specified reasonable legal notices or
- author attributions in that material or in the Appropriate Legal
- Notices displayed by works containing it; or
-
- c) Prohibiting misrepresentation of the origin of that material, or
- requiring that modified versions of such material be marked in
- reasonable ways as different from the original version; or
-
- d) Limiting the use for publicity purposes of names of licensors or
- authors of the material; or
-
- e) Declining to grant rights under trademark law for use of some
- trade names, trademarks, or service marks; or
-
- f) Requiring indemnification of licensors and authors of that
- material by anyone who conveys the material (or modified versions of
- it) with contractual assumptions of liability to the recipient, for
- any liability that these contractual assumptions directly impose on
- those licensors and authors.
-
- All other non-permissive additional terms are considered "further
-restrictions" within the meaning of section 10. If the Program as you
-received it, or any part of it, contains a notice stating that it is
-governed by this License along with a term that is a further
-restriction, you may remove that term. If a license document contains
-a further restriction but permits relicensing or conveying under this
-License, you may add to a covered work material governed by the terms
-of that license document, provided that the further restriction does
-not survive such relicensing or conveying.
-
- If you add terms to a covered work in accord with this section, you
-must place, in the relevant source files, a statement of the
-additional terms that apply to those files, or a notice indicating
-where to find the applicable terms.
-
- Additional terms, permissive or non-permissive, may be stated in the
-form of a separately written license, or stated as exceptions;
-the above requirements apply either way.
-
- 8. Termination.
-
- You may not propagate or modify a covered work except as expressly
-provided under this License. Any attempt otherwise to propagate or
-modify it is void, and will automatically terminate your rights under
-this License (including any patent licenses granted under the third
-paragraph of section 11).
-
- However, if you cease all violation of this License, then your
-license from a particular copyright holder is reinstated (a)
-provisionally, unless and until the copyright holder explicitly and
-finally terminates your license, and (b) permanently, if the copyright
-holder fails to notify you of the violation by some reasonable means
-prior to 60 days after the cessation.
-
- Moreover, your license from a particular copyright holder is
-reinstated permanently if the copyright holder notifies you of the
-violation by some reasonable means, this is the first time you have
-received notice of violation of this License (for any work) from that
-copyright holder, and you cure the violation prior to 30 days after
-your receipt of the notice.
-
- Termination of your rights under this section does not terminate the
-licenses of parties who have received copies or rights from you under
-this License. If your rights have been terminated and not permanently
-reinstated, you do not qualify to receive new licenses for the same
-material under section 10.
-
- 9. Acceptance Not Required for Having Copies.
-
- You are not required to accept this License in order to receive or
-run a copy of the Program. Ancillary propagation of a covered work
-occurring solely as a consequence of using peer-to-peer transmission
-to receive a copy likewise does not require acceptance. However,
-nothing other than this License grants you permission to propagate or
-modify any covered work. These actions infringe copyright if you do
-not accept this License. Therefore, by modifying or propagating a
-covered work, you indicate your acceptance of this License to do so.
-
- 10. Automatic Licensing of Downstream Recipients.
-
- Each time you convey a covered work, the recipient automatically
-receives a license from the original licensors, to run, modify and
-propagate that work, subject to this License. You are not responsible
-for enforcing compliance by third parties with this License.
-
- An "entity transaction" is a transaction transferring control of an
-organization, or substantially all assets of one, or subdividing an
-organization, or merging organizations. If propagation of a covered
-work results from an entity transaction, each party to that
-transaction who receives a copy of the work also receives whatever
-licenses to the work the party's predecessor in interest had or could
-give under the previous paragraph, plus a right to possession of the
-Corresponding Source of the work from the predecessor in interest, if
-the predecessor has it or can get it with reasonable efforts.
-
- You may not impose any further restrictions on the exercise of the
-rights granted or affirmed under this License. For example, you may
-not impose a license fee, royalty, or other charge for exercise of
-rights granted under this License, and you may not initiate litigation
-(including a cross-claim or counterclaim in a lawsuit) alleging that
-any patent claim is infringed by making, using, selling, offering for
-sale, or importing the Program or any portion of it.
-
- 11. Patents.
-
- A "contributor" is a copyright holder who authorizes use under this
-License of the Program or a work on which the Program is based. The
-work thus licensed is called the contributor's "contributor version".
-
- A contributor's "essential patent claims" are all patent claims
-owned or controlled by the contributor, whether already acquired or
-hereafter acquired, that would be infringed by some manner, permitted
-by this License, of making, using, or selling its contributor version,
-but do not include claims that would be infringed only as a
-consequence of further modification of the contributor version. For
-purposes of this definition, "control" includes the right to grant
-patent sublicenses in a manner consistent with the requirements of
-this License.
-
- Each contributor grants you a non-exclusive, worldwide, royalty-free
-patent license under the contributor's essential patent claims, to
-make, use, sell, offer for sale, import and otherwise run, modify and
-propagate the contents of its contributor version.
-
- In the following three paragraphs, a "patent license" is any express
-agreement or commitment, however denominated, not to enforce a patent
-(such as an express permission to practice a patent or covenant not to
-sue for patent infringement). To "grant" such a patent license to a
-party means to make such an agreement or commitment not to enforce a
-patent against the party.
-
- If you convey a covered work, knowingly relying on a patent license,
-and the Corresponding Source of the work is not available for anyone
-to copy, free of charge and under the terms of this License, through a
-publicly available network server or other readily accessible means,
-then you must either (1) cause the Corresponding Source to be so
-available, or (2) arrange to deprive yourself of the benefit of the
-patent license for this particular work, or (3) arrange, in a manner
-consistent with the requirements of this License, to extend the patent
-license to downstream recipients. "Knowingly relying" means you have
-actual knowledge that, but for the patent license, your conveying the
-covered work in a country, or your recipient's use of the covered work
-in a country, would infringe one or more identifiable patents in that
-country that you have reason to believe are valid.
-
- If, pursuant to or in connection with a single transaction or
-arrangement, you convey, or propagate by procuring conveyance of, a
-covered work, and grant a patent license to some of the parties
-receiving the covered work authorizing them to use, propagate, modify
-or convey a specific copy of the covered work, then the patent license
-you grant is automatically extended to all recipients of the covered
-work and works based on it.
-
- A patent license is "discriminatory" if it does not include within
-the scope of its coverage, prohibits the exercise of, or is
-conditioned on the non-exercise of one or more of the rights that are
-specifically granted under this License. You may not convey a covered
-work if you are a party to an arrangement with a third party that is
-in the business of distributing software, under which you make payment
-to the third party based on the extent of your activity of conveying
-the work, and under which the third party grants, to any of the
-parties who would receive the covered work from you, a discriminatory
-patent license (a) in connection with copies of the covered work
-conveyed by you (or copies made from those copies), or (b) primarily
-for and in connection with specific products or compilations that
-contain the covered work, unless you entered into that arrangement,
-or that patent license was granted, prior to 28 March 2007.
-
- Nothing in this License shall be construed as excluding or limiting
-any implied license or other defenses to infringement that may
-otherwise be available to you under applicable patent law.
-
- 12. No Surrender of Others' Freedom.
-
- If conditions are imposed on you (whether by court order, agreement or
-otherwise) that contradict the conditions of this License, they do not
-excuse you from the conditions of this License. If you cannot convey a
-covered work so as to satisfy simultaneously your obligations under this
-License and any other pertinent obligations, then as a consequence you may
-not convey it at all. For example, if you agree to terms that obligate you
-to collect a royalty for further conveying from those to whom you convey
-the Program, the only way you could satisfy both those terms and this
-License would be to refrain entirely from conveying the Program.
-
- 13. Remote Network Interaction; Use with the GNU General Public License.
-
- Notwithstanding any other provision of this License, if you modify the
-Program, your modified version must prominently offer all users
-interacting with it remotely through a computer network (if your version
-supports such interaction) an opportunity to receive the Corresponding
-Source of your version by providing access to the Corresponding Source
-from a network server at no charge, through some standard or customary
-means of facilitating copying of software. This Corresponding Source
-shall include the Corresponding Source for any work covered by version 3
-of the GNU General Public License that is incorporated pursuant to the
-following paragraph.
-
- Notwithstanding any other provision of this License, you have
-permission to link or combine any covered work with a work licensed
-under version 3 of the GNU General Public License into a single
-combined work, and to convey the resulting work. The terms of this
-License will continue to apply to the part which is the covered work,
-but the work with which it is combined will remain governed by version
-3 of the GNU General Public License.
-
- 14. Revised Versions of this License.
-
- The Free Software Foundation may publish revised and/or new versions of
-the GNU Affero General Public License from time to time. Such new versions
-will be similar in spirit to the present version, but may differ in detail to
-address new problems or concerns.
-
- Each version is given a distinguishing version number. If the
-Program specifies that a certain numbered version of the GNU Affero General
-Public License "or any later version" applies to it, you have the
-option of following the terms and conditions either of that numbered
-version or of any later version published by the Free Software
-Foundation. If the Program does not specify a version number of the
-GNU Affero General Public License, you may choose any version ever published
-by the Free Software Foundation.
-
- If the Program specifies that a proxy can decide which future
-versions of the GNU Affero General Public License can be used, that proxy's
-public statement of acceptance of a version permanently authorizes you
-to choose that version for the Program.
-
- Later license versions may give you additional or different
-permissions. However, no additional obligations are imposed on any
-author or copyright holder as a result of your choosing to follow a
-later version.
-
- 15. Disclaimer of Warranty.
-
- THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
-APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
-HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
-OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
-THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
-IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
-ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
-
- 16. Limitation of Liability.
-
- IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
-WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
-THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
-GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
-USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
-DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
-PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
-EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
-SUCH DAMAGES.
-
- 17. Interpretation of Sections 15 and 16.
-
- If the disclaimer of warranty and limitation of liability provided
-above cannot be given local legal effect according to their terms,
-reviewing courts shall apply local law that most closely approximates
-an absolute waiver of all civil liability in connection with the
-Program, unless a warranty or assumption of liability accompanies a
-copy of the Program in return for a fee.
-
- END OF TERMS AND CONDITIONS
-
- How to Apply These Terms to Your New Programs
-
- If you develop a new program, and you want it to be of the greatest
-possible use to the public, the best way to achieve this is to make it
-free software which everyone can redistribute and change under these terms.
-
- To do so, attach the following notices to the program. It is safest
-to attach them to the start of each source file to most effectively
-state the exclusion of warranty; and each file should have at least
-the "copyright" line and a pointer to where the full notice is found.
-
-
- Copyright (C)
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published
- by the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-
-Also add information on how to contact you by electronic and paper mail.
-
- If your software can interact with users remotely through a computer
-network, you should also make sure that it provides a way for users to
-get its source. For example, if your program is a web application, its
-interface could display a "Source" link that leads users to an archive
-of the code. There are many ways you could offer source, and different
-solutions will be better for different programs; see section 13 for the
-specific requirements.
-
- You should also get your employer (if you work as a programmer) or school,
-if any, to sign a "copyright disclaimer" for the program, if necessary.
-For more information on this, and how to apply and follow the GNU AGPL, see
-.
\ No newline at end of file
diff --git a/unsloth/kernels/moe/grouped_gemm/__init__.py b/unsloth/kernels/moe/grouped_gemm/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/unsloth/kernels/moe/grouped_gemm/interface.py b/unsloth/kernels/moe/grouped_gemm/interface.py
deleted file mode 100644
index 5588458973..0000000000
--- a/unsloth/kernels/moe/grouped_gemm/interface.py
+++ /dev/null
@@ -1,1041 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-import logging
-import warnings
-from dataclasses import asdict
-from unsloth import DEVICE_TYPE
-
-import torch
-import triton
-
-from .kernels.backward import (
- _autotuned_grouped_gemm_dW_kernel,
- _autotuned_grouped_gemm_dX_kernel,
- _grouped_gemm_dW_kernel,
- _grouped_gemm_dX_kernel,
-)
-from .kernels.forward import (
- _autotuned_grouped_gemm_forward_kernel,
- _grouped_gemm_forward_kernel,
-)
-from .kernels.tuning import (
- KernelConfigBackward_dW,
- KernelConfigBackward_dX,
- KernelConfigForward,
-)
-
-logger = logging.getLogger(__name__)
-# Set formatter to include timestamp, pathname and lineno
-formatter = logging.Formatter(
- "%(asctime)s::%(levelname)s,%(pathname)s:%(lineno)d:: %(message)s"
-)
-
-# Add console handler
-ch = logging.StreamHandler()
-ch.setFormatter(formatter)
-logger.addHandler(ch)
-
-
-# Precompute TMA support to avoid graph breaks
-# TMA requires both:
-# 1. NVIDIA GPU with capability >= 9 (Hopper+)
-# 2. Triton version with TMA API (make_tensor_descriptor or _experimental_make_tensor_descriptor)
-def _check_tma_support():
- if DEVICE_TYPE in ("xpu", "hip"):
- return False
- import triton.language as tl
-
- gpu_supports_tma = torch.cuda.get_device_capability()[0] >= 9
- # Check for both old experimental and new stable API names
- triton_has_tma_api = hasattr(tl, "make_tensor_descriptor") or hasattr(
- tl, "_experimental_make_tensor_descriptor"
- )
- return gpu_supports_tma and triton_has_tma_api
-
-
-_SUPPORTS_TMA = _check_tma_support()
-
-# Check if triton.set_allocator is available (Triton 3.0+)
-_HAS_SET_ALLOCATOR = hasattr(triton, "set_allocator")
-
-
-def supports_tma():
- return _SUPPORTS_TMA
-
-
-# Helper to support allow_in_graph
-try:
- from torch.compiler import allow_in_graph
-except ImportError:
- from torch._dynamo import allow_in_graph
-
-
-# Helper to detect if we're in tracing/compilation mode
-def _is_tracing(*tensors):
- """
- Check if tensors are fake tensors used during torch.compile tracing.
- During tracing, tensors are FakeTensor/FunctionalTensor and we can't run Triton kernels.
- During execution, tensors are real Tensors and we MUST run the kernels.
-
- NOTE: We do NOT use torch.compiler.is_compiling() because it returns True
- during both tracing AND execution. We only want to skip kernels during tracing
- when tensors are actually fake.
- """
- for t in tensors:
- name = type(t).__name__
- if name in ("FakeTensor", "FunctionalTensor", "FunctionalTensorWrapper"):
- return True
- return False
-
-
-_per_device_alloc_fns = {}
-
-
-def get_per_device_per_stream_alloc_fn(device):
- if device not in _per_device_alloc_fns:
- _per_stream_tensors = {}
-
- def alloc_fn(size: int, alignment: int, stream):
- assert alignment == 128
- if (
- stream not in _per_stream_tensors
- or _per_stream_tensors[stream].numel() < size
- ):
- _per_stream_tensors[stream] = torch.empty(
- size, device = device, dtype = torch.int8
- )
- _per_stream_tensors[stream].__hibernate__ = {"type": "ignore"}
- return _per_stream_tensors[stream]
-
- _per_device_alloc_fns[device] = alloc_fn
- return _per_device_alloc_fns[device]
-
-
-def log_kernel_info(
- compiled_kernel: triton.compiler.CompiledKernel, best_config: triton.Config = None
-):
- kernel_name = compiled_kernel.name
- nregs = compiled_kernel.n_regs
- nspills = compiled_kernel.n_spills
- metadata = compiled_kernel.metadata
- logger.debug(
- f"{kernel_name}: n_regs={nregs} n_spills={nspills} metadata={metadata}"
- )
- if best_config is not None:
- logger.debug(f"{kernel_name} autotuned best_config: {best_config}")
-
-
-@allow_in_graph
-def grouped_gemm_forward(
- X: torch.Tensor,
- W: torch.Tensor,
- topk: int,
- m_sizes: torch.Tensor,
- gather_indices: torch.Tensor = None,
- topk_weights: torch.Tensor = None,
- # Fusions
- permute_x: bool = False,
- permute_y: bool = False,
- fuse_mul_post: bool = False,
- # Autotuning - manual kernel params will be ignored if autotune is True
- autotune: bool = False,
- # Kernel tuning params if not autotuning -- NOTE: these params need to be tuned, otherwise performance will be poor
- BLOCK_SIZE_M: int = 32,
- BLOCK_SIZE_N: int = 32,
- BLOCK_SIZE_K: int = 32,
- num_warps: int = 4,
- num_stages: int = 2,
- use_tma_load_w: bool = False,
- use_tma_load_x: bool = False,
- use_tma_store: bool = False,
- # software pipelining -- set to True for now, won't impact until loop is re-written
- flatten: bool = True,
- # debugging
- debug: bool = False,
-) -> torch.Tensor:
- """
- Grouped GEMM forward pass for MoE MLPs.
-
- The implementation offers a number of fusions specific to MoE:
- - `permute_x`: fuse the permutation of hidden states from token order (original order) to grouped expert order, typically only needed for the first grouped GEMM in an MoE MLP.
- - When `permute_x` is True, `X` is expected to be of shape (num_tokens, K).
- - When `permute_x` is False, `X` is expected to be of shape (total_tokens, K) where `total_tokens = num_tokens * topk` AND already permuted to grouped expert order, i.e., hidden states are sorted such that tokens assigned to each expert are contiguous.
- - `permute_y`: fused the permutation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP.
- - `fuse_mul_pre`: fuse the multiplication of the routed input with topk_weights, only done in the first grouped GEMM in an MoE MLP as for Llama4. Do not use, since results in performance regression as it interrupts the GEMM mainloop.
- - `fuse_mul_post`: fuse the multiplication of the routed output with topk_weights, used only when `permute_y` is True. NOTE: this should only be used when using this kernel for inference, not for training.
-
- X: (M, K) hidden states where M is the num_tokens if `permute_x` is True, otherwise `total_tokens` where `total_tokens = num_tokens * topk`.
- W: (E, N, K) expert weights, where E is number of experts, N in the intermediate (output) dim, and K is the reduction dim
- m_sizes: tokens assigned to each expert which correspond to the size of M in the respective GEMMs in the grouped GEMM.
- gather_indices: (total_tokens,) indices of tokens assigned to each expert. E.g., slicing gather_indices by cumsum of m_sizes gives the indices of tokens assigned to each expert.
- topk_weights: (total_tokens,) weights to multiply routed output by in expert MLP calculation, used only when `fuse_mul_post` is True (see note on `fuse_mul_post`).
- use_fast_accum: currently unused; trade off faster accumulation dtype in GEMM for less precision.
- use_tma_load_x: use TMA for loading activations, incompatible with permute_x. TODO: add TMA gather / scatter support for Blackwell+.
- use_tma_load_w: use TMA for loading weights. If TMA supported, this should always be enabled as it is faster than global memory load.
- use_tma_store: use TMA for storing output, incompatible with permute_y. TODO: add TMA scatter support for Blackwell+.
-
- Returns:
- y: (total_tokens, N) output of grouped GEMM
- """
-
- assert X.device.type == "cuda", "X and W must be on CUDA"
- assert m_sizes.device.type == "cuda", "m_sizes must be on CUDA"
-
- X = X.contiguous()
- W = W.contiguous()
- m_sizes = m_sizes.contiguous()
-
- # Preconditions
- assert not (permute_x and permute_y), "Cannot permute both X and Y"
- assert not (permute_y and use_tma_store), "Cannot use both TMA store and permute_y"
-
- if use_tma_load_x:
- # TMA load for activations, TMA gather only supported on Blackwell+
- assert not permute_x, "Cannot use both use_tma_load_x and permute_x"
-
- use_tma = use_tma_load_w or use_tma_load_x or use_tma_store
- if not supports_tma() and use_tma:
- warnings.warn("TMA not supported, tma_load will be set to False")
- use_tma_load_w = False
- use_tma_load_x = False
- use_tma_store = False
-
- if use_tma or autotune:
- # Respect global persistent allocator if set
- if _HAS_SET_ALLOCATOR and not getattr(triton, "_unsloth_allocator_set", False):
-
- def alloc_fn(size: int, alignment: int, stream: int):
- return torch.empty(size, device = "cuda", dtype = torch.int8)
-
- triton.set_allocator(alloc_fn)
-
- if W.ndim == 3:
- num_experts = W.shape[0]
- N = W.shape[1]
- # K = W.shape[2]
- else:
- num_experts = m_sizes.shape[0]
- N = W.shape[0] // num_experts
-
- X = X.view(-1, X.shape[-1])
- W = W.view(-1, W.shape[-1])
-
- if permute_x or permute_y:
- assert (
- gather_indices is not None
- ), "gather_indices must be provided when permute_x or permute_y is True"
- assert gather_indices.is_contiguous()
- assert gather_indices.device.type == "cuda"
- assert gather_indices.ndim == 1
- total_tokens = gather_indices.shape[0]
- num_tokens = total_tokens // topk
- if permute_x:
- assert (
- X.shape[0] == num_tokens
- ), f"X.shape[0] ({X.shape[0]}) must match num_tokens ({num_tokens})"
- else:
- assert (
- X.shape[0] == total_tokens
- ), f"X.shape[0] ({X.shape[0]}) must match total_tokens ({total_tokens})"
- else:
- total_tokens = X.shape[0]
- num_tokens = total_tokens // topk
-
- _, K = X.shape
- assert K == W.shape[1], f"K ({K}) must match W.shape[1] ({W.shape[1]})"
-
- if fuse_mul_post:
- global _FUSED_MUL_WARN
- if not _FUSED_MUL_WARN:
- warnings.warn(
- "fused_mul should only be used for inference, not for training"
- )
- _FUSED_MUL_WARN = True
- assert permute_y, "FUSE_MUL requires PERMUTE_Y"
- assert topk_weights is not None
- assert topk_weights.numel() == total_tokens
- assert topk_weights.device.type == "cuda"
- assert topk_weights.is_contiguous()
- topk_weights = topk_weights.view(-1)
- if debug:
- print(
- f"DEBUG::GROUPED_GEMM {topk_weights.tolist()} {gather_indices.tolist()}"
- )
-
- y = torch.empty((total_tokens, N), device = X.device, dtype = X.dtype)
- # if total_tokens == 0 or N == 0:
- # return y
-
- NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
-
- def grid(META):
- return (NUM_SMS,)
-
- if not autotune:
- # BLOCK_SIZE_K = min(K, BLOCK_SIZE_K)
- # BLOCK_SIZE_N = min(N, BLOCK_SIZE_N)
- pass
-
- if debug:
- print(
- f"DEBUG::GROUPED_GEMM {num_tokens = } {topk = } {num_experts = } {N = } {K = } {BLOCK_SIZE_M = } {BLOCK_SIZE_N = } {BLOCK_SIZE_K = } {permute_x = }"
- )
- print(
- f"DEBUG::GROUPED_GEMM {m_sizes.tolist()} {(gather_indices // topk).tolist()}"
- )
-
- kernel_args = {
- # Inputs
- "x_ptr": X,
- "w_ptr": W,
- "m_sizes_ptr": m_sizes,
- "gather_indices_ptr": gather_indices,
- "topk_weights_ptr": topk_weights,
- # Output
- "y_ptr": y,
- # Problem shapes
- "NUM_TOKENS": num_tokens,
- "NUM_EXPERTS": num_experts,
- "TOPK": topk,
- "N": N,
- "K": K,
- "NUM_SMS": NUM_SMS,
- # Gather / Scatter
- "PERMUTE_X": permute_x,
- "PERMUTE_Y": permute_y,
- # TopK weight merging
- "FUSE_MUL_POST": fuse_mul_post,
- # Loop pipelining
- "FLATTEN": flatten,
- }
- if not autotune:
- kernel_args.update(
- {
- "USE_TMA_LOAD_W": use_tma_load_w,
- "USE_TMA_LOAD_X": use_tma_load_x,
- "USE_TMA_STORE": use_tma_store,
- "BLOCK_SIZE_M": BLOCK_SIZE_M,
- "BLOCK_SIZE_N": BLOCK_SIZE_N,
- "BLOCK_SIZE_K": BLOCK_SIZE_K,
- "num_warps": num_warps,
- "num_stages": num_stages,
- }
- )
-
- kernel = (
- _autotuned_grouped_gemm_forward_kernel
- if autotune
- else _grouped_gemm_forward_kernel
- )
-
- is_fake = _is_tracing(X, W)
- if not is_fake:
- compiled_kernel: triton.compiler.CompiledKernel = kernel[grid](**kernel_args)
- if autotune:
- log_kernel_info(compiled_kernel, kernel.best_config)
- else:
- log_kernel_info(compiled_kernel)
-
- return y
-
-
-@allow_in_graph
-def grouped_gemm_dX(
- dY: torch.Tensor,
- W: torch.Tensor,
- gather_indices: torch.Tensor,
- m_sizes: torch.Tensor,
- topk: int,
- BLOCK_SIZE_M: int = 32,
- BLOCK_SIZE_N: int = 32,
- BLOCK_SIZE_K: int = 32,
- debug: bool = False,
- permute_x: bool = False,
- permute_y: bool = False,
- use_tma_load_w: bool = False,
- use_tma_load_dy: bool = False,
- use_tma_store: bool = False,
- num_warps: int = 4,
- num_stages: int = 2,
- flatten: bool = True,
- fuse_mul_pre: bool = False,
- fuse_mul_post: bool = False,
- autotune: bool = False,
-) -> torch.Tensor:
- """
- dX backward kernel
- grad_output: (M, N)
- gather_indices: (total_tokens,), indices of tokens assigned to each expert. E.g., slicing gather_indices by cumsum of m_sizes gives the indices of tokens assigned to each expert.
- m_sizes: tokens assigned to each expert which correspond to the size of M in the respective GEMMs in the grouped GEMM.
- topk: number of experts chosen per token.
- `permute_x`: whether X was permuted on load in the forward pass, typically only used for the first grouped GEMM in an MoE MLP to group tokens by expert.
- - In the forward pass, if we permuted X on load, we need to permute store in the backward pass
- - Shapes
- - the forward pass input X shape is [NUM_TOKENS, K], reduce across K, output y is [NUM_TOKENS * TOPK, K]
- - the backward pass input dy shape is [NUM_TOKENS * TOPK, N], reduce across N, output dX is [NUM_TOKENS * TOPK, K]
- - Note that in the backward pass, the output size is still [NUM_TOKENS * TOPK, K] since we still need to accumulate gradients for each expert chosen by the token in a post-processing step.
- `permute_y`: whether the output was permuted on store in the forward pass, typically only used for the second grouped GEMM in an MoE MLP to restore to the original token order.
- - In the forward pass, if we permuted output on store (e.g., in the second grouped GEMM in fused MoE MLP), we need to permute on load to get from token order to expert grouped order
- - We still store in contiguous order since we are writing out dX which will be the input to the backwards pass of the first grouped GEMM
- `fuse_mul_{pre,post}`: always set to False since this should only be used for inference.
- use_tma_load_dy: use TMA for loading dy. use_tma_load_dy is incompatible with permute_y. TODO: add TMA gather / scatter support for Blackwell+ which will enable permute_y and use_tma_load_dy.
- use_tma_load_w: use TMA for loading weights. If TMA supported, this should always be enabled as it is faster than global memory load.
- use_tma_store: use TMA for storing dX. Incompatible with permute_x. TODO: add TMA gather / scatter support for Blackwell+ which will enable permute_x and use_tma_store.
- """
- assert (
- not fuse_mul_pre
- ), "fuse_mul_pre should only be used for inference, not for training"
- assert (
- not fuse_mul_post
- ), "fuse_mul_post should only be used for inference, not for training"
- assert dY.is_contiguous()
- assert W.is_contiguous()
- assert m_sizes.is_contiguous()
- assert m_sizes.ndim == 1
-
- # Preconditions
- assert not (permute_x and permute_y), "Cannot permute both X and Y"
- # Note that this is flipped from the forward pass
- # If we permuted y in the forward, we need to permute on load in the backward
- assert not (permute_y and use_tma_load_dy), "Cannot use both TMA load and permute_y"
- assert not (permute_x and use_tma_store), "Cannot use both TMA store and permute_x"
-
- use_tma = use_tma_load_dy or use_tma_load_w or use_tma_store
- if not supports_tma() and use_tma:
- warnings.warn("TMA not supported, tma_load will be set to False")
- use_tma_load_w = False
- use_tma_load_dy = False
- use_tma_store = False
-
- if use_tma or autotune:
- # Respect global persistent allocator if set
- if _HAS_SET_ALLOCATOR and not getattr(triton, "_unsloth_allocator_set", False):
-
- def alloc_fn(size: int, alignment: int, stream: int):
- # print(f"DEBUG::GROUPED_GEMM alloc_fn {size=} {alignment=} {stream=}")
- return torch.empty(size, device = "cuda", dtype = torch.int8)
-
- triton.set_allocator(alloc_fn)
-
- if W.ndim == 3:
- num_experts = W.shape[0]
- N = W.shape[1]
- else:
- num_experts = m_sizes.shape[0]
- N = W.shape[0] // num_experts
-
- dY = dY.view(-1, dY.shape[-1])
- W = W.view(-1, W.shape[-1])
-
- M_total, N_grad = dY.shape
- N_total, K = W.shape
- # N = N_total // num_experts
- assert N_grad == N, f"Grad_output N ({N_grad}) must match weight N ({N})"
-
- assert (
- M_total % topk == 0
- ), f"M_total ({M_total}) must be divisible by topk ({topk})"
- num_tokens = M_total // topk
-
- total_tokens = gather_indices.shape[0]
- assert (
- total_tokens == M_total
- ), f"Total tokens ({total_tokens}) must match M_total ({M_total})"
-
- # Note that the output shape is [NUM_TOKENS * TOPK, K] even when `permute_x` is True since we need to accumulate gradients across all experts chosen by the token.
- # This will be done in a post-processing step reduction step.
- output_shape = (total_tokens, K)
- dX = torch.zeros(output_shape, device = dY.device, dtype = dY.dtype)
-
- NUM_SMS = torch.cuda.get_device_properties(
- "cuda"
- ).multi_processor_count # if not debug else 1
-
- def grid(META):
- return (NUM_SMS,)
-
- if not autotune:
- # BLOCK_SIZE_N = min(N_grad, BLOCK_SIZE_N)
- # BLOCK_SIZE_K = min(K, BLOCK_SIZE_K)
- pass
-
- if debug:
- print(
- f"DEBUG::GROUPED_GEMM {num_tokens = } {topk = } {output_shape = } {num_experts = } {N = } {K = } {BLOCK_SIZE_M = } {BLOCK_SIZE_N = } {BLOCK_SIZE_K = } {NUM_SMS = }"
- )
- print(f"DEBUG::GROUPED_GEMM {m_sizes.tolist()}")
-
- kernel_args = {
- # Inputs
- "dY_ptr": dY,
- "w_ptr": W,
- "gather_indices_ptr": gather_indices,
- "m_sizes_ptr": m_sizes,
- # Output
- "dX_ptr": dX,
- # Problem sizes
- "NUM_EXPERTS": num_experts,
- "NUM_TOKENS": num_tokens,
- "TOPK": topk,
- "N": N,
- "K": K,
- "NUM_SMS": NUM_SMS,
- # Gather / Scatter
- "PERMUTE_X": permute_x,
- "PERMUTE_Y": permute_y,
- "FLATTEN": flatten,
- }
- if not autotune:
- kernel_args.update(
- {
- "BLOCK_SIZE_M": BLOCK_SIZE_M,
- "BLOCK_SIZE_N": BLOCK_SIZE_N,
- "BLOCK_SIZE_K": BLOCK_SIZE_K,
- "num_warps": num_warps,
- "num_stages": num_stages,
- "USE_TMA_LOAD_dY": use_tma_load_dy,
- "USE_TMA_LOAD_W": use_tma_load_w,
- "USE_TMA_STORE": use_tma_store,
- }
- )
- kernel = _autotuned_grouped_gemm_dX_kernel if autotune else _grouped_gemm_dX_kernel
-
- is_fake = _is_tracing(dY, W)
- if not is_fake:
- compiled_kernel: triton.compiler.CompiledKernel = kernel[grid](**kernel_args)
-
- if autotune:
- log_kernel_info(compiled_kernel, kernel.best_config)
- else:
- log_kernel_info(compiled_kernel)
- return dX
-
-
-@allow_in_graph
-def grouped_gemm_dW(
- X: torch.Tensor,
- dY: torch.Tensor,
- m_sizes: torch.Tensor,
- gather_indices: torch.Tensor,
- topk: int,
- BLOCK_SIZE_M: int = 32,
- BLOCK_SIZE_N: int = 32,
- BLOCK_SIZE_K: int = 32,
- permute_x: bool = False,
- permute_y: bool = False,
- use_tma_load_dy: bool = False,
- use_tma_load_x: bool = False,
- use_tma_store: bool = False,
- fuse_mul_pre: bool = False,
- fuse_mul_post: bool = False,
- num_warps: int = 4,
- num_stages: int = 2,
- flatten: bool = True,
- autotune: bool = False,
- debug: bool = False,
-) -> torch.Tensor:
- """
- X: (M, K) hidden states where M is the num_tokens if `permute_x` is True, otherwise `total_tokens` where `total_tokens = num_tokens * topk`.
- dY: (M, N)
- topk: number of experts to choose per token.
- m_sizes: tokens assigned to each expert which correspond to the size of M in the respective GEMMs in the grouped GEMM.
- gather_indices: (total_tokens,) indices of tokens assigned to each expert. E.g., slicing gather_indices by cumsum of m_sizes gives the indices of tokens assigned to each expert.
- permute_x: whether X was permuted on load in the forward pass, typically only used for the first grouped GEMM in an MoE MLP to group tokens by expert.
- - for the first grouped GEMM, we permuted on load -> X was [num_tokens, K] and stored y in expert grouped order [num_tokens * topk, K]
- - in the backwards pass, we need to permute on load of X while loading dy in contiguous (expert grouped) order
- - since we are writing out dW, there is no need to permute on store
- permute_y: whether the output was permuted on store in the forward pass, typically only used for the second grouped GEMM in an MoE MLP to restore to the original token order.
- - for the second grouped GEMM, we permuted on store -> y was permuted from expert grouped order to token order while X was loaded in expert grouped order since it was the output of the first grouped GEMM
- - in the backwards pass, we need to permute on load of dy to get from token order to expert grouped order to match the order of X
- - since we are writing out dW, there is no need to permute on store
- use_tma_load_dy: use TMA for loading dy. use_tma_load_dy is incompatible with permute_y. TODO: add TMA gather / scatter support for Blackwell+ which will enable permute_y and use_tma_load_dy.
- use_tma_load_x: use TMA for loading x. use_tma_load_x is incompatible with permute_x. TODO: add TMA gather / scatter support for Blackwell+ which will enable permute_x and use_tma_load_x.
- use_tma_store: use TMA for storing dW. If TMA supported, this should always be enabled as it is faster than global memory store.
- """
- assert not fuse_mul_pre, "fuse_mul_pre not supported"
- assert not fuse_mul_post, "fuse_mul_post not supported"
- NUM_SMS = (
- torch.cuda.get_device_properties("cuda").multi_processor_count
- if not debug
- else 1
- )
- X = X.view(-1, X.shape[-1]).contiguous()
- dY = dY.contiguous()
- m_sizes = m_sizes.contiguous()
-
- # Preconditions
- assert not (permute_x and permute_y), "Cannot permute both X and Y"
- assert not (permute_y and use_tma_load_dy), "Cannot use both TMA load and permute_y"
- assert not (permute_x and use_tma_load_x), "Cannot use both TMA load and permute_x"
-
- use_tma = use_tma_load_dy or use_tma_load_x or use_tma_store
- if not supports_tma() and use_tma:
- warnings.warn("TMA not supported, tma_load will be set to False")
- use_tma_load_x = False
- use_tma_load_dy = False
- use_tma_store = False
-
- if use_tma or autotune:
- # Respect global persistent allocator if set
- if _HAS_SET_ALLOCATOR and not getattr(triton, "_unsloth_allocator_set", False):
-
- def alloc_fn(size: int, alignment: int, stream: int):
- return torch.empty(size, device = "cuda", dtype = torch.int8)
-
- triton.set_allocator(alloc_fn)
-
- if permute_x or permute_y:
- assert gather_indices is not None
- assert gather_indices.is_contiguous()
- assert gather_indices.device.type == "cuda"
- assert gather_indices.ndim == 1
- total_tokens = gather_indices.shape[0]
- num_tokens = total_tokens // topk
- if permute_x:
- assert X.shape[0] == num_tokens
- else:
- assert X.shape[0] == total_tokens
- else:
- total_tokens = X.shape[0]
- num_tokens = total_tokens // topk
-
- num_experts = m_sizes.shape[0]
- # Get dimensions
- _, K = X.shape
- M_grad, N = dY.shape
-
- assert M_grad == total_tokens, f"dY M ({M_grad}) != total_tokens ({total_tokens})"
-
- dW = torch.zeros((num_experts, N, K), device = X.device, dtype = X.dtype)
-
- if not autotune:
- # BLOCK_SIZE_N = min(N, BLOCK_SIZE_N)
- # BLOCK_SIZE_K = min(K, BLOCK_SIZE_K)
- pass
-
- def grid(META):
- return (NUM_SMS,)
-
- if debug:
- print(
- f"DEBUG::GROUPED_GEMM_DW_TMA {num_experts = } {N = } {K = } {BLOCK_SIZE_M = } {BLOCK_SIZE_N = } {BLOCK_SIZE_K = } {NUM_SMS = }"
- )
-
- print(f"DEBUG::GROUPED_GEMM_DW_TMA {m_sizes.tolist() = }")
- print(f"DEBUG::GROUPED_GEMM_DW_TMA {gather_indices.tolist() = }")
- m_start = 0
- for i in range(num_experts):
- expert_token_idx = gather_indices[m_start : m_start + m_sizes[i]]
- t_start = 0
- while t_start < m_sizes[i]:
- token_idx = expert_token_idx[t_start : t_start + BLOCK_SIZE_M]
- if permute_x:
- token_idx = token_idx // topk
- print(
- f"DEBUG::GROUPED_GEMM_DW_TMA Token expert {i} indices: {token_idx.tolist()}"
- )
- t_start += BLOCK_SIZE_M
-
- m_start += m_sizes[i]
-
- kernel_args = {
- # Inputs
- "x_ptr": X,
- "dY_ptr": dY,
- "m_sizes_ptr": m_sizes,
- "gather_indices_ptr": gather_indices,
- # Output
- "dW_ptr": dW,
- # Problem sizes
- "NUM_TOKENS": num_tokens,
- "TOPK": topk,
- "NUM_EXPERTS": num_experts,
- "N": N,
- "K": K,
- "NUM_SMS": NUM_SMS,
- # Gather / Scatter
- "PERMUTE_X": permute_x,
- "PERMUTE_Y": permute_y,
- # Loop pipelining
- "FLATTEN": flatten,
- }
-
- if not autotune:
- kernel_args.update(
- {
- "BLOCK_SIZE_M": BLOCK_SIZE_M,
- "BLOCK_SIZE_N": BLOCK_SIZE_N,
- "BLOCK_SIZE_K": BLOCK_SIZE_K,
- "USE_TMA_LOAD_dY": use_tma_load_dy,
- "USE_TMA_LOAD_X": use_tma_load_x,
- "USE_TMA_STORE": use_tma_store,
- "num_warps": num_warps,
- "num_stages": num_stages,
- }
- )
-
- kernel = _autotuned_grouped_gemm_dW_kernel if autotune else _grouped_gemm_dW_kernel
-
- is_fake = _is_tracing(X, dY)
- if not is_fake:
- compiled_kernel: triton.compiler.CompiledKernel = kernel[grid](**kernel_args)
-
- if autotune:
- log_kernel_info(compiled_kernel, kernel.best_config)
- else:
- log_kernel_info(compiled_kernel)
-
- return dW
-
-
-class GroupedGemm(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- X,
- W,
- m_sizes,
- topk,
- gather_indices,
- permute_x,
- permute_y,
- topk_weights,
- fuse_mul_post,
- kernel_config_fwd,
- kernel_config_bwd_dX,
- kernel_config_bwd_dW,
- autotune,
- dX_only,
- dW_only,
- ):
- ctx.topk = topk
- ctx.permute_x = permute_x
- ctx.permute_y = permute_y
- ctx.fuse_mul_post = fuse_mul_post
- ctx.kernel_config_fwd = kernel_config_fwd
- ctx.kernel_config_bwd_dX = kernel_config_bwd_dX
- ctx.kernel_config_bwd_dW = kernel_config_bwd_dW
- ctx.autotune = autotune
- ctx.dX_only = dX_only
- ctx.dW_only = dW_only
-
- # NOTE: we don't save topk_weights for backward since we do not support training with fused_mul
- ctx.save_for_backward(X, W, m_sizes, gather_indices)
-
- fwd_config = {}
- if kernel_config_fwd is not None:
- fwd_config["BLOCK_SIZE_M"] = kernel_config_fwd.BLOCK_SIZE_M
- fwd_config["BLOCK_SIZE_N"] = kernel_config_fwd.BLOCK_SIZE_N
- fwd_config["BLOCK_SIZE_K"] = kernel_config_fwd.BLOCK_SIZE_K
- fwd_config["num_warps"] = kernel_config_fwd.num_warps
- fwd_config["num_stages"] = kernel_config_fwd.num_stages
- fwd_config["use_tma_load_x"] = kernel_config_fwd.use_tma_load_x
- fwd_config["use_tma_load_w"] = kernel_config_fwd.use_tma_load_w
- fwd_config["use_tma_store"] = kernel_config_fwd.use_tma_store
-
- return grouped_gemm_forward(
- X = X,
- W = W,
- topk = topk,
- m_sizes = m_sizes,
- gather_indices = gather_indices,
- topk_weights = topk_weights,
- permute_x = permute_x,
- permute_y = permute_y,
- fuse_mul_post = fuse_mul_post,
- # Autotune -- this will override the manual kernel config if true
- autotune = autotune,
- # Manual kernel config
- **fwd_config,
- )
-
- @staticmethod
- def backward(ctx, dY):
- dY = dY.contiguous()
- X, W, m_sizes, gather_indices = ctx.saved_tensors
- topk = ctx.topk
- permute_x = ctx.permute_x
- permute_y = ctx.permute_y
- fuse_mul_post = ctx.fuse_mul_post
- kernel_config_bwd_dX = ctx.kernel_config_bwd_dX
- kernel_config_bwd_dW = ctx.kernel_config_bwd_dW
- autotune = ctx.autotune
- dX_only = ctx.dX_only
- dW_only = ctx.dW_only
-
- if not autotune:
- if not dW_only:
- assert (
- kernel_config_bwd_dX is not None
- ), "kernel_config_bwd_dX must be provided if autotune is False"
- if not dX_only:
- assert (
- kernel_config_bwd_dW is not None
- ), "kernel_config_bwd_dW must be provided if autotune is False"
-
- assert (
- not fuse_mul_post
- ), "fused_mul should only be used for inference, not for training"
-
- if not dX_only:
- bwd_dW_config = {}
-
- if kernel_config_bwd_dW is not None:
- bwd_dW_config["use_tma_load_dy"] = kernel_config_bwd_dW.use_tma_load_dy
- bwd_dW_config["use_tma_load_x"] = kernel_config_bwd_dW.use_tma_load_x
- bwd_dW_config["use_tma_store"] = kernel_config_bwd_dW.use_tma_store
- bwd_dW_config["BLOCK_SIZE_M"] = kernel_config_bwd_dW.BLOCK_SIZE_M
- bwd_dW_config["BLOCK_SIZE_N"] = kernel_config_bwd_dW.BLOCK_SIZE_N
- bwd_dW_config["BLOCK_SIZE_K"] = kernel_config_bwd_dW.BLOCK_SIZE_K
- bwd_dW_config["num_warps"] = kernel_config_bwd_dW.num_warps
- bwd_dW_config["num_stages"] = kernel_config_bwd_dW.num_stages
-
- dW = grouped_gemm_dW(
- X = X,
- dY = dY,
- m_sizes = m_sizes,
- gather_indices = gather_indices,
- topk = topk,
- permute_x = permute_x,
- permute_y = permute_y,
- # Autotune -- this will override the manual kernel config if true
- autotune = autotune,
- # Manual kernel config
- **bwd_dW_config,
- )
- else:
- dW = None
-
- if not dW_only:
- bwd_dX_config = {}
- if kernel_config_bwd_dX is not None:
- bwd_dX_config["use_tma_load_dy"] = kernel_config_bwd_dX.use_tma_load_dy
- bwd_dX_config["use_tma_load_w"] = kernel_config_bwd_dX.use_tma_load_w
- bwd_dX_config["use_tma_store"] = kernel_config_bwd_dX.use_tma_store
- bwd_dX_config["BLOCK_SIZE_M"] = kernel_config_bwd_dX.BLOCK_SIZE_M
- bwd_dX_config["BLOCK_SIZE_N"] = kernel_config_bwd_dX.BLOCK_SIZE_N
- bwd_dX_config["BLOCK_SIZE_K"] = kernel_config_bwd_dX.BLOCK_SIZE_K
- bwd_dX_config["num_warps"] = kernel_config_bwd_dX.num_warps
- bwd_dX_config["num_stages"] = kernel_config_bwd_dX.num_stages
-
- dX = grouped_gemm_dX(
- dY = dY,
- W = W,
- m_sizes = m_sizes,
- gather_indices = gather_indices,
- topk = topk,
- permute_x = permute_x,
- permute_y = permute_y,
- # Autotune -- this will override the manual kernel config if true
- autotune = autotune,
- # Manual kernel config
- **bwd_dX_config,
- )
-
- if topk > 1 and permute_x:
- dX = dX.view(X.shape[0], topk, -1).sum(dim = 1)
- else:
- dX = None
-
- return (
- dX,
- dW,
- None, # m_sizes
- None, # gather_indices
- None, # topk
- None, # permute_x
- None, # permute_y
- None, # topk_weights
- None, # fuse_mul_post
- None, # kernel_config_fwd
- None, # kernel_config_bwd_dX
- None, # kernel_config_bwd_dW
- None, # autotune
- None, # dX_only
- None, # dW_only
- )
-
-
-def check_valid_config_fwd(
- permute_x,
- permute_y,
- use_tma_load_x,
- use_tma_load_w,
- use_tma_store,
- fuse_mul_post,
- is_first_gemm,
-):
- """
- Check if the configuration is valid for the forward pass.
- """
- is_second_gemm = not is_first_gemm
-
- assert not (permute_x and permute_y), "Cannot permute both X and Y"
- assert not (
- is_second_gemm and permute_x
- ), "Cannot permute X for the second grouped GEMM"
- assert not (
- is_first_gemm and permute_y
- ), "Cannot permute Y for the first grouped GEMM"
- assert not (
- fuse_mul_post and is_first_gemm
- ), "Cannot fuse mul for the first grouped GEMM"
- assert not (
- use_tma_load_x and permute_x
- ), "Cannot use TMA load and permute X unless on sm100+ (Blackwell+)"
- assert not (
- use_tma_store and permute_y and is_second_gemm
- ), "Cannot use TMA store and permute Y for the second grouped GEMM unless on sm100+ (Blackwell+)"
-
-
-def check_valid_config_bwd_dW(
- permute_x,
- permute_y,
- use_tma_load_dY,
- use_tma_load_x,
- use_tma_store,
- fuse_mul_post,
- is_first_gemm,
-):
- """
- Check if the configuration is valid for the backward pass of dW.
- """
- is_second_gemm = not is_first_gemm
- if fuse_mul_post:
- assert False, "Cannot fuse_mul is not supported for backward pass"
- if is_second_gemm and permute_y and use_tma_load_dY:
- assert False, "Cannot use TMA load and permute Y for the second grouped GEMM"
- if is_first_gemm and permute_x and use_tma_load_x:
- assert False, "Cannot use TMA load and permute X for the first grouped GEMM"
-
-
-def check_valid_config_bwd_dX(
- permute_x,
- permute_y,
- use_tma_load_dY,
- use_tma_load_w,
- use_tma_store,
- fuse_mul_post,
- is_first_gemm,
-):
- """
- Check if the configuration is valid for the backward pass of dW.
- """
- is_second_gemm = not is_first_gemm
- if fuse_mul_post:
- assert False, "Cannot fuse_mul is not supported for backward pass"
- if is_second_gemm and permute_y and use_tma_load_dY:
- assert False, "Cannot use TMA load and permute Y for the second grouped GEMM"
- if use_tma_store and permute_x and is_first_gemm:
- assert False, "Cannot use TMA store and permute X for the first grouped GEMM"
-
-
-def grouped_gemm(
- X: torch.Tensor,
- W: torch.Tensor,
- m_sizes: torch.Tensor,
- topk: int,
- gather_indices: torch.Tensor = None,
- permute_x: bool = False,
- permute_y: bool = False,
- topk_weights = None,
- fuse_mul_post = False,
- kernel_config_fwd: KernelConfigForward = None,
- kernel_config_bwd_dX: KernelConfigBackward_dX = None,
- kernel_config_bwd_dW: KernelConfigBackward_dW = None,
- autotune: bool = False,
- is_first_gemm: bool = True,
- # Only for debugging
- dX_only: bool = False,
- dW_only: bool = False,
-):
- """
- Grouped GEMM for MoE MLPs.
-
- The implementation offers a number of fusions specific to MoE:
- - `permute_x`: fuse the permutation of hidden states from token order (original order) to grouped expert order, typically only needed for the first grouped GEMM in an MoE MLP.
- - When `permute_x` is True, `X` is expected to be of shape (num_tokens, K).
- - When `permute_x` is False, `X` is expected to be of shape (total_tokens, K) where `total_tokens = num_tokens * topk` AND already permuted to grouped expert order, i.e., hidden states are sorted such that tokens assigned to each expert are contiguous.
- - `permute_y`: fused the permutation of the output from expert grouped order back to original token order, typically only needed for the second grouped GEMM in an MoE MLP.
- - `fuse_mul`: fuse the multiplication of the routed output with topk_weights, used only when `permute_y` is True. NOTE: this should only be used when using this kernel for inference, not for training.
-
- X: (M, K) hidden states where M is the num_tokens if `permute_x` is True, otherwise `total_tokens` where `total_tokens = num_tokens * topk`.
- W: (E, N, K) expert weights, where E is number of experts, N in the intermediate (output) dim, and K is the reduction dim
- m_sizes: tokens assigned to each expert which correspond to the size of M in the respective GEMMs in the grouped GEMM.
- gather_indices: (total_tokens,) indices of tokens assigned to each expert. E.g., slicing gather_indices by cumsum of m_sizes gives the indices of tokens assigned to each expert. Needed when either `permute_x` or `permute_y` is True.
- topk_weights: (total_tokens,) weights to multiply routed output by in expert MLP calculation, used only when `fuse_mul` is True (see note on `fuse_mul`).
- kernel_config_fwd: KernelConfigForward for forward pass.
- kernel_config_bwd_dX: KernelConfigBackward_dX for backward pass of dX.
- kernel_config_bwd_dW: KernelConfigBackward_dW for backward pass of dW.
- autotune: whether to autotune the kernel, if yes, kernel_config_fwd, kernel_config_bwd_dX, and kernel_config_bwd_dW will be ignored.
- is_first_gemm: whether this is the first grouped GEMM in an MoE MLP. This is needed to check whether kernel configs are valid. `permute_x` should only be used for first gemm; `permute_y` should only be used for second gemm.
- This will impact whether TMA can be used for loading and storing.
-
- """
- if not autotune:
- assert (
- kernel_config_fwd is not None
- ), "kernel_config_fwd must be provided if autotune is False"
-
- check_valid_config_fwd(
- permute_x,
- permute_y,
- use_tma_load_x = kernel_config_fwd.use_tma_load_x,
- use_tma_load_w = kernel_config_fwd.use_tma_load_w,
- use_tma_store = kernel_config_fwd.use_tma_store,
- fuse_mul_post = fuse_mul_post,
- is_first_gemm = is_first_gemm,
- )
- if kernel_config_bwd_dW is not None and not dX_only:
- check_valid_config_bwd_dW(
- permute_x,
- permute_y,
- use_tma_load_dY = kernel_config_bwd_dW.use_tma_load_dy,
- use_tma_load_x = kernel_config_bwd_dW.use_tma_load_x,
- use_tma_store = kernel_config_bwd_dW.use_tma_store,
- fuse_mul_post = fuse_mul_post,
- is_first_gemm = is_first_gemm,
- )
- if kernel_config_bwd_dX is not None and not dW_only:
- check_valid_config_bwd_dX(
- permute_x,
- permute_y,
- use_tma_load_dY = kernel_config_bwd_dX.use_tma_load_dy,
- use_tma_load_w = kernel_config_bwd_dX.use_tma_load_w,
- use_tma_store = kernel_config_bwd_dX.use_tma_store,
- fuse_mul_post = fuse_mul_post,
- is_first_gemm = is_first_gemm,
- )
-
- if permute_x or permute_y:
- assert (
- gather_indices is not None
- ), "gather_indices is required when either permute_x or permute_y is True"
-
- if fuse_mul_post:
- assert (
- topk_weights is not None
- ), "topk_weights is required when fuse_mul_post is True"
-
- X = X.view(-1, X.shape[-1])
- m_sizes = m_sizes.view(-1)
- gather_indices = gather_indices.view(-1)
-
- return GroupedGemm.apply(
- X,
- W,
- m_sizes,
- topk,
- gather_indices,
- permute_x,
- permute_y,
- topk_weights,
- fuse_mul_post,
- kernel_config_fwd,
- kernel_config_bwd_dX,
- kernel_config_bwd_dW,
- autotune,
- dX_only,
- dW_only,
- )
diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/__init__.py b/unsloth/kernels/moe/grouped_gemm/kernels/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py b/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py
deleted file mode 100644
index d25913975e..0000000000
--- a/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py
+++ /dev/null
@@ -1,441 +0,0 @@
-# Unsloth
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published
-# by the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-
-"""
-Autotuning utils
-"""
-
-import logging
-from itertools import product
-from typing import List
-
-import torch
-import triton
-
-logger = logging.getLogger(__name__)
-
-DEFAULT_M_BLOCK_SIZES = [64, 128]
-DEFAULT_N_BLOCK_SIZES = [64, 128, 256]
-DEFAULT_K_BLOCK_SIZES = [64, 128, 256]
-DEFAULT_NUM_CTAS = 1
-DEFAULT_NUM_WARPS = [4, 8]
-DEFAULT_NUM_STAGES = [3, 4, 5]
-BOOLS = [True, False]
-
-
-def val_to_list(val):
- if val is None:
- return None
- elif isinstance(val, list):
- return val
- else:
- return [val]
-
-
-def convert_args_to_list(args):
- return [val_to_list(arg) for arg in args]
-
-
-def _triton_supports_tma():
- """Check if current Triton version supports TMA API."""
- import triton.language as tl
-
- # Check for both old experimental and new stable API names
- return hasattr(tl, "make_tensor_descriptor") or hasattr(
- tl, "_experimental_make_tensor_descriptor"
- )
-
-
-# Precompute at module import
-# NOTE: TMA is disabled for now due to compatibility issues with permute_x/permute_y settings
-# in the MoE grouped GEMM forward/backward passes. Re-enable once these are resolved.
-_TRITON_HAS_TMA = False # _triton_supports_tma()
-
-
-def get_forward_configs(
- BLOCK_M = DEFAULT_M_BLOCK_SIZES,
- BLOCK_N = DEFAULT_N_BLOCK_SIZES,
- BLOCK_K = DEFAULT_K_BLOCK_SIZES,
- TMA_LOAD_X = None, # Auto-detect if not specified
- TMA_LOAD_W = None, # Auto-detect if not specified
- TMA_STORE = False, # NOTE: TMA_STORE is disabled for now
- num_warps = DEFAULT_NUM_WARPS,
- num_stages = DEFAULT_NUM_STAGES,
- num_ctas = DEFAULT_NUM_CTAS,
-):
- # Auto-detect TMA support
- if TMA_LOAD_X is None:
- TMA_LOAD_X = _TRITON_HAS_TMA
- if TMA_LOAD_W is None:
- TMA_LOAD_W = _TRITON_HAS_TMA
-
- (
- BLOCK_M,
- BLOCK_N,
- BLOCK_K,
- TMA_LOAD_X,
- TMA_LOAD_W,
- TMA_STORE,
- num_warps,
- num_stages,
- num_ctas,
- ) = convert_args_to_list(
- [
- BLOCK_M,
- BLOCK_N,
- BLOCK_K,
- TMA_LOAD_X,
- TMA_LOAD_W,
- TMA_STORE,
- num_warps,
- num_stages,
- num_ctas,
- ]
- )
- kernel_configs = []
- for (
- block_m,
- block_n,
- block_k,
- w,
- s,
- tma_load_x,
- tma_load_w,
- tma_store,
- num_ctas,
- ) in product(
- BLOCK_M,
- BLOCK_N,
- BLOCK_K,
- num_warps,
- num_stages,
- TMA_LOAD_X,
- TMA_LOAD_W,
- TMA_STORE,
- num_ctas,
- ):
- kernel_configs.append(
- triton.Config(
- dict(
- BLOCK_SIZE_M = block_m,
- BLOCK_SIZE_N = block_n,
- BLOCK_SIZE_K = block_k,
- USE_TMA_LOAD_X = tma_load_x,
- USE_TMA_LOAD_W = tma_load_w,
- USE_TMA_STORE = tma_store,
- ),
- num_warps = w,
- num_stages = s,
- num_ctas = num_ctas,
- )
- )
-
- return kernel_configs
-
-
-def get_dX_kernel_configs(
- BLOCK_M = DEFAULT_M_BLOCK_SIZES,
- BLOCK_N = DEFAULT_N_BLOCK_SIZES,
- BLOCK_K = DEFAULT_K_BLOCK_SIZES,
- TMA_LOAD_dY = None, # Auto-detect if not specified
- TMA_LOAD_W = None, # Auto-detect if not specified
- TMA_STORE = False, # NOTE: TMA_STORE is disabled for now
- num_warps = DEFAULT_NUM_WARPS,
- num_stages = DEFAULT_NUM_STAGES,
- num_ctas = DEFAULT_NUM_CTAS,
-):
- # Auto-detect TMA support
- if TMA_LOAD_dY is None:
- TMA_LOAD_dY = _TRITON_HAS_TMA
- if TMA_LOAD_W is None:
- TMA_LOAD_W = _TRITON_HAS_TMA
- (
- BLOCK_M,
- BLOCK_N,
- BLOCK_K,
- TMA_LOAD_dY,
- TMA_LOAD_W,
- TMA_STORE,
- num_warps,
- num_stages,
- num_ctas,
- ) = convert_args_to_list(
- [
- BLOCK_M,
- BLOCK_N,
- BLOCK_K,
- TMA_LOAD_dY,
- TMA_LOAD_W,
- TMA_STORE,
- num_warps,
- num_stages,
- num_ctas,
- ]
- )
- kernel_configs = []
- for (
- block_m,
- block_n,
- block_k,
- w,
- s,
- tma_load_dy,
- tma_load_w,
- tma_store,
- num_ctas,
- ) in product(
- BLOCK_M,
- BLOCK_N,
- BLOCK_K,
- num_warps,
- num_stages,
- TMA_LOAD_dY,
- TMA_LOAD_W,
- TMA_STORE,
- num_ctas,
- ):
- kernel_configs.append(
- triton.Config(
- dict(
- BLOCK_SIZE_M = block_m,
- BLOCK_SIZE_N = block_n,
- BLOCK_SIZE_K = block_k,
- USE_TMA_LOAD_dY = tma_load_dy,
- USE_TMA_LOAD_W = tma_load_w,
- USE_TMA_STORE = tma_store,
- ),
- num_warps = w,
- num_stages = s,
- num_ctas = num_ctas,
- )
- )
-
- return kernel_configs
-
-
-def get_dW_kernel_configs(
- BLOCK_M = DEFAULT_M_BLOCK_SIZES,
- BLOCK_N = DEFAULT_N_BLOCK_SIZES,
- BLOCK_K = DEFAULT_K_BLOCK_SIZES,
- num_warps = DEFAULT_NUM_WARPS,
- num_stages = DEFAULT_NUM_STAGES,
- num_ctas = DEFAULT_NUM_CTAS,
- TMA_LOAD_dY = None, # Auto-detect if not specified
- TMA_LOAD_X = None, # Auto-detect if not specified
- TMA_STORE = False,
-):
- # Auto-detect TMA support
- if TMA_LOAD_dY is None:
- TMA_LOAD_dY = _TRITON_HAS_TMA
- if TMA_LOAD_X is None:
- TMA_LOAD_X = _TRITON_HAS_TMA
- (
- BLOCK_M,
- BLOCK_N,
- BLOCK_K,
- num_warps,
- num_stages,
- num_ctas,
- TMA_LOAD_dY,
- TMA_LOAD_X,
- TMA_STORE,
- ) = convert_args_to_list(
- [
- BLOCK_M,
- BLOCK_N,
- BLOCK_K,
- num_warps,
- num_stages,
- num_ctas,
- TMA_LOAD_dY,
- TMA_LOAD_X,
- TMA_STORE,
- ]
- )
- kernel_configs = []
- for (
- block_m,
- block_n,
- block_k,
- w,
- s,
- tma_load_dy,
- tma_load_x,
- tma_store,
- num_ctas,
- ) in product(
- BLOCK_M,
- BLOCK_N,
- BLOCK_K,
- num_warps,
- num_stages,
- TMA_LOAD_dY,
- TMA_LOAD_X,
- TMA_STORE,
- num_ctas,
- ):
- kernel_configs.append(
- triton.Config(
- dict(
- BLOCK_SIZE_M = block_m,
- BLOCK_SIZE_N = block_n,
- BLOCK_SIZE_K = block_k,
- USE_TMA_LOAD_dY = tma_load_dy,
- USE_TMA_LOAD_X = tma_load_x,
- USE_TMA_STORE = tma_store,
- ),
- num_warps = w,
- num_stages = s,
- num_ctas = num_ctas,
- )
- )
-
- return kernel_configs
-
-
-def estimate_smem_reqs(
- num_stages: int,
- BLOCK_SIZE_M: int,
- BLOCK_SIZE_N: int,
- BLOCK_SIZE_K: int,
- dtype: torch.dtype,
-):
- num_bytes = dtype.itemsize
- return (
- num_stages * BLOCK_SIZE_K * (BLOCK_SIZE_M + BLOCK_SIZE_N)
- + BLOCK_SIZE_M * BLOCK_SIZE_N
- ) * num_bytes
-
-
-def exceeds_smem_capacity(
- num_stages: int,
- BLOCK_SIZE_M: int,
- BLOCK_SIZE_N: int,
- BLOCK_SIZE_K: int,
- dtype: torch.dtype,
- smem_size: int,
- slack: float = 50000,
-):
- smem_reqs = estimate_smem_reqs(
- num_stages, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, dtype
- )
- return smem_reqs > smem_size + slack
-
-
-def common_prune_criteria(config: triton.Config, kwargs: dict, dtype):
- from ..interface import supports_tma
- from .tuning import get_device_properties
-
- smem_size = get_device_properties().SIZE_SMEM
-
- num_stages = config.num_stages
- BLOCK_SIZE_M = config.kwargs["BLOCK_SIZE_M"]
- BLOCK_SIZE_N = config.kwargs["BLOCK_SIZE_N"]
- BLOCK_SIZE_K = config.kwargs["BLOCK_SIZE_K"]
-
- num_tokens = kwargs["NUM_TOKENS"]
- num_experts = kwargs["NUM_EXPERTS"]
- permute_x = kwargs["PERMUTE_X"]
- permute_y = kwargs["PERMUTE_Y"]
- tokens_per_expert = num_tokens // num_experts
-
- # use_tma = [k for k in config.kwargs.keys() if k.startswith("USE_TMA_")]
- MIN_BLOCK_SIZE_M = DEFAULT_M_BLOCK_SIZES[0]
- if exceeds_smem_capacity(
- num_stages, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, dtype, smem_size
- ):
- return True
- if BLOCK_SIZE_M > tokens_per_expert * 2 and tokens_per_expert > MIN_BLOCK_SIZE_M:
- return True
- if permute_x and permute_y:
- return True
- # if not supports_tma() and any(use_tma):
- # return True
- return False
-
-
-def maybe_disable_tma(config: triton.Config):
- from ..interface import supports_tma
-
- tma_keys = [k for k in config.kwargs.keys() if k.startswith("USE_TMA_")]
- if not supports_tma():
- logger.info("Disabling TMA")
- for k in tma_keys:
- config.kwargs[k] = False
-
-
-def prune_kernel_configs_fwd(configs: list[triton.Config], args, **kwargs):
- x = kwargs["x_ptr"]
- dtype = x.dtype
-
- logger.debug(f"Pruning configs: {len(configs)}")
-
- pruned_configs = []
- for config in configs:
- # disable TMA if gpu does not support it
- maybe_disable_tma(config)
-
- if common_prune_criteria(config, kwargs, dtype):
- continue
- if config.kwargs["USE_TMA_LOAD_X"] and kwargs["PERMUTE_X"]:
- # Dynamically disable TMA_LOAD_X for permuted X
- config.kwargs["USE_TMA_LOAD_X"] = False
- if config.kwargs["USE_TMA_STORE"] and kwargs["PERMUTE_Y"]:
- continue
-
- pruned_configs.append(config)
-
- logger.debug(f"Pruned configs: {len(pruned_configs)}")
- return pruned_configs
-
-
-def prune_dX_configs(configs: List[triton.Config], args, **kwargs):
- dtype = kwargs["w_ptr"].dtype
-
- logger.debug(f"Pruning configs: {len(configs)}")
- pruned_configs = []
-
- for config in configs:
- if common_prune_criteria(config, kwargs, dtype):
- continue
- if config.kwargs["USE_TMA_LOAD_dY"] and kwargs["PERMUTE_Y"]:
- # dynamically disable TMA_LOAD_dY for permuted Y
- config.kwargs["USE_TMA_LOAD_dY"] = False
- if config.kwargs["USE_TMA_STORE"] and kwargs["PERMUTE_X"]:
- continue
- pruned_configs.append(config)
-
- logger.debug(f"Pruned configs: {len(pruned_configs)}")
- return pruned_configs
-
-
-def prune_kernel_configs_backward_dW(configs: list[triton.Config], args, **kwargs):
- dtype = kwargs["x_ptr"].dtype
-
- pruned_configs = []
- logger.debug(f"Pruning configs: {len(configs)}")
-
- for config in configs:
- if common_prune_criteria(config, kwargs, dtype):
- continue
- if config.kwargs["USE_TMA_LOAD_dY"] and kwargs["PERMUTE_Y"]:
- config.kwargs["USE_TMA_LOAD_dY"] = False
- if config.kwargs["USE_TMA_LOAD_X"] and kwargs["PERMUTE_X"]:
- config.kwargs["USE_TMA_LOAD_X"] = False
- pruned_configs.append(config)
-
- logger.debug(f"Pruned configs: {len(pruned_configs)}")
- return pruned_configs
diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/backward.py b/unsloth/kernels/moe/grouped_gemm/kernels/backward.py
deleted file mode 100644
index 5e07056b52..0000000000
--- a/unsloth/kernels/moe/grouped_gemm/kernels/backward.py
+++ /dev/null
@@ -1,505 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-import torch
-import triton
-import triton.language as tl
-
-from .autotuning import (
- get_dW_kernel_configs,
- get_dX_kernel_configs,
- prune_dX_configs,
- prune_kernel_configs_backward_dW,
-)
-
-"""
-dX backward kernel
-
-- Shapes
- - the forward pass input X shape is [NUM_TOKENS, K] if permute_x else [NUM_TOKENS * TOPK, K]; output y is [NUM_TOKENS * TOPK, N]
- - the backward pass input dy shape is [NUM_TOKENS * TOPK, N], reduce across N, output dX is [NUM_TOKENS * TOPK, K]
-- Note that in the backward pass, the output size is still [NUM_TOKENS * TOPK, K] since we still need to accumulate gradients for each expert chosen by the token in a post-processing step.
-
-`permute_x` notes:
-- In the forward pass, if we permute X on load, we need to permute on store in the backward pass to restore to original token order
-- the output dX with have shape [NUM_TOKENS * TOPK, K] and we need to perform an additional reduction across topk to accumulate gradients
-- This is done as a post-processing step in autograd.Function.
-- If not `permute_x`, this postprocessing step should take place outside autograd.Function such that the gradient shape matches the input X shape.
-
-`permute_y` notes:
-- In the forward pass, if we permuted output on store (e.g., in the second grouped GEMM in fused MoE MLP), we need to permute on load to get from token order to expert grouped order
-- We still store in contiguous order since we are writing out dX which will be the input to the backwards pass of the first grouped GEMM
-
-`fused_mul` notes:
-- In the forward pass, if we used the multiplication of topk weights (e.g., in the second grouped GEMM in fused MoE MLP), we need to make a few additional changes:
- 1) We load topk_weights in natural (token) order. Since we only enable `fuse_mul` when permuting on store (`permute_y`), we multiply grad_output by topk_weights before backpropagating
- 2) We need to calculate the gradient of the topk_weights. This gets messy since we need do an additional elementwise multiplication in the GEMM main loop and then write out in unpermuted order. For now, we do not fuse this step but calculate as a simple
-
-Invalid combinations:
-- permute_y and use_tma_load: permuting y on store in forward -> load in permuted order in backward, therefore can't use TMA load (unless Blackwell which supports gather / scatter TMA)
-- permute_x and use_tma_store: permuting x on load in forward -> store in permuted order in backward, therefore can't use TMA store (unless Blackwell which supports gather / scatter TMA)
-
-TODO:
-- We define indices for all conditions and expect that unused indices will be DCE'd during compilation. Check that this is the case otherwise will result in unnecessary register usage.
-"""
-
-
-@triton.jit
-def _grouped_gemm_dX_kernel(
- dY_ptr, # [M_total, N]
- w_ptr, # [E, N, K]
- dX_ptr, # [M_total, K]
- gather_indices_ptr,
- m_sizes_ptr,
- # problem sizes
- NUM_EXPERTS: tl.constexpr,
- NUM_TOKENS,
- TOPK: tl.constexpr,
- N: tl.constexpr,
- K: tl.constexpr,
- NUM_SMS,
- # Tuning parameters
- BLOCK_SIZE_M: tl.constexpr,
- BLOCK_SIZE_N: tl.constexpr,
- BLOCK_SIZE_K: tl.constexpr,
- PERMUTE_X: tl.constexpr = False,
- PERMUTE_Y: tl.constexpr = False,
- USE_TMA_LOAD_W: tl.constexpr = False,
- USE_TMA_LOAD_dY: tl.constexpr = False,
- USE_TMA_STORE: tl.constexpr = False,
- FLATTEN: tl.constexpr = True,
-) -> None:
- TOTAL_TOKENS = NUM_TOKENS * TOPK
- output_dtype = dX_ptr.dtype.element_ty
-
- tidx = tl.program_id(0)
- # This removes the need for predication along N in the GEMM main loop
- tl.static_assert(N % BLOCK_SIZE_N == 0, "N must be divisible by BLOCK_SIZE_N")
- tl.static_assert(K % BLOCK_SIZE_K == 0, "K must be divisible by BLOCK_SIZE_K")
-
- # Create TMA descriptors for loading sorted tokens
- # When using TMA load, we don't permute_x, so shape should be [TOTAL_TOKENS, K]
- # Also, we are defining a single global descriptor with single block shape
- # Need to check that this does not result in errors when crossing expert boundaries
- if USE_TMA_LOAD_dY:
- dY_desc = tl.make_tensor_descriptor(
- dY_ptr,
- shape = [TOTAL_TOKENS, N],
- strides = [N, 1],
- block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
- )
-
- if USE_TMA_LOAD_W:
- expert_stride = N * K
- w_desc = tl.make_tensor_descriptor(
- w_ptr,
- shape = [NUM_EXPERTS, N, K],
- strides = [expert_stride, K, 1],
- block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
- )
-
- m_end = 0
- processed_tiles = 0
- m_block_range = tl.arange(0, BLOCK_SIZE_M)
- n_block_range = tl.arange(0, BLOCK_SIZE_N)
- k_block_range = tl.arange(0, BLOCK_SIZE_K)
-
- for expert_idx in range(NUM_EXPERTS, flatten = FLATTEN):
- m_start = m_end
- m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)
- m_end = m_start + m_size
-
- if m_size > 0:
- # Advance n offset to the weights for that respective expert
- n_start = expert_idx * N
- # N_start_offset = g.to(tl.int64) * N
- # tiles for this group's GEMM
- num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
- num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
- num_tiles_per_expert = num_m_tiles * num_k_tiles
-
- if USE_TMA_STORE:
- # Need to define descript within loop to predicate store along M
- tl.static_assert(
- K % BLOCK_SIZE_K == 0, "K must be divisible by BLOCK_SIZE_K"
- )
- dX_desc = tl.make_tensor_descriptor(
- dX_ptr,
- shape = [m_end, K],
- strides = [K, 1],
- block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
- )
-
- # Lower bound and upper bound are defined relative to the total tiles processed so far
- # This ensures that we are only processing tiles for the current expert group AND
- # we never exceed the total number of tiles for all expert groups
- while tidx >= processed_tiles and tidx < (
- processed_tiles + num_tiles_per_expert
- ):
- group_index = tidx - processed_tiles
-
- # Output tile for this thread block for this expert group
- tile_m_idx = group_index % num_m_tiles
- tile_k_idx = group_index // num_m_tiles
-
- if PERMUTE_X or PERMUTE_Y:
- # These will be used for loading and storing in permuted order
- gather_offsets = tile_m_idx * BLOCK_SIZE_M + m_block_range
- # indices_to_gather = m_start + gather_offsets
- indices_to_gather = m_start + tl.max_contiguous(
- tl.multiple_of(gather_offsets % m_size, BLOCK_SIZE_M),
- BLOCK_SIZE_M,
- )
- expert_token_idx = tl.load(
- gather_indices_ptr + indices_to_gather,
- mask = indices_to_gather < TOTAL_TOKENS,
- )
- expert_token_offsets = expert_token_idx[:, None]
-
- # Masks for permuted load and store
- row_mask = gather_offsets < m_size
- row_mask = row_mask[:, None]
-
- # We only take into account the following two cases: (PERMUTE_X and NOT PERMUTE_Y) and (NOT PERMUTE_X and PERMUTE_Y)
- # Hence, we can make the following simplifying assumptions when loading and storing
- # Note the different strides between the two cases: the offsets for loading and storing are flipped and the strides must also be adjusted
-
- if PERMUTE_X:
- # Case where we permuted on load in the forward pass (typically first grouped GEMM in MoE MLP)
- load_a_idx = (
- indices_to_gather[:, None] * N
- ) # Load in contiguous (expert grouped) order
- store_idx = (
- expert_token_offsets * K
- ) # Permute on store from expert -> token order
- else:
- # Case where we permuted on store in the forward pass (typically second grouped GEMM in MoE MLP)
- load_a_idx = (
- expert_token_offsets * N
- ) # Permute on load from token -> expert order
- store_idx = (
- indices_to_gather[:, None] * K
- ) # Store in contiguous order
- else:
- # # Position in full matrix - needed for TMA
- # m_offset = (M_start + (tile_m_idx * BLOCK_SIZE_M)).to(tl.int32)
- # k_offset = (tile_k_idx * BLOCK_SIZE_K).to(tl.int32)
- # Offsets *relative* to the *current* expert -- m_start will then advance to this expert's start token
- offs_am = tile_m_idx * BLOCK_SIZE_M + m_block_range
-
- # [M, N] @ [N, K] -> [M, K] => Stride for A is N, stride for B is K
- # We need two additional offsets:
- # 1. For A, m_start to advance to this expert's start token
- # 2. For B, n_start to advance to this expert's weights since we are passing in an [E, N, K] weight matrix
- row_offsets_a = m_start + offs_am[:, None]
- load_a_idx = row_offsets_a * N
- store_idx = row_offsets_a * K
- row_mask = offs_am[:, None] < m_size
-
- if not USE_TMA_LOAD_dY:
- dY_ptrs = dY_ptr + load_a_idx + n_block_range[None, :]
-
- offs_bk = tile_k_idx * BLOCK_SIZE_K + k_block_range
- if not USE_TMA_LOAD_W:
- row_offsets_b = n_start + n_block_range
- # offs_bn = n_start + n_block_range
- # row_offsets_b = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
- w_ptrs = w_ptr + row_offsets_b[:, None] * K + offs_bk[None, :]
-
- # TODO: check whether predication along K is needed since we checked that K is divisible by BLOCK_SIZE_K in the forward kernel
- # col_mask = offs_bk[None, :] < K
- store_mask = row_mask # & col_mask
-
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype = tl.float32)
-
- # GEMM main loop
- for n_offset in range(0, N, BLOCK_SIZE_N):
- # dY block [M, N]
- if not USE_TMA_LOAD_dY:
- dY = tl.load(dY_ptrs, mask = row_mask)
- else:
- dY = dY_desc.load(
- [m_start + tile_m_idx * BLOCK_SIZE_M, n_offset]
- )
-
- if not USE_TMA_LOAD_W:
- w = tl.load(w_ptrs) # , mask=col_mask)
- else:
- w = w_desc.load(
- [expert_idx, n_offset, tile_k_idx * BLOCK_SIZE_K]
- )
- w = tl.reshape(w, (BLOCK_SIZE_N, BLOCK_SIZE_K))
- # TODO: check if predication along K is needed since we checked that K is divisible by BLOCK_SIZE_K in the forward kernel
-
- # [M, N] @ [N, K] -> [M, K]
- dY = dY.to(w.dtype)
- accumulator += tl.dot(dY, w) # NOTE: no transpose of b
-
- # Advance A along contiguous dimension
- if not USE_TMA_LOAD_dY:
- dY_ptrs += BLOCK_SIZE_N
- # Note we are no longer advancing B along contiguous dimension since weights are arranged as [N, K]
- # Instead, we need to stride by K to advance to the [N_BLOCK_SIZE, K_BLOCK_SIZE] tile
- if not USE_TMA_LOAD_W:
- w_ptrs += BLOCK_SIZE_N * K
-
- dX = accumulator.to(output_dtype)
-
- # Writing out a BLOCK_M x BLOCK_K tile, so we need to stride by K
- if USE_TMA_STORE:
- offset_m = tile_m_idx * BLOCK_SIZE_M # .to(tl.int32)
- offset_k = tile_k_idx * BLOCK_SIZE_K # .to(tl.int32)
- dX_desc.store([m_start + offset_m, offset_k], dX)
- else:
- tl.store(
- dX_ptr + store_idx + offs_bk[None, :],
- dX,
- mask = store_mask,
- )
-
- # Move to the next tile within this expert group
- tidx += NUM_SMS
-
- # Update the total tiles count for the next expert group
- processed_tiles += num_tiles_per_expert
-
-
-_autotuned_grouped_gemm_dX_kernel = triton.autotune(
- configs = get_dX_kernel_configs(),
- prune_configs_by = {"early_config_prune": prune_dX_configs},
- # NOTE: NUM_TOKENS removed from key to avoid recompilation for every sequence length
- key = ["NUM_EXPERTS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
-)(_grouped_gemm_dX_kernel)
-
-"""
-notes on permute_x:
-- for the first grouped GEMM, we permuted on load -> X was [num_tokens, K] and stored y in expert grouped order [num_tokens * topk, K]
-- in the backwards pass, we need to permute on load of X while loading dy in contiguous (expert grouped) order
-- since we are writing out dW, there is no need to permute on store
-
-notes on permute_y:
-- for the second grouped GEMM, we permuted on store -> y was permuted from expert grouped order to token order, x was loaded in expert grouped order since it was the output of the first grouped GEMM
-- in the backwards pass, we need to permute on load of dy to get from token order to expert grouped order to match the order of X
-- since we are writing out dW, there is no need to permute on store
-
-notes on TMA loading:
-- if we're TMA loading both X and dY, then we need to mask along the M dimension
-to account for expert boundaries
-- we can either
- - define TMA descriptors within the outer for loop to predicate loads
- or
- - mask along M after loading
-"""
-
-
-@triton.jit
-def _grouped_gemm_dW_kernel(
- x_ptr,
- dY_ptr,
- dW_ptr,
- m_sizes_ptr,
- gather_indices_ptr,
- # problem sizes
- NUM_TOKENS,
- TOPK: tl.constexpr,
- NUM_EXPERTS: tl.constexpr,
- N: tl.constexpr,
- K: tl.constexpr,
- NUM_SMS,
- BLOCK_SIZE_N: tl.constexpr,
- BLOCK_SIZE_K: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr,
- PERMUTE_X: tl.constexpr = False,
- PERMUTE_Y: tl.constexpr = False,
- USE_TMA_LOAD_dY: tl.constexpr = False,
- USE_TMA_LOAD_X: tl.constexpr = False,
- USE_TMA_STORE: tl.constexpr = False,
- FLATTEN: tl.constexpr = True,
- acc_dtype: tl.constexpr = tl.float32,
-) -> None:
- TOTAL_TOKENS = NUM_TOKENS * TOPK
- TMA_LOAD_BOTH: tl.constexpr = USE_TMA_LOAD_X and USE_TMA_LOAD_dY
-
- tidx = tl.program_id(0)
- output_dtype = dW_ptr.dtype.element_ty
-
- if USE_TMA_LOAD_dY and not TMA_LOAD_BOTH:
- dY_desc = tl.make_tensor_descriptor(
- dY_ptr,
- shape = [TOTAL_TOKENS, N],
- strides = [N, 1],
- block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
- )
-
- if USE_TMA_LOAD_X and not TMA_LOAD_BOTH:
- x_desc = tl.make_tensor_descriptor(
- x_ptr,
- shape = [TOTAL_TOKENS, K],
- strides = [K, 1],
- block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
- )
- # Output tiles per expert, since each expert weight matrix is [N, K]
- num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
- num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
- output_tiles_per_expert = num_n_tiles * num_k_tiles
-
- block_range_m = tl.arange(0, BLOCK_SIZE_M)
- block_range_n = tl.arange(0, BLOCK_SIZE_N)
- block_range_k = tl.arange(0, BLOCK_SIZE_K)
-
- # NOTE: Important that N % BLOCK_SIZE_N == 0 and K % BLOCK_SIZE_K == 0 when using TMA store
- if USE_TMA_STORE:
- tl.static_assert(N % BLOCK_SIZE_N == 0, "N must be divisible by BLOCK_SIZE_N")
- tl.static_assert(K % BLOCK_SIZE_K == 0, "K must be divisible by BLOCK_SIZE_K")
- dW_desc = tl.make_tensor_descriptor(
- dW_ptr,
- shape = [NUM_EXPERTS, N, K],
- strides = [N * K, K, 1],
- block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
- )
-
- for tile_idx in range(
- tidx, output_tiles_per_expert, NUM_SMS
- ): # , flatten=FLATTEN):
- # Output tile index
- tile_n_idx = tile_idx % num_n_tiles
- tile_k_idx = tile_idx // num_n_tiles
-
- # Output tile offsets
- n_offset = tile_n_idx * BLOCK_SIZE_N
- k_offset = tile_k_idx * BLOCK_SIZE_K
-
- # For storing
- # TODO: Check whether the k mask is needed since we statically check that K is divisible by BLOCK_SIZE_K in the forward kernel
- # ditto for n_mask
- n_mask = block_range_n + n_offset < N
- k_mask = block_range_k + k_offset < K
- nk_mask = n_mask[:, None] & k_mask[None, :]
-
- m_end = 0
- for expert_idx in range(NUM_EXPERTS):
- # We need to instantiate a fresh accumulator for each expert
- accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype = acc_dtype)
-
- m_start = m_end
- # Need to figure out why this cast is needed, otherwise compiler complains about mismatching types
- m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)
- m_end = m_start + m_size
-
- # NOTE: when storing the result, we need to offset by n_start since we are storing the result for this expert to the global [E, N, K] weight matrix
- n_start = expert_idx * N
- store_row_offs = n_start + n_offset + block_range_n
-
- if m_size > 0:
- if TMA_LOAD_BOTH:
- dY_desc = tl.make_tensor_descriptor(
- dY_ptr,
- shape = [m_end, N],
- strides = [N, 1],
- block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
- )
-
- x_desc = tl.make_tensor_descriptor(
- x_ptr,
- shape = [m_end, K],
- strides = [K, 1],
- block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
- )
-
- for tile_m_idx in range(0, m_size, BLOCK_SIZE_M):
- m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - tile_m_idx)
-
- if m_block_size > 0:
- # Global offset for this chunk
- m_global_offset = m_start + tile_m_idx
- m_offsets = m_global_offset + block_range_m
-
- if PERMUTE_X or PERMUTE_Y:
- # These will be used for loading and storing in permuted order
- gather_offsets = (
- tile_m_idx + block_range_m
- ) # NOTE: tile_m_idx is already strided by BLOCK_SIZE_M
-
- indices_to_gather = m_start + tl.max_contiguous(
- tl.multiple_of(gather_offsets % m_size, BLOCK_SIZE_M),
- BLOCK_SIZE_M,
- )
- # indices_to_gather = m_start + gather_offsets
- expert_token_idx = tl.load(
- gather_indices_ptr + indices_to_gather,
- mask = indices_to_gather < TOTAL_TOKENS,
- )
- expert_token_offsets = expert_token_idx[:, None]
-
- # Masks for permuted load and store
- row_load_mask = gather_offsets < m_size
-
- # We only take into account the following two cases: (PERMUTE_X and NOT PERMUTE_Y) and (NOT PERMUTE_X and PERMUTE_Y)
- # Hence, we can make the following simplifying assumptions when loading and storing
- # Note the different strides between the two cases: the offsets for loading and storing are flipped and the strides must also be adjusted
- if PERMUTE_X:
- x_row_load_idx = (
- (expert_token_offsets // TOPK) * K
- ) # Permute on load from token -> expert order, divide by TOPK to index from original number of tokens
- dY_row_load_idx = m_offsets[:, None] * N
- else:
- x_row_load_idx = (
- indices_to_gather[:, None] * K
- ) # Load in contiguous order (no permutation on load)
- dY_row_load_idx = expert_token_offsets * N
-
- else:
- x_row_load_idx = m_offsets[:, None] * K
- dY_row_load_idx = m_offsets[:, None] * N
- row_load_mask = block_range_m < m_block_size
-
- mk_mask = row_load_mask[:, None] & k_mask[None, :]
- mn_mask = row_load_mask[:, None] & n_mask[None, :]
-
- if USE_TMA_LOAD_X:
- x = x_desc.load([m_global_offset, k_offset])
- else:
- x = tl.load(
- x_ptr
- + x_row_load_idx
- + (k_offset + block_range_k)[None, :],
- mask = mk_mask,
- )
-
- if USE_TMA_LOAD_dY:
- dY = dY_desc.load([m_global_offset, n_offset])
- else:
- dY = tl.load(
- dY_ptr
- + dY_row_load_idx
- + (n_offset + block_range_n)[None, :],
- mask = mn_mask,
- )
-
- accumulator += tl.dot(
- dY.T.to(x.dtype), # [BLOCK_N, BLOCK_M]
- x, # [BLOCK_M, BLOCK_K]
- )
-
- y = accumulator.to(output_dtype)
- if USE_TMA_STORE:
- # Need to expand dims to match [E, N, K] shape
- y = tl.expand_dims(y, 0)
- dW_desc.store([expert_idx, n_offset, k_offset], y)
- else:
- tl.store(
- dW_ptr
- # + (n_offset + offs_n)[:, None] * K
- + store_row_offs[:, None] * K
- + (k_offset + block_range_k)[None, :],
- y,
- mask = nk_mask,
- )
-
-
-_autotuned_grouped_gemm_dW_kernel = triton.autotune(
- configs = get_dW_kernel_configs(),
- prune_configs_by = {"early_config_prune": prune_kernel_configs_backward_dW},
- # NOTE: NUM_TOKENS removed from key to avoid recompilation for every sequence length
- key = ["NUM_EXPERTS", "N", "K", "PERMUTE_X", "PERMUTE_Y"],
-)(_grouped_gemm_dW_kernel)
diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/forward.py b/unsloth/kernels/moe/grouped_gemm/kernels/forward.py
deleted file mode 100644
index a42ec5ffe9..0000000000
--- a/unsloth/kernels/moe/grouped_gemm/kernels/forward.py
+++ /dev/null
@@ -1,267 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-import torch
-import triton
-import triton.language as tl
-
-from .autotuning import (
- get_forward_configs,
- prune_kernel_configs_fwd,
-)
-
-
-#
-# PERMUTE_X -> permute tokens so that they are ordered by expert
-# PERMUTE_Y -> permute output so that they are ordered by token
-# These are effectively the same thing: the former loads in permuted order, the latter stores in permuted order => we only need to define the permutation indices once
-# In the former, we use these row indices when loading X
-# For the latter, we use these row indices when storing Y
-# FUSE_MUL -> multiply routed outputs by their respective weights
-# topk_weights are in token order
-# Only account for the case when X is in expert order and we are permuting Y when fusing mul -- this precondition is checked in the interface
-@triton.jit
-def _grouped_gemm_forward_kernel(
- x_ptr,
- w_ptr,
- y_ptr,
- # Variable depending on routed probs
- m_sizes_ptr,
- gather_indices_ptr,
- topk_weights_ptr,
- # Constant problem shapes
- NUM_EXPERTS: tl.constexpr,
- NUM_TOKENS,
- TOPK: tl.constexpr,
- N: tl.constexpr,
- K: tl.constexpr,
- NUM_SMS,
- # Tuning params
- BLOCK_SIZE_M: tl.constexpr,
- BLOCK_SIZE_N: tl.constexpr,
- BLOCK_SIZE_K: tl.constexpr,
- PERMUTE_X: tl.constexpr = False,
- PERMUTE_Y: tl.constexpr = False,
- FUSE_MUL_PRE: tl.constexpr = False,
- FUSE_MUL_POST: tl.constexpr = False,
- USE_FAST_ACCUM: tl.constexpr = False,
- USE_TMA_LOAD_W: tl.constexpr = False,
- USE_TMA_LOAD_X: tl.constexpr = False,
- USE_TMA_STORE: tl.constexpr = False,
- acc_dtype: tl.constexpr = tl.float32,
- FLATTEN: tl.constexpr = True,
-) -> None:
- tl.static_assert(K % BLOCK_SIZE_K == 0)
-
- TOTAL_TOKENS = NUM_TOKENS * TOPK
- SHOULD_PERMUTE: tl.constexpr = PERMUTE_X or PERMUTE_Y
- SHOULD_FUSE_MUL: tl.constexpr = FUSE_MUL_PRE or FUSE_MUL_POST
- SHOULD_PERMUTE_OR_FUSE: tl.constexpr = SHOULD_PERMUTE or SHOULD_FUSE_MUL
- # tl.static_print("SHOULD_PERMUTE", PERMUTE_X, PERMUTE_Y, FUSE_MUL_PRE, FUSE_MUL_POST, SHOULD_PERMUTE, SHOULD_FUSE, SHOULD_PERMUTE_OR_FUSE)
- tidx = tl.program_id(0)
- output_dtype: tl.dtype = y_ptr.dtype.element_ty
-
- # Create TMA descriptors for loading sorted tokens
- # When using TMA load, we don't permute_x, so shape should be [TOTAL_TOKENS, K]
- # Also, we are defining a single global descriptor with single block shape
- # Need to check that this does not result in errors when crossing expert boundaries
- if USE_TMA_LOAD_X:
- x_desc = tl.make_tensor_descriptor(
- x_ptr,
- shape = [TOTAL_TOKENS, K],
- strides = [K, 1],
- block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_K],
- )
-
- if USE_TMA_LOAD_W:
- expert_stride = N * K
- w_desc = tl.make_tensor_descriptor(
- w_ptr,
- shape = [NUM_EXPERTS, N, K],
- strides = [expert_stride, K, 1],
- block_shape = [1, BLOCK_SIZE_N, BLOCK_SIZE_K],
- )
-
- m_end = 0
- processed_tiles = 0
- m_block_range = tl.arange(0, BLOCK_SIZE_M)
-
- for expert_idx in tl.range(NUM_EXPERTS, flatten = FLATTEN):
- m_start = m_end
- m_size = tl.load(m_sizes_ptr + expert_idx).to(tl.int32)
- m_end = m_start + m_size
-
- if m_size > 0:
- n_start = expert_idx * N
-
- num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
- num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
- num_tiles_per_expert = num_m_tiles * num_n_tiles
-
- # Need to create tma_store within loop since we need to predicate stores based on m_size
- if USE_TMA_STORE:
- y_desc = tl.make_tensor_descriptor(
- y_ptr, # + m_start * N,
- shape = [m_end, N],
- strides = [N, 1],
- block_shape = [BLOCK_SIZE_M, BLOCK_SIZE_N],
- )
-
- # Process tiles for this expert
- while (
- tidx >= processed_tiles
- and tidx < processed_tiles + num_tiles_per_expert
- ):
- tile_idx = tidx - processed_tiles
-
- # Check if L2 cache re-use for this order is optimal
- tile_m_idx = tile_idx % num_m_tiles
- tile_n_idx = tile_idx // num_m_tiles
-
- if SHOULD_PERMUTE_OR_FUSE:
- # These will be used for loading and storing in permuted order
- gather_offsets = tile_m_idx * BLOCK_SIZE_M + m_block_range
- indices_to_gather = m_start + tl.max_contiguous(
- tl.multiple_of(gather_offsets % m_size, BLOCK_SIZE_M),
- BLOCK_SIZE_M,
- )
- expert_token_idx = tl.load(
- gather_indices_ptr + indices_to_gather,
- mask = indices_to_gather < TOTAL_TOKENS,
- )
- expert_token_offsets = expert_token_idx[:, None]
-
- # Masks for permuted load and store
-
- row_mask = gather_offsets < m_size
- row_mask = row_mask[:, None]
-
- # row_mask = indices_to_gather < m_end
- # row_mask = row_mask[:, None]
-
- # We only take into account the following two cases: (PERMUTE_X and NOT PERMUTE_Y) and (NOT PERMUTE_X and PERMUTE_Y)
- # Hence, we can make the following simplifying assumptions when loading and storing
- # Note the different strides between the two cases: the offsets for loading and storing are flipped and the strides must also be adjusted
- if PERMUTE_X:
- load_idx = (
- (expert_token_offsets // TOPK) * K
- ) # Permute on load from token -> expert order, divide by TOPK to index from original number of tokens
- store_idx = (
- indices_to_gather[:, None] * N
- ) # Store in contiguous order
- else:
- off_am = tile_m_idx * BLOCK_SIZE_M
- if not PERMUTE_Y:
- # These will already be computed if permuting y
- offs_am = off_am + m_block_range
- row_mask = offs_am[:, None] < m_size
- row_idx = m_start + offs_am[:, None]
- store_idx = row_idx * N
- if not USE_TMA_LOAD_X:
- load_idx = row_idx * K
-
- if PERMUTE_Y:
- if not USE_TMA_LOAD_X:
- load_idx = (
- indices_to_gather[:, None] * K
- ) # Load in contiguous order (no permutation on load)
- # offs_am = off_am + m_block_range
- # row_mask = offs_am[:, None] < m_size
- store_idx = (
- expert_token_offsets * N
- ) # Permute on store from expert -> token order
-
- # We always load topk weights in expert order
- # In the pre-multiplication case, we multiply permuted hidden states by weights before the first gemm
- # In the post-multiplication case, we multiply permuted hidden states by weights after the second gemm
- # In either case, the hidden states are grouped by expert, so we always permute on load of topk weights
- if SHOULD_FUSE_MUL:
- topk_load_idx = expert_token_offsets
-
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype = acc_dtype)
-
- offs_k = tl.arange(0, BLOCK_SIZE_K)
-
- if not USE_TMA_LOAD_X:
- x_ptrs = x_ptr + load_idx + offs_k[None, :]
-
- if not USE_TMA_LOAD_W:
- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- offs_bn = tl.max_contiguous(
- tl.multiple_of(offs_bn % N, BLOCK_SIZE_N), BLOCK_SIZE_N
- )
- w_ptrs = w_ptr + (n_start + offs_bn[:, None]) * K + offs_k[None, :]
-
- for k_offset in range(0, K, BLOCK_SIZE_K):
- if not USE_TMA_LOAD_X:
- x = tl.load(x_ptrs, mask = row_mask)
- else:
- x = x_desc.load([m_start + off_am, k_offset])
-
- if FUSE_MUL_PRE:
- # Check for correct broadcasting
- topk_weights = tl.load(
- topk_weights_ptr + topk_load_idx, mask = row_mask
- )
- x *= topk_weights.to(x.dtype)
-
- if not USE_TMA_LOAD_W:
- w = tl.load(w_ptrs, mask = offs_bn[:, None] < N)
- else:
- w = w_desc.load(
- [expert_idx, tile_n_idx * BLOCK_SIZE_N, k_offset]
- )
- w = tl.reshape(w, (BLOCK_SIZE_N, BLOCK_SIZE_K))
-
- x = x.to(w.dtype)
- accumulator += tl.dot(x, w.T)
-
- if not USE_TMA_LOAD_X:
- x_ptrs += BLOCK_SIZE_K
-
- if not USE_TMA_LOAD_W:
- w_ptrs += BLOCK_SIZE_K
-
- y = accumulator.to(output_dtype)
-
- # NOTE: order of fusing multiplication is important
- # Fusing before accumulator dtype conversion results in numerical diffs
- if FUSE_MUL_POST:
- # Check for correct broadcasting
- topk_weights = tl.load(
- topk_weights_ptr + topk_load_idx, mask = row_mask
- )
- y *= topk_weights.to(output_dtype)
-
- offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- store_mask = row_mask & (offs_bn[None, :] < N)
-
- if USE_TMA_STORE:
- offset_m = tile_m_idx * BLOCK_SIZE_M # .to(tl.int32)
- offset_n = tile_n_idx * BLOCK_SIZE_N # .to(tl.int32)
- y_desc.store([m_start + offset_m, offset_n], y)
- else:
- tl.store(
- y_ptr + store_idx + offs_bn[None, :],
- y,
- mask = store_mask,
- )
- tidx += NUM_SMS
-
- processed_tiles += num_tiles_per_expert
-
-
-_autotuned_grouped_gemm_forward_kernel = triton.autotune(
- configs = get_forward_configs(),
- prune_configs_by = {"early_config_prune": prune_kernel_configs_fwd},
- # NOTE: NUM_TOKENS removed from key to avoid recompilation for every sequence length
- # The kernel handles variable token counts via m_sizes and tile-based processing
- key = [
- "NUM_EXPERTS",
- "N",
- "K",
- "PERMUTE_X",
- "PERMUTE_Y",
- "FUSE_MUL_POST",
- ],
-)(_grouped_gemm_forward_kernel)
diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/tuning.py b/unsloth/kernels/moe/grouped_gemm/kernels/tuning.py
deleted file mode 100644
index 00d4824703..0000000000
--- a/unsloth/kernels/moe/grouped_gemm/kernels/tuning.py
+++ /dev/null
@@ -1,277 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-"""
-Manual tuning utils
-"""
-
-from collections import OrderedDict
-from dataclasses import asdict, dataclass, fields
-from itertools import product
-from typing import Optional
-
-import pandas as pd
-import torch
-import triton
-from triton.runtime.errors import OutOfResources
-
-from .autotuning import (
- BOOLS,
- DEFAULT_K_BLOCK_SIZES,
- DEFAULT_M_BLOCK_SIZES,
- DEFAULT_N_BLOCK_SIZES,
- DEFAULT_NUM_STAGES,
- DEFAULT_NUM_WARPS,
-)
-
-
-@dataclass
-class DeviceProperties:
- NUM_SM: int
- NUM_REGS: int
- SIZE_SMEM: int
- WARP_SIZE: int
-
-
-_DEVICE_PROPERTIES: Optional[DeviceProperties] = None
-
-
-def get_device_properties():
- global _DEVICE_PROPERTIES
- if _DEVICE_PROPERTIES is None:
- properties = triton.runtime.driver.active.utils.get_device_properties(
- torch.cuda.current_device()
- )
- NUM_SM = properties["multiprocessor_count"]
- NUM_REGS = properties["max_num_regs"]
- SIZE_SMEM = properties["max_shared_mem"]
- WARP_SIZE = properties["warpSize"]
- _DEVICE_PROPERTIES = DeviceProperties(NUM_SM, NUM_REGS, SIZE_SMEM, WARP_SIZE)
- return _DEVICE_PROPERTIES
-
-
-@dataclass
-class KernelConfig:
- BLOCK_SIZE_M: int = 32
- BLOCK_SIZE_N: int = 32
- BLOCK_SIZE_K: int = 32
- num_warps: int = 4
- num_stages: int = 2
- flatten: bool = True
- permute_x: bool = False
- permute_y: bool = False
- fuse_mul_post: bool = False
- use_tma_store: bool = False
-
- def to_string(self, include_tuning_params: bool = False, include_tma: bool = False):
- s = []
- if self.permute_x:
- s.append("permute_x")
- if self.permute_y:
- s.append("permute_y")
- if include_tuning_params:
- s.append(
- f"BLOCK_SIZE_M={self.BLOCK_SIZE_M},BLOCK_SIZE_N={self.BLOCK_SIZE_N},BLOCK_SIZE_K={self.BLOCK_SIZE_K},num_warps={self.num_warps},num_stages={self.num_stages},flatten={self.flatten}"
- )
- if include_tma:
- for f in fields(self):
- if f.name.startswith("use_tma_"):
- if getattr(self, f.name):
- s.append(f.name)
- return ",".join(s)
-
-
-@dataclass
-class KernelConfigForward(KernelConfig):
- use_tma_load_w: bool = False
- use_tma_load_x: bool = False
-
-
-@dataclass
-class KernelConfigBackward_dW(KernelConfig):
- use_tma_load_dy: bool = False
- use_tma_load_x: bool = False
-
-
-@dataclass
-class KernelConfigBackward_dX(KernelConfig):
- use_tma_load_dy: bool = False
- use_tma_load_w: bool = False
-
-
-@dataclass
-class KernelResult:
- torch_time: float
- triton_time: float
- speedup: float
- kernel_config: KernelConfig
-
- def to_dict(self):
- return OrderedDict(
- **asdict(self.kernel_config),
- torch_time = self.torch_time,
- triton_time = self.triton_time,
- speedup = self.speedup,
- )
-
- @staticmethod
- def to_dataframe(
- results: list["KernelResult"], sort_by: str = "speedup", ascending: bool = False
- ):
- df = pd.DataFrame([result.to_dict() for result in results])
- df = df.sort_values(by = sort_by, ascending = ascending)
- return df
-
- @staticmethod
- def to_csv(
- results: list["KernelResult"],
- sort_by: str = "speedup",
- ascending: bool = False,
- filename: str = "results.csv",
- ):
- df = KernelResult.to_dataframe(results, sort_by, ascending)
- df.to_csv(filename, index = False)
-
- @staticmethod
- def print_table(
- results: list["KernelResult"],
- sort_by: str = "speedup",
- ascending: bool = False,
- num_results: int = 10,
- ):
- df = KernelResult.to_dataframe(results, sort_by, ascending)
- print(df.head(num_results).to_string(index = False))
-
-
-def get_kernel_configs(
- BLOCK_M = DEFAULT_M_BLOCK_SIZES,
- BLOCK_N = DEFAULT_N_BLOCK_SIZES,
- BLOCK_K = DEFAULT_K_BLOCK_SIZES,
- num_warps = DEFAULT_NUM_WARPS,
- num_stages = DEFAULT_NUM_STAGES,
- use_tma_loads = BOOLS,
- fuse_permute = BOOLS,
-):
- kernel_configs_fwd = []
- kernel_configs_backward_dW = []
- kernel_configs_backward_dX = []
- for block_m, block_n, block_k, w, s, use_tma_load, permute in product(
- BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages, use_tma_loads, fuse_permute
- ):
- kernel_configs_fwd.append(
- KernelConfigForward(
- BLOCK_SIZE_M = block_m,
- BLOCK_SIZE_N = block_n,
- BLOCK_SIZE_K = block_k,
- num_warps = w,
- num_stages = s,
- use_tma_load_x = use_tma_load,
- use_tma_load_w = use_tma_load,
- use_tma_store = False,
- permute_x = permute,
- permute_y = permute,
- )
- )
- kernel_configs_backward_dW.append(
- KernelConfigBackward_dW(
- BLOCK_SIZE_M = block_m,
- BLOCK_SIZE_N = block_n,
- BLOCK_SIZE_K = block_k,
- num_warps = w,
- num_stages = s,
- use_tma_load_dy = use_tma_load,
- use_tma_load_x = use_tma_load,
- use_tma_store = False,
- permute_x = permute,
- permute_y = permute,
- )
- )
- kernel_configs_backward_dX.append(
- KernelConfigBackward_dX(
- BLOCK_SIZE_M = block_m,
- BLOCK_SIZE_N = block_n,
- BLOCK_SIZE_K = block_k,
- num_warps = w,
- num_stages = s,
- use_tma_load_dy = use_tma_load,
- use_tma_load_w = use_tma_load,
- use_tma_store = False,
- permute_x = permute,
- permute_y = permute,
- )
- )
-
- kernel_configs_fwd = prune_kernel_configs_fwd(kernel_configs_fwd)
- kernel_configs_backward_dW = prune_kernel_configs_backward_dW(
- kernel_configs_backward_dW
- )
- kernel_configs_backward_dX = prune_kernel_configs_backward_dX(
- kernel_configs_backward_dX
- )
- return kernel_configs_fwd, kernel_configs_backward_dW, kernel_configs_backward_dX
-
-
-def prune_kernel_configs_fwd(configs: list[KernelConfigForward]):
- pruned_configs = []
- for config in configs:
- if config.use_tma_load_x and config.permute_x:
- continue
- if config.permute_x and config.permute_y:
- continue
- if config.use_tma_store and config.permute_y:
- continue
- pruned_configs.append(config)
- return pruned_configs
-
-
-def prune_kernel_configs_backward_dX(configs: list[KernelConfigBackward_dX]):
- pruned_configs = []
- for config in configs:
- if config.use_tma_load_dy and config.permute_y:
- continue
- if config.permute_x and config.permute_y:
- continue
- if config.use_tma_store and config.permute_x:
- continue
- pruned_configs.append(config)
- return pruned_configs
-
-
-def prune_kernel_configs_backward_dW(configs: list[KernelConfigBackward_dW]):
- pruned_configs = []
- for config in configs:
- if config.use_tma_load_dy and config.permute_y:
- continue
- if config.use_tma_load_x and config.permute_x:
- continue
- if config.permute_x and config.permute_y:
- continue
- pruned_configs.append(config)
- return pruned_configs
-
-
-class TritonTuningContext:
- def __init__(self, kernel_config: KernelConfig):
- self.kernel_config = kernel_config
- self.success = True
-
- def __enter__(self):
- # Setup code can be added here if needed
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- if exc_type is OutOfResources:
- name = exc_value.name
- required = exc_value.required
- limit = exc_value.limit
- print(
- f"Kernel config {self.kernel_config} failed: {name}, required: {required}, limit: {limit}"
- )
- self.success = False
- elif exc_type is not None:
- print(
- f"Error running Triton grouped GEMM for kernel config: {self.kernel_config}: {exc_value}"
- )
- self.success = False
- # Return False to propagate exceptions, True to suppress them
- return True
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/__init__.py b/unsloth/kernels/moe/grouped_gemm/reference/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py b/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py
deleted file mode 100644
index 6bb0bfb0c3..0000000000
--- a/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py
+++ /dev/null
@@ -1,437 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-from dataclasses import dataclass
-from typing import Tuple
-
-import torch
-import torch.nn.functional as F
-from transformers.models.llama4 import Llama4TextConfig
-from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
-
-from ...interface import grouped_gemm
-from ...kernels.tuning import (
- KernelConfigBackward_dW,
- KernelConfigBackward_dX,
- KernelConfigForward,
-)
-from ..moe_ops import (
- get_routing_indices,
- permute,
- torch_grouped_gemm,
- unpermute,
-)
-
-"""
-Reference implementation of Llama4 MoE block using triton grouped gemm.
-
-`Llama4GroupedGemmTextMoe` is the HF `Llama4TextMoe` block implemented with a torch-native grouped gemm.
-`Llama4TritonTextMoe` is the HF `Llama4TextMoe` implemented with triton grouped gemm.
-"""
-
-
-@dataclass
-class Llama4MoeResult:
- token_counts_by_expert: torch.Tensor
- gather_indices: torch.Tensor
- topk_weights: torch.Tensor
- hidden_states_after_weight_merge: torch.Tensor
- first_gemm: torch.Tensor
- intermediate: torch.Tensor
- second_gemm: torch.Tensor
- hidden_states_unpermute: torch.Tensor
- shared_expert_out: torch.Tensor
- final_out: torch.Tensor
- router_logits: torch.Tensor = None
-
-
-class Llama4GroupedGemmTextMoe(Llama4TextMoe):
- EXPERT_WEIGHT_NAMES = ["experts.gate_up_proj", "experts.down_proj"]
-
- def __init__(
- self,
- config: Llama4TextConfig,
- overlap_router_shared = False,
- verbose = False,
- debug = False,
- ):
- super().__init__(config)
- self.overlap_router_shared = overlap_router_shared
- self.verbose = verbose
- self.debug = debug
-
- # Permute in-place expert weights
- E, K, N = self.num_experts, self.hidden_dim, self.experts.expert_dim
- assert self.experts.gate_up_proj.shape == torch.Size(
- [E, K, 2 * N]
- ), f"{self.experts.gate_up_proj.shape} != {[E, K, 2 * N]}"
- permuted_shape = [E, 2 * N, K]
- permuted_stride = [2 * N * K, K, 1]
- if verbose:
- print(
- f"Changing gate_up_proj from {self.experts.gate_up_proj.size()}:{self.experts.gate_up_proj.stride()} to {permuted_shape}:{permuted_stride}"
- )
- with torch.no_grad():
- self.experts.gate_up_proj.as_strided_(permuted_shape, permuted_stride)
-
- if verbose:
- print(
- f"{self.experts.gate_up_proj.shape}:{self.experts.gate_up_proj.stride()}"
- )
-
- assert self.experts.down_proj.shape == torch.Size(
- [E, N, K]
- ), f"{self.experts.down_proj.shape} != {[E, N, K]}"
- permuted_shape = [E, K, N]
- permuted_stride = [K * N, N, 1]
- if verbose:
- print(
- f"Changing down_proj from {self.experts.down_proj.size()}:{self.experts.down_proj.stride()} to {permuted_shape}:{permuted_stride}"
- )
-
- with torch.no_grad():
- self.experts.down_proj.as_strided_(permuted_shape, permuted_stride)
-
- if verbose:
- print(f"{self.experts.down_proj.shape}:{self.experts.down_proj.stride()}")
-
- if overlap_router_shared:
- self.shared_expert_stream = torch.cuda.Stream()
- self.default_event = torch.cuda.Event()
- self.shared_expert_end_event = torch.cuda.Event()
-
- @torch.no_grad
- def copy_weights(self, other: Llama4TextMoe):
- for name, param_to_copy in other.named_parameters():
- if self.verbose:
- print(f"Copying {name} with shape {param_to_copy.shape}")
- param = self.get_parameter(name)
-
- if any(n in name for n in self.EXPERT_WEIGHT_NAMES):
- param_to_copy = param_to_copy.permute(0, 2, 1)
-
- assert (
- param.shape == param_to_copy.shape
- ), f"{param.shape} != {param_to_copy.shape}"
- param.copy_(param_to_copy)
-
- return self
-
- def check_weights(self, other: Llama4TextMoe):
- for name, other_param in other.named_parameters():
- if any(n in name for n in self.EXPERT_WEIGHT_NAMES):
- other_param = other_param.permute(0, 2, 1)
- param = self.get_parameter(name)
- assert param.equal(other_param), f"Param {name} not equal!"
- assert param.is_contiguous(), f"{name} not contiguous!"
-
- def act_and_mul(self, x: torch.Tensor) -> torch.Tensor:
- assert x.shape[-1] == 2 * self.experts.expert_dim
- gate_proj = x[..., : self.experts.expert_dim]
- up_proj = x[..., self.experts.expert_dim :]
- return self.experts.act_fn(gate_proj) * up_proj
-
- def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # router_logits: (batch * sequence_length, n_experts)
- hidden_states = hidden_states.view(-1, self.hidden_dim)
- router_logits = self.router(hidden_states)
- routing_weights, selected_experts = torch.topk(
- router_logits, self.top_k, dim = -1
- )
-
- routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)
-
- return router_logits, routing_weights, selected_experts
-
- def get_token_counts_and_gather_indices(
- self, selected_experts: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- token_counts_by_expert, gather_indices = get_routing_indices(
- selected_experts, self.num_experts
- )
- assert not token_counts_by_expert.requires_grad
- assert not gather_indices.requires_grad
- return token_counts_by_expert, gather_indices
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """ """
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- num_tokens = batch_size * sequence_length
- total_tokens = num_tokens * self.top_k
- hidden_states = hidden_states.view(-1, hidden_dim)
-
- if self.overlap_router_shared:
- # Marker for all prior ops on default stream
- self.default_event.record()
-
- router_logits, routing_weights, selected_experts = self.run_router(
- hidden_states
- )
- assert routing_weights.shape == (
- num_tokens,
- self.top_k,
- ), f"{routing_weights.shape} != {(num_tokens, self.top_k)}"
-
- if self.overlap_router_shared:
- with torch.cuda.stream(self.shared_expert_stream):
- # Ensure prior kernels on default stream complete
- self.default_event.wait()
-
- shared_expert_out = self.shared_expert(hidden_states)
- # Ensure hidden states remains valid on this stream
- hidden_states.record_stream(self.shared_expert_stream)
-
- self.shared_expert_end_event.record()
-
- # Ensure shared expert still valid on default stream
- shared_expert_out.record_stream(torch.cuda.current_stream())
- self.shared_expert_end_event.wait()
- else:
- shared_expert_out = self.shared_expert(hidden_states)
-
- hidden_states = (
- hidden_states.view(num_tokens, self.top_k, hidden_dim)
- * routing_weights[..., None]
- )
-
- if self.top_k > 1:
- hidden_states = hidden_states.sum(dim = 1)
- hidden_states_after_weight_merge = hidden_states.view(-1, hidden_dim)
-
- # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
- # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph
- token_counts_by_expert, gather_indices = (
- self.get_token_counts_and_gather_indices(selected_experts)
- )
-
- # 2. Permute tokens from token order to expert order
- hidden_states = permute(
- hidden_states_after_weight_merge, gather_indices, self.top_k
- )
- assert hidden_states.shape == (total_tokens, hidden_dim)
-
- # Start expert computation
- first_gemm = torch_grouped_gemm(
- X = hidden_states, W = self.experts.gate_up_proj, m_sizes = token_counts_by_expert
- )
- assert first_gemm.shape == (total_tokens, 2 * self.experts.expert_dim)
-
- intermediate = self.act_and_mul(first_gemm)
- assert intermediate.shape == (total_tokens, self.experts.expert_dim)
-
- # See comment above
- second_gemm = torch_grouped_gemm(
- X = intermediate, W = self.experts.down_proj, m_sizes = token_counts_by_expert
- )
- assert second_gemm.shape == (total_tokens, hidden_dim)
-
- # Post-processing
- hidden_states_unpermute = unpermute(second_gemm, gather_indices)
- assert hidden_states_unpermute.shape == (total_tokens, hidden_dim)
- # grouped_gemm_out = hidden_states.view(batch_size, sequence_length, hidden_dim)
-
- final_out = hidden_states_unpermute + shared_expert_out
-
- result = (
- Llama4MoeResult(
- token_counts_by_expert = token_counts_by_expert,
- gather_indices = gather_indices,
- topk_weights = routing_weights,
- hidden_states_after_weight_merge = hidden_states_after_weight_merge,
- first_gemm = first_gemm,
- intermediate = intermediate,
- second_gemm = second_gemm,
- hidden_states_unpermute = hidden_states_unpermute,
- shared_expert_out = shared_expert_out,
- final_out = final_out,
- router_logits = router_logits,
- )
- if self.debug
- else (final_out, routing_weights)
- )
-
- return result
-
-
-class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe):
- def __init__(
- self,
- config: Llama4TextConfig,
- overlap_router_shared = False,
- permute_x: bool = False,
- permute_y: bool = True,
- autotune: bool = True,
- kernel_config_fwd: KernelConfigForward = None,
- kernel_config_bwd_dW: KernelConfigBackward_dW = None,
- kernel_config_bwd_dX: KernelConfigBackward_dX = None,
- dW_only: bool = False,
- dX_only: bool = False,
- verbose = False,
- ):
- super().__init__(config, overlap_router_shared = overlap_router_shared)
- assert not permute_x, "Llama4 triton grouped gemm does not support permute x due to pre-multiplication of router weights"
- self.permute_x = permute_x
- self.permute_y = permute_y
- self.autotune = autotune
- if not autotune:
- assert (
- kernel_config_fwd is not None
- and kernel_config_bwd_dW is not None
- and kernel_config_bwd_dX is not None
- ), "Kernel configs must be provided if autotune is False"
- self.kernel_config_fwd = kernel_config_fwd
- self.kernel_config_bwd_dW = kernel_config_bwd_dW
- self.kernel_config_bwd_dX = kernel_config_bwd_dX
- self.dW_only = dW_only
- self.dX_only = dX_only
-
- @torch.no_grad
- def copy_weights(self, other: Llama4TextMoe):
- for name, param_to_copy in other.named_parameters():
- if self.verbose:
- print(f"Copying {name} with shape {param_to_copy.shape}")
- param = self.get_parameter(name)
-
- if any(n in name for n in self.EXPERT_WEIGHT_NAMES):
- param_to_copy = param_to_copy.permute(0, 2, 1)
-
- assert (
- param.shape == param_to_copy.shape
- ), f"{param.shape} != {param_to_copy.shape}"
- param.copy_(param_to_copy)
-
- return self
-
- def check_weights(self, other: Llama4TextMoe):
- for name, other_param in other.named_parameters():
- if any(n in name for n in self.EXPERT_WEIGHT_NAMES):
- other_param = other_param.permute(0, 2, 1)
- param = self.get_parameter(name)
- assert param.equal(other_param), f"Param {name} not equal!"
- assert param.is_contiguous(), f"{name} not contiguous!"
-
- def act_and_mul(self, x: torch.Tensor) -> torch.Tensor:
- assert x.shape[-1] == 2 * self.experts.expert_dim
- gate_proj = x[..., : self.experts.expert_dim]
- up_proj = x[..., self.experts.expert_dim :]
- return self.experts.act_fn(gate_proj) * up_proj
-
- def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # router_logits: (batch * sequence_length, n_experts)
- hidden_states = hidden_states.view(-1, self.hidden_dim)
- router_logits = self.router(hidden_states)
- routing_weights, selected_experts = torch.topk(
- router_logits, self.top_k, dim = -1
- )
-
- routing_weights = F.sigmoid(routing_weights.float()).to(hidden_states.dtype)
-
- return router_logits, routing_weights, selected_experts
-
- def get_token_counts_and_gather_indices(
- self, selected_experts: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- token_counts_by_expert, gather_indices = get_routing_indices(
- selected_experts, self.num_experts
- )
- assert not token_counts_by_expert.requires_grad
- assert not gather_indices.requires_grad
- return token_counts_by_expert, gather_indices
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """ """
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- num_tokens = batch_size * sequence_length
- total_tokens = num_tokens * self.top_k
- hidden_states = hidden_states.view(-1, hidden_dim)
-
- if self.overlap_router_shared:
- # Marker for all prior ops on default stream
- self.default_event.record()
-
- router_logits, routing_weights, selected_experts = self.run_router(
- hidden_states
- )
- assert routing_weights.shape == (
- num_tokens,
- self.top_k,
- ), f"{routing_weights.shape} != {(num_tokens, self.top_k)}"
-
- if self.overlap_router_shared:
- with torch.cuda.stream(self.shared_expert_stream):
- # Ensure prior kernels on default stream complete
- self.default_event.wait()
-
- shared_expert_out = self.shared_expert(hidden_states)
- # Ensure hidden states remains valid on this stream
- hidden_states.record_stream(self.shared_expert_stream)
-
- self.shared_expert_end_event.record()
-
- # Ensure shared expert still valid on default stream
- shared_expert_out.record_stream(torch.cuda.current_stream())
- self.shared_expert_end_event.wait()
- else:
- shared_expert_out = self.shared_expert(hidden_states)
-
- hidden_states = (
- hidden_states.view(num_tokens, self.top_k, hidden_dim)
- * routing_weights[..., None]
- )
-
- if self.top_k > 1:
- hidden_states = hidden_states.sum(dim = 1)
- hidden_states = hidden_states.view(-1, hidden_dim)
-
- # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
- # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph
- token_counts_by_expert, gather_indices = (
- self.get_token_counts_and_gather_indices(selected_experts)
- )
-
- # 2. Permute tokens from token order to expert order
- hidden_states = permute(hidden_states, gather_indices, self.top_k)
- assert hidden_states.shape == (total_tokens, hidden_dim)
-
- # Start expert computation
- hidden_states = grouped_gemm(
- X = hidden_states,
- W = self.experts.gate_up_proj,
- m_sizes = token_counts_by_expert,
- gather_indices = gather_indices,
- topk = self.top_k,
- permute_x = self.permute_x,
- permute_y = False, # output of first grouped gemm should never be permuted
- autotune = self.autotune,
- kernel_config_fwd = self.kernel_config_fwd,
- kernel_config_bwd_dW = self.kernel_config_bwd_dW,
- kernel_config_bwd_dX = self.kernel_config_bwd_dX,
- is_first_gemm = True,
- dW_only = self.dW_only,
- dX_only = self.dX_only,
- )
- hidden_states = self.act_and_mul(hidden_states)
- hidden_states = grouped_gemm(
- X = hidden_states,
- W = self.experts.down_proj,
- m_sizes = token_counts_by_expert,
- gather_indices = gather_indices,
- topk = self.top_k,
- permute_x = False,
- permute_y = self.permute_y,
- autotune = self.autotune,
- kernel_config_fwd = self.kernel_config_fwd,
- kernel_config_bwd_dW = self.kernel_config_bwd_dW,
- kernel_config_bwd_dX = self.kernel_config_bwd_dX,
- is_first_gemm = False,
- dW_only = self.dW_only,
- dX_only = self.dX_only,
- )
-
- # Post-processing
- # 1. Unpermute from expert order to token order
- if not self.permute_y:
- hidden_states = unpermute(hidden_states, gather_indices)
- hidden_states += shared_expert_out
-
- return hidden_states, routing_weights
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py b/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py
deleted file mode 100644
index 31c635ba37..0000000000
--- a/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py
+++ /dev/null
@@ -1,348 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-from dataclasses import dataclass
-from typing import Tuple
-
-import torch
-import torch.nn.functional as F
-from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
-from transformers.models.qwen3_moe.modeling_qwen3_moe import (
- ACT2FN,
- Qwen3MoeSparseMoeBlock,
-)
-
-from ...interface import grouped_gemm
-from ...kernels.tuning import (
- KernelConfigBackward_dW,
- KernelConfigBackward_dX,
- KernelConfigForward,
-)
-from ..moe_ops import (
- get_routing_indices,
- permute,
- torch_grouped_gemm,
- unpermute,
-)
-
-"""
-Reference implementation of HF Qwen3 MoE block using grouped gemm.
-
-The Qwen3MoeGroupedGEMMBlock is a reference torch-native implementation.
-Qwen3MoeFusedGroupedGEMMBlock is a version using the triton grouped gemm kernel.
-
-NOTE: This is NOT to be used for production as it contains many extra checks and saves all intermediate results for debugging.
-"""
-
-
-@dataclass
-class GroupedGEMMResult:
- token_counts_by_expert: torch.Tensor
- gather_indices: torch.Tensor
- topk_weights: torch.Tensor
- first_gemm: torch.Tensor
- intermediate: torch.Tensor
- second_gemm: torch.Tensor
- hidden_states_unpermute: torch.Tensor
- hidden_states: torch.Tensor # final output
-
-
-class Qwen3MoeGroupedGEMMBlock(torch.nn.Module):
- def __init__(
- self,
- config,
- gate: torch.Tensor,
- gate_up_proj: torch.Tensor,
- down_proj: torch.Tensor,
- ):
- super().__init__()
- self.num_experts = config.num_experts
- self.top_k = config.num_experts_per_tok
- self.norm_topk_prob = config.norm_topk_prob
- self.hidden_size = config.hidden_size
- self.moe_intermediate_size = config.moe_intermediate_size
-
- assert gate.shape == (config.num_experts, config.hidden_size)
- assert gate_up_proj.shape == (
- config.num_experts,
- 2 * config.moe_intermediate_size,
- config.hidden_size,
- )
- assert down_proj.shape == (
- config.num_experts,
- config.hidden_size,
- config.moe_intermediate_size,
- )
-
- # gating
- self.gate = torch.nn.Parameter(gate)
-
- # experts
- self.gate_up_proj = torch.nn.Parameter(gate_up_proj, requires_grad = True)
- self.down_proj = torch.nn.Parameter(down_proj, requires_grad = True)
- self.act_fn = ACT2FN[config.hidden_act]
-
- @staticmethod
- def extract_hf_weights(moe_block: Qwen3MoeSparseMoeBlock):
- config: Qwen3MoeConfig = moe_block.experts[0].config
- num_experts = config.num_experts
-
- gate = moe_block.gate.weight.data
- gate_proj = torch.stack(
- [moe_block.experts[i].gate_proj.weight.data for i in range(num_experts)],
- dim = 0,
- )
- up_proj = torch.stack(
- [moe_block.experts[i].up_proj.weight.data for i in range(num_experts)],
- dim = 0,
- )
- down_proj = torch.stack(
- [moe_block.experts[i].down_proj.weight.data for i in range(num_experts)],
- dim = 0,
- )
- gate_up_proj = torch.cat([gate_proj, up_proj], dim = 1)
- return gate, gate_up_proj, down_proj
-
- @classmethod
- def from_hf(cls, moe_block: Qwen3MoeSparseMoeBlock):
- config: Qwen3MoeConfig = moe_block.experts[0].config
- gate, gate_up_proj, down_proj = cls.extract_hf_weights(moe_block)
- return cls(config, gate, gate_up_proj, down_proj)
-
- def check_weights(self, moe_block: Qwen3MoeSparseMoeBlock):
- for i in range(self.num_experts):
- assert self.gate_up_proj[i].equal(
- torch.cat(
- [
- moe_block.experts[i].gate_proj.weight.data,
- moe_block.experts[i].up_proj.weight.data,
- ],
- dim = 0,
- )
- )
- assert self.down_proj[i].equal(moe_block.experts[i].down_proj.weight.data)
-
- def act_and_mul(self, x: torch.Tensor) -> torch.Tensor:
- assert x.shape[-1] == 2 * self.moe_intermediate_size
- gate_proj = x[..., : self.moe_intermediate_size]
- up_proj = x[..., self.moe_intermediate_size :]
- return self.act_fn(gate_proj) * up_proj
-
- def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # router_logits: (batch * sequence_length, n_experts)
- router_logits = torch.nn.functional.linear(hidden_states, self.gate)
-
- routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)
- routing_weights, selected_experts = torch.topk(
- routing_weights, self.top_k, dim = -1
- )
- if self.norm_topk_prob: # only diff with mixtral sparse moe block!
- routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
- # we cast back to the input dtype
- routing_weights = routing_weights.to(hidden_states.dtype)
-
- return router_logits, routing_weights, selected_experts
-
- def get_token_counts_and_gather_indices(
- self, selected_experts: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- token_counts_by_expert, gather_indices = get_routing_indices(
- selected_experts, self.num_experts
- )
- assert not token_counts_by_expert.requires_grad
- assert not gather_indices.requires_grad
- return token_counts_by_expert, gather_indices
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """ """
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- num_tokens = batch_size * sequence_length
- total_tokens = num_tokens * self.top_k
-
- hidden_states = hidden_states.view(-1, hidden_dim)
-
- router_logits, routing_weights, selected_experts = self.run_router(
- hidden_states
- )
-
- # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
- # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph
- token_counts_by_expert, gather_indices = (
- self.get_token_counts_and_gather_indices(selected_experts)
- )
-
- # 2. Permute tokens from token order to expert order
- hidden_states = permute(hidden_states, gather_indices, self.top_k)
- assert hidden_states.shape == (total_tokens, hidden_dim)
-
- # Start expert computation
- first_gemm = torch_grouped_gemm(
- X = hidden_states, W = self.gate_up_proj, m_sizes = token_counts_by_expert
- )
- assert first_gemm.shape == (total_tokens, 2 * self.moe_intermediate_size)
- intermediate = self.act_and_mul(first_gemm)
- assert intermediate.shape == (total_tokens, self.moe_intermediate_size)
- second_gemm = torch_grouped_gemm(
- X = intermediate, W = self.down_proj, m_sizes = token_counts_by_expert
- )
- assert second_gemm.shape == (total_tokens, hidden_dim)
-
- # Post-processing
- # 1. Unpermute from expert order to token order
- hidden_states_unpermute = unpermute(second_gemm, gather_indices)
- assert hidden_states_unpermute.shape == (total_tokens, hidden_dim)
-
- # 2. Merge topk weights
- hidden_states = (
- hidden_states_unpermute.view(num_tokens, self.top_k, hidden_dim)
- * routing_weights[..., None]
- )
- hidden_states = hidden_states.sum(dim = 1)
- assert hidden_states.shape == (num_tokens, hidden_dim)
-
- hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
- return GroupedGEMMResult(
- token_counts_by_expert = token_counts_by_expert,
- gather_indices = gather_indices,
- topk_weights = routing_weights,
- first_gemm = first_gemm,
- intermediate = intermediate,
- second_gemm = second_gemm,
- hidden_states_unpermute = hidden_states_unpermute,
- hidden_states = hidden_states,
- ), router_logits
-
-
-class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
- def __init__(
- self,
- config: Qwen3MoeConfig,
- gate: torch.Tensor,
- gate_up_proj: torch.Tensor,
- down_proj: torch.Tensor,
- permute_x: bool = True,
- permute_y: bool = True,
- autotune: bool = True,
- kernel_config_fwd: KernelConfigForward = None,
- kernel_config_bwd_dW: KernelConfigBackward_dW = None,
- kernel_config_bwd_dX: KernelConfigBackward_dX = None,
- dW_only: bool = False,
- dX_only: bool = False,
- ):
- super().__init__(config, gate, gate_up_proj, down_proj)
- self.permute_x = permute_x
- self.permute_y = permute_y
- self.autotune = autotune
- if not autotune:
- assert (
- kernel_config_fwd is not None
- and kernel_config_bwd_dW is not None
- and kernel_config_bwd_dX is not None
- ), "Kernel configs must be provided if autotune is False"
- self.kernel_config_fwd = kernel_config_fwd
- self.kernel_config_bwd_dW = kernel_config_bwd_dW
- self.kernel_config_bwd_dX = kernel_config_bwd_dX
- self.dW_only = dW_only
- self.dX_only = dX_only
-
- @classmethod
- def from_hf(
- cls,
- moe_block: Qwen3MoeSparseMoeBlock,
- permute_x: bool = True,
- permute_y: bool = True,
- autotune: bool = True,
- kernel_config_fwd: KernelConfigForward = None,
- kernel_config_bwd_dW: KernelConfigBackward_dW = None,
- kernel_config_bwd_dX: KernelConfigBackward_dX = None,
- dW_only: bool = False,
- dX_only: bool = False,
- ):
- config: Qwen3MoeConfig = moe_block.experts[0].config
- gate, gate_up_proj, down_proj = Qwen3MoeGroupedGEMMBlock.extract_hf_weights(
- moe_block
- )
- return cls(
- config,
- gate,
- gate_up_proj,
- down_proj,
- permute_x = permute_x,
- permute_y = permute_y,
- autotune = autotune,
- kernel_config_fwd = kernel_config_fwd,
- kernel_config_bwd_dW = kernel_config_bwd_dW,
- kernel_config_bwd_dX = kernel_config_bwd_dX,
- dW_only = dW_only,
- dX_only = dX_only,
- )
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- num_tokens = batch_size * sequence_length
- total_tokens = num_tokens * self.top_k
-
- hidden_states = hidden_states.view(-1, hidden_dim)
-
- router_logits, routing_weights, selected_experts = self.run_router(
- hidden_states
- )
- # Pre-processing
- # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
- # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph
- token_counts_by_expert, gather_indices = (
- self.get_token_counts_and_gather_indices(selected_experts)
- )
-
- # 2. permute_x -> permutation will be fused in prologue of first grouped gemm
- if not self.permute_x:
- hidden_states = permute(hidden_states, gather_indices, self.top_k)
- # Start expert computation
- hidden_states = grouped_gemm(
- X = hidden_states,
- W = self.gate_up_proj,
- m_sizes = token_counts_by_expert,
- gather_indices = gather_indices,
- topk = self.top_k,
- permute_x = self.permute_x,
- permute_y = False, # output of first grouped gemm should never be permuted
- autotune = self.autotune,
- kernel_config_fwd = self.kernel_config_fwd,
- kernel_config_bwd_dW = self.kernel_config_bwd_dW,
- kernel_config_bwd_dX = self.kernel_config_bwd_dX,
- is_first_gemm = True,
- dW_only = self.dW_only,
- dX_only = self.dX_only,
- )
- hidden_states = self.act_and_mul(hidden_states)
- hidden_states = grouped_gemm(
- X = hidden_states,
- W = self.down_proj,
- m_sizes = token_counts_by_expert,
- gather_indices = gather_indices,
- topk = self.top_k,
- permute_x = False,
- permute_y = self.permute_y,
- autotune = self.autotune,
- kernel_config_fwd = self.kernel_config_fwd,
- kernel_config_bwd_dW = self.kernel_config_bwd_dW,
- kernel_config_bwd_dX = self.kernel_config_bwd_dX,
- is_first_gemm = False,
- dW_only = self.dW_only,
- dX_only = self.dX_only,
- )
-
- # Post-processing
- # 1. Unpermute from expert order to token order
- if not self.permute_y:
- hidden_states = unpermute(hidden_states, gather_indices)
-
- # 2. Merge topk weights
- hidden_states = (
- hidden_states.view(num_tokens, self.top_k, hidden_dim)
- * routing_weights[..., None]
- )
- hidden_states = hidden_states.sum(dim = 1)
-
- hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
- return hidden_states, router_logits
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py b/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py
deleted file mode 100644
index 2a015252c6..0000000000
--- a/unsloth/kernels/moe/grouped_gemm/reference/moe_block.py
+++ /dev/null
@@ -1,161 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-import torch
-from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
-from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
-
-from ..interface import grouped_gemm
-from ..kernels.tuning import (
- KernelConfigBackward_dW,
- KernelConfigBackward_dX,
- KernelConfigForward,
-)
-from .moe_ops import (
- Qwen3MoeGroupedGEMMBlock,
- permute,
- unpermute,
-)
-
-"""
-Reference implementation of MoE block using grouped gemm.
-
-This is the same as the Qwen3MoeGroupedGEMMBlock but with triton grouped gemm in place of torch-native grouped gemm implementation.
-
-NOTE: This is NOT to be used for production as it contains many extra checks and saves all intermediate results for debugging.
-"""
-
-
-class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
- def __init__(
- self,
- config: Qwen3MoeConfig,
- gate: torch.Tensor,
- gate_up_proj: torch.Tensor,
- down_proj: torch.Tensor,
- permute_x: bool = True,
- permute_y: bool = True,
- autotune: bool = True,
- kernel_config_fwd: KernelConfigForward = None,
- kernel_config_bwd_dW: KernelConfigBackward_dW = None,
- kernel_config_bwd_dX: KernelConfigBackward_dX = None,
- dW_only: bool = False,
- dX_only: bool = False,
- ):
- super().__init__(config, gate, gate_up_proj, down_proj)
- self.permute_x = permute_x
- self.permute_y = permute_y
- self.autotune = autotune
- if not autotune:
- assert (
- kernel_config_fwd is not None
- and kernel_config_bwd_dW is not None
- and kernel_config_bwd_dX is not None
- ), "Kernel configs must be provided if autotune is False"
- self.kernel_config_fwd = kernel_config_fwd
- self.kernel_config_bwd_dW = kernel_config_bwd_dW
- self.kernel_config_bwd_dX = kernel_config_bwd_dX
- self.dW_only = dW_only
- self.dX_only = dX_only
-
- @classmethod
- def from_hf(
- cls,
- moe_block: Qwen3MoeSparseMoeBlock,
- permute_x: bool = True,
- permute_y: bool = True,
- autotune: bool = True,
- kernel_config_fwd: KernelConfigForward = None,
- kernel_config_bwd_dW: KernelConfigBackward_dW = None,
- kernel_config_bwd_dX: KernelConfigBackward_dX = None,
- dW_only: bool = False,
- dX_only: bool = False,
- ):
- config: Qwen3MoeConfig = moe_block.experts[0].config
- gate, gate_up_proj, down_proj = Qwen3MoeGroupedGEMMBlock.extract_hf_weights(
- moe_block
- )
- return cls(
- config,
- gate,
- gate_up_proj,
- down_proj,
- permute_x = permute_x,
- permute_y = permute_y,
- autotune = autotune,
- kernel_config_fwd = kernel_config_fwd,
- kernel_config_bwd_dW = kernel_config_bwd_dW,
- kernel_config_bwd_dX = kernel_config_bwd_dX,
- dW_only = dW_only,
- dX_only = dX_only,
- )
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- num_tokens = batch_size * sequence_length
- total_tokens = num_tokens * self.top_k
-
- hidden_states = hidden_states.view(-1, hidden_dim)
-
- router_logits, routing_weights, selected_experts = self.run_router(
- hidden_states
- )
- # Pre-processing
- # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
- # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph
- token_counts_by_expert, gather_indices = (
- self.get_token_counts_and_gather_indices(selected_experts)
- )
-
- # 2. permute_x -> permutation will be fused in prologue of first grouped gemm
- if not self.permute_x:
- hidden_states = permute(hidden_states, gather_indices, self.top_k)
- # Start expert computation
- hidden_states = grouped_gemm(
- X = hidden_states,
- W = self.gate_up_proj,
- m_sizes = token_counts_by_expert,
- gather_indices = gather_indices,
- topk = self.top_k,
- permute_x = self.permute_x,
- permute_y = False, # output of first grouped gemm should never be permuted
- autotune = self.autotune,
- kernel_config_fwd = self.kernel_config_fwd,
- kernel_config_bwd_dW = self.kernel_config_bwd_dW,
- kernel_config_bwd_dX = self.kernel_config_bwd_dX,
- is_first_gemm = True,
- dW_only = self.dW_only,
- dX_only = self.dX_only,
- )
- hidden_states = self.act_and_mul(hidden_states)
- hidden_states = grouped_gemm(
- X = hidden_states,
- W = self.down_proj,
- m_sizes = token_counts_by_expert,
- gather_indices = gather_indices,
- topk = self.top_k,
- permute_x = False,
- permute_y = self.permute_y,
- autotune = self.autotune,
- kernel_config_fwd = self.kernel_config_fwd,
- kernel_config_bwd_dW = self.kernel_config_bwd_dW,
- kernel_config_bwd_dX = self.kernel_config_bwd_dX,
- is_first_gemm = False,
- dW_only = self.dW_only,
- dX_only = self.dX_only,
- )
-
- # Post-processing
- # 1. Unpermute from expert order to token order
- if not self.permute_y:
- hidden_states = unpermute(hidden_states, gather_indices)
-
- # 2. Merge topk weights
- hidden_states = (
- hidden_states.view(num_tokens, self.top_k, hidden_dim)
- * routing_weights[..., None]
- )
- hidden_states = hidden_states.sum(dim = 1)
-
- hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
- return hidden_states, router_logits
diff --git a/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py b/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py
deleted file mode 100644
index 46d9c3c514..0000000000
--- a/unsloth/kernels/moe/grouped_gemm/reference/moe_ops.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-import torch
-import torch.nn.functional as F
-
-
-def permute(X: torch.Tensor, gather_indices: torch.Tensor, topk: int):
- """
- Scatters X to a new tensor with shape [total_tokens, hidden_dim] where total_tokens is num_tokens * topk,
- permuting the tokens according to sorted_token_idx.
-
- Helper for grouped gemm where hidden states need be ordered by expert.
- X: [num_tokens, hidden_dim]
- sorted_token_idx: [num_tokens * topk]
- topk: int
-
- Returns:
- [total_tokens, hidden_dim]
- """
- assert gather_indices.ndim == 1
- X = X.view(-1, X.shape[-1])
- # Shortcut for topk == 1
- if topk == 1:
- return X[gather_indices]
-
- return X[gather_indices // topk]
-
-
-def unpermute(X: torch.Tensor, gather_indices: torch.Tensor):
- X = X.view(-1, X.shape[-1]) if X.ndim > 2 else X
- unpermuted = torch.empty_like(X)
- unpermuted.index_copy_(0, gather_indices, X)
- return unpermuted.view_as(X)
-
-
-def calculate_topk(
- gating_output: torch.Tensor,
- top_k: int,
- use_sigmoid: bool,
- renormalize: bool,
- pre_act: bool = True,
- post_act: bool = False,
-):
- """
- If post_act is True, then activation function is run AFTER topk
- If post_act is False, then activation function is run BEFORE topk
-
- This is to align with triton_bench implementation (post_act) whereas most models use pre_act (e.g. llama4, deepseek)
- """
- assert pre_act ^ post_act, "only one of pre_act or post_act can be True"
-
- def _activation(gating_output: torch.Tensor):
- if use_sigmoid:
- scores = torch.sigmoid(gating_output.to(torch.float32)).to(
- gating_output.dtype
- )
- else:
- scores = F.softmax(gating_output.to(torch.float32), dim = 1).to(
- gating_output.dtype
- )
-
- return scores
-
- if pre_act:
- scores = _activation(gating_output)
- else:
- scores = gating_output
-
- topk_weights, topk_ids = torch.topk(scores, k = top_k, dim = 1)
-
- if post_act:
- topk_weights = _activation(topk_weights)
-
- if renormalize:
- topk_weights /= torch.sum(topk_weights, dim = -1, keepdim = True).to(
- gating_output.dtype
- )
-
- return topk_weights, topk_ids
-
-
-@torch.no_grad()
-def get_routing_indices(
- selected_experts, num_experts, return_scatter_indices: bool = False
-):
- """
- Returns:
- token_counts_by_expert: [num_experts]
- gather_indices: [num_tokens]
- scatter_indices [Optional] (torch.Tensor):
- Indices for unpermuting gathered inputs back to token order, shape ``(bs * seqlen * top_k,)``.
- """
- # group tokens together by expert indices from 0 to num_experts and pass that to experts forward
- token_counts_by_expert = torch.histc(
- selected_experts.view(-1),
- bins = num_experts,
- min = 0,
- max = num_experts,
- )
- # token_indices_experts_sorted shape (bs*slen*top_k,)
- gather_indices = torch.argsort(selected_experts.view(-1), stable = True)
- if return_scatter_indices:
- scatter_indices = gather_indices.argsort()
- return token_counts_by_expert, gather_indices, scatter_indices
- else:
- return token_counts_by_expert, gather_indices
-
-
-def torch_grouped_gemm(X, W, m_sizes, transpose = True):
- """
- X: [M, K] if forward, else [M, N]
- W: [E, N, K]
- m_sizes: [E]
-
- Returns:
- Y: [M, N] if forward, else [M, K]
- """
- X = X.view(-1, X.shape[-1])
- M, K = X.shape
-
- assert m_sizes.ndim == 1
- E = m_sizes.shape[0]
-
- assert W.ndim == 3
- assert W.shape[0] == E
-
- N = W.shape[1]
-
- result = torch.zeros((M, N), dtype = X.dtype, device = X.device)
-
- m_start = 0
- for g in range(E):
- m_size = m_sizes[g]
- if m_size > 0:
- m_end = m_start + m_size
-
- # Extract group input
- # m_size x K
- X_g = X[m_start:m_end]
- # N x K
- W_g = W[g]
-
- # Y_g = X_g @ W_g.T -> [m_size, N]
- W_g = W_g.T if transpose else W_g
- Y_g = X_g @ W_g
-
- result[m_start:m_end] = Y_g
-
- m_start = m_end
- return result
diff --git a/unsloth/kernels/moe/requirements.txt b/unsloth/kernels/moe/requirements.txt
deleted file mode 100644
index ea76e50564..0000000000
--- a/unsloth/kernels/moe/requirements.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-torch
-git+https://github.com/huggingface/transformers.git@main
-pytest
-pandas
-ruff
\ No newline at end of file
diff --git a/unsloth/kernels/moe/tests/__init__.py b/unsloth/kernels/moe/tests/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/unsloth/kernels/moe/tests/common.py b/unsloth/kernels/moe/tests/common.py
deleted file mode 100644
index bfe6f2094b..0000000000
--- a/unsloth/kernels/moe/tests/common.py
+++ /dev/null
@@ -1,336 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-import itertools
-from contextlib import contextmanager
-from dataclasses import dataclass, field
-
-import torch
-
-from grouped_gemm.kernels.tuning import (
- KernelConfig,
- KernelConfigBackward_dW,
- KernelConfigBackward_dX,
- KernelConfigForward,
- prune_kernel_configs_backward_dW,
- prune_kernel_configs_backward_dX,
- prune_kernel_configs_fwd,
-)
-
-
-def print_delimiter(char = "-", length = 80):
- print(char * length)
-
-
-@contextmanager
-def delimiter_context():
- print_delimiter()
- yield
- print_delimiter()
-
-
-def make_inputs(M, N, K, E, topk, dtype, requires_grad = False):
- X1 = (
- torch.randn((M, K), device = "cuda", dtype = dtype, requires_grad = requires_grad)
- / 10
- )
- X2 = (
- torch.randn(
- (M * topk, N), device = "cuda", dtype = dtype, requires_grad = requires_grad
- )
- / 10
- )
- W1 = (
- torch.randn(
- (E, 2 * N, K), device = "cuda", dtype = dtype, requires_grad = requires_grad
- )
- / 10
- )
- W2 = (
- torch.randn((E, K, N), device = "cuda", dtype = dtype, requires_grad = requires_grad)
- / 10
- )
- score = torch.randn((M, E), device = "cuda", dtype = dtype, requires_grad = requires_grad)
- if requires_grad:
- X1.retain_grad()
- X2.retain_grad()
- W1.retain_grad()
- W2.retain_grad()
- score.retain_grad()
- return X1, X2, W1, W2, score
-
-
-@dataclass(kw_only = True)
-class DataConfig:
- seq_len: int
- dtype: torch.dtype
- device: str = "cuda"
- bs: int = 1
-
-
-@dataclass(kw_only = True)
-class ModelConfig:
- hidden_size: int
- intermediate_size: int
- num_experts: int
- topk: int
- use_sigmoid: bool
- renormalize: bool
- pre_mul: bool = False
- post_mul: bool = field(init = False)
-
- def __post_init__(self):
- self.post_mul = not self.pre_mul
-
-
-@dataclass(kw_only = True)
-class GroupedGEMMTestConfig:
- name: str = "test"
- data_config: DataConfig
- model_config: ModelConfig
-
-
-TOLERANCE = {
- torch.bfloat16: (1e-3, 1e-3),
- torch.float16: (1e-4, 1e-4),
- torch.float32: (1e-5, 1e-5),
-}
-
-
-# from https://github.com/triton-lang/triton/blob/main/bench/triton_bench/testing.py
-def assert_equal(ref, tri):
- if isinstance(ref, torch.Tensor):
- assert torch.all(ref == tri), f"tensors not equal {ref} != {tri}"
- else:
- assert ref == tri, f"ref not equal to tri {ref} != {tri}"
-
-
-def assert_close(ref, tri, maxtol = None, rmstol = None, description = "--", verbose = True):
- if tri.dtype.itemsize == 1:
- ref_as_type = ref.to(tri.dtype)
- if ref.dtype == tri.dtype:
- assert torch.all(ref_as_type == tri)
- return
- ref = ref_as_type
-
- if maxtol is None:
- maxtol = 2e-2
- if rmstol is None:
- rmstol = 4e-3
- """
- Compare reference values against obtained values.
- """
-
- # cast to float32:
- ref = ref.to(torch.float32).detach()
- tri = tri.to(torch.float32).detach()
- assert (
- ref.shape == tri.shape
- ), f"Tensors must have same size {ref.shape = } {tri.shape = }"
-
- # deal with infinite elements:
- inf_mask_ref = torch.isinf(ref)
- inf_mask_tri = torch.isinf(tri)
- assert torch.equal(
- inf_mask_ref, inf_mask_tri
- ), "Tensor must have same infinite elements"
- refn = torch.where(inf_mask_ref, 0, ref)
- trin = torch.where(inf_mask_tri, 0, tri)
-
- # normalise so that RMS calculation doesn't overflow:
- eps = 1.0e-30
- multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
- refn *= multiplier
- trin *= multiplier
-
- ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
-
- rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
- max_err = torch.max(rel_err).item()
- rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
-
- if verbose:
- print(
- "%s maximum relative error = %s (threshold = %s)"
- % (description, max_err, maxtol)
- )
- print(
- "%s RMS relative error = %s (threshold = %s)"
- % (description, rms_err, rmstol)
- )
-
- if max_err > maxtol:
- bad_idxs = torch.nonzero(rel_err > maxtol)
- num_nonzero = bad_idxs.size(0)
- bad_idxs = bad_idxs[:1000]
- print(
- "%d / %d mismatched elements (shape = %s) at coords %s"
- % (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist())
- )
-
- bad_idxs = bad_idxs.unbind(-1)
- print("ref values: ", ref[*bad_idxs].cpu())
- print("tri values: ", tri[*bad_idxs].cpu())
-
- assert max_err <= maxtol
- assert rms_err <= rmstol
-
-
-def assert_indx_equal(ref, tri):
- assert_equal(ref, tri[: len(ref)])
- assert torch.all(tri[len(ref) :] == -1)
-
-
-def get_kernel_test_configs(
- BLOCK_SIZE_M = 32,
- BLOCK_SIZE_N = 32,
- BLOCK_SIZE_K = 32,
- num_warps = 4,
- num_stages = 2,
-) -> list[KernelConfig]:
- configs_fwd = []
- configs_bwd_dX = []
- configs_bwd_dW = []
-
- for permute_x in [False, True]:
- for permute_y in [False, True]:
- for use_tma_load_w in [True, False]:
- for use_tma_load_x in [True, False]:
- for use_tma_store in [True, False]:
- configs_fwd.append(
- KernelConfigForward(
- BLOCK_SIZE_M = BLOCK_SIZE_M,
- BLOCK_SIZE_N = BLOCK_SIZE_N,
- BLOCK_SIZE_K = BLOCK_SIZE_K,
- num_warps = num_warps,
- num_stages = num_stages,
- use_tma_load_w = use_tma_load_w,
- use_tma_load_x = use_tma_load_x,
- use_tma_store = use_tma_store,
- permute_x = permute_x,
- permute_y = permute_y,
- )
- )
- configs_bwd_dX.append(
- KernelConfigBackward_dX(
- BLOCK_SIZE_M = BLOCK_SIZE_M,
- BLOCK_SIZE_N = BLOCK_SIZE_N,
- BLOCK_SIZE_K = BLOCK_SIZE_K,
- num_warps = num_warps,
- num_stages = num_stages,
- use_tma_load_dy = use_tma_load_x,
- use_tma_load_w = use_tma_load_w,
- permute_x = permute_x,
- permute_y = permute_y,
- use_tma_store = use_tma_store,
- )
- )
- configs_bwd_dW.append(
- KernelConfigBackward_dW(
- BLOCK_SIZE_M = BLOCK_SIZE_M,
- BLOCK_SIZE_N = BLOCK_SIZE_N,
- BLOCK_SIZE_K = BLOCK_SIZE_K,
- num_warps = num_warps,
- num_stages = num_stages,
- use_tma_load_dy = use_tma_load_w,
- use_tma_load_x = use_tma_load_x,
- permute_x = permute_x,
- permute_y = permute_y,
- use_tma_store = use_tma_store,
- )
- )
- configs_fwd = prune_kernel_configs_fwd(configs_fwd)
- configs_bwd_dX = prune_kernel_configs_backward_dX(configs_bwd_dX)
- configs_bwd_dW = prune_kernel_configs_backward_dW(configs_bwd_dW)
- return configs_fwd, configs_bwd_dX, configs_bwd_dW
-
-
-def remove_feature_flags(
- kernel_configs: list[KernelConfig],
- permute_x: bool = True,
- permute_y: bool = True,
- tma_loads: bool = True,
- tma_store: bool = True,
-):
- pruned_configs = []
- for config in kernel_configs:
- # Remove permute flags first:
- if permute_x and config.permute_x:
- continue
- if permute_y and config.permute_y:
- continue
- if tma_loads:
- if isinstance(config, KernelConfigForward):
- if config.use_tma_load_w or config.use_tma_load_x:
- continue
- if isinstance(config, KernelConfigBackward_dX):
- if config.use_tma_load_dy or config.use_tma_load_w:
- continue
- if isinstance(config, KernelConfigBackward_dW):
- if config.use_tma_load_dy or config.use_tma_load_x:
- continue
- if tma_store:
- if config.use_tma_store:
- continue
- pruned_configs.append(config)
- return pruned_configs
-
-
-# Test Configs
-
-TOPK = [1, 4]
-NUM_EXPERTS = [4, 16]
-
-TEST_MODEL_SIZES = [
- (32, 32), # Debug
- (128, 128), # Small
- (512, 512), # Medium
-]
-
-SMALL_MODEL_CONFIGS = [
- ModelConfig(
- topk = topk,
- num_experts = num_experts,
- hidden_size = model_size[0],
- intermediate_size = model_size[1],
- use_sigmoid = False,
- renormalize = False,
- )
- for topk, num_experts, model_size in itertools.product(
- TOPK, NUM_EXPERTS, TEST_MODEL_SIZES
- )
-]
-LLAMA_MODEL_CONFIG = ModelConfig(
- topk = 1,
- num_experts = 16,
- hidden_size = 5120,
- intermediate_size = 8192,
- use_sigmoid = True,
- renormalize = False,
-)
-QWEN_MODEL_CONFIG = ModelConfig(
- topk = 8,
- num_experts = 128,
- hidden_size = 2048,
- intermediate_size = 768,
- use_sigmoid = False,
- renormalize = False,
-)
-
-SEQLENS = [128, 1024]
-DTYPE = [torch.bfloat16]
-
-DATA_CONFIGS = [
- DataConfig(seq_len = seq_len, dtype = dtype)
- for seq_len, dtype in itertools.product(SEQLENS, DTYPE)
-]
-KERNEL_CONFIGS_FWD, KERNEL_CONFIGS_BWD_dX, KERNEL_CONFIGS_BWD_dW = (
- get_kernel_test_configs()
-)
-
-if __name__ == "__main__":
- print(
- KERNEL_CONFIGS_BWD_dX[0].to_string(
- include_tuning_params = False, include_tma = False
- )
- )
diff --git a/unsloth/kernels/moe/tests/moe_utils.py b/unsloth/kernels/moe/tests/moe_utils.py
deleted file mode 100644
index 26ac9fdb5e..0000000000
--- a/unsloth/kernels/moe/tests/moe_utils.py
+++ /dev/null
@@ -1,507 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-from dataclasses import dataclass, fields
-
-import torch
-import torch.nn as nn
-from huggingface_hub import HfApi
-from huggingface_hub.utils import _safetensors
-from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
-from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
-
-from grouped_gemm.interface import grouped_gemm
-from grouped_gemm.kernels.tuning import (
- KernelConfigBackward_dW,
- KernelConfigBackward_dX,
- KernelConfigForward,
-)
-from grouped_gemm.reference.layers.qwen3_moe import (
- GroupedGEMMResult,
- Qwen3MoeGroupedGEMMBlock,
-)
-from grouped_gemm.reference.moe_ops import permute, unpermute
-
-
-def rebind_experts_to_shared_buffer(
- moe_block: Qwen3MoeSparseMoeBlock, config: Qwen3MoeConfig
-):
- num_experts = config.num_experts
- hidden_size = config.hidden_size
- interm_size = config.moe_intermediate_size
- device = moe_block.experts[0].down_proj.weight.device
- dtype = moe_block.experts[0].down_proj.weight.dtype
-
- buffer_up = torch.empty(
- num_experts, interm_size, hidden_size, device = device, dtype = dtype
- )
- buffer_gate = torch.empty(
- num_experts, interm_size, hidden_size, device = device, dtype = dtype
- )
- buffer_down = torch.empty(
- num_experts, hidden_size, interm_size, device = device, dtype = dtype
- )
-
- # Step 2: Copy existing expert weights into buffers
- for i, expert in enumerate(moe_block.experts):
- buffer_up[i].copy_(expert.up_proj.weight.data)
- buffer_gate[i].copy_(expert.gate_proj.weight.data)
- buffer_down[i].copy_(expert.down_proj.weight.data)
-
- # Step 3: Rebind expert weights to views in shared buffer
- for i, expert in enumerate(moe_block.experts):
- expert.up_proj.weight = torch.nn.Parameter(buffer_up[i])
- expert.gate_proj.weight = torch.nn.Parameter(buffer_gate[i])
- expert.down_proj.weight = torch.nn.Parameter(buffer_down[i])
-
- return buffer_up, buffer_gate, buffer_down
-
-
-def get_expert_metadata(model_id: str):
- api = HfApi()
- metadata: _safetensors.SafetensorsRepoMetadata = api.get_safetensors_metadata(
- model_id
- )
- return metadata.files_metadata
-
-
-def clone_experts(
- moe_block: Qwen3MoeSparseMoeBlock, config: Qwen3MoeConfig, copy: bool = True
-):
- down_projs = torch.empty(
- config.num_experts, config.hidden_size, config.moe_intermediate_size
- )
- up_projs = torch.empty(
- config.num_experts, config.moe_intermediate_size, config.hidden_size
- )
- gate_projs = torch.empty(
- config.num_experts, config.moe_intermediate_size, config.hidden_size
- )
- for expert_idx, expert in enumerate(moe_block.experts):
- down_projs[expert_idx].copy_(expert.down_proj.weight.data)
- up_projs[expert_idx].copy_(expert.up_proj.weight.data)
- gate_projs[expert_idx].copy_(expert.gate_proj.weight.data)
- return gate_projs, up_projs, down_projs
-
-
-@dataclass
-class ForwardResult:
- output: torch.Tensor
- router_logits: torch.Tensor
- X: torch.Tensor
- # When using grouped gemm MoE implementation to additional debugging / checking of intermediate results
- grouped_gemm_result: GroupedGEMMResult = None
-
-
-@dataclass
-class BackwardResult:
- X_grad: torch.Tensor
- gate_grad: torch.Tensor
- gate_proj_grad: torch.Tensor
- up_proj_grad: torch.Tensor
- down_proj_grad: torch.Tensor
-
-
-def check_down_proj_grad(
- moe_block: Qwen3MoeSparseMoeBlock,
- grouped_gemm_block: Qwen3MoeGroupedGEMMBlock,
- atol: float,
- rtol: float,
-):
- for i, expert in enumerate(moe_block.experts):
- ref_grad = expert.down_proj.weight.grad
- assert ref_grad is not None
- test_grad = grouped_gemm_block.down_proj.grad[i]
- assert test_grad is not None
- diff = (ref_grad - test_grad).abs().max()
- if not torch.allclose(ref_grad, test_grad, atol = atol, rtol = rtol):
- print(f"expert {i} down_proj_grad_diff: {diff.detach().cpu().item():.6f}")
-
-
-def check_gate_up_proj_grad(
- moe_block: Qwen3MoeSparseMoeBlock,
- grouped_gemm_block: Qwen3MoeGroupedGEMMBlock,
- atol: float,
- rtol: float,
-):
- moe_intermediate_size = grouped_gemm_block.moe_intermediate_size
- for i, expert in enumerate(moe_block.experts):
- ref_gate_proj_grad = expert.gate_proj.weight.grad
- ref_up_proj_grad = expert.up_proj.weight.grad
- assert ref_gate_proj_grad is not None
- assert ref_up_proj_grad is not None
-
- # Extract gradients
- test_gate_proj_grad = grouped_gemm_block.gate_up_proj.grad[
- i, :moe_intermediate_size
- ]
- test_up_proj_grad = grouped_gemm_block.gate_up_proj.grad[
- i, moe_intermediate_size:
- ]
- assert test_gate_proj_grad is not None
- assert test_up_proj_grad is not None
-
- # Sanity check shapes
- assert (
- ref_gate_proj_grad.shape == test_gate_proj_grad.shape
- ), f"{ref_gate_proj_grad.shape} != {test_gate_proj_grad.shape}"
- assert (
- ref_up_proj_grad.shape == test_up_proj_grad.shape
- ), f"{ref_up_proj_grad.shape} != {test_up_proj_grad.shape}"
-
- # Check gradients
- diff = (ref_gate_proj_grad - test_gate_proj_grad).abs().max()
- if not torch.allclose(
- ref_gate_proj_grad, test_gate_proj_grad, atol = atol, rtol = rtol
- ):
- print(f"expert {i} gate_proj_grad_diff: {diff.detach().cpu().item():.6f}")
- diff = (ref_up_proj_grad - test_up_proj_grad).abs().max()
- if not torch.allclose(
- ref_up_proj_grad, test_up_proj_grad, atol = atol, rtol = rtol
- ):
- print(f"expert {i} up_proj_grad_diff: {diff.detach().cpu().item():.6f}")
-
-
-def check_gate_grad(
- moe_block: Qwen3MoeSparseMoeBlock,
- grouped_gemm_block: Qwen3MoeGroupedGEMMBlock,
- atol: float,
- rtol: float,
-):
- ref_grad = moe_block.gate.weight.grad
- assert ref_grad is not None
- test_grad = grouped_gemm_block.gate.grad
- assert test_grad is not None
- diff = (ref_grad - test_grad).abs().max()
- if not torch.allclose(ref_grad, test_grad, atol = atol, rtol = rtol):
- print(f"gate_grad_diff: {diff.detach().cpu().item():.6f}")
-
-
-def check_wgrad(
- moe_block: Qwen3MoeSparseMoeBlock,
- grouped_gemm_block: Qwen3MoeGroupedGEMMBlock,
- atol: float,
- rtol: float,
-):
- check_down_proj_grad(moe_block, grouped_gemm_block, atol, rtol)
- check_gate_up_proj_grad(moe_block, grouped_gemm_block, atol, rtol)
- check_gate_grad(moe_block, grouped_gemm_block, atol, rtol)
-
-
-def check_tensor_allclose(
- X_ref: torch.Tensor,
- X_test: torch.Tensor,
- atol: float,
- rtol: float,
- name: str,
- verbose: bool = False,
-):
- diff = (X_ref - X_test).abs().max()
- if verbose:
- print(f"{name} diff: {diff.detach().cpu().item():.6f}")
- assert torch.allclose(
- X_ref, X_test, atol = atol, rtol = rtol
- ), f"{name} diff: {diff.detach().cpu().item():.6f}"
-
-
-def check_expert_grads(
- ref_result: BackwardResult,
- test_result: BackwardResult,
- atol: float,
- rtol: float,
- verbose: bool = False,
-):
- fields_to_check = [f.name for f in fields(BackwardResult) if "proj" in f.name]
- assert len(fields_to_check) == 3
-
- for field in fields_to_check:
- ref_grads = getattr(ref_result, field)
- test_grads = getattr(test_result, field)
- assert (
- ref_grads.shape == test_grads.shape
- ), f"{field}: {ref_grads.shape} != {test_grads.shape}"
-
- # Test each expert
- for i in range(ref_grads.shape[0]):
- ref_grad = ref_grads[i]
- test_grad = test_grads[i]
- diff = (ref_grad - test_grad).abs().max()
- assert torch.allclose(
- ref_grad, test_grad, atol = atol, rtol = rtol
- ), f"{field}[{i}] diff: {diff.detach().cpu().item():.6f}"
-
- # Test all experts
- diff = (ref_grads - test_grads).abs().max()
- if verbose:
- print(f"{field} diff: {diff.detach().cpu().item():.6f}")
- assert torch.allclose(
- ref_grads, test_grads, atol = atol, rtol = rtol
- ), f"{field} diff: {diff.detach().cpu().item():.6f}"
-
-
-def check_grads(
- ref_result: BackwardResult,
- test_result: BackwardResult,
- atol: float,
- rtol: float,
- verbose: bool = False,
-):
- check_tensor_allclose(
- ref_result.X_grad, test_result.X_grad, atol, rtol, "X.grad", verbose
- )
- check_tensor_allclose(
- ref_result.gate_grad, test_result.gate_grad, atol, rtol, "gate.grad", verbose
- )
- check_expert_grads(ref_result, test_result, atol, rtol, verbose)
-
-
-def check_fwd(
- ref_result: ForwardResult,
- test_result: ForwardResult,
- atol: float,
- rtol: float,
- verbose: bool = False,
-):
- # First check hidden states (output)
- ref_output = ref_result.output
- test_output = test_result.output
- diff = (ref_output - test_output).abs().max()
- if verbose:
- print(f"output diff: {diff.detach().cpu().item():.6f}")
- assert torch.allclose(
- ref_output, test_output, atol = atol, rtol = rtol
- ), f"output diff: {diff.detach().cpu().item():.6f}"
-
- # Check router logits
- ref_router_logits = ref_result.router_logits
- test_router_logits = test_result.router_logits
- diff = (ref_router_logits - test_router_logits).abs().max()
- if verbose:
- print(f"router_logits diff: {diff.detach().cpu().item():.6f}")
- assert torch.allclose(
- ref_router_logits, test_router_logits, atol = atol, rtol = rtol
- ), f"router_logits diff: {diff.detach().cpu().item():.6f}"
-
-
-def check_grouped_gemm_results(
- grouped_result: GroupedGEMMResult,
- fused_result: GroupedGEMMResult,
- permute_y: bool,
- atol: float,
- rtol: float,
- verbose: bool = False,
-):
- for field in fields(GroupedGEMMResult):
- ref_value = getattr(grouped_result, field.name)
- test_value = getattr(fused_result, field.name)
- diff = (ref_value - test_value).abs().max()
-
- # second_gemm in torch grouped gemm is not yet unpermuted so comparing the fused unpermuted second_gemm will result in error
- # instead the hidden_states_unpermute should match since hidden_states_unpermute for the fused result is the same as second_gemm
- if field.name == "second_gemm" and permute_y:
- continue
-
- if verbose:
- print(f"{field.name} diff: {diff.detach().cpu().item():.6f}")
-
- assert torch.allclose(
- ref_value, test_value, atol = atol, rtol = rtol
- ), f"{field.name} diff: {diff.detach().cpu().item():.6f}"
-
-
-def run_forward(model: nn.Module, X: torch.Tensor, is_grouped_gemm: bool = False):
- X = X.detach().clone().requires_grad_(True)
- output, router_logits = model(X)
- if is_grouped_gemm:
- result = ForwardResult(
- output = output.hidden_states,
- router_logits = router_logits,
- X = X,
- grouped_gemm_result = output,
- )
- else:
- result = ForwardResult(output = output, router_logits = router_logits, X = X)
- return result
-
-
-def run_backward(
- model: nn.Module, grad_output: torch.Tensor, output: torch.Tensor, X: torch.Tensor
-):
- output.backward(grad_output)
- assert X.grad is not None
- for name, param in model.named_parameters():
- assert param.grad is not None, f"{name} grad is None"
- if isinstance(model, Qwen3MoeSparseMoeBlock):
- gate_grad = model.gate.weight.grad
- gate_proj_grad = torch.stack(
- [expert.gate_proj.weight.grad for expert in model.experts]
- )
- up_proj_grad = torch.stack(
- [expert.up_proj.weight.grad for expert in model.experts]
- )
- down_proj_grad = torch.stack(
- [expert.down_proj.weight.grad for expert in model.experts]
- )
- elif isinstance(model, Qwen3MoeGroupedGEMMBlock):
- gate_grad = model.gate.grad
- gate_proj_grad, up_proj_grad = model.gate_up_proj.grad.chunk(2, dim = 1)
- down_proj_grad = model.down_proj.grad
- else:
- raise ValueError(f"Unsupported model type: {type(model)}")
- return BackwardResult(
- X_grad = X.grad,
- gate_grad = gate_grad,
- gate_proj_grad = gate_proj_grad,
- up_proj_grad = up_proj_grad,
- down_proj_grad = down_proj_grad,
- )
-
-
-class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock):
- """
- Reference implementation of MoE block using grouped gemm.
-
- This is the same as the Qwen3MoeGroupedGEMMBlock but with triton grouped gemm in place of torch-native grouped gemm implementation.
-
- NOTE: This is NOT to be used for production as it contains many extra checks and saves all intermediate results for debugging.
- See grouped_gemm/reference/moe_block.py for a cleaner implementation.
- """
-
- def __init__(
- self,
- config: Qwen3MoeConfig,
- gate: torch.Tensor,
- gate_up_proj: torch.Tensor,
- down_proj: torch.Tensor,
- permute_x: bool = False,
- permute_y: bool = False,
- autotune: bool = True,
- kernel_config_fwd: KernelConfigForward = None,
- kernel_config_bwd_dW: KernelConfigBackward_dW = None,
- kernel_config_bwd_dX: KernelConfigBackward_dX = None,
- ):
- super().__init__(config, gate, gate_up_proj, down_proj)
- self.permute_x = permute_x
- self.permute_y = permute_y
- self.autotune = autotune
- if not autotune:
- assert (
- kernel_config_fwd is not None
- and kernel_config_bwd_dW is not None
- and kernel_config_bwd_dX is not None
- ), "Kernel configs must be provided if autotune is False"
- self.kernel_config_fwd = kernel_config_fwd
- self.kernel_config_bwd_dW = kernel_config_bwd_dW
- self.kernel_config_bwd_dX = kernel_config_bwd_dX
-
- @classmethod
- def from_hf(
- cls,
- moe_block: Qwen3MoeSparseMoeBlock,
- permute_x: bool = False,
- permute_y: bool = False,
- autotune: bool = True,
- kernel_config_fwd: KernelConfigForward = None,
- kernel_config_bwd_dW: KernelConfigBackward_dW = None,
- kernel_config_bwd_dX: KernelConfigBackward_dX = None,
- ):
- config: Qwen3MoeConfig = moe_block.experts[0].config
- gate, gate_up_proj, down_proj = Qwen3MoeGroupedGEMMBlock.extract_hf_weights(
- moe_block
- )
- return cls(
- config,
- gate,
- gate_up_proj,
- down_proj,
- permute_x = permute_x,
- permute_y = permute_y,
- autotune = autotune,
- kernel_config_fwd = kernel_config_fwd,
- kernel_config_bwd_dW = kernel_config_bwd_dW,
- kernel_config_bwd_dX = kernel_config_bwd_dX,
- )
-
- def forward(self, hidden_states: torch.Tensor, debug: bool = False) -> torch.Tensor:
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- num_tokens = batch_size * sequence_length
- total_tokens = num_tokens * self.top_k
-
- hidden_states = hidden_states.view(-1, hidden_dim)
-
- router_logits, routing_weights, selected_experts = self.run_router(
- hidden_states
- )
- # Pre-processing
- # 1. Compute tokens per expert and indices for gathering tokes from token order to expert order
- # NOTE: these are auxiliary data structs which don't need to be recorded in autograd graph
- token_counts_by_expert, gather_indices = (
- self.get_token_counts_and_gather_indices(selected_experts)
- )
-
- # 2. permute_x -> permutation will be fused in prologue of first grouped gemm
- if not self.permute_x:
- hidden_states = permute(hidden_states, gather_indices, self.top_k)
- assert hidden_states.shape == (total_tokens, hidden_dim)
-
- # Start expert computation
- first_gemm = grouped_gemm(
- X = hidden_states,
- W = self.gate_up_proj,
- m_sizes = token_counts_by_expert,
- gather_indices = gather_indices,
- topk = self.top_k,
- permute_x = self.permute_x,
- permute_y = False, # output of first grouped gemm should never be permuted
- autotune = self.autotune,
- kernel_config_fwd = self.kernel_config_fwd,
- kernel_config_bwd_dW = self.kernel_config_bwd_dW,
- kernel_config_bwd_dX = self.kernel_config_bwd_dX,
- is_first_gemm = True,
- )
- assert first_gemm.shape == (total_tokens, 2 * self.moe_intermediate_size)
- intermediate = self.act_and_mul(first_gemm)
- assert intermediate.shape == (total_tokens, self.moe_intermediate_size)
- second_gemm = grouped_gemm(
- X = intermediate,
- W = self.down_proj,
- m_sizes = token_counts_by_expert,
- gather_indices = gather_indices,
- topk = self.top_k,
- permute_x = False,
- permute_y = self.permute_y,
- autotune = self.autotune,
- kernel_config_fwd = self.kernel_config_fwd,
- kernel_config_bwd_dW = self.kernel_config_bwd_dW,
- kernel_config_bwd_dX = self.kernel_config_bwd_dX,
- is_first_gemm = False,
- )
- assert second_gemm.shape == (total_tokens, hidden_dim)
-
- # Post-processing
- # 1. Unpermute from expert order to token order
- if not self.permute_y:
- hidden_states_unpermute = unpermute(second_gemm, gather_indices)
- assert hidden_states_unpermute.shape == (total_tokens, hidden_dim)
- else:
- hidden_states_unpermute = second_gemm
-
- # 2. Merge topk weights
- hidden_states = (
- hidden_states_unpermute.view(num_tokens, self.top_k, hidden_dim)
- * routing_weights[..., None]
- )
- hidden_states = hidden_states.sum(dim = 1)
- assert hidden_states.shape == (num_tokens, hidden_dim)
-
- hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim)
- return GroupedGEMMResult(
- token_counts_by_expert = token_counts_by_expert,
- gather_indices = gather_indices,
- topk_weights = routing_weights,
- first_gemm = first_gemm,
- intermediate = intermediate,
- second_gemm = second_gemm,
- hidden_states_unpermute = hidden_states_unpermute,
- hidden_states = hidden_states,
- ), router_logits
diff --git a/unsloth/kernels/moe/tests/run_qwen3_moe_tests.sh b/unsloth/kernels/moe/tests/run_qwen3_moe_tests.sh
deleted file mode 100755
index ed0b9f6210..0000000000
--- a/unsloth/kernels/moe/tests/run_qwen3_moe_tests.sh
+++ /dev/null
@@ -1,35 +0,0 @@
-#!/bin/bash
-
-set -euo pipefail
-
-SEQLENS=(1024)
-DTYPES=(bfloat16)
-PERMUTE_X=(false true)
-PERMUTE_Y=(false true)
-AUTOTUNE=(false true)
-
-for SEQLEN in "${SEQLENS[@]}"; do
- for DTYPE in "${DTYPES[@]}"; do
- for PX in "${PERMUTE_X[@]}"; do
- for PY in "${PERMUTE_Y[@]}"; do
- for AT in "${AUTOTUNE[@]}"; do
-
- ARGS=()
- [[ "$PX" == "true" ]] && ARGS+=("--permute_x")
- [[ "$PY" == "true" ]] && ARGS+=("--permute_y")
- [[ "$AT" == "true" ]] && ARGS+=("--autotune")
-
- ARGS+=(--seqlen "$SEQLEN" --dtype "$DTYPE")
-
- echo "Running with args: ${ARGS[*]}"
- if ! python -m tests.test_qwen3_moe "${ARGS[@]}"; then
- echo "❌ Test failed with args: --permute_x=$PX --permute_y=$PY --autotune=$AT --seqlen=$SEQLEN --dtype=$DTYPE" >&2
- else
- echo "✅ Test passed with args: --permute_x=$PX --permute_y=$PY --autotune=$AT --seqlen=$SEQLEN --dtype=$DTYPE"
- fi
-
- done
- done
- done
- done
-done
diff --git a/unsloth/kernels/moe/tests/test_grouped_gemm.py b/unsloth/kernels/moe/tests/test_grouped_gemm.py
deleted file mode 100644
index bd98b6a276..0000000000
--- a/unsloth/kernels/moe/tests/test_grouped_gemm.py
+++ /dev/null
@@ -1,1213 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-from dataclasses import asdict
-
-import pytest
-import torch
-
-from grouped_gemm.interface import (
- grouped_gemm,
- grouped_gemm_dW,
- grouped_gemm_dX,
- grouped_gemm_forward,
-)
-from grouped_gemm.kernels.tuning import (
- KernelConfig,
- KernelConfigBackward_dW,
- KernelConfigBackward_dX,
- KernelConfigForward,
-)
-from grouped_gemm.reference.moe_ops import (
- calculate_topk,
- get_routing_indices,
- permute,
- torch_grouped_gemm,
- unpermute,
-)
-
-from .common import (
- DATA_CONFIGS,
- KERNEL_CONFIGS_FWD,
- LLAMA_MODEL_CONFIG,
- QWEN_MODEL_CONFIG,
- SMALL_MODEL_CONFIGS,
- TOLERANCE,
- DataConfig,
- KERNEL_CONFIGS_BWD_dW,
- KERNEL_CONFIGS_BWD_dX,
- ModelConfig,
- make_inputs,
-)
-
-SEED = 0
-
-
-# Only certain combinations of permute_x, permute_y, use_W1 are valid.
-# use_W1 => first grouped GEMM in a fused MoE MLP
-# use_W2 => second grouped GEMM in a fused MoE MLP
-# permute_x => permute the input to the grouped GEMM, only done for the first grouped GEMM
-# permute_y => permute the output of the grouped GEMM, only done for the second grouped GEMM
-# fuse_mul_post => fuse the multiplication of topk weights in the epilogue of the second grouped GEMM; only used for inference, not currently tested
-def check_valid_config(
- permute_x, permute_y, use_W1, fuse_mul_post = False, is_backward = False, verbose = False
-):
- use_W2 = not use_W1
-
- if permute_x and permute_y:
- if verbose:
- print(f"Skipping test: {permute_x = } {permute_y = }")
- return False
- if use_W2 and permute_x:
- if verbose:
- print(f"Skipping test: {permute_x = } {use_W2 = }")
- return False
- if use_W1 and permute_y:
- if verbose:
- print(f"Skipping test: {permute_y = } {use_W1 = }")
- return False
- if fuse_mul_post and use_W1:
- if verbose:
- print(f"Skipping test: {fuse_mul_post = } {use_W1 = }")
- return False
- if is_backward and fuse_mul_post:
- if verbose:
- print(f"Skipping test: {fuse_mul_post = } {is_backward = }")
- return False
-
- return True
-
-
-"""
-grouped_gemm_forward
-
-permute_x: typically in a fused MoE MLP, we can fuse the permutation of hidden states (X) from token order to expert grouped order needed for grouped GEMM by directly loading X in permuted order rather than launching a separate permutation kernel.
-permute_y: We can also fuse the unpermutation of tokens after the second grouped GEMM to restore to original token order. This is fused into the second grouped GEMM by directly storing the output in unpermuted order.
-fuse_mul: We can also fuse the multiplication of topk weights in the epilogue of the second grouped GEMM. Note that this is only supported for inference and not training, although this may change in the future.
-use_W1 test the shapes for the first grouped GEMM in a fused MoE MLP
-use_W2 = `not use_W1` tests the shapes for the second grouped GEMM in a fused MoE MLP
-
-Given the above, only certain combinations are valid:
-- use_W1 is always False when permute_y is True since we only permute the second grouped GEMM
-- use_W2 is always False when permute_x is True since we only permute the first grouped GEMM
-- only one of permute_x and permute_y can be True
-- fuse_mul is only True if permute_y is also True
-
-See `check_valid_config` for more details.
-"""
-
-
-def _test_grouped_gemm_forward(
- data_config: DataConfig,
- model_config: ModelConfig,
- permute_x: bool,
- permute_y: bool,
- use_W1: bool, # W1 -> first grouped GEMM in a fused MoE MLP, not W1 -> second grouped GEMM in a fused MoE MLP
- fuse_mul_post: bool = False,
- flatten: bool = True,
- # Manually tuned parameters
- use_tma_load_w: bool = False,
- use_tma_load_x: bool = False,
- use_tma_store: bool = False,
- BLOCK_SIZE_M: int = None,
- BLOCK_SIZE_N: int = None,
- BLOCK_SIZE_K: int = None,
- num_warps: int = None,
- num_stages: int = None,
- # Autotuning parameters
- autotune: bool = False,
- num_autotune_configs: int = None,
- # Flag to manually enable TMA store
- allow_tma_store: bool = False,
- use_autograd: bool = False,
-):
- if not check_valid_config(
- permute_x, permute_y, use_W1 = use_W1, fuse_mul_post = fuse_mul_post
- ):
- pytest.skip(
- f"Skipping test due to invalid config: {permute_x = } {permute_y = } {use_W1 = } {fuse_mul_post = }"
- )
-
- if use_tma_store and not allow_tma_store:
- pytest.skip("TMA store needs to be debugged due to non-deterministic behavior")
-
- X1, X2, W1, W2, gating_output = make_inputs(
- M = data_config.bs * data_config.seq_len,
- N = model_config.intermediate_size,
- K = model_config.hidden_size,
- E = model_config.num_experts,
- topk = model_config.topk,
- dtype = data_config.dtype,
- )
- topk = model_config.topk
- use_sigmoid = model_config.use_sigmoid
- renormalize = model_config.renormalize
-
- X = X1 if use_W1 else X2
- num_tokens = data_config.bs * data_config.seq_len
- E, K, N = W2.shape # E = num_experts, K = hidden_size, N = intermediate_size
- assert W1.shape == (E, 2 * N, K)
- W = W1 if use_W1 else W2
-
- if use_W1:
- assert X.shape == (
- num_tokens,
- K,
- ), f"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}"
- else:
- assert X.shape == (
- num_tokens * topk,
- N,
- ), f"X.shape: {X.shape}, num_tokens: {num_tokens}, topk: {topk}, N: {N}"
-
- total_tokens = num_tokens * topk
- output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)
-
- topk_weights, topk_ids = calculate_topk(
- gating_output, topk, use_sigmoid = use_sigmoid, renormalize = renormalize
- )
- topk_weights = topk_weights.view(-1) # num_tokens * topk
- topk_ids = topk_ids.view(-1) # num_tokens * topk
-
- expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts = E)
- assert len(gather_indices) == total_tokens
- assert len(expert_token_counts) == E
-
- atol, rtol = TOLERANCE[X.dtype]
-
- Xperm = permute(X, gather_indices, topk)
-
- Xref = Xperm
-
- assert (
- Xperm.shape == (total_tokens, K) if use_W1 else (total_tokens, N)
- ), f"Xperm.shape: {Xperm.shape}, total_tokens: {total_tokens}, K: {K}"
-
- ref_output = torch_grouped_gemm(X = Xref, W = W, m_sizes = expert_token_counts)
-
- if permute_x:
- X_test = X
- else:
- X_test = Xperm
-
- # No need to run all configs for tests, otherwise takes too long
- if autotune:
- from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel
-
- if num_autotune_configs is not None:
- _autotuned_grouped_gemm_forward_kernel.configs = (
- _autotuned_grouped_gemm_forward_kernel.configs[:num_autotune_configs]
- )
-
- # Use autograd.Function interface
- if use_autograd:
- from grouped_gemm.interface import grouped_gemm
-
- kernel_config_fwd = KernelConfigForward(
- BLOCK_SIZE_M = BLOCK_SIZE_M,
- BLOCK_SIZE_N = BLOCK_SIZE_N,
- BLOCK_SIZE_K = BLOCK_SIZE_K,
- num_warps = num_warps,
- num_stages = num_stages,
- permute_x = permute_x,
- permute_y = permute_y,
- fuse_mul_post = fuse_mul_post,
- use_tma_load_w = use_tma_load_w,
- use_tma_load_x = use_tma_load_x,
- use_tma_store = use_tma_store,
- )
-
- test_output = grouped_gemm(
- X = X_test,
- W = W,
- topk = topk,
- m_sizes = expert_token_counts,
- gather_indices = gather_indices,
- topk_weights = topk_weights if fuse_mul_post else None,
- permute_x = permute_x,
- permute_y = permute_y,
- fuse_mul_post = fuse_mul_post,
- kernel_config_fwd = kernel_config_fwd,
- autotune = autotune,
- is_first_gemm = use_W1,
- )
- # Use manual interface
- else:
- test_output = grouped_gemm_forward(
- X = X_test,
- W = W,
- topk = topk,
- m_sizes = expert_token_counts,
- gather_indices = gather_indices,
- topk_weights = topk_weights if fuse_mul_post else None,
- permute_x = permute_x,
- permute_y = permute_y,
- fuse_mul_post = fuse_mul_post,
- use_tma_load_w = use_tma_load_w,
- use_tma_load_x = use_tma_load_x,
- use_tma_store = use_tma_store,
- autotune = autotune,
- BLOCK_SIZE_M = BLOCK_SIZE_M,
- BLOCK_SIZE_N = BLOCK_SIZE_N,
- BLOCK_SIZE_K = BLOCK_SIZE_K,
- num_warps = num_warps,
- num_stages = num_stages,
- flatten = flatten,
- )
- assert ref_output.shape == output_shape
- assert test_output.shape == output_shape
-
- if permute_y:
- ref_output = unpermute(ref_output, gather_indices)
- if fuse_mul_post:
- # if we don't permute_y, then test output is permuted with topk weights applied
- # the ref output needs to be unpermuted before multiplying by topk weights since topk weights are in token order
- if not permute_y:
- ref_output = unpermute(ref_output, gather_indices)
- test_output = unpermute(test_output, gather_indices)
- ref_output = ref_output * topk_weights[:, None]
-
- assert torch.allclose(
- ref_output, test_output, atol = atol, rtol = rtol
- ), f"Grouped gemm forward failed: {(ref_output - test_output).abs().max().item():.6f}"
-
-
-# NOTE: Fuse multiplication of topk weights is only supported for inference and not training, although this may change in the future; not currently tested.
-@pytest.mark.parametrize(
- "kernel_config",
- KERNEL_CONFIGS_FWD,
- ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),
-)
-@pytest.mark.parametrize(
- "model_config",
- SMALL_MODEL_CONFIGS + [QWEN_MODEL_CONFIG, LLAMA_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_forward_manual(
- data_config: DataConfig,
- model_config: ModelConfig,
- kernel_config: KernelConfigForward,
- use_W1: bool,
-):
- _test_grouped_gemm_forward(
- data_config = data_config,
- model_config = model_config,
- use_W1 = use_W1,
- **asdict(kernel_config),
- )
-
-
-@pytest.mark.parametrize(
- "kernel_config",
- KERNEL_CONFIGS_FWD,
- ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),
-)
-@pytest.mark.parametrize(
- "model_config",
- SMALL_MODEL_CONFIGS + [QWEN_MODEL_CONFIG, LLAMA_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_forward_manual_autograd(
- data_config: DataConfig,
- model_config: ModelConfig,
- kernel_config: KernelConfigForward,
- use_W1: bool,
-):
- _test_grouped_gemm_forward(
- data_config = data_config,
- model_config = model_config,
- use_W1 = use_W1,
- use_autograd = True,
- **asdict(kernel_config),
- )
-
-
-@pytest.mark.parametrize(
- "num_autotune_configs", [10], ids = lambda x: f"num_autotune_configs={x}"
-)
-@pytest.mark.parametrize(
- "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
-)
-@pytest.mark.parametrize(
- "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
-)
-@pytest.mark.parametrize(
- "model_config",
- [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_forward_autotune(
- data_config: DataConfig,
- model_config: ModelConfig,
- permute_x: bool,
- permute_y: bool,
- use_W1: bool,
- num_autotune_configs: int,
-):
- _test_grouped_gemm_forward(
- data_config = data_config,
- model_config = model_config,
- permute_x = permute_x,
- permute_y = permute_y,
- use_W1 = use_W1,
- num_autotune_configs = num_autotune_configs,
- autotune = True,
- use_autograd = False,
- )
-
-
-@pytest.mark.parametrize(
- "num_autotune_configs", [10], ids = lambda x: f"num_autotune_configs={x}"
-)
-@pytest.mark.parametrize(
- "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
-)
-@pytest.mark.parametrize(
- "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
-)
-@pytest.mark.parametrize(
- "model_config",
- [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_forward_autotune_autograd(
- data_config: DataConfig,
- model_config: ModelConfig,
- permute_x: bool,
- permute_y: bool,
- use_W1: bool,
- num_autotune_configs: int,
-):
- _test_grouped_gemm_forward(
- data_config = data_config,
- model_config = model_config,
- permute_x = permute_x,
- permute_y = permute_y,
- use_W1 = use_W1,
- num_autotune_configs = num_autotune_configs,
- autotune = True,
- use_autograd = True,
- )
-
-
-"""
-grouped_gemm_backward_dX
-
-use_W1 test the shapes for the first grouped GEMM in a fused MoE MLP
-use_W2 = `not use_W1` tests the shapes for the second grouped GEMM in a fused MoE MLP
-
-Only certain combinations of permute_x, permute_y, and fuse_mul are supported.
-
-Typically in a fused MoE MLP, we can fuse the permutation of hidden states (X) from token order to expert grouped order needed for grouped GEMM by directly loading X in permuted order rather than launching a separate permutation kernel.
-We can also fuse the unpermutation of tokens after the second grouped GEMM to restore to original token order. This is fused into the second grouped GEMM by directly storing the output in unpermuted order.
-
-Hence the following conditions:
-- If use_W1 there are two cases:
- - permute_x is False and topk > 1:
- - dX_test is still in permuted order and has shape (total_tokens, K)
- - it needs to be unpermuted and summed across topk before comparing to ref_grad
-- permute_x is True:
- - dX_test is already unpermuted and summed across topk with shape (num_tokens, K)
- - no further processing is needed
-- permute_x is False and topk == 1:
- - dX_test needs to be permuted, no need to sum since topk == 1
-
-- If use_W2:
- - permute_x is always False
- - if permute_y:
- - grad_output needs to be unpermuted before passing to grouped_gemm_dX
- - dX_test is permuted and has shape (total_tokens, N)
- - it needs to be unpermuted before comparing to ref_grad or can be compared directly to Xperm.grad
- - if not permute_y:
- - dX_test is not permuted and has shape (total_tokens, N)
- - no further processing is needed
-"""
-
-
-def _test_grouped_gemm_backward_dX(
- data_config: DataConfig,
- model_config: ModelConfig,
- permute_x: bool = False,
- permute_y: bool = False,
- use_tma_load_dy: bool = False,
- use_tma_load_w: bool = False,
- use_tma_store: bool = False,
- use_W1: bool = True,
- autotune: bool = False,
- num_autotune_configs: int = None,
- BLOCK_SIZE_M: int = None,
- BLOCK_SIZE_N: int = None,
- BLOCK_SIZE_K: int = None,
- num_warps: int = None,
- num_stages: int = None,
- flatten: bool = True,
- allow_tma_store: bool = False,
- use_autograd: bool = False,
- fuse_mul_post: bool = False,
-):
- if not check_valid_config(permute_x, permute_y, use_W1 = use_W1, is_backward = True):
- pytest.skip(
- f"Skipping test due to invalid config: {permute_x = } {permute_y = } {use_W1 = }"
- )
-
- if use_tma_store and not allow_tma_store:
- pytest.skip("TMA store needs to be debugged due to non-deterministic behavior")
-
- if (
- autotune
- and model_config.intermediate_size <= 128
- and model_config.hidden_size <= 128
- ):
- pytest.skip("Skipping autotuning for small model configs")
-
- # Prevent OOM for large intermediate sizes
- if model_config.intermediate_size > 2048:
- model_config.intermediate_size = 1024
- if model_config.hidden_size > 2048:
- model_config.hidden_size = 1024
-
- use_W2 = not use_W1
- X1, X2, W1, W2, gating_output = make_inputs(
- M = data_config.bs * data_config.seq_len,
- N = model_config.intermediate_size,
- K = model_config.hidden_size,
- E = model_config.num_experts,
- topk = model_config.topk,
- dtype = data_config.dtype,
- requires_grad = True,
- )
- topk = model_config.topk
- num_experts = model_config.num_experts
- use_sigmoid = model_config.use_sigmoid
- renormalize = model_config.renormalize
-
- X = X1 if use_W1 else X2
- num_tokens = data_config.bs * data_config.seq_len
- total_tokens = num_tokens * topk
-
- E, K, N = W2.shape # E = num_experts, K = hidden_size, N = intermediate_size
- assert W1.shape == (E, 2 * N, K)
- W = W1 if use_W1 else W2
-
- if use_W1:
- assert X.shape == (
- num_tokens,
- K,
- ), f"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}"
- else:
- assert X.shape == (
- total_tokens,
- N,
- ), f"X.shape: {X.shape}, total_tokens: {total_tokens}, N: {N}"
-
- W_test = W.detach().clone().requires_grad_(True)
-
- topk_weights, topk_ids = calculate_topk(
- gating_output, topk, use_sigmoid = use_sigmoid, renormalize = renormalize
- )
- topk_weights = topk_weights.view(-1) # num_tokens * topk
- topk_ids = topk_ids.view(-1) # num_tokens * topk
-
- expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts = E)
- assert len(gather_indices) == total_tokens
- assert len(expert_token_counts) == num_experts
-
- atol, rtol = TOLERANCE[X.dtype]
- Xperm = permute(X, gather_indices, topk)
-
- # Need to retain grad otherwise grad is not propagated
- X.retain_grad()
- W.retain_grad()
- Xperm.retain_grad()
-
- assert Xperm.shape == (total_tokens, K) if use_W1 else (total_tokens, N)
-
- output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)
- ref_output = torch_grouped_gemm(X = Xperm, W = W, m_sizes = expert_token_counts)
- assert (
- ref_output.shape == output_shape
- ), f"ref_output.shape: {ref_output.shape}, output_shape: {output_shape}"
-
- if permute_y:
- ref_output = unpermute(ref_output, gather_indices)
-
- grad_output = torch.randn_like(ref_output)
- ref_output.backward(grad_output)
-
- assert X.grad is not None
- assert W.grad is not None
-
- ref_grad = Xperm.grad
-
- if autotune:
- # No need to run all configs for autotuning
- from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dX_kernel
-
- if num_autotune_configs is not None:
- _autotuned_grouped_gemm_dX_kernel.configs = (
- _autotuned_grouped_gemm_dX_kernel.configs[:num_autotune_configs]
- )
-
- if use_autograd:
- from grouped_gemm.interface import grouped_gemm
-
- if not autotune:
- kernel_config_fwd = KernelConfigForward()
- kernel_config_bwd_dX = KernelConfigBackward_dX(
- use_tma_load_dy = use_tma_load_dy,
- use_tma_load_w = use_tma_load_w,
- use_tma_store = use_tma_store,
- BLOCK_SIZE_M = BLOCK_SIZE_M,
- BLOCK_SIZE_N = BLOCK_SIZE_N,
- BLOCK_SIZE_K = BLOCK_SIZE_K,
- num_warps = num_warps,
- num_stages = num_stages,
- )
- kernel_config_bwd_dW = KernelConfigBackward_dW()
- else:
- from grouped_gemm.kernels.backward import (
- _autotuned_grouped_gemm_dW_kernel,
- _autotuned_grouped_gemm_dX_kernel,
- )
- from grouped_gemm.kernels.forward import (
- _autotuned_grouped_gemm_forward_kernel,
- )
-
- if num_autotune_configs is not None:
- _autotuned_grouped_gemm_dX_kernel.configs = (
- _autotuned_grouped_gemm_dX_kernel.configs[:num_autotune_configs]
- )
- _autotuned_grouped_gemm_forward_kernel.configs = (
- _autotuned_grouped_gemm_forward_kernel.configs[
- :num_autotune_configs
- ]
- )
-
- kernel_config_fwd = None
- kernel_config_bwd_dX = None
- X_ = (
- X.detach().clone().requires_grad_(True)
- if permute_x
- else Xperm.detach().clone().requires_grad_(True)
- )
- test_output = grouped_gemm(
- X = X_,
- W = W_test,
- m_sizes = expert_token_counts,
- gather_indices = gather_indices,
- topk = topk,
- permute_x = permute_x,
- permute_y = permute_y,
- autotune = autotune,
- kernel_config_fwd = kernel_config_fwd,
- kernel_config_bwd_dX = kernel_config_bwd_dX,
- is_first_gemm = use_W1,
- dX_only = True,
- )
- assert (
- test_output.shape == ref_output.shape
- ), f"test_output.shape: {test_output.shape}, ref_output.shape: {ref_output.shape}"
- assert torch.allclose(
- test_output, ref_output, atol = atol, rtol = rtol
- ), f"Grouped gemm backward_dX forward outputs mismatch: {(test_output - ref_output).abs().max().item():.6f}"
- test_output.backward(grad_output)
- assert X_.grad is not None
-
- # NOTE:need to handle grad differenlty in this case due to errors arising to do how torch autograd handles unpermute and sum reduction
- # the grad of Xperm unpermuted and reduced across topk should match X_.grad
- # However, both will have a numerical difference with that of ref_grad
- # This is due to the fact that torch autograd handles unpermute and sum reduction differently see: https://discuss.pytorch.org/t/permute-unpermute-gradient/219557 else:
- if permute_x and use_W1:
- X_grad_unperm = unpermute(Xperm.grad, gather_indices)
- manual_grad_check = X_grad_unperm.view(num_tokens, topk, K).sum(dim = 1)
- assert (
- manual_grad_check.shape == X_.grad.shape
- ), f"manual_grad_check.shape: {manual_grad_check.shape}, X_.grad.shape: {X_.grad.shape}"
- assert torch.allclose(
- manual_grad_check, X_.grad, atol = atol, rtol = rtol
- ), f"Grouped gemm backward_dX forward outputs mismatch: {(manual_grad_check - X_.grad).abs().max().item():.6f}"
- manual_diff = (X_.grad - manual_grad_check).abs().max().item()
- autograd_diff = (X_.grad - X.grad).abs().max().item()
- print(f"manual_diff: {manual_diff:.6f}, autograd_diff: {autograd_diff:.6f}")
- else:
- assert torch.allclose(
- X_.grad, ref_grad, atol = atol, rtol = rtol
- ), f"Grouped gemm backward_dX forward outputs mismatch: {(X_.grad - ref_grad).abs().max().item():.6f}"
- return
- else:
- dX_test = grouped_gemm_dX(
- dY = grad_output,
- W = W_test,
- gather_indices = gather_indices,
- m_sizes = expert_token_counts,
- topk = topk,
- permute_x = permute_x,
- permute_y = permute_y,
- use_tma_load_w = use_tma_load_w,
- use_tma_load_dy = use_tma_load_dy,
- use_tma_store = use_tma_store,
- autotune = autotune,
- BLOCK_SIZE_M = BLOCK_SIZE_M,
- BLOCK_SIZE_N = BLOCK_SIZE_N,
- BLOCK_SIZE_K = BLOCK_SIZE_K,
- num_warps = num_warps,
- num_stages = num_stages,
- flatten = flatten,
- # debug=True,
- )
-
- # if permute_x and use_W1 (first grouped GEMM) then the kernel should have unpermuted the dX
- # therefore we need to unpermute the ref_grad to compare to the output of the kernel
- if permute_x and use_W1:
- ref_grad = unpermute(ref_grad, gather_indices)
-
- assert (
- ref_grad.shape == dX_test.shape
- ), f"Grouped gemm manual backward_dX outputs mismatch: ref_grad: {ref_grad.shape}, dX_test: {dX_test.shape}"
- diff = (ref_grad - dX_test).abs().max().item()
-
- assert torch.allclose(
- ref_grad, dX_test, atol = atol, rtol = rtol
- ), f"Grouped gemm manual backward_dX outputs mismatch: {diff:.6f}"
-
- if permute_x and use_W1:
- # Show that reduction results in diffs
- # First calculate X.grad manually by backpropping through unpermuted ref_grad
- dX_ref_check = ref_grad.view(num_tokens, topk, K).sum(dim = 1)
- # Do the same for the actual output of the kernel
- dX_test_check = dX_test.view(num_tokens, topk, K).sum(dim = 1)
- # Show diffs for each combination
- diff_ref_check = (X.grad - dX_ref_check).abs().max().item()
- diff_test_check = (X.grad - dX_test_check).abs().max().item()
- diff_check_test = (dX_ref_check - dX_test_check).abs().max().item()
- print(
- f"diff_ref_check: {diff_ref_check:.6f}, diff_test_check: {diff_test_check:.6f}, diff_check_test: {diff_check_test:.6f}"
- )
-
-
-# NOTE: We reduce the size of the Llama4 model configs to prevent OOM
-# Important to note that for the full model size (5120, 8192), the tests do result in diffs on the order of 1e-2.
-@pytest.mark.parametrize(
- "kernel_config",
- KERNEL_CONFIGS_BWD_dX,
- ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),
-)
-@pytest.mark.parametrize(
- "model_config",
- SMALL_MODEL_CONFIGS[:1] + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_backward_dX_manual(
- data_config: DataConfig,
- model_config: ModelConfig,
- kernel_config: KernelConfigBackward_dX,
- use_W1: bool,
-):
- _test_grouped_gemm_backward_dX(
- data_config = data_config,
- model_config = model_config,
- use_W1 = use_W1,
- use_autograd = False,
- **asdict(kernel_config),
- )
-
-
-@pytest.mark.parametrize(
- "kernel_config",
- KERNEL_CONFIGS_BWD_dX,
- ids = lambda x: x.to_string(include_tuning_params = True, include_tma = True),
-)
-@pytest.mark.parametrize(
- "model_config",
- SMALL_MODEL_CONFIGS[:1] + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_backward_dX_manual_autograd(
- data_config: DataConfig,
- model_config: ModelConfig,
- kernel_config: KernelConfigBackward_dX,
- use_W1: bool,
-):
- _test_grouped_gemm_backward_dX(
- data_config = data_config,
- model_config = model_config,
- use_W1 = use_W1,
- use_autograd = True,
- **asdict(kernel_config),
- )
-
-
-@pytest.mark.parametrize(
- "num_autotune_configs", [20], ids = lambda x: f"num_autotune_configs={x}"
-)
-@pytest.mark.parametrize(
- "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
-)
-@pytest.mark.parametrize(
- "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
-)
-@pytest.mark.parametrize(
- "model_config",
- [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_backward_dX_autotune(
- data_config: DataConfig,
- model_config: ModelConfig,
- permute_x: bool,
- permute_y: bool,
- use_W1: bool,
- num_autotune_configs: int,
-):
- # TMA loads / stores will be autotuned
- _test_grouped_gemm_backward_dX(
- data_config = data_config,
- model_config = model_config,
- permute_x = permute_x,
- permute_y = permute_y,
- use_W1 = use_W1,
- autotune = True,
- use_autograd = False,
- num_autotune_configs = num_autotune_configs,
- )
-
-
-@pytest.mark.parametrize(
- "num_autotune_configs", [20], ids = lambda x: f"num_autotune_configs={x}"
-)
-@pytest.mark.parametrize(
- "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
-)
-@pytest.mark.parametrize(
- "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
-)
-@pytest.mark.parametrize(
- "model_config",
- [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_backward_dX_autotune_autograd(
- data_config: DataConfig,
- model_config: ModelConfig,
- permute_x: bool,
- permute_y: bool,
- use_W1: bool,
- num_autotune_configs: int,
-):
- # TMA loads / stores will be autotuned
- _test_grouped_gemm_backward_dX(
- data_config = data_config,
- model_config = model_config,
- permute_x = permute_x,
- permute_y = permute_y,
- use_W1 = use_W1,
- autotune = True,
- use_autograd = True,
- num_autotune_configs = num_autotune_configs,
- )
-
-
-def _test_grouped_gemm_backward_dW(
- data_config: DataConfig,
- model_config: ModelConfig,
- permute_x: bool,
- permute_y: bool,
- use_W1: bool,
- use_tma_load_dy: bool = False,
- use_tma_load_x: bool = False,
- use_tma_store: bool = False,
- BLOCK_SIZE_M: int = None,
- BLOCK_SIZE_N: int = None,
- BLOCK_SIZE_K: int = None,
- num_warps: int = None,
- num_stages: int = None,
- flatten: bool = True,
- autotune: bool = False,
- num_autotune_configs: int = None,
- allow_tma_store: bool = False,
- debug: bool = False,
- fuse_mul_post: bool = False, # Unused for backward_dW
- use_autograd: bool = False,
-):
- if not check_valid_config(
- permute_x,
- permute_y,
- fuse_mul_post = fuse_mul_post,
- use_W1 = use_W1,
- is_backward = True,
- ):
- pytest.skip(
- f"Skipping test due to invalid config: {permute_x = } {permute_y = } {use_W1 = }"
- )
-
- if use_tma_store and not allow_tma_store:
- pytest.skip("TMA store needs to be debugged due to non-deterministic behavior")
-
- X1, X2, W1, W2, gating_output = make_inputs(
- M = data_config.bs * data_config.seq_len,
- N = model_config.intermediate_size,
- K = model_config.hidden_size,
- E = model_config.num_experts,
- topk = model_config.topk,
- dtype = data_config.dtype,
- requires_grad = True,
- )
- topk = model_config.topk
- num_experts = model_config.num_experts
- use_sigmoid = model_config.use_sigmoid
- renormalize = model_config.renormalize
-
- X = X1 if use_W1 else X2
- num_tokens = data_config.bs * data_config.seq_len
- E, K, N = W2.shape # E = num_experts, K = hidden_size, N = intermediate_size
- assert W1.shape == (E, 2 * N, K)
- W = W1 if use_W1 else W2
-
- if use_W1:
- assert X.shape == (
- num_tokens,
- K,
- ), f"X.shape: {X.shape}, num_tokens: {num_tokens}, K: {K}"
- else:
- assert X.shape == (
- num_tokens * topk,
- N,
- ), f"X.shape: {X.shape}, num_tokens: {num_tokens}, topk: {topk}, N: {N}"
-
- total_tokens = num_tokens * topk
- output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)
-
- X_test = X.detach().clone().requires_grad_(True)
- W_test = W.detach().clone().requires_grad_(True)
-
- topk_weights, topk_ids = calculate_topk(
- gating_output, topk, use_sigmoid = use_sigmoid, renormalize = renormalize
- )
- topk_weights = topk_weights.view(-1) # num_tokens * topk
- topk_ids = topk_ids.view(-1) # num_tokens * topk
-
- expert_token_counts, gather_indices = get_routing_indices(topk_ids, num_experts = E)
- assert len(gather_indices) == total_tokens
- assert len(expert_token_counts) == num_experts
-
- atol, rtol = TOLERANCE[X.dtype]
- Xperm = permute(X, gather_indices, topk)
- Xperm_test = Xperm.detach().clone().requires_grad_(True)
-
- # Need to retain grad otherwise grad is not propagated
- X.retain_grad()
- W.retain_grad()
- Xperm.retain_grad()
- assert Xperm.shape == (total_tokens, K) if use_W1 else (total_tokens, N)
-
- output_shape = (total_tokens, 2 * N) if use_W1 else (total_tokens, K)
-
- ref_output = torch_grouped_gemm(X = Xperm, W = W, m_sizes = expert_token_counts)
- assert ref_output.shape == output_shape
-
- # if permute_y then the assumption is that the output of grouped_gemm was unpermuted on store
- # Therefore we have to unpermute before backpropping to ensure proper alignment
- if permute_y:
- ref_output = unpermute(ref_output, gather_indices)
-
- grad_output = torch.randn_like(ref_output)
- ref_output.backward(grad_output)
- assert X.grad is not None
- assert W.grad is not None
-
- # Test backward kernel directly
- X_ = X_test if permute_x else Xperm_test
-
- if debug:
- torch.set_printoptions(precision = 4)
- for i in range(num_experts):
- print(f"Expert {i} weight grad:\n{W.grad[i, :5, :5]}")
-
- if autotune:
- from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dW_kernel
-
- if num_autotune_configs is not None:
- _autotuned_grouped_gemm_dW_kernel.configs = (
- _autotuned_grouped_gemm_dW_kernel.configs[:num_autotune_configs]
- )
-
- if use_autograd:
- from grouped_gemm.interface import grouped_gemm
-
- if not autotune:
- kernel_config_fwd = KernelConfigForward(
- # Only care about backward_dW config
- use_tma_load_w = False,
- use_tma_load_x = False,
- use_tma_store = False,
- BLOCK_SIZE_M = BLOCK_SIZE_M,
- BLOCK_SIZE_N = BLOCK_SIZE_N,
- BLOCK_SIZE_K = BLOCK_SIZE_K,
- num_warps = num_warps,
- num_stages = num_stages,
- )
- kernel_config_bwd_dW = KernelConfigBackward_dW(
- use_tma_load_dy = use_tma_load_dy,
- use_tma_load_x = use_tma_load_x,
- use_tma_store = use_tma_store,
- BLOCK_SIZE_M = BLOCK_SIZE_M,
- BLOCK_SIZE_N = BLOCK_SIZE_N,
- BLOCK_SIZE_K = BLOCK_SIZE_K,
- num_warps = num_warps,
- num_stages = num_stages,
- )
- else:
- from grouped_gemm.kernels.backward import _autotuned_grouped_gemm_dW_kernel
- from grouped_gemm.kernels.forward import (
- _autotuned_grouped_gemm_forward_kernel,
- )
-
- if num_autotune_configs is not None:
- _autotuned_grouped_gemm_forward_kernel.configs = (
- _autotuned_grouped_gemm_forward_kernel.configs[
- :num_autotune_configs
- ]
- )
- _autotuned_grouped_gemm_dW_kernel.configs = (
- _autotuned_grouped_gemm_dW_kernel.configs[:num_autotune_configs]
- )
- kernel_config_fwd = None
- kernel_config_bwd_dW = None
-
- test_output = grouped_gemm(
- X = X_,
- W = W_test,
- m_sizes = expert_token_counts,
- gather_indices = gather_indices,
- topk = topk,
- permute_x = permute_x,
- permute_y = permute_y,
- kernel_config_fwd = kernel_config_fwd,
- kernel_config_bwd_dW = kernel_config_bwd_dW,
- autotune = autotune,
- is_first_gemm = use_W1,
- dW_only = True,
- )
- assert (
- test_output.shape == ref_output.shape
- ), f"Grouped gemm autograd backward_dW outputs mismatch: {test_output.shape} != {ref_output.shape}"
- assert torch.allclose(
- test_output, ref_output, atol = atol, rtol = rtol
- ), f"Grouped gemm autograd backward_dW forward outputs mismatch: {test_output.shape} != {ref_output.shape}"
- test_output.backward(grad_output)
- assert W_test.grad is not None
- dW_test = W_test.grad
- else:
- dW_test = grouped_gemm_dW(
- dY = grad_output,
- X = X_,
- m_sizes = expert_token_counts,
- gather_indices = gather_indices,
- topk = topk,
- permute_x = permute_x,
- permute_y = permute_y,
- use_tma_load_dy = use_tma_load_dy,
- use_tma_load_x = use_tma_load_x,
- use_tma_store = use_tma_store,
- BLOCK_SIZE_M = BLOCK_SIZE_M,
- BLOCK_SIZE_N = BLOCK_SIZE_N,
- BLOCK_SIZE_K = BLOCK_SIZE_K,
- num_warps = num_warps,
- num_stages = num_stages,
- flatten = flatten,
- autotune = autotune,
- debug = debug,
- )
- assert (
- W.grad.shape == dW_test.shape
- ), f"Grouped gemm manual backward_dW outputs mismatch: W.grad: {W.grad.shape}, dW_test: {dW_test.shape}"
-
- if debug:
- with torch.no_grad():
- if not torch.allclose(W.grad, dW_test, atol = atol, rtol = rtol):
- print(f"Ref Wgrad sum: {W.grad.sum().item():.4f}")
- print(f"Test Wgrad sum: {dW_test.sum().item():.4f}")
-
- for i in range(num_experts):
- print(f"Expert {i} weight grad:\n{W.grad[i, :5, :5]}")
- print(f"Expert {i} dW_test:\n{dW_test[i, :5, :5]}")
- expert_diff = (W.grad[i, :, :] - dW_test[i, :, :]).abs().max().item()
- print(f"Expert {i} diff: {expert_diff:.6f}")
-
- diff = (W.grad - dW_test).abs().max().item()
- assert (
- False
- ), f"Grouped gemm manual backward_dW outputs mismatch: {diff:.6f}"
- else:
- diff = (W.grad - dW_test).abs().max().item()
- assert torch.allclose(
- W.grad, dW_test, atol = atol, rtol = rtol
- ), f"Grouped gemm manual backward_dW outputs mismatch: {diff:.6f}"
-
-
-@pytest.mark.parametrize(
- "kernel_config",
- KERNEL_CONFIGS_BWD_dW,
- ids = lambda x: x.to_string(include_tuning_params = False, include_tma = True),
-)
-@pytest.mark.parametrize(
- "model_config",
- SMALL_MODEL_CONFIGS + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_backward_dW_manual(
- data_config: DataConfig,
- model_config: ModelConfig,
- kernel_config: KernelConfig,
- use_W1: bool,
- debug: bool = False,
-):
- _test_grouped_gemm_backward_dW(
- data_config = data_config,
- model_config = model_config,
- use_W1 = use_W1,
- use_autograd = False,
- **asdict(kernel_config),
- )
-
-
-@pytest.mark.parametrize(
- "kernel_config",
- KERNEL_CONFIGS_BWD_dW,
- ids = lambda x: x.to_string(include_tuning_params = False, include_tma = True),
-)
-@pytest.mark.parametrize(
- "model_config",
- SMALL_MODEL_CONFIGS + [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_backward_dW_manual_autograd(
- data_config: DataConfig,
- model_config: ModelConfig,
- kernel_config: KernelConfig,
- use_W1: bool,
- debug: bool = False,
-):
- _test_grouped_gemm_backward_dW(
- data_config = data_config,
- model_config = model_config,
- use_W1 = use_W1,
- use_autograd = True,
- **asdict(kernel_config),
- )
-
-
-@pytest.mark.parametrize(
- "num_autotune_configs", [20], ids = lambda x: f"num_autotune_configs={x}"
-)
-@pytest.mark.parametrize(
- "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
-)
-@pytest.mark.parametrize(
- "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
-)
-@pytest.mark.parametrize(
- "model_config",
- [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_backward_dW_autotune(
- data_config: DataConfig,
- model_config: ModelConfig,
- permute_x: bool,
- permute_y: bool,
- use_W1: bool,
- num_autotune_configs: int,
-):
- _test_grouped_gemm_backward_dW(
- data_config = data_config,
- model_config = model_config,
- use_W1 = use_W1,
- permute_x = permute_x,
- permute_y = permute_y,
- autotune = True,
- use_autograd = False,
- num_autotune_configs = num_autotune_configs,
- )
-
-
-@pytest.mark.parametrize(
- "num_autotune_configs", [20], ids = lambda x: f"num_autotune_configs={x}"
-)
-@pytest.mark.parametrize(
- "permute_x", [True, False], ids = lambda x: "permute_x" if x else ""
-)
-@pytest.mark.parametrize(
- "permute_y", [True, False], ids = lambda x: "permute_y" if x else ""
-)
-@pytest.mark.parametrize(
- "model_config",
- [LLAMA_MODEL_CONFIG, QWEN_MODEL_CONFIG],
- ids = lambda x: f"topk={x.topk} num_experts={x.num_experts} hidden_size={x.hidden_size} intermediate_size={x.intermediate_size}",
-)
-@pytest.mark.parametrize(
- "data_config", DATA_CONFIGS, ids = lambda x: f"seq_len={x.seq_len} dtype={x.dtype}"
-)
-@pytest.mark.parametrize("use_W1", [True, False], ids = lambda x: f"use_W1={x}")
-def test_grouped_gemm_backward_dW_autotune_autograd(
- data_config: DataConfig,
- model_config: ModelConfig,
- permute_x: bool,
- permute_y: bool,
- use_W1: bool,
- num_autotune_configs: int,
-):
- _test_grouped_gemm_backward_dW(
- data_config = data_config,
- model_config = model_config,
- use_W1 = use_W1,
- permute_x = permute_x,
- permute_y = permute_y,
- autotune = True,
- use_autograd = True,
- num_autotune_configs = num_autotune_configs,
- )
diff --git a/unsloth/kernels/moe/tests/test_llama4_moe.py b/unsloth/kernels/moe/tests/test_llama4_moe.py
deleted file mode 100644
index 13ad552bf4..0000000000
--- a/unsloth/kernels/moe/tests/test_llama4_moe.py
+++ /dev/null
@@ -1,262 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-import argparse
-import sys
-from contextlib import contextmanager
-from functools import partial
-
-import pytest
-import torch
-from transformers import AutoConfig
-from transformers.models.llama4 import Llama4Config, Llama4TextConfig
-from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
-
-from grouped_gemm.kernels.tuning import (
- KernelConfigBackward_dW,
- KernelConfigBackward_dX,
- KernelConfigForward,
-)
-from grouped_gemm.reference.layers.llama4_moe import (
- Llama4GroupedGemmTextMoe,
- Llama4TritonTextMoe,
-)
-
-TOLERANCES = {
- torch.bfloat16: (1e-2, 1e-2),
- torch.float16: (1e-3, 1e-3),
- torch.float: (1e-5, 1e-5),
-}
-
-LLAMA4_SCOUT_ID = "meta-llama/Llama-4-Scout-17B-16E"
-SEED = 42
-SEQ_LENS = [1024]
-DTYPES = [torch.bfloat16]
-# Reduce the number of autotuning configs to prevent excessive runtime
-NUM_AUTOTUNE_CONFIGS = 50
-
-
-@contextmanager
-def annotated_context(prelude, epilogue = "Passed!", char = "-", num_chars = 80):
- print(char * num_chars)
- print(prelude)
- yield
- print(epilogue)
- print(char * num_chars)
-
-
-def get_text_config(model_id):
- config: Llama4Config = AutoConfig.from_pretrained(model_id)
- return config.text_config
-
-
-def prep_triton_kernel_traits(autotune):
- if not autotune:
- kernel_config_fwd = KernelConfigForward()
- kernel_config_bwd_dW = KernelConfigBackward_dW()
- kernel_config_bwd_dX = KernelConfigBackward_dX()
- else:
- from grouped_gemm.kernels.backward import (
- _autotuned_grouped_gemm_dW_kernel,
- _autotuned_grouped_gemm_dX_kernel,
- )
- from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel
-
- # Hack to reduce number of autotuning configs
- _autotuned_grouped_gemm_forward_kernel.configs = (
- _autotuned_grouped_gemm_forward_kernel.configs[:NUM_AUTOTUNE_CONFIGS]
- )
- _autotuned_grouped_gemm_dW_kernel.configs = (
- _autotuned_grouped_gemm_dW_kernel.configs[:NUM_AUTOTUNE_CONFIGS]
- )
- _autotuned_grouped_gemm_dX_kernel.configs = (
- _autotuned_grouped_gemm_dX_kernel.configs[:NUM_AUTOTUNE_CONFIGS]
- )
-
- kernel_config_fwd = None
- kernel_config_bwd_dW = None
- kernel_config_bwd_dX = None
-
- return kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX
-
-
-def sparse_to_dense(t: torch.Tensor):
- t = t.sum(dim = 0).view(-1)
- return t
-
-
-@torch.no_grad()
-def _check_diff(
- t1: torch.Tensor,
- t2: torch.Tensor,
- atol,
- rtol,
- precision = ".6f",
- verbose = False,
- msg = "",
-):
- t2 = t2.view_as(t1)
- diff = t1.sub(t2).abs().max().item()
- if verbose:
- if msg == "":
- msg = "diff"
- print(f"{msg}: {diff:{precision}}")
- assert torch.allclose(t1, t2, atol = atol, rtol = rtol)
-
-
-def run_backwards(y: torch.Tensor, grad_output: torch.Tensor, module: torch.nn.Module):
- y.backward(grad_output)
- for name, param in module.named_parameters():
- assert param.grad is not None, f"{name} missing grad!"
-
-
-def _check_grads(
- m1: torch.nn.Module,
- m2: torch.nn.Module,
- atol,
- rtol,
- precision = ".6f",
- verbose = False,
- msg = "",
-):
- for name, param in m1.named_parameters():
- _check_diff(
- param.grad,
- m2.get_parameter(name).grad,
- atol = atol,
- rtol = rtol,
- precision = precision,
- verbose = verbose,
- msg = f"{msg}:{name}.grad",
- )
-
-
-@pytest.fixture
-def model_config():
- return AutoConfig.from_pretrained(LLAMA4_SCOUT_ID).text_config
-
-
-@pytest.mark.parametrize(
- "overlap_router_shared",
- [False, True],
- ids = lambda x: "overlap_router_shared" if x else "no_overlap",
-)
-@pytest.mark.parametrize(
- "permute_y", [False, True], ids = lambda x: "permute_y" if x else "no_permute_y"
-)
-@pytest.mark.parametrize(
- "permute_x", [False], ids = lambda x: "permute_x" if x else "no_permute_x"
-) # Llama4 does not support permute_x
-@pytest.mark.parametrize(
- "autotune", [True], ids = lambda x: "autotune" if x else "manual"
-)
-@pytest.mark.parametrize("seqlen", SEQ_LENS, ids = lambda x: f"seqlen={x}")
-@pytest.mark.parametrize("dtype", DTYPES, ids = str)
-def test_llama4_ref(
- dtype: torch.dtype,
- seqlen,
- autotune: bool,
- permute_x: bool,
- permute_y: bool,
- overlap_router_shared: bool,
- model_config: Llama4TextConfig, # test fixture
- bs: int = 1,
- device = "cuda",
- precision = ".6f",
- verbose = False,
-):
- torch.manual_seed(
- SEED
- ) # Should not be needed when running using pytest -- autouse fixture in conftest.py
- device = "cuda"
- hidden_dim = model_config.hidden_size
- atol, rtol = TOLERANCES[dtype]
- check_diff = partial(
- _check_diff, atol = atol, rtol = rtol, precision = precision, verbose = verbose
- )
- check_grads = partial(
- _check_grads, atol = atol, rtol = rtol, precision = precision, verbose = verbose
- )
-
- # Reference op -- HF
- llama4_ref = Llama4TextMoe(model_config).to(dtype = dtype, device = device)
-
- # Torch grouped gemm impl
- llama4_gg_ref = Llama4GroupedGemmTextMoe(
- model_config, overlap_router_shared = overlap_router_shared
- ).to(dtype = dtype, device = device)
- llama4_gg_ref.copy_weights(llama4_ref)
- llama4_gg_ref.check_weights(llama4_ref)
-
- x_ref = torch.randn(
- bs, seqlen, hidden_dim, dtype = dtype, device = device, requires_grad = True
- )
- x_torch_gg = x_ref.detach().clone().requires_grad_()
- x_triton = x_ref.detach().clone().requires_grad_()
-
- y_ref, routing_ref = llama4_ref(x_ref)
- y_torch_gg, routing_torch_gg = llama4_gg_ref(x_torch_gg)
- assert y_ref.shape == y_torch_gg.shape, f"{y_ref.shape} != {y_torch_gg.shape}"
- with annotated_context("Testing torch grouped gemm Llama4TextMoe"):
- check_diff(y_ref, y_torch_gg, msg = "y_torch_gg")
- check_diff(
- sparse_to_dense(routing_ref), routing_torch_gg, msg = "routing_torch_gg"
- )
-
- kernel_config_fwd, kernel_config_bwd_dW, kernel_config_bwd_dX = (
- prep_triton_kernel_traits(autotune)
- )
-
- llama4_triton = Llama4TritonTextMoe(
- model_config,
- overlap_router_shared = overlap_router_shared,
- permute_x = permute_x,
- permute_y = permute_y,
- autotune = autotune,
- kernel_config_fwd = kernel_config_fwd,
- kernel_config_bwd_dW = kernel_config_bwd_dW,
- kernel_config_bwd_dX = kernel_config_bwd_dX,
- ).to(device = device, dtype = dtype)
- llama4_triton.copy_weights(llama4_ref)
- llama4_triton.check_weights(llama4_ref)
-
- y_triton, routing_triton = llama4_triton(x_triton)
- with annotated_context("Testing triton grouped gemm Llama4TextMoe forward"):
- check_diff(y_ref, y_triton, msg = "y_triton")
- check_diff(sparse_to_dense(routing_ref), routing_triton, msg = "routing_triton")
-
- ref_grad = torch.randn_like(y_ref)
- run_backwards(y_ref, ref_grad, llama4_ref)
- run_backwards(y_torch_gg, ref_grad, llama4_gg_ref)
- with annotated_context("Testing torch group gemm Llama4TextMoe backward"):
- check_grads(llama4_ref, llama4_gg_ref, msg = "torch_gg")
-
- run_backwards(y_triton, ref_grad, llama4_triton)
- with annotated_context("Testing triton group gemm Llama4TextMoe backward"):
- check_grads(llama4_ref, llama4_triton, msg = "triton")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--seqlen", type = int, default = 1024)
- parser.add_argument(
- "--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
- )
- args = parser.parse_args()
- args.dtype = getattr(torch, args.dtype)
- args_dict = vars(args)
-
- model_id = LLAMA4_SCOUT_ID
-
- text_config: Llama4TextConfig = get_text_config(model_id)
- for overlap in [False, True]:
- test_llama4_ref(
- seqlen = args.seqlen,
- model_config = text_config,
- dtype = args.dtype,
- autotune = True,
- permute_x = False,
- permute_y = True,
- overlap_router_shared = overlap,
- verbose = True,
- )
diff --git a/unsloth/kernels/moe/tests/test_qwen3_moe.py b/unsloth/kernels/moe/tests/test_qwen3_moe.py
deleted file mode 100644
index 42a8356c09..0000000000
--- a/unsloth/kernels/moe/tests/test_qwen3_moe.py
+++ /dev/null
@@ -1,273 +0,0 @@
-# SPDX-License-Identifier: GNU Affero General Public License v3.0
-# Copyright 2023-present the Unsloth team. All rights reserved.
-
-import argparse
-from contextlib import contextmanager
-
-import pytest
-import torch
-from transformers import AutoConfig
-from transformers.models.qwen3_moe import Qwen3MoeConfig
-from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
-
-from grouped_gemm.kernels.tuning import (
- KernelConfigBackward_dW,
- KernelConfigBackward_dX,
- KernelConfigForward,
-)
-from grouped_gemm.reference.layers.qwen3_moe import Qwen3MoeGroupedGEMMBlock
-
-from .moe_utils import (
- Qwen3MoeFusedGroupedGEMMBlock,
- check_fwd,
- check_grads,
- check_grouped_gemm_results,
- run_backward,
- run_forward,
-)
-
-"""
-Qwen3 MoE tests
-
-NOTE: Test this as a module and NOT with pytest as running with pytest results in random numerical errors: python -m tests.test_qwen3_moe --permute_x --permute_y --autotune NOT pytest -sv tests/test_qwen3_moe.py
-More specifically, all tests pass when run individually, but some will fail randomly (even with the same seed) when the entire test is run as a parametrized test suite using pytest, likely due to how pytest interacts with triton / autotuning.
-
-See tests/run_qwen3_moe_tests.sh for a script that runs all the tests
-
-The tests run the following:
-Huggingface's Qwen3 MoE block (Qwen3MoeSparseMoeBlock)
-Torch-native grouped gemm version of MoE block (Qwen3MoeGroupedGEMMBlock), which is the HF block with the expert computation replaced with a torch-native grouped gemm
-Triton kernel grouped gemm version of MoE block (Qwen3MoeFusedGroupedGEMMBlock), which is the HF block with the expert computation replaced with the fused triton grouped gemm kernel
-
-The tests check the following:
-- HF MoE block vs torch grouped gemm MoE block (sanity check)
-- torch grouped gemm MoE block vs fused grouped gemm MoE block -- this allows us to test each of the intermediate results for easier debugging
-- HF MoE block vs fused grouped gemm MoE block -- this is the actual test
-
-Both forward and backward passes are tests:
-- forward: output of the moe block
-- backwards:
- - X: gradient of the input to the moe block
- - gate.weight: gradient of the gate weights (router weights)
- - gate_proj: gradient of concatenated gate projections
- - up_proj: gradient of the concatenated up projections
- - down_proj: gradient of the concatenated down projections
-
-Additionally, for the torch grouped gemm and triton grouped gemm versions, the intermediate outputs of the forward pass are checked:
-- first_gemm: output of the first grouped gemm (X @ fused_gate_proj)
-- intermediate: output of silu_mul(first_gemm)
-- second_gemm: output of the second grouped gemm (intermediate @ down_proj)
-- hidden_states_unpermute: output of the second_gemm after unpermuting back to token order (from expert grouped order); in the case where the permutation is fused in the triton kernel, this is the same as second_gemm
-- hidden_states: output with the topk_weights applied
-"""
-
-TOLERANCES = {
- torch.bfloat16: (1e-2, 1e-2),
- torch.float16: (1e-3, 1e-3),
- torch.float: (1e-5, 1e-5),
-}
-
-
-@pytest.fixture(scope = "module")
-def model_id():
- return "Qwen/Qwen3-30B-A3B"
-
-
-@pytest.fixture(scope = "module")
-def config(model_id: str):
- return AutoConfig.from_pretrained(model_id)
-
-
-@contextmanager
-def annotated_context(prelude, epilogue = "Passed!", char = "-", num_chars = 80):
- print(char * num_chars)
- print(prelude)
- yield
- print(epilogue)
- print(char * num_chars)
-
-
-SEED = 42
-SEQ_LENS = [1024]
-DTYPES = [torch.bfloat16]
-
-# Reduce the number of autotuning configs to prevent excessive runtime
-NUM_AUTOTUNE_CONFIGS = 50
-
-
-@pytest.mark.parametrize(
- "permute_y", [True], ids = lambda x: "permute_y" if x else "no_permute_y"
-)
-@pytest.mark.parametrize(
- "permute_x", [True], ids = lambda x: "permute_x" if x else "no_permute_x"
-)
-@pytest.mark.parametrize(
- "autotune", [True], ids = lambda x: "autotune" if x else "manual"
-)
-@pytest.mark.parametrize("seqlen", SEQ_LENS, ids = lambda x: f"seqlen={x}")
-@pytest.mark.parametrize("dtype", DTYPES, ids = str)
-def test_qwen3_moe(
- config: Qwen3MoeConfig,
- seqlen: int,
- dtype: torch.dtype,
- permute_x: bool,
- permute_y: bool,
- autotune: bool,
-):
- torch.manual_seed(
- SEED
- ) # Should not be needed when running using pytest -- autouse fixture in conftest.py
- device = "cuda"
- hidden_size = config.hidden_size
- bs = 1
- atol, rtol = TOLERANCES[dtype]
- # Reference op -- HF
- moe_block = Qwen3MoeSparseMoeBlock(config).to(device, dtype)
-
- # Torch-native grouped gemm version of MoE Block -- for sanity checking
- grouped_gemm_block = Qwen3MoeGroupedGEMMBlock.from_hf(moe_block).to(device, dtype)
- grouped_gemm_block.check_weights(moe_block)
-
- if not autotune:
- kernel_config_fwd = KernelConfigForward()
- kernel_config_bwd_dW = KernelConfigBackward_dW()
- kernel_config_bwd_dX = KernelConfigBackward_dX()
- else:
- from grouped_gemm.kernels.backward import (
- _autotuned_grouped_gemm_dW_kernel,
- _autotuned_grouped_gemm_dX_kernel,
- )
- from grouped_gemm.kernels.forward import _autotuned_grouped_gemm_forward_kernel
-
- # Hack to reduce number of autotuning configs
- _autotuned_grouped_gemm_forward_kernel.configs = (
- _autotuned_grouped_gemm_forward_kernel.configs[:NUM_AUTOTUNE_CONFIGS]
- )
- _autotuned_grouped_gemm_dW_kernel.configs = (
- _autotuned_grouped_gemm_dW_kernel.configs[:NUM_AUTOTUNE_CONFIGS]
- )
- _autotuned_grouped_gemm_dX_kernel.configs = (
- _autotuned_grouped_gemm_dX_kernel.configs[:NUM_AUTOTUNE_CONFIGS]
- )
-
- kernel_config_fwd = None
- kernel_config_bwd_dW = None
- kernel_config_bwd_dX = None
-
- # Triton kernel grouped gemm version of MoE Block -- this is what we're testing
- fused_gemm_block = Qwen3MoeFusedGroupedGEMMBlock.from_hf(
- moe_block,
- permute_x = permute_x,
- permute_y = permute_y,
- autotune = autotune,
- kernel_config_fwd = kernel_config_fwd,
- kernel_config_bwd_dW = kernel_config_bwd_dW,
- kernel_config_bwd_dX = kernel_config_bwd_dX,
- ).to(device, dtype)
- fused_gemm_block.check_weights(moe_block)
-
- X = torch.randn(
- bs, seqlen, hidden_size, dtype = dtype, device = device, requires_grad = True
- )
-
- # Forward
- ref_result = run_forward(moe_block, X, is_grouped_gemm = False)
- grouped_result = run_forward(grouped_gemm_block, X, is_grouped_gemm = True)
- fused_result = run_forward(fused_gemm_block, X, is_grouped_gemm = True)
-
- with annotated_context(
- "Testing forward pass",
- epilogue = "Passed forward tests!",
- char = "=",
- num_chars = 100,
- ):
- # Sanity checks
-
- with annotated_context(
- "Checking HF vs torch grouped gemm MoE forward outputs..."
- ):
- check_fwd(ref_result, grouped_result, atol, rtol, verbose = False)
-
- with annotated_context(
- "Checking torch grouped gemm MoE vs fused grouped gemm MoE forward outputs..."
- ):
- # We implement a custom check for grouped gemm results to test each of the intermediate results for easier debugging
- check_grouped_gemm_results(
- grouped_result.grouped_gemm_result,
- fused_result.grouped_gemm_result,
- permute_y = permute_y,
- atol = atol,
- rtol = rtol,
- verbose = False,
- )
- # Actual test
- with annotated_context(
- "Checking HF vs fused grouped gemm MoE forward outputs..."
- ):
- check_fwd(ref_result, fused_result, atol, rtol, verbose = True)
-
- # Backward
- grad_output = torch.randn_like(ref_result.output)
- ref_backward_result = run_backward(
- moe_block, grad_output, output = ref_result.output, X = ref_result.X
- )
- grouped_backward_result = run_backward(
- grouped_gemm_block,
- grad_output,
- output = grouped_result.output,
- X = grouped_result.X,
- )
- fused_backward_result = run_backward(
- fused_gemm_block, grad_output, output = fused_result.output, X = fused_result.X
- )
-
- with annotated_context(
- "Testing backward pass",
- epilogue = "Passed backward tests!",
- char = "=",
- num_chars = 100,
- ):
- # Sanity checks
- with annotated_context("Checking HF vs torch grouped gemm MoE grads..."):
- check_grads(
- ref_backward_result, grouped_backward_result, atol, rtol, verbose = False
- )
- with annotated_context(
- "Checking torch grouped gemm MoE vs fused grouped gemm MoE grads..."
- ):
- check_grads(
- grouped_backward_result,
- fused_backward_result,
- atol,
- rtol,
- verbose = False,
- )
-
- # Actual test
- with annotated_context("Checking HF vs fused grouped gemm MoE grads..."):
- check_grads(
- ref_backward_result, fused_backward_result, atol, rtol, verbose = True
- )
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--seqlen", type = int, default = 1024)
- parser.add_argument(
- "--dtype", type = str, choices = ["bfloat16", "float16"], default = "bfloat16"
- )
- parser.add_argument("--permute_x", action = "store_true")
- parser.add_argument("--permute_y", action = "store_true")
- parser.add_argument("--autotune", action = "store_true")
- args = parser.parse_args()
- args.dtype = getattr(torch, args.dtype)
- args_dict = vars(args)
-
- model_id = "Qwen/Qwen3-30B-A3B"
- config = AutoConfig.from_pretrained(model_id)
- atol, rtol = TOLERANCES[args.dtype]
-
- print(
- f"Testing {model_id} with seqlen={args.seqlen}, dtype={args.dtype}, permute_x={args.permute_x}, permute_y={args.permute_y}, autotune={args.autotune}, atol={atol}, rtol={rtol}"
- )
- test_qwen3_moe(config, **args_dict)
diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py
index 74c16c1e63..8f54e74908 100644
--- a/unsloth/kernels/rms_layernorm.py
+++ b/unsloth/kernels/rms_layernorm.py
@@ -15,27 +15,22 @@
import triton
import triton.language as tl
import torch
-from .utils import calculate_settings, torch_gpu_device
-
+from .utils import calculate_settings, torch_cuda_device
@triton.jit
def _rms_layernorm_forward(
- Y,
- Y_row_stride: tl.constexpr,
- X,
- X_row_stride: tl.constexpr,
- W,
- W_row_stride: tl.constexpr,
- r,
- r_row_stride: tl.constexpr,
- n_cols: tl.constexpr,
- eps: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
+ Y, Y_row_stride,
+ X, X_row_stride,
+ W, W_row_stride,
+ r, r_row_stride : tl.constexpr,
+ n_cols : tl.constexpr,
+ eps : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
):
"""
- Fast RMS Layernorm kernel
- Inspiration from a Triton tutorial:
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
+ Fast RMS Layernorm kernel
+ Inspiration from a Triton tutorial:
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
@@ -46,72 +41,61 @@ def _rms_layernorm_forward(
r += row_idx * r_row_stride
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
- W_row = tl.load(W + col_offsets, mask = mask, other = 0) # .to(tl.float32)
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
- # Explicit float32 scalar to ensure correct type promotion on HIP/ROCm
- eps_f32 = tl.full((), eps, tl.float32)
- inv_var = tl.math.rsqrt(row_var + eps_f32)
+ inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
- normed = normed.to(W_row.dtype) # Exact copy from HF
+ normed = normed.to(W_row.dtype) # Exact copy from HF
output = normed * W_row
tl.store(Y + col_offsets, output, mask = mask)
+pass
def _rms_layernorm_backward(
- dY,
- dY_row_stride: tl.constexpr,
- dX,
- dX_row_stride: tl.constexpr,
- X,
- X_row_stride: tl.constexpr,
- W,
- W_row_stride: tl.constexpr,
- r,
- r_row_stride: tl.constexpr,
+ dY, dY_row_stride,
+ dX, dX_row_stride,
+ X, X_row_stride,
+ W, W_row_stride,
+ r, r_row_stride : tl.constexpr,
# dW, dW_row_stride,
- n_cols: tl.constexpr,
- eps: tl.constexpr,
- GEMMA: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
+ n_cols : tl.constexpr,
+ eps : tl.constexpr,
+ GEMMA : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
):
"""
- Fast RMS Layernorm kernel for the backward pass
- Inspiration from a Triton tutorial:
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
+ Fast RMS Layernorm kernel for the backward pass
+ Inspiration from a Triton tutorial:
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dY += row_idx * dY_row_stride
- X += row_idx * X_row_stride
- r += row_idx * r_row_stride
+ X += row_idx * X_row_stride
+ r += row_idx * r_row_stride
- if GEMMA:
- dX += row_idx * dY_row_stride
- else:
- dX = dY
+ if GEMMA: dX += row_idx * dY_row_stride
+ else: dX = dY
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
- X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
- W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
# Get saved row variance
inv_var = tl.load(r).to(tl.float32)
normed = X_row * inv_var
- if GEMMA:
- dY_W = dY_row * (W_row + 1.0)
- else:
- dY_W = dY_row * W_row
+ if GEMMA: dY_W = dY_row * (W_row + 1.0)
+ else: dY_W = dY_row * W_row
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
- output = inv_var / n_cols * (n_cols * dY_W - normed * rowsum_dY_normed)
+ output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
tl.store(dX + col_offsets, output, mask = mask)
-
-
+pass
_rms_layernorm_backward = triton.jit(_rms_layernorm_backward)
_rms_layernorm_backward = triton.heuristics(
{
@@ -122,17 +106,13 @@ def _rms_layernorm_backward(
@triton.jit
def _gemma_rms_layernorm_forward(
- Y,
- Y_row_stride: tl.constexpr,
- X,
- X_row_stride: tl.constexpr,
- W,
- W_row_stride: tl.constexpr,
- r,
- r_row_stride: tl.constexpr,
- n_cols: tl.constexpr,
- eps: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
+ Y, Y_row_stride,
+ X, X_row_stride,
+ W, W_row_stride,
+ r, r_row_stride : tl.constexpr,
+ n_cols : tl.constexpr,
+ eps : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
):
# Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
# and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
@@ -149,27 +129,26 @@ def _gemma_rms_layernorm_forward(
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
- # Explicit float32 scalar to ensure correct type promotion on HIP/ROCm
- eps_f32 = tl.full((), eps, tl.float32)
- inv_var = tl.math.rsqrt(row_var + eps_f32)
+ inv_var = tl.math.rsqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
output = normed * (W_row + 1.0)
tl.store(Y + col_offsets, output, mask = mask)
+pass
class Fast_RMS_Layernorm(torch.autograd.Function):
@staticmethod
- def forward(ctx, X: torch.Tensor, W: torch.Tensor, eps: float, gemma: bool = False):
+ def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool = False):
shape = X.shape
- dim: int = shape[-1]
- X = X.reshape(-1, dim)
- n_rows: int
- n_cols: int
+ dim : int = shape[-1]
+ X = X.view(-1, dim)
+ n_rows : int
+ n_cols : int
n_rows, n_cols = X.shape
- BLOCK_SIZE: int
- num_warps: int
+ BLOCK_SIZE : int
+ num_warps : int
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
device = X.device
@@ -177,141 +156,121 @@ def forward(ctx, X: torch.Tensor, W: torch.Tensor, eps: float, gemma: bool = Fal
r = torch.empty(n_rows, dtype = torch.float32, device = device)
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
- with torch_gpu_device(device):
+ with torch_cuda_device(device):
fx[(n_rows,)](
- Y,
- Y.stride(0),
- X,
- X.stride(0),
- W,
- W.stride(0),
- r,
- r.stride(0),
- n_cols,
- eps,
+ Y, Y.stride(0),
+ X, X.stride(0),
+ W, W.stride(0),
+ r, r.stride(0),
+ n_cols, eps,
BLOCK_SIZE = BLOCK_SIZE,
- num_warps = num_warps,
+ num_warps = num_warps,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
- ctx.num_warps = num_warps
+ ctx.num_warps = num_warps
ctx.GEMMA = gemma
ctx.save_for_backward(X, W, r)
return Y.view(*shape)
+ pass
@staticmethod
- def backward(ctx, dY: torch.Tensor):
+ def backward(ctx, dY : torch.Tensor):
shape = dY.shape
- dim: int = shape[-1]
- dY = dY.reshape(-1, dim)
+ dim : int = shape[-1]
+ dY = dY.view(-1, dim)
X, W, r = ctx.saved_tensors
- n_rows: int
- n_cols: int
+ n_rows : int
+ n_cols : int
n_rows, n_cols = dY.shape
# dW = X
dX = torch.empty_like(dY) if ctx.GEMMA else dY
- with torch_gpu_device(dY.device):
+ with torch_cuda_device(dY.device):
_rms_layernorm_backward[(n_rows,)](
- dY,
- dY.stride(0),
- dX,
- dX.stride(0),
- X,
- X.stride(0),
- W,
- W.stride(0),
- r,
- r.stride(0),
+ dY, dY.stride(0),
+ dX, dX.stride(0),
+ X, X .stride(0),
+ W, W .stride(0),
+ r, r .stride(0),
# dW, dW.stride(0),
- n_cols,
- ctx.eps,
- GEMMA = ctx.GEMMA,
+ n_cols, ctx.eps,
+ GEMMA = ctx.GEMMA,
BLOCK_SIZE = ctx.BLOCK_SIZE,
- num_warps = ctx.num_warps,
+ num_warps = ctx.num_warps,
)
dX = dX.view(*shape)
return dX, None, None, None
+ pass
+pass
# [TODO] Unsure why RMS Layernorm is not torch.compiling properly
@torch.compiler.disable
-def fast_rms_layernorm(layernorm, X: torch.Tensor, gemma: bool = False):
- W: torch.Tensor = layernorm.weight
- eps: float = (
- layernorm.variance_epsilon
- if hasattr(layernorm, "variance_epsilon")
+def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False):
+ W : torch.Tensor = layernorm.weight
+ eps : float = layernorm.variance_epsilon if \
+ hasattr(layernorm, "variance_epsilon") \
else layernorm.eps
- )
out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
return out
+pass
from transformers.models.llama.modeling_llama import LlamaRMSNorm
-
-
class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
def forward(self, X):
return fast_rms_layernorm(self, X, gemma = False)
-
+ pass
+pass
try:
from transformers.models.mllama.modeling_mllama import MllamaTextRMSNorm
-
class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
def forward(self, X):
return fast_rms_layernorm(self, X, gemma = False)
-
-
+ pass
+ pass
except:
pass
-
+pass
def patch_rms_layernorm():
import transformers.models.llama.modeling_llama
-
transformers.models.llama.modeling_llama.LlamaRMSNorm = Unsloth_LlamaRMSNorm
try:
import transformers.models.mllama.modeling_mllama
-
- transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = (
- Unsloth_MllamaTextRMSNorm
- )
+ transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = Unsloth_MllamaTextRMSNorm
except:
pass
return
+pass
def unpatch_rms_layernorm():
import transformers.models.llama.modeling_llama
-
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
try:
import transformers.models.mllama.modeling_mllama
-
transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = MllamaTextRMSNorm
except:
pass
return
+pass
def test_rms_layernorm(
- dim = 1024,
- eps = 1e-5,
- dtype = torch.float16,
- bsz = 21,
- random_state = 3407,
- seqlen = 3341,
+ dim = 1024, eps = 1e-5, dtype = torch.float16,
+ bsz = 21, random_state = 3407, seqlen = 3341,
):
from transformers.models.llama.modeling_llama import LlamaRMSNorm
-
layernorm = LlamaRMSNorm((dim,), eps = eps).to("cuda")
torch.cuda.manual_seed(random_state)
torch.manual_seed(random_state)
torch.nn.init.uniform_(layernorm.weight)
X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
XX = X.clone()
- X.requires_grad_(True)
+ X .requires_grad_(True)
XX.requires_grad_(True)
Y = layernorm(X)
YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
@@ -320,7 +279,8 @@ def test_rms_layernorm(
# from unsloth.kernels import fast_rms_layernorm
Y = fast_rms_layernorm(layernorm, XX)
Y.backward(YY)
- assert torch.amax(correct_grad - XX.grad).item() <= 0.05
+ assert(torch.amax(correct_grad - XX.grad).item() <= 0.05)
+pass
def testing_suite_layernorm():
@@ -337,3 +297,9 @@ def testing_suite_layernorm():
random_state = random_state,
seqlen = seqlen,
)
+ pass
+ pass
+ pass
+ pass
+ pass
+pass
diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py
index fcc9cb923b..a14a485352 100644
--- a/unsloth/kernels/rope_embedding.py
+++ b/unsloth/kernels/rope_embedding.py
@@ -1,159 +1,54 @@
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
+# http://www.apache.org/licenses/LICENSE-2.0
#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import triton
import triton.language as tl
import torch
-from ..device_type import DEVICE_COUNT
-from .utils import calculate_settings, torch_gpu_device, torch_device_stream
-
-
-def _rope_embedding_QK(
- Q,
- Q_batch_stride,
- Q_head_stride,
- Q_seq_stride,
- K,
- K_batch_stride,
- K_head_stride,
- K_seq_stride,
- cos,
- cos_row_stride,
- sin,
- sin_row_stride,
- rope_embedding_indices,
- seqlen,
- head_dim: tl.constexpr,
- n_heads_K: tl.constexpr,
- BACKWARD_PASS: tl.constexpr,
- HAS_ROPE_INDICES: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
-):
- row_position = tl.program_id(0)
- head_position = tl.program_id(1)
- col_offsets = tl.arange(0, BLOCK_SIZE)
- half_head_dim = head_dim // 2
- mask = col_offsets < half_head_dim
-
- if HAS_ROPE_INDICES:
- rot_position = tl.load(
- rope_embedding_indices + row_position,
- eviction_policy = "evict_first",
- ).to(tl.int32)
- else:
- rot_position = row_position % seqlen
-
- cos_ptr = cos + rot_position * cos_row_stride
- sin_ptr = sin + rot_position * sin_row_stride
- sin1 = tl.load(
- sin_ptr + col_offsets,
- mask = mask,
- other = 0,
- )
- cos1 = tl.load(
- cos_ptr + col_offsets,
- mask = mask,
- other = 0,
- )
- if BACKWARD_PASS:
- sin1 = -sin1
-
- batch_id = row_position // seqlen
- seq_index = row_position - batch_id * seqlen
-
- q_ptr = (
- Q
- + batch_id * Q_batch_stride
- + head_position * Q_head_stride
- + seq_index * Q_seq_stride
- )
- q0 = tl.load(q_ptr + col_offsets, mask = mask, other = 0)
- q1 = tl.load(q_ptr + half_head_dim + col_offsets, mask = mask, other = 0)
- tl.store(q_ptr + col_offsets, q0 * cos1 - q1 * sin1, mask = mask)
- tl.store(q_ptr + half_head_dim + col_offsets, q1 * cos1 + q0 * sin1, mask = mask)
-
- if head_position < n_heads_K:
- k_ptr = (
- K
- + batch_id * K_batch_stride
- + head_position * K_head_stride
- + seq_index * K_seq_stride
- )
- k0 = tl.load(k_ptr + col_offsets, mask = mask, other = 0)
- k1 = tl.load(k_ptr + half_head_dim + col_offsets, mask = mask, other = 0)
- tl.store(k_ptr + col_offsets, k0 * cos1 - k1 * sin1, mask = mask)
- tl.store(k_ptr + half_head_dim + col_offsets, k1 * cos1 + k0 * sin1, mask = mask)
-
-
-_rope_embedding_QK = triton.jit(_rope_embedding_QK)
-_rope_embedding_QK = triton.heuristics(
- {
- "BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),
- "HAS_ROPE_INDICES": lambda args: bool(args["HAS_ROPE_INDICES"]),
- }
-)(_rope_embedding_QK)
-
-
-ROPE_GROUP_SIZE: int = 4
-
+from .utils import calculate_settings, torch_cuda_device
+ROPE_GROUP_SIZE : int = 4
def _rope_embedding(
- Q,
- Q_row_stride: tl.constexpr,
- cos,
- cos_row_stride: tl.constexpr,
- sin,
- sin_row_stride: tl.constexpr,
+ Q, Q_row_stride,
+ cos, cos_row_stride,
+ sin, sin_row_stride,
seqlen,
- head_dim: tl.constexpr,
- n_heads: tl.constexpr,
- BACKWARD_PASS: tl.constexpr,
- BLOCK_SIZE: tl.constexpr,
+ head_dim : tl.constexpr,
+ n_heads : tl.constexpr,
+ BACKWARD_PASS : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
):
"""
- Calculates the RoPE Embedding quickly
- RoPE is Q * cos + rotate_half(Q) * sin
- See our blog post for more info
+ Calculates the RoPE Embedding quickly
+ RoPE is Q * cos + rotate_half(Q) * sin
+ See our blog post for more info
"""
ROPE_GROUP_SIZE = 4
- row_position = tl.program_id(0)
+ row_position = tl.program_id(0)
group_head_position = tl.program_id(1)
- col_offsets = tl.arange(0, BLOCK_SIZE)
+ col_offsets = tl.arange(0, BLOCK_SIZE)
half_head_dim = head_dim // 2
mask = col_offsets < half_head_dim
- sin1 = tl.load(
- sin
- + (row_position % seqlen) * sin_row_stride
- + half_head_dim * 0
- + col_offsets,
- mask = mask,
- other = 0,
- )
- cos1 = tl.load(
- cos
- + (row_position % seqlen) * cos_row_stride
- + half_head_dim * 0
- + col_offsets,
- mask = mask,
- other = 0,
- )
+ sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
+ half_head_dim*0 + col_offsets, mask = mask, other = 0)
+ cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
+ half_head_dim*0 + col_offsets, mask = mask, other = 0)
if BACKWARD_PASS:
# See our blog post for more info.
sin1 = -sin1
+ pass
# [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
head_start = group_head_position * ROPE_GROUP_SIZE
@@ -162,18 +57,16 @@ def _rope_embedding(
# 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
for k in range(head_start, head_end):
offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
- offs_q2 = (
- row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
- )
+ offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
- tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask = mask)
- tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask = mask)
-
-
+ tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
+ tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
+ pass
+pass
_rope_embedding = triton.jit(_rope_embedding)
_rope_embedding = triton.heuristics(
{
@@ -186,243 +79,84 @@ class Fast_RoPE_Embedding(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, cos, sin):
cos, sin = cos.squeeze(), sin.squeeze()
- batch: int
- seq_len: int
- n_heads: int
- head_dim: int
+ batch : int
+ seq_len : int
+ n_heads : int
+ head_dim : int
batch, seq_len, n_heads, head_dim = Q.shape
- Q = Q.reshape(batch * seq_len, n_heads * head_dim)
- n_rows: int
- n_cols: int
+ Q = Q.view(batch*seq_len, n_heads*head_dim)
+ n_rows : int
+ n_cols : int
n_rows, n_cols = Q.shape
- assert seq_len <= cos.shape[0]
+ assert(seq_len <= cos.shape[0])
# [TODO] Changing blocksize to head_dim//2 seems to have
# some concurrency / un-deterministic issues.
- BLOCK_SIZE, num_warps = calculate_settings(head_dim // 2) # (head_dim//2)
-
+ BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
+
# group_size = 4 # 4 or 8, too large group_size can hurt performance.
- div: int
- mod: int
+ div : int
+ mod : int
div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
- n_groups: int = div + (mod != 0)
+ n_groups : int = div + (mod != 0)
- with torch_gpu_device(Q.device):
- _rope_embedding[
- (
- n_rows,
- n_groups,
- )
- ](
- Q,
- Q.stride(0),
- cos,
- cos.stride(0),
- sin,
- sin.stride(0),
+ with torch_cuda_device(Q.device):
+ _rope_embedding[(n_rows, n_groups, )](
+ Q, Q.stride(0),
+ cos, cos.stride(0),
+ sin, sin.stride(0),
seq_len,
- head_dim,
- n_heads,
+ head_dim, n_heads,
BACKWARD_PASS = False,
BLOCK_SIZE = BLOCK_SIZE,
- num_warps = num_warps,
+ num_warps = num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
- ctx.num_warps = num_warps
+ ctx.num_warps = num_warps
ctx.n_groups = n_groups
ctx.cos = cos
ctx.sin = sin
- return Q.reshape(batch, seq_len, n_heads, head_dim)
+ return Q.view(batch, seq_len, n_heads, head_dim)
+ pass
@staticmethod
def backward(ctx, dY):
- batch: int
- seq_len: int
- n_heads: int
- head_dim: int
+ batch : int
+ seq_len : int
+ n_heads : int
+ head_dim : int
batch, seq_len, n_heads, head_dim = dY.shape
- dY = dY.reshape(batch * seq_len, n_heads * head_dim)
- n_rows: int
- n_cols: int
+ dY = dY.reshape(batch*seq_len, n_heads*head_dim)
+ # Must be reshape not view
+ n_rows : int
+ n_cols : int
n_rows, n_cols = dY.shape
cos = ctx.cos
sin = ctx.sin
- with torch_gpu_device(dY.device):
- _rope_embedding[
- (
- n_rows,
- ctx.n_groups,
- )
- ](
- dY,
- dY.stride(0),
- cos,
- cos.stride(0),
- sin,
- sin.stride(0),
- seq_len,
- head_dim,
- n_heads,
+ with torch_cuda_device(dY.device):
+ _rope_embedding[(n_rows, ctx.n_groups, )](
+ dY, dY .stride(0),
+ cos, cos.stride(0),
+ sin, sin.stride(0),
+ seq_len, head_dim, n_heads,
BACKWARD_PASS = True,
BLOCK_SIZE = ctx.BLOCK_SIZE,
- num_warps = ctx.num_warps,
+ num_warps = ctx.num_warps,
)
- dY = dY.reshape(batch, seq_len, n_heads, head_dim)
- return (
- dY,
- None,
- None,
- )
-
+ dY = dY.view(batch, seq_len, n_heads, head_dim)
+ return dY, None, None,
+ pass
+pass
# [TODO] Unsure why RoPE Embedding is not torch.compiling properly
@torch.compiler.disable
-def fast_rope_embedding(
- Q,
- K,
- cos,
- sin,
- rope_embedding_indices = None,
-):
- if rope_embedding_indices is not None:
- Q_out, K_out = Fast_RoPE_Embedding_QK.apply(
- Q, K, cos, sin, rope_embedding_indices
- )
- else:
- Q_out = Fast_RoPE_Embedding.apply(
- Q.transpose(1, 2).contiguous(), cos, sin
- ).transpose(1, 2)
- K_out = Fast_RoPE_Embedding.apply(
- K.transpose(1, 2).contiguous(), cos, sin
- ).transpose(1, 2)
- if DEVICE_COUNT > 1:
- torch_device_stream(Q.device).synchronize()
- return Q_out, K_out
-
-
-class Fast_RoPE_Embedding_QK(torch.autograd.Function):
- @staticmethod
- def forward(ctx, Q, K, cos, sin, rope_indices):
- has_indices = rope_indices is not None
- cos, sin = cos.squeeze(), sin.squeeze()
-
- batch, n_heads_Q, seq_len, head_dim = Q.shape
- _, n_heads_K, _, _ = K.shape
-
- # Inplace rotary embedding is generally fine
- Q_out = Q.clone() if not Q.is_contiguous() else Q
- K_out = K.clone() if not K.is_contiguous() else K
-
- if has_indices:
- # TRL's rotary indices are always in int32, so casting is just for safety
- rope_ptr = rope_indices.reshape(-1).to(dtype = torch.int32, device = Q.device)
- else:
- rope_ptr = cos.new_empty(1, dtype = torch.int32)
-
- BLOCK_SIZE, num_warps = calculate_settings(head_dim)
-
- Q_batch_stride, Q_head_stride, Q_seq_stride = (
- Q_out.stride(0),
- Q_out.stride(1),
- Q_out.stride(2),
- )
- K_batch_stride, K_head_stride, K_seq_stride = (
- K_out.stride(0),
- K_out.stride(1),
- K_out.stride(2),
- )
-
- with torch_gpu_device(Q.device):
- _rope_embedding_QK[(batch * seq_len, n_heads_Q)](
- Q_out,
- Q_batch_stride,
- Q_head_stride,
- Q_seq_stride,
- K_out,
- K_batch_stride,
- K_head_stride,
- K_seq_stride,
- cos,
- cos.stride(0),
- sin,
- sin.stride(0),
- rope_ptr,
- seq_len,
- head_dim = head_dim,
- n_heads_K = n_heads_K,
- BACKWARD_PASS = False,
- HAS_ROPE_INDICES = has_indices,
- BLOCK_SIZE = BLOCK_SIZE,
- num_warps = num_warps,
- )
-
- ctx.block_size = BLOCK_SIZE
- ctx.num_warps = num_warps
- ctx.has_indices = has_indices
- ctx.cos = cos
- ctx.sin = sin
- ctx.rope_indices = rope_ptr if has_indices else None
- ctx.seq_len = seq_len
- ctx.n_heads_Q = n_heads_Q
- ctx.n_heads_K = n_heads_K
-
- return (
- Q_out,
- K_out,
- )
-
- @staticmethod
- def backward(ctx, dQ, dK):
- batch, _, _, head_dim = dQ.shape
-
- rope_ptr = (
- ctx.rope_indices
- if ctx.has_indices
- else ctx.cos.new_empty(1, dtype = torch.int32)
- )
-
- # Inplace rotary embedding is generally fine
- dQ_out = dQ.clone() if not dQ.is_contiguous() else dQ
- dK_out = dK.clone() if not dK.is_contiguous() else dK
-
- Q_batch_stride, Q_head_stride, Q_seq_stride = (
- dQ_out.stride(0),
- dQ_out.stride(1),
- dQ_out.stride(2),
- )
- K_batch_stride, K_head_stride, K_seq_stride = (
- dK_out.stride(0),
- dK_out.stride(1),
- dK_out.stride(2),
- )
-
- with torch_gpu_device(dQ.device):
- _rope_embedding_QK[(batch * ctx.seq_len, ctx.n_heads_Q)](
- dQ_out,
- Q_batch_stride,
- Q_head_stride,
- Q_seq_stride,
- dK_out,
- K_batch_stride,
- K_head_stride,
- K_seq_stride,
- ctx.cos,
- ctx.cos.stride(0),
- ctx.sin,
- ctx.sin.stride(0),
- rope_ptr,
- ctx.seq_len,
- head_dim = head_dim,
- n_heads_K = ctx.n_heads_K,
- BACKWARD_PASS = True,
- HAS_ROPE_INDICES = ctx.has_indices,
- BLOCK_SIZE = ctx.block_size,
- num_warps = ctx.num_warps,
- )
-
- return (dQ_out, dK_out, None, None, None)
+def fast_rope_embedding(Q, K, cos, sin):
+ Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
+ K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
+ return Q, K
+pass
class Slow_RoPE_Embedding(torch.autograd.Function):
@@ -432,11 +166,11 @@ def forward(ctx, Q, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
- cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
- sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
# Q * cos + rotate_half(Q) * sin
- half = Q.shape[-1] // 2
+ half = Q.shape[-1]//2
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
Q *= cos
Q.addcmul_(RH_Q, sin)
@@ -444,22 +178,25 @@ def forward(ctx, Q, cos, sin, position_ids):
# Q += RH_Q
ctx.save_for_backward(cos, sin)
return Q
+ pass
@staticmethod
def backward(ctx, dY):
cos, sin = ctx.saved_tensors
# Q * cos + rotate_half.T(Q) * sin
- half = dY.shape[-1] // 2
+ half = dY.shape[-1]//2
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
dY *= cos
dY.addcmul_(RH_dY, sin)
# RH_dY *= sin
# dY += RH_dY
return dY, None, None, None
+ pass
+pass
def inplace_rope_embedding(Q, K, cos, sin, position_ids):
Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
- torch_device_stream(Q.device).synchronize()
return Q, K
+pass
diff --git a/unsloth/kernels/swiglu.py b/unsloth/kernels/swiglu.py
index b3ae9d40e6..12f1f5e063 100644
--- a/unsloth/kernels/swiglu.py
+++ b/unsloth/kernels/swiglu.py
@@ -15,73 +15,42 @@
import triton
import triton.language as tl
import torch
-from .utils import calculate_settings, torch_gpu_device
-
-# signed int32 max is 2**31-1 so num_elements cannot exceed 2**31
-NUM_INT32_ELEMENTS = 2**31
-SAFE_INT32_BUFFER_MULTIPLIER = 4
-BLOCK_SIZE = 1024
-INT32_SAFETY_BUFFER = NUM_INT32_ELEMENTS - BLOCK_SIZE * SAFE_INT32_BUFFER_MULTIPLIER
+from .utils import calculate_settings, torch_cuda_device
@triton.jit
-def _fg_kernel(
- e,
- g,
- h,
- n_elements,
- BLOCK_SIZE: tl.constexpr,
- LONG_INDEXING: tl.constexpr,
-):
+def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
block_idx = tl.program_id(0)
- if LONG_INDEXING:
- offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
- tl.int64
- )
- n_elements = tl.cast(n_elements, tl.int64)
- else:
- offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
- g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# f = e * sigmoid(e)
- f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
- f_row = f_row.to(g_row.dtype) # Exact copy from HF
+ f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
# h = f * g
h_row = f_row * g_row
# Store h
tl.store(h + offsets, h_row, mask = mask)
+pass
def swiglu_fg_kernel(e, g):
batch, seq_len, hd = e.shape
n_elements = e.numel()
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- with torch_gpu_device(e.device):
- _fg_kernel[grid](
- e,
- g,
- h,
- n_elements,
- BLOCK_SIZE = BLOCK_SIZE,
- LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
- )
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ with torch_cuda_device(e.device):
+ _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
return h
+pass
@triton.jit
-def _DWf_DW_dfg_kernel(
- DW,
- e,
- g,
- n_elements,
- BLOCK_SIZE: tl.constexpr,
- LONG_INDEXING: tl.constexpr,
-):
+def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
"""
e = e.float()
se = 1.0 / (1.0 + torch.exp(-e))
@@ -92,27 +61,21 @@ def _DWf_DW_dfg_kernel(
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
"""
block_idx = tl.program_id(0)
- if LONG_INDEXING:
- offsets = block_idx.to(tl.int64) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE).to(
- tl.int64
- )
- n_elements = tl.cast(n_elements, tl.int64)
- else:
- offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
- DW_row = tl.load(DW + offsets, mask = mask, other = 0) # .to(tl.float32)
- e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
- g_row = tl.load(g + offsets, mask = mask, other = 0) # .to(tl.float32)
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
# e = e.float()
# se = 1.0 / (1.0 + torch.exp(-e))
- se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
+ se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
# f = (se * e).to(dtype)
f_row = se_row * e_row
f_row = f_row.to(DW_row.dtype)
# h = f * g
- h_row = f_row * g_row
+ h_row = f_row * g_row
# df = DW * f
df_row = DW_row * f_row
# dg = DW * g
@@ -122,22 +85,17 @@ def _DWf_DW_dfg_kernel(
de_row = de_row.to(DW_row.dtype)
# Store derivatives in buffers
- tl.store(DW + offsets, h_row, mask = mask) # h = f * g
- tl.store(e + offsets, df_row, mask = mask) # df = DW * f
- tl.store(g + offsets, de_row, mask = mask) # de
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
+ tl.store(g + offsets, de_row, mask = mask) # de
+pass
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
- batch_seq_len, hd = e.shape # Flattened to 2D, so 1st dim is bsz * seq_len
+ batch_seq_len, hd = e.shape
n_elements = e.numel()
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- with torch_gpu_device(e.device):
- _DWf_DW_dfg_kernel[grid](
- DW,
- e,
- g,
- n_elements,
- BLOCK_SIZE = BLOCK_SIZE,
- LONG_INDEXING = 0 if n_elements <= INT32_SAFETY_BUFFER else 1,
- )
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ with torch_cuda_device(e.device):
+ _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
return DW, e, g
+pass
diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py
index 90f2d5d238..db1d73c340 100644
--- a/unsloth/kernels/utils.py
+++ b/unsloth/kernels/utils.py
@@ -12,482 +12,169 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import importlib
import triton
-import ctypes
-
-MAX_FUSED_SIZE: int = 65536
+MAX_FUSED_SIZE : int = 65536
next_power_of_2 = triton.next_power_of_2
import functools
-from typing import Optional
-
-from ..device_type import (
- is_hip,
- get_device_type,
- DEVICE_TYPE,
- DEVICE_TYPE_TORCH,
- DEVICE_COUNT,
- ALLOW_PREQUANTIZED_MODELS,
-)
-from .fp8 import weight_dequant, fp8_linear
-import functools
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
import torch
-
torch_Tensor = torch.Tensor
-from unsloth_zoo.utils import Version
-
-if DEVICE_TYPE == "xpu" and Version(torch.__version__) < Version("2.6.0"):
- raise RuntimeError(
- "Intel xpu currently supports unsloth with torch.version >= 2.6.0"
- )
-
+from packaging.version import Version
if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
-
-if DEVICE_TYPE == "xpu":
- torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
- torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
+pass
# tl.math.tanh now is libdevice.tanh
+from packaging.version import Version
import triton
import triton.language as tl
-
if Version(triton.__version__) >= Version("3.0.0"):
- if DEVICE_TYPE == "xpu":
- triton_tanh = tl.extra.intel.libdevice.tanh
- else:
- from triton.language.extra import libdevice
-
- triton_tanh = libdevice.tanh
+ from triton.language.extra import libdevice
+ triton_tanh = libdevice.tanh
triton_cast = tl.cast
else:
triton_tanh = tl.math.tanh
-
# No casting in old Triton versions
@triton.jit
def triton_cast(x, dtype):
return x.to(dtype)
+ pass
+pass
-@functools.lru_cache(1)
-def is_cdna():
- return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
- "gfx940",
- "gfx941",
- "gfx942",
- "gfx950", # CDNA4 (MI350/MI355X)
- )
-
-
-@functools.lru_cache(1)
-def is_rdna():
- """Detect ROCm-supported RDNA consumer/workstation GPUs (RDNA3, RDNA4)."""
- return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
- "gfx1100",
- "gfx1101",
- "gfx1200",
- "gfx1201",
- )
-
-
-def calculate_settings(
- n: int,
-) -> (
- int,
- int,
-):
- BLOCK_SIZE: int = next_power_of_2(n)
+def calculate_settings(n : int) -> (int, int,):
+ BLOCK_SIZE : int = next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
- raise RuntimeError(
- f"Cannot launch Triton kernel since n = {n} exceeds "
- f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}."
- )
- num_warps: int = 4
- if BLOCK_SIZE >= 32768:
- num_warps = 32
- elif BLOCK_SIZE >= 8192:
- num_warps = 16
- elif BLOCK_SIZE >= 2048:
- num_warps = 8
+ raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
+ f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
+ num_warps : int = 4
+ if BLOCK_SIZE >= 32768: num_warps = 32
+ elif BLOCK_SIZE >= 8192: num_warps = 16
+ elif BLOCK_SIZE >= 2048: num_warps = 8
return BLOCK_SIZE, num_warps
+pass
-HAS_CUDA_STREAM = False
import bitsandbytes as bnb
+import ctypes
# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
get_ptr = bnb.functional.get_ptr
-if DEVICE_TYPE == "xpu":
- HAS_XPU_STREAM = True
-
-if DEVICE_COUNT > 1:
- if DEVICE_TYPE in ("cuda", "hip"):
- torch_gpu_device = torch.cuda.device
- elif DEVICE_TYPE == "xpu":
- torch_gpu_device = torch.xpu.device
+if torch.cuda.device_count() > 1:
+ torch_cuda_device = torch.cuda.device
else:
from contextlib import nullcontext
-
- def torch_gpu_device(device):
- return nullcontext()
-
-
-# INTEL GPU Specific Logic
-if DEVICE_TYPE == "xpu":
- _gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
-# NVIDIA GPU Default Logic
-else:
- _gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream
-
+ def torch_cuda_device(device): return nullcontext()
+pass
+_cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream
c_void_p = ctypes.c_void_p
-
-
def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
- return c_void_p(_gpu_getCurrentRawStream(tensor.device.index))
-
+ return c_void_p(_cuda_getCurrentRawStream(tensor.device.index))
+pass
# Get array of CUDA streams and other buffers
global CUDA_STREAMS
-global XPU_STREAMS
global WEIGHT_BUFFERS
global ABSMAX_BUFFERS
-# INTEL GPU Specific Logic
-if DEVICE_TYPE == "xpu":
- _XPU_STREAMS = {
- (index := torch.xpu.device(i).idx): ctypes.c_void_p(
- torch._C._xpu_getCurrentRawStream(index)
- )
- for i in range(DEVICE_COUNT)
- }
- XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1)
- WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
- ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1)
- for k, v in _XPU_STREAMS.items():
- XPU_STREAMS[k] = v
- XPU_STREAMS = tuple(XPU_STREAMS)
- del _XPU_STREAMS
-else:
- # NVIDIA GPU Default Logic
- _CUDA_STREAMS = {
- (index := torch.cuda.device(i).idx): ctypes.c_void_p(
- torch._C._cuda_getCurrentRawStream(index)
- )
- for i in range(DEVICE_COUNT)
- }
- CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
- WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
- ABSMAX_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
- for k, v in _CUDA_STREAMS.items():
- CUDA_STREAMS[k] = v
- CUDA_STREAMS = tuple(CUDA_STREAMS)
- del _CUDA_STREAMS
+_CUDA_STREAMS = {
+ (index := torch.cuda.device(i).idx) : ctypes.c_void_p(torch._C._cuda_getCurrentRawStream(index))
+ for i in range(torch.cuda.device_count())
+}
+CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
+WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
+ABSMAX_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
+for k, v in _CUDA_STREAMS.items(): CUDA_STREAMS[k] = v
+CUDA_STREAMS = tuple(CUDA_STREAMS)
+del _CUDA_STREAMS
# Bitsandbytes operations
-ctypes_c_int = ctypes.c_int
+ctypes_c_int = ctypes.c_int
ctypes_c_int32 = ctypes.c_int32
-cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
-cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
-cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
-
-if DEVICE_TYPE == "xpu":
- # https://github.com/bitsandbytes-foundation/bitsandbytes/blob/c3b8de268fdb55a88f92feada23fc811a1e6877a/bitsandbytes/backends/xpu/ops.py#L115
- # for xpu, inference gemv using above link
- cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemv_4bit_inference_fp16
- cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemv_4bit_inference_bf16
-else:
- cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
- cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
-
-
-torch_device_stream = (
- torch.xpu.current_stream if DEVICE_TYPE == "xpu" else torch.cuda.current_stream
-)
-
+cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
+cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
+cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
+cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
+cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
torch_mm = torch.mm
torch_mv = torch.mv
torch_matmul = torch.matmul
-torch_addmm = torch.addmm
-torch_empty = torch.empty
-torch_float32 = torch.float32
-torch_float16 = torch.float16
-torch_bfloat16 = torch.bfloat16
-
-
-# Check whether torchao can be imported to get Float8Tensor
-if importlib.util.find_spec("torchao") is not None:
- try:
- from torchao.quantization import Float8Tensor
- except:
- import torchao
-
- if Version(torchao.__version__) >= Version("0.15.0"):
- print(
- f"Unsloth: `from torchao.quantization import Float8Tensor` failed on version={torchao.__version__}"
- )
- Float8Tensor = type(None)
-else:
- Float8Tensor = type(None)
-
-
-def QUANT_STATE(W):
- return getattr(W, "quant_state", None)
+torch_addmm = torch.addmm
+torch_empty = torch.empty
+def QUANT_STATE(W): return getattr(W, "quant_state", None)
def get_lora_parameters(proj):
- """
- Return a 5-tuple of (weight, weight quant_state, lora A, lora B, and lora scale).
- If QAT is enabled, additionally fake quantize the base layer and lora weights.
- """
# For DPO or disabled adapters
- base_layer = getattr(
- proj, "base_layer", proj
- ) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
+ base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
- # Optionally apply fake quantization to base layer weights for QAT
- if hasattr(base_layer, "weight_fake_quantizer"):
- weight_fake_quantizer = getattr(base_layer, "weight_fake_quantizer", None)
- if weight_fake_quantizer is not None:
- W = weight_fake_quantizer(W)
-
- # Get quant state for 4bit or FP8
- W_quant = getattr(W, "quant_state", None)
- if W_quant is None:
- W_quant = getattr(base_layer, "weight_scale_inv", None)
- if W_quant is None:
- W_quant = getattr(base_layer, "weight_scale", None)
-
- if getattr(base_layer, "quant_method", None) == "fp8":
- # we need to somehow store and pass this information :)
- W.block_size = getattr(base_layer, "block_size", [128, 128])
- W_quant.block_size = W.block_size
-
# if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
if getattr(proj, "disable_adapters", True) or proj.merged:
- return W, W_quant, None, None, None
+ return W, getattr(W, "quant_state", None), None, None, None
+ pass
adapter = getattr(proj, "active_adapters", None)
- if adapter is None:
- adapter = getattr(proj, "active_adapter", ("default"))
+ if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
adapter = adapter[0]
-
- # Optionally apply fake quantization to lora weights for QAT
- lora_A_linear = proj.lora_A[adapter]
- lora_B_linear = proj.lora_B[adapter]
- A = lora_A_linear.weight
- B = lora_B_linear.weight
- if hasattr(lora_A_linear, "weight_fake_quantizer"):
- lora_A_fake_quantizer = getattr(lora_A_linear, "weight_fake_quantizer", None)
- if lora_A_fake_quantizer is not None:
- A = lora_A_fake_quantizer(A)
- if hasattr(lora_B_linear, "weight_fake_quantizer"):
- lora_B_fake_quantizer = getattr(lora_B_linear, "weight_fake_quantizer", None)
- if lora_B_fake_quantizer is not None:
- B = lora_B_fake_quantizer(B)
-
+
return (
W,
- W_quant,
- A,
- B,
+ getattr(W, "quant_state", None),
+ proj.lora_A [adapter].weight,
+ proj.lora_B [adapter].weight,
proj.scaling[adapter],
)
+pass
def get_lora_parameters_bias(proj):
# For DPO or disabled adapters
- base_layer = getattr(
- proj, "base_layer", proj
- ) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
+ base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
- # Get quant state for 4bit or FP8
- W_quant = getattr(W, "quant_state", None)
- if W_quant is None:
- W_quant = getattr(base_layer, "weight_scale_inv", None)
- if W_quant is None:
- W_quant = getattr(base_layer, "weight_scale", None)
-
# if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
if getattr(proj, "disable_adapters", True) or proj.merged:
- return W, W_quant, None, None, None, base_layer.bias
-
- if getattr(base_layer, "quant_method", None) == "fp8":
- # we need to somehow store and pass this information :)
- W.block_size = getattr(base_layer, "block_size", [128, 128])
- W_quant.block_size = W.block_size
+ return W, getattr(W, "quant_state", None), None, None, None, base_layer.bias
+ pass
adapter = getattr(proj, "active_adapters", None)
- if adapter is None:
- adapter = getattr(proj, "active_adapter", ("default"))
+ if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
adapter = adapter[0]
return (
W,
- W_quant,
- proj.lora_A[adapter].weight,
- proj.lora_B[adapter].weight,
+ getattr(W, "quant_state", None),
+ proj.lora_A [adapter].weight,
+ proj.lora_B [adapter].weight,
proj.scaling[adapter],
base_layer.bias,
)
+pass
-
-def _maybe_fake_quantize_activations(
- X: torch.Tensor, proj: torch.nn.Module
-) -> torch.Tensor:
- """
- If QAT is enabled, fake quantize the input activations.
- Otherwise, just return the input activations as is.
- Weights are fake quantized separately in `get_lora_parameters`.
- """
- base_layer = getattr(proj, "base_layer", proj)
- activation_fake_quantizer = getattr(base_layer, "activation_fake_quantizer", None)
- if activation_fake_quantizer is not None:
- X = activation_fake_quantizer(X)
- return X
-
-
-# INTEL GPU Specific Logic
-if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
-
- @torch.inference_mode
- def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
- # TODO: After adding XPU BNB support, check this function
- if isinstance(W, Float8Tensor):
- return W.dequantize()
- if quant_state is None:
- return W
- if W.dtype == torch.float8_e4m3fn:
- return weight_dequant(W, quant_state)
- if type(quant_state) is not list:
- # New quant_state as a class
- # https://github.com/TimDettmers/bitsandbytes/pull/763/files
- absmax = quant_state.absmax
- shape = quant_state.shape
- dtype = quant_state.dtype
- blocksize = quant_state.blocksize
- offset = quant_state.offset
- state2 = quant_state.state2
- absmax2 = state2.absmax
- code2 = state2.code
- blocksize2 = state2.blocksize
- else:
- # Old quant_state as a list of lists
- absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
- offset, state2 = compressed_stats
- absmax2, code2, blocksize2, _, _, _, _ = state2
- global XPU_STREAMS
- device = W.device
- device_index = device.index
- XPU_STREAM = XPU_STREAMS[device_index]
-
- n_elements_absmax = absmax.numel()
- # Create weight matrix
- if use_global_buffer:
- # Use same buffers for faster inference
- size = shape[0] * shape[1]
- global WEIGHT_BUFFERS
- global ABSMAX_BUFFERS
- WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
- ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
- if WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype:
- WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
- size, dtype = dtype, device = device, requires_grad = False
- )
- ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
- n_elements_absmax,
- dtype = torch.float32,
- device = device,
- requires_grad = False,
- )
-
- if size > WEIGHT_BUFFER.numel():
- WEIGHT_BUFFER.resize_(size)
- if n_elements_absmax > ABSMAX_BUFFER.numel():
- ABSMAX_BUFFER.resize_(n_elements_absmax)
-
- out = WEIGHT_BUFFER[:size].view(shape)
- out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
- else:
- if out is None:
- out = torch_empty(
- shape, dtype = dtype, device = device, requires_grad = False
- )
- else:
- assert out.shape == shape
- assert out.dtype == dtype
- out_absmax = torch_empty(
- n_elements_absmax,
- dtype = torch_float32,
- device = device,
- requires_grad = False,
- )
-
- # NF4 dequantization of statistics
- ptr_out_absmax = get_ptr(out_absmax)
- with torch_gpu_device(device):
- cdequantize_blockwise_fp32(
- get_ptr(code2),
- get_ptr(absmax),
- get_ptr(absmax2),
- ptr_out_absmax,
- ctypes_c_int(blocksize2),
- ctypes_c_int(n_elements_absmax),
- XPU_STREAM,
- )
- out_absmax += offset
-
- # Dequantize W
- fx = (
- cdequantize_blockwise_fp16_nf4
- if dtype == torch_float16
- else cdequantize_blockwise_bf16_nf4
- )
- fx(
- get_ptr(None),
- get_ptr(W),
- ptr_out_absmax,
- get_ptr(out),
- ctypes_c_int(blocksize),
- ctypes_c_int(out.numel()),
- XPU_STREAM,
- )
- # Careful returning transposed data
- is_transposed = True if W.shape[0] == 1 else False
- return out.t() if is_transposed else out
-
-# NVIDIA GPU Default Logic
-elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
-
+if HAS_CUDA_STREAM:
@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
- if isinstance(W, Float8Tensor):
- return W.dequantize()
- if quant_state is None:
- return W
- if W.dtype == torch.float8_e4m3fn:
- return weight_dequant(W, quant_state)
+ if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
- absmax = quant_state.absmax
- shape = quant_state.shape
- dtype = quant_state.dtype
- blocksize = quant_state.blocksize
- offset = quant_state.offset
- state2 = quant_state.state2
- absmax2 = state2.absmax
- code2 = state2.code
+ absmax = quant_state.absmax
+ shape = quant_state.shape
+ dtype = quant_state.dtype
+ blocksize = quant_state.blocksize
+ offset = quant_state.offset
+ state2 = quant_state.state2
+ absmax2 = state2.absmax
+ code2 = state2.code
blocksize2 = state2.blocksize
else:
# Old quant_state as a list of lists
@@ -504,102 +191,65 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
# Create weight matrix
if use_global_buffer:
+
# Use same buffers for faster inference
- size = shape[0] * shape[1]
+ size = shape[0]*shape[1]
global WEIGHT_BUFFERS
global ABSMAX_BUFFERS
WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
- if WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype:
- WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(
- size, dtype = dtype, device = device, requires_grad = False
- )
- ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(
- n_elements_absmax,
- dtype = torch_float32,
- device = device,
- requires_grad = False,
- )
+ if WEIGHT_BUFFER is None:
+ WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(size, dtype = dtype, device = device, requires_grad = False)
+ ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
- if size > WEIGHT_BUFFER.numel():
- WEIGHT_BUFFER.resize_(size)
- if n_elements_absmax > ABSMAX_BUFFER.numel():
- ABSMAX_BUFFER.resize_(n_elements_absmax)
+ if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
+ if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)
out = WEIGHT_BUFFER[:size].view(shape)
out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
else:
if out is None:
- out = torch_empty(
- shape, dtype = dtype, device = device, requires_grad = False
- )
+ out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
else:
- assert out.shape == shape
- assert out.dtype == dtype
- out_absmax = torch_empty(
- n_elements_absmax,
- dtype = torch_float32,
- device = device,
- requires_grad = False,
- )
+ assert(out.shape == shape)
+ assert(out.dtype == dtype)
+ out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
pass
# NF4 dequantization of statistics
ptr_out_absmax = get_ptr(out_absmax)
- with torch_gpu_device(device):
+ with torch_cuda_device(device):
cdequantize_blockwise_fp32(
- get_ptr(code2),
- get_ptr(absmax),
- get_ptr(absmax2),
- ptr_out_absmax,
- ctypes_c_int(blocksize2),
- ctypes_c_int(n_elements_absmax),
- CUDA_STREAM,
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
+ ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM
)
out_absmax += offset
# Dequantize W
- fx = (
- cdequantize_blockwise_fp16_nf4
- if dtype == torch_float16
- else cdequantize_blockwise_bf16_nf4
- )
- fx(
- get_ptr(None),
- get_ptr(W),
- ptr_out_absmax,
- get_ptr(out),
- ctypes_c_int(blocksize),
- ctypes_c_int(out.numel()),
- CUDA_STREAM,
- )
+ fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
+ cdequantize_blockwise_bf16_nf4
+ fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
+ ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,)
pass
# Careful returning transposed data
- is_transposed = True if W.shape[0] == 1 else False
+ is_transposed = (True if W.shape[0] == 1 else False)
return out.t() if is_transposed else out
-
pass
else:
-
@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
- if isinstance(W, Float8Tensor):
- return W.dequantize()
- if quant_state is None:
- return W
- if W.dtype == torch.float8_e4m3fn:
- return weight_dequant(W, quant_state)
+ if quant_state is None: return W
if type(quant_state) is not list:
# New quant_state as a class
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
- absmax = quant_state.absmax
- shape = quant_state.shape
- dtype = quant_state.dtype
- blocksize = quant_state.blocksize
- offset = quant_state.offset
- state2 = quant_state.state2
- absmax2 = state2.absmax
- code2 = state2.code
+ absmax = quant_state.absmax
+ shape = quant_state.shape
+ dtype = quant_state.dtype
+ blocksize = quant_state.blocksize
+ offset = quant_state.offset
+ state2 = quant_state.state2
+ absmax2 = state2.absmax
+ code2 = state2.code
blocksize2 = state2.blocksize
else:
# Old quant_state as a list of lists
@@ -615,157 +265,33 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
if out is None:
out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
else:
- assert out.shape == shape
- assert out.dtype == dtype
- out_absmax = torch_empty(
- n_elements_absmax, dtype = torch_float32, device = device, requires_grad = False
- )
+ assert(out.shape == shape)
+ assert(out.dtype == dtype)
+ out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
cdequantize_blockwise_fp32(
- get_ptr(code2),
- get_ptr(absmax),
- get_ptr(absmax2),
- ptr_out_absmax,
- ctypes_c_int(blocksize2),
- ctypes_c_int(n_elements_absmax),
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
+ ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax),
)
out_absmax += offset
- fx = (
- cdequantize_blockwise_fp16_nf4
- if dtype == torch_float16
- else cdequantize_blockwise_bf16_nf4
- )
- fx(
- get_ptr(None),
- get_ptr(W),
- ptr_out_absmax,
- get_ptr(out),
- ctypes_c_int(blocksize),
- ctypes_c_int(out.numel()),
- )
+ fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
+ cdequantize_blockwise_bf16_nf4
+ fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
+ ctypes_c_int(blocksize), ctypes_c_int(out.numel()),)
# Careful returning transposed data
- is_transposed = True if W.shape[0] == 1 else False
+ is_transposed = (True if W.shape[0] == 1 else False)
return out.t() if is_transposed else out
-
pass
+pass
-# INTEL GPU Specific Logic
-if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM:
-
- def fast_gemv(X, W, quant_state, out = None):
- if quant_state is None:
- return torch_matmul(X, W, out = out)
- # For fast X @ W where seq_len == 1
- # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
- _, q_len, hd = X.shape
- # assert(q_len == 1)
-
- if type(quant_state) is not list:
- # https://github.com/TimDettmers/bitsandbytes/pull/763/files
- absmax = quant_state.absmax
- shape = quant_state.shape
- dtype = quant_state.dtype
- blocksize = quant_state.blocksize
- stats = quant_state.code
- offset = quant_state.offset
- state2 = quant_state.state2
- absmax2 = state2.absmax
- code2 = state2.code
- blocksize2 = state2.blocksize
- else:
- absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (
- quant_state
- )
- offset, state2 = compressed_stats
- absmax2, code2, blocksize2, _, _, _, _ = state2
- global XPU_STREAMS
- device = W.device
- device_index = device.index
- XPU_STREAM = XPU_STREAMS[device_index]
-
- # assert(dtype == X.dtype)
- bout = shape[0]
-
- if out is None:
- out = torch_empty(
- (
- 1,
- 1,
- bout,
- ),
- dtype = dtype,
- device = device,
- )
- # else:
- # assert(out.shape == (1, 1, bout,))
- # pass
-
- if DEVICE_TYPE == "xpu":
- m = 1
- n = shape[0]
- else:
- n = 1
- m = shape[0]
- k = shape[1]
- lda = shape[0]
- ldc = shape[0]
- ldb = (hd + 1) // 2
- m = ctypes_c_int32(m)
- n = ctypes_c_int32(n)
- k = ctypes_c_int32(k)
- lda = ctypes_c_int32(lda)
- ldb = ctypes_c_int32(ldb)
- ldc = ctypes_c_int32(ldc)
-
- df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
- with torch_gpu_device(device):
- cdequantize_blockwise_fp32(
- get_ptr(code2),
- get_ptr(absmax),
- get_ptr(absmax2),
- get_ptr(df),
- ctypes_c_int(blocksize2),
- ctypes_c_int(df.numel()),
- XPU_STREAM,
- )
- df += offset
- absmax = df
-
- fx = (
- cgemm_4bit_inference_naive_fp16
- if dtype == torch_float16
- else cgemm_4bit_inference_naive_bf16
- )
-
- blocksize = ctypes_c_int32(blocksize)
- fx(
- m,
- n,
- k,
- get_ptr(X),
- get_ptr(W),
- get_ptr(absmax),
- get_ptr(stats),
- get_ptr(out),
- lda,
- ldb,
- ldc,
- blocksize,
- XPU_STREAM,
- )
-
- return out
-
-elif DEVICE_TYPE in ("cuda", "hip") and HAS_CUDA_STREAM:
-
+if HAS_CUDA_STREAM:
def fast_gemv(X, W, quant_state, out = None):
- if quant_state is None:
- return torch_matmul(X, W, out = out)
+ if quant_state is None: return torch_matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
@@ -773,20 +299,18 @@ def fast_gemv(X, W, quant_state, out = None):
if type(quant_state) is not list:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
- absmax = quant_state.absmax
- shape = quant_state.shape
- dtype = quant_state.dtype
- blocksize = quant_state.blocksize
- stats = quant_state.code
- offset = quant_state.offset
- state2 = quant_state.state2
- absmax2 = state2.absmax
- code2 = state2.code
+ absmax = quant_state.absmax
+ shape = quant_state.shape
+ dtype = quant_state.dtype
+ blocksize = quant_state.blocksize
+ stats = quant_state.code
+ offset = quant_state.offset
+ state2 = quant_state.state2
+ absmax2 = state2.absmax
+ code2 = state2.code
blocksize2 = state2.blocksize
else:
- absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (
- quant_state
- )
+ absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
@@ -794,20 +318,12 @@ def fast_gemv(X, W, quant_state, out = None):
device = W.device
device_index = device.index
CUDA_STREAM = CUDA_STREAMS[device_index]
-
+
# assert(dtype == X.dtype)
bout = shape[0]
if out is None:
- out = torch_empty(
- (
- 1,
- 1,
- bout,
- ),
- dtype = dtype,
- device = device,
- )
+ out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
# else:
# assert(out.shape == (1, 1, bout,))
# pass
@@ -817,7 +333,7 @@ def fast_gemv(X, W, quant_state, out = None):
k = shape[1]
lda = shape[0]
ldc = shape[0]
- ldb = (hd + 1) // 2
+ ldb = (hd+1)//2
m = ctypes_c_int32(m)
n = ctypes_c_int32(n)
k = ctypes_c_int32(k)
@@ -825,52 +341,28 @@ def fast_gemv(X, W, quant_state, out = None):
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)
- df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
- with torch_gpu_device(device):
+ df = torch_empty(absmax.shape, dtype = torch.float32, device = device)
+ with torch_cuda_device(device):
cdequantize_blockwise_fp32(
- get_ptr(code2),
- get_ptr(absmax),
- get_ptr(absmax2),
- get_ptr(df),
- ctypes_c_int(blocksize2),
- ctypes_c_int(df.numel()),
- CUDA_STREAM,
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
+ ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM,
)
df += offset
absmax = df
- fx = (
- cgemm_4bit_inference_naive_fp16
- if dtype == torch_float16
- else cgemm_4bit_inference_naive_bf16
- )
+ fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
+ cgemm_4bit_inference_naive_bf16
blocksize = ctypes_c_int32(blocksize)
- fx(
- m,
- n,
- k,
- get_ptr(X),
- get_ptr(W),
- get_ptr(absmax),
- get_ptr(stats),
- get_ptr(out),
- lda,
- ldb,
- ldc,
- blocksize,
- CUDA_STREAM,
- )
+ fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
+ lda, ldb, ldc, blocksize, CUDA_STREAM,)
pass
return out
-
pass
else:
-
def fast_gemv(X, W, quant_state, out = None):
- if quant_state is None:
- return torch_matmul(X, W, out = out)
+ if quant_state is None: return torch.matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
@@ -878,20 +370,18 @@ def fast_gemv(X, W, quant_state, out = None):
if type(quant_state) is not list:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
- absmax = quant_state.absmax
- shape = quant_state.shape
- dtype = quant_state.dtype
- blocksize = quant_state.blocksize
- stats = quant_state.code
- offset = quant_state.offset
- state2 = quant_state.state2
- absmax2 = state2.absmax
- code2 = state2.code
+ absmax = quant_state.absmax
+ shape = quant_state.shape
+ dtype = quant_state.dtype
+ blocksize = quant_state.blocksize
+ stats = quant_state.code
+ offset = quant_state.offset
+ state2 = quant_state.state2
+ absmax2 = state2.absmax
+ code2 = state2.code
blocksize2 = state2.blocksize
else:
- absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = (
- quant_state
- )
+ absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
@@ -900,15 +390,7 @@ def fast_gemv(X, W, quant_state, out = None):
device = W.device
if out is None:
- out = torch_empty(
- (
- 1,
- 1,
- bout,
- ),
- dtype = dtype,
- device = device,
- )
+ out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
# else:
# assert(out.shape == (1, 1, bout,))
# pass
@@ -918,7 +400,7 @@ def fast_gemv(X, W, quant_state, out = None):
k = shape[1]
lda = shape[0]
ldc = shape[0]
- ldb = (hd + 1) // 2
+ ldb = (hd+1)//2
m = ctypes_c_int32(m)
n = ctypes_c_int32(n)
k = ctypes_c_int32(k)
@@ -926,60 +408,40 @@ def fast_gemv(X, W, quant_state, out = None):
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)
- df = torch_empty(absmax.shape, dtype = torch_float32, device = device)
+ df = torch_empty(absmax.shape, dtype = torch.float32, device = device)
cdequantize_blockwise_fp32(
- get_ptr(code2),
- get_ptr(absmax),
- get_ptr(absmax2),
- get_ptr(df),
- ctypes_c_int(blocksize2),
- ctypes_c_int(df.numel()),
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
+ ctypes_c_int(blocksize2), ctypes_c_int(df.numel()),
)
df += offset
absmax = df
- fx = (
- cgemm_4bit_inference_naive_fp16
- if dtype == torch_float16
- else cgemm_4bit_inference_naive_bf16
- )
+ fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
+ cgemm_4bit_inference_naive_bf16
blocksize = ctypes_c_int32(blocksize)
- fx(
- m,
- n,
- k,
- get_ptr(X),
- get_ptr(W),
- get_ptr(absmax),
- get_ptr(stats),
- get_ptr(out),
- lda,
- ldb,
- ldc,
- blocksize,
- )
+ fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
+ lda, ldb, ldc, blocksize,)
return out
-
pass
+pass
def fast_linear_forward(proj, X, temp_lora = None, out = None):
+
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
bsz, q_len, in_dim = X.shape
- if q_len != 1:
- return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
+ if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
if W_quant is None:
out = torch_matmul(X, W.t(), out = out)
- elif W.dtype == torch.float8_e4m3fn:
- out = fp8_linear(X, W, W_quant, bias)
elif bsz == 1 and q_len == 1:
out = fast_gemv(X, W, W_quant, out = out)
else:
W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
out = torch_matmul(X, W, out = out)
+ pass
# Add in LoRA weights
if lora_A is not None:
@@ -989,27 +451,29 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
if not hasattr(lora_A, "_fast_lora"):
lora_A._fast_lora = lora_A.to(dtype)
lora_B._fast_lora = lora_B.to(dtype)
-
+ pass
+
if bsz == 1:
out = out.view(out_dim)
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
else:
out = out.view(bsz, out_dim)
- temp_lora = torch_mm(
- X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora
- )
+ temp_lora = torch_mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
+ pass
out = out.view(bsz, 1, out_dim)
+ pass
- if bias is not None:
- out += bias
+ if bias is not None: out += bias
return out
+pass
def matmul_lora(X, W, W_quant, A, B, s, out = None):
dtype = X.dtype
+ W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
if X.dim() == 3:
batch, seq_len, d = X.shape
@@ -1017,24 +481,9 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
reshape = True
else:
reshape = False
-
- if isinstance(W, Float8Tensor):
- assert W.ndim == 2
- if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
- # In the backward pass, rowwise scaled becomes colwise scaled after we
- # transpose the weight tensor. Use this case to detect backward.
- # TODO: would be simpler if we simply don't call `matmul_lora` in backward
- W = W.dequantize()
- else:
- W = W.contiguous()
- out = torch_matmul(X, W.t(), out = out)
- elif W.dtype == torch.float8_e4m3fn:
- out = fp8_linear(X, W, W_quant)
- else:
- W = fast_dequantize(W, W_quant, use_global_buffer = True)
- out = torch_matmul(X, W.t(), out = out)
- if W_quant is not None:
- del W
+ pass
+ out = torch_matmul(X, W, out = out)
+ if W_quant is not None: del W
if A is not None:
# LoRA is enabled
@@ -1042,5 +491,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
XA = torch_matmul(X, A.to(dtype))
out.addmm_(XA, B.to(dtype), alpha = s)
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
-
+ pass
+
return out.view(batch, seq_len, -1) if reshape else out
+pass
diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py
index 138f309032..317525c793 100644
--- a/unsloth/models/__init__.py
+++ b/unsloth/models/__init__.py
@@ -12,20 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .llama import FastLlamaModel
-from .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel
+from .llama import FastLlamaModel
+from .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel
from .mistral import FastMistralModel
-from .qwen2 import FastQwen2Model
-from .qwen3 import FastQwen3Model
-from .qwen3_moe import FastQwen3MoeModel
+from .qwen2 import FastQwen2Model
from .granite import FastGraniteModel
-from .sentence_transformer import FastSentenceTransformer
-
-try:
- from .falcon_h1 import FastFalconH1Model
-except:
- # transformers_version < 4.53.0 does not have falcon_h1 so silently skip it for now
- pass
-from .dpo import PatchDPOTrainer, PatchKTOTrainer
-from ._utils import is_bfloat16_supported, is_vLLM_available, __version__
-from .rl import PatchFastRL, vLLMSamplingParams
+from .dpo import PatchDPOTrainer, PatchKTOTrainer
+from ._utils import is_bfloat16_supported, __version__
+from .rl import PatchFastRL, vLLMSamplingParams
diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py
index 19b5fe0574..69cc1e6884 100644
--- a/unsloth/models/_utils.py
+++ b/unsloth/models/_utils.py
@@ -12,23 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "2026.3.4"
+__version__ = "2025.3.14"
__all__ = [
"SUPPORTS_BFLOAT16",
"is_bfloat16_supported",
- "is_vLLM_available",
+
"prepare_model_for_kbit_training",
"xformers",
"xformers_attention",
"xformers_version",
"__version__",
- "importlib_version",
"HAS_FLASH_ATTENTION",
"HAS_FLASH_ATTENTION_SOFTCAPPING",
"USE_MODELSCOPE",
"platform_system",
- "resolve_hip_gpu_stats_name",
"patch_tokenizer",
"get_statistics",
"Unsloth_Offloaded_Gradient_Checkpointer",
@@ -50,71 +48,37 @@
"patch_layernorm",
"patch_torch_compile",
"patch_model_and_tokenizer",
+
"patch_unsloth_gradient_checkpointing",
"unpatch_unsloth_gradient_checkpointing",
"patch_gradient_checkpointing",
"unpatch_gradient_checkpointing",
+
"HAS_CUT_CROSS_ENTROPY",
"EMPTY_LOGITS",
"fused_linear_cross_entropy",
- "unsloth_fused_ce_loss",
"patch_unsloth_smart_gradient_checkpointing",
"unpatch_unsloth_smart_gradient_checkpointing",
- "apply_unsloth_gradient_checkpointing",
+
"patch_compiled_autograd",
"process_vision_info",
"unsloth_compile_transformers",
- "prefer_flex_attn_if_supported",
"patch_fast_lora",
- "validate_loftq_config",
- "RaiseUninitialized",
- "fast_inference_setup",
- "patch_peft_fast_inference",
- "error_out_no_vllm",
- "dequantize_module_weight",
- "patch_hf_quantizer",
- "verify_fp8_support_if_applicable",
- "_get_inference_mode_context_manager",
- "hf_login",
- "is_moe_model",
- "get_moe_target_parameters",
- "make_fast_generate_wrapper",
]
import torch
-from typing import Union, Optional, List, Any, Callable, Tuple, Iterator
+from typing import Union, Optional, List, Any, Callable, Tuple
from platform import system as platform_system
-
platform_system = platform_system()
import numpy as np
import contextlib
import re
-from dataclasses import dataclass, field
-import functools
-import textwrap
-import logging
-import warnings, subprocess, inspect, psutil, os, math
-from unsloth_zoo.utils import Version, get_quant_type
-from importlib.metadata import version as importlib_version
-from ..device_type import (
- is_hip,
- get_device_type,
- DEVICE_TYPE,
- DEVICE_TYPE_TORCH,
- DEVICE_COUNT,
- ALLOW_PREQUANTIZED_MODELS,
-)
-from ..import_fixes import UNSLOTH_ENABLE_LOGGING
-from unsloth_zoo.log import logger
+import warnings, subprocess, re, inspect, psutil, os, math
+from unsloth_zoo.utils import Version
+
from unsloth_zoo.tokenizer_utils import (
patch_tokenizer as _patch_tokenizer,
)
-from unsloth_zoo.rl_environments import (
- check_python_modules,
- create_locked_down_function,
- execute_with_time_limit,
- Benchmarker,
-)
from unsloth_zoo.patching_utils import (
patch_compiling_bitsandbytes,
patch_layernorm,
@@ -127,10 +91,12 @@
unsloth_offloaded_gradient_checkpoint,
patch_unsloth_gradient_checkpointing,
unpatch_unsloth_gradient_checkpointing,
+
Unsloth_Gradient_Checkpointer,
unsloth_gradient_checkpoint,
patch_gradient_checkpointing,
unpatch_gradient_checkpointing,
+
patch_unsloth_smart_gradient_checkpointing,
unpatch_unsloth_smart_gradient_checkpointing,
)
@@ -138,7 +104,6 @@
HAS_CUT_CROSS_ENTROPY,
fused_linear_cross_entropy,
_unsloth_get_batch_samples,
- unsloth_fused_ce_loss,
)
from unsloth_zoo.vision_utils import (
process_vision_info,
@@ -150,301 +115,52 @@
from unsloth_zoo.training_utils import (
prepare_model_for_training,
)
-
-
-def resolve_hip_gpu_stats_name(gpu_stats):
- name = str(getattr(gpu_stats, "name", "") or "").strip()
- name = re.sub(r"\s*\([^)]*\)\s*$", "", name).strip()
- normalized_name = name.lower().strip(". ")
- if normalized_name and normalized_name not in ("amd radeon graphics",):
- return name + ". "
-
- try:
- torch_name = str(torch.cuda.get_device_name(0) or "").strip()
- torch_name = re.sub(r"\s*\([^)]*\)\s*$", "", torch_name).strip()
- except Exception:
- torch_name = ""
- normalized_torch_name = torch_name.lower().strip(". ")
- if normalized_torch_name and normalized_torch_name not in ("amd radeon graphics",):
- return torch_name + ". "
-
- arch_name = ""
- for key in ("gcnArchName", "gcn_arch_name", "arch_name", "gfx_arch_name"):
- value = getattr(gpu_stats, key, None)
- if value is not None and str(value).strip():
- arch_name = str(value).strip()
- break
-
- if arch_name:
- arch_name = arch_name.strip()
- match = re.search(r"(gfx[0-9a-z]+)", arch_name, flags = re.I)
- if match:
- return f"AMD {match.group(1).lower()} GPU. "
- return "AMD GPU. "
-
-
from unsloth_zoo.temporary_patches import (
TEMPORARY_PATCHES,
)
-
-
-def apply_unsloth_gradient_checkpointing(
- use_gradient_checkpointing, max_seq_length, dtype
-):
- """
- Apply gradient checkpointing with smart heuristics.
-
- For seq < 512, the overhead of gradient offloading in gc="unsloth" mode
- is not worth it. Benchmarks show standard gc is faster for small sequences.
-
- Args:
- use_gradient_checkpointing: "unsloth", True, False, or None
- max_seq_length: The maximum sequence length
- dtype: The model dtype for patching
-
- Returns:
- The effective use_gradient_checkpointing value (may change from "unsloth" to True)
- """
- if use_gradient_checkpointing == "unsloth":
- # Gradient offloading overhead is not worth it for small sequences.
- # Benchmarks show crossover point is around seq_len 384-512.
- # For seq < 512, standard gradient checkpointing is faster.
- if max_seq_length < 512:
- unpatch_unsloth_smart_gradient_checkpointing()
- return True
- else:
- patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
- return "unsloth"
- elif use_gradient_checkpointing in (True, False):
- # User explicitly set True or False - unpatch any previous "unsloth" patching
- unpatch_unsloth_smart_gradient_checkpointing()
- return use_gradient_checkpointing
- return use_gradient_checkpointing
-
-
-def prefer_flex_attn_if_supported(model_class, config):
- if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0":
- return None
- try:
- from transformers.utils.import_utils import is_torch_flex_attn_available
-
- if not is_torch_flex_attn_available():
- return None
- if model_class is None or not getattr(
- model_class, "_supports_flex_attn", False
- ):
- return None
- # GPT-OSS, Mllama and Gemma3N use eager/sdpa attention during
- # inference since flex attention returns incorrect results or errors out.
- # GPT-OSS: left padding issues cause incorrect outputs.
- # Mllama: _update_causal_mask uses make_flex_block_causal_mask which
- # creates BlockMask with Q_LEN=KV_LEN=total_seq_len, but during
- # decode q_len=1, causing ValueError. Needs transformers update.
- # Gemma3N: timm vision wrappers (eg Gemma3nVisionConfig) do not
- # support flex_attention.
- model_type = getattr(config, "model_type", "") if config else ""
- if model_type in ("gpt_oss", "mllama") or str(model_type).startswith("gemma3n"):
- return None
- if config is not None:
- setattr(config, "_attn_implementation", "flex_attention")
- if hasattr(config, "attn_implementation"):
- setattr(config, "attn_implementation", "flex_attention")
- return "flex_attention"
- except Exception:
- return None
-
-
-def _run_temporary_patches(phase):
- import inspect
-
- for temporary_patch in TEMPORARY_PATCHES:
- try:
- sig = inspect.signature(temporary_patch)
- if "phase" in sig.parameters:
- temporary_patch(phase = phase)
- else:
- temporary_patch()
- except (ValueError, TypeError):
- temporary_patch()
-
-
-_run_temporary_patches("init")
+for temporary_patch in TEMPORARY_PATCHES:
+ temporary_patch()
# =============================================
# Disable some warnings which can get annoying
-warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
-warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "torch")
-warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
-warnings.filterwarnings(
- action = "ignore", category = FutureWarning, module = "huggingface_hub"
-)
-warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl")
-warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl")
-warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub")
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
-warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
-warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
-warnings.filterwarnings(
- action = "ignore", category = RuntimeWarning, module = "multiprocessing"
-)
+warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
+warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
+warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
-warnings.filterwarnings(action = "ignore", category = UserWarning, module = "triton")
-warnings.filterwarnings(action = "ignore", category = UserWarning, module = "bitsandbytes")
# Stop "Special tokens have been added in the vocabulary, ..."
-logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL + 1)
-
-TORCHAO_MSG = "Error: torchao not found, please install with `pip install torchao`"
-
+import logging
+logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)
# Ignore logging messages
class HideLoggingMessage(logging.Filter):
- __slots__ = ("text",)
-
- def __init__(self, text):
- self.text = text
-
- def filter(self, x):
- return not (self.text in x.getMessage())
-
-
-# Replace warning messages (analogous to HideLoggingMessage but for warnings.warn)
-class ReplaceWarningMessage:
- """
- Intercepts warnings.warn calls and replaces matching messages with Unsloth branded ones.
- Uses a list of registered (match_text, replacement, category) rules checked in order.
- """
-
- _rules = []
- _original_showwarning = None
- _installed = False
-
- @classmethod
- def add_rule(cls, match_text, replacement, category = None):
- cls._rules.append((match_text, replacement, category))
- if not cls._installed:
- cls._install()
-
- @classmethod
- def _install(cls):
- cls._original_showwarning = warnings.showwarning
- cls._installed = True
-
- def _patched_showwarning(
- message, category, filename, lineno, file = None, line = None
- ):
- msg_str = str(message)
- for match_text, replacement, match_category in cls._rules:
- if match_text in msg_str and (
- match_category is None or category is match_category
- ):
- print(replacement)
- return
- cls._original_showwarning(message, category, filename, lineno, file, line)
-
- warnings.showwarning = _patched_showwarning
-
-
-# Stop vLLM messages
-if not UNSLOTH_ENABLE_LOGGING:
- try:
- from vllm.worker.worker import logger as vllm_worker_logger
-
- vllm_worker_logger.addFilter(HideLoggingMessage("Sleep mode freed"))
- del vllm_worker_logger
- except:
- pass
- try:
- from vllm.v1.worker.gpu_worker import logger as vllm_gpu_worker_logger
-
- vllm_gpu_worker_logger.addFilter(HideLoggingMessage("Sleep mode freed"))
- del vllm_gpu_worker_logger
- except:
- pass
- try:
- from vllm.executor.executor_base import logger as vllm_executor_logger
-
- vllm_executor_logger.addFilter(HideLoggingMessage("to fall asleep"))
- vllm_executor_logger.addFilter(HideLoggingMessage("to wake up"))
- vllm_executor_logger.addFilter(HideLoggingMessage("Executor is not sleeping"))
- del vllm_executor_logger
- except:
- pass
- try:
- from vllm.v1.executor.abstract import logger as vllm_v1_executor_logger
-
- vllm_v1_executor_logger.addFilter(HideLoggingMessage("to fall asleep"))
- vllm_v1_executor_logger.addFilter(HideLoggingMessage("to wake up"))
- vllm_v1_executor_logger.addFilter(
- HideLoggingMessage("Executor is not sleeping")
- )
- del vllm_v1_executor_logger
- except:
- pass
- try:
- from vllm.core.block.prefix_caching_block import (
- logger as vllm_prefix_caching_logger,
- )
+ __slots__ = "text",
+ def __init__(self, text): self.text = text
+ def filter(self, x): return not (self.text in x.getMessage())
+pass
- vllm_prefix_caching_logger.addFilter(HideLoggingMessage("reset prefix cache"))
- del vllm_prefix_caching_logger
- except:
- pass
- try:
- from vllm.v1.core.block_pool import logger as vllm_block_pool_logger
-
- vllm_block_pool_logger.addFilter(HideLoggingMessage("reset prefix cache"))
- del vllm_block_pool_logger
- except:
- pass
- try:
- from vllm.lora.models import logger as vllm_lora_model_logger
-
- vllm_lora_model_logger.addFilter(
- HideLoggingMessage(
- "Regarding multimodal models, vLLM currently only supports adding"
- )
- )
- del vllm_lora_model_logger
- except:
- pass
- try:
- from vllm.attention.utils.fa_utils import (
- logger as vllm_attention_utils_fa_utils_logger,
- )
-
- vllm_attention_utils_fa_utils_logger.addFilter(
- HideLoggingMessage("Cannot use FA version")
- )
- del vllm_attention_utils_fa_utils_logger
- except:
- pass
-
-# The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here.
+# The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here.
from transformers.training_args import logger as transformers_training_args_logger
-
transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups"))
# torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED.
transformers_training_args_logger.addFilter(HideLoggingMessage("torch.distributed"))
-# average_tokens_across_devices is set to True but it is invalid when world size is1
-transformers_training_args_logger.addFilter(
- HideLoggingMessage("average_tokens_across_devices")
-)
del transformers_training_args_logger
# No label_names provided for model class
from transformers.trainer import logger as transformers_trainer_logger
-
transformers_trainer_logger.addFilter(HideLoggingMessage("No label_names"))
-
-# The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config.
-transformers_trainer_logger.addFilter(HideLoggingMessage("The tokenizer has new"))
del transformers_trainer_logger
# Using the default loss: `ForCausalLMLoss`.
try:
from transformers.modeling_utils import logger as transformers_modeling_utils_logger
-
transformers_modeling_utils_logger.addFilter(HideLoggingMessage("ForCausalLMLoss"))
del transformers_modeling_utils_logger
except:
@@ -453,335 +169,88 @@ def _patched_showwarning(
# The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
try:
from accelerate.utils.modeling import logger as accelerate_utils_modeling_logger
-
- accelerate_utils_modeling_logger.addFilter(
- HideLoggingMessage("The model weights are not tied")
- )
+ accelerate_utils_modeling_logger.addFilter(HideLoggingMessage("The model weights are not tied"))
del accelerate_utils_modeling_logger
except:
pass
# Setting `pad_token_id` to `eos_token_id`
try:
- from transformers.generation.utils import (
- logger as transformers_generation_utils_logger,
- )
-
- transformers_generation_utils_logger.addFilter(
- HideLoggingMessage("Setting `pad_token_id` to `eos_token_id`")
- )
- # "You have set `compile_config`
- transformers_generation_utils_logger.addFilter(HideLoggingMessage("compile_config"))
+ from transformers.generation.utils import logger as transformers_generation_utils_logger
+ transformers_generation_utils_logger.addFilter(HideLoggingMessage("Setting `pad_token_id` to `eos_token_id`"))
del transformers_generation_utils_logger
except:
pass
-# The following generation flags are not valid and may be ignored:
-try:
- from transformers.generation.configuration_utils import (
- logger as configuration_logger,
- )
-
- configuration_logger.addFilter(HideLoggingMessage("following generation flags"))
- del configuration_logger
-except:
- pass
-
-# Gemma3 It is strongly recommended to train Gemma3 models with the `eager`
-try:
- from transformers.models.gemma3.modeling_gemma3 import logger as gemma3_logger
-
- gemma3_logger.addFilter(HideLoggingMessage("strongly recommended"))
- del gemma3_logger
-except:
- pass
-
-# Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed.
-try:
- from huggingface_hub.file_download import logger as hub_logger
-
- hub_logger.addFilter(HideLoggingMessage("hf_xet"))
- del hub_logger
-except:
- pass
-
-# MXFP4 quantization requires triton >= 3.4.0
-try:
- from transformers.quantizers.quantizer_mxfp4 import logger as mxfp4_logger
-
- mxfp4_logger.addFilter(HideLoggingMessage("requires triton"))
- del mxfp4_logger
-except:
- pass
-
-# You passed `quantization_config` or equivalent parameters
-try:
- warnings.filterwarnings(
- action = "ignore",
- message = r".*quantization_config.*",
- category = UserWarning,
- append = True,
- )
-except:
- pass
-
-# UserWarning: Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead
-# Will be fixed in torch 2.8.1 https://github.com/pytorch/pytorch/issues/158463
-try:
- warnings.filterwarnings(
- action = "ignore",
- message = r".*Logical operators 'and' and 'or'.*",
- category = UserWarning,
- append = True,
- )
-except:
- pass
-
-# Using a slow image processor as `use_fast`
-try:
- from transformers.processing_utils import logger as processing_utils_logger
-
- processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
- del processing_utils_logger
-except:
- pass
-
-# Using a slow image processor as `use_fast`
-try:
- from transformers.models.auto.image_processing_auto import (
- logger as processing_utils_logger,
- )
-
- processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
- del processing_utils_logger
-except:
- pass
-
-# `use_cache=True` is incompatible with gradient checkpointing
-try:
- from transformers.trainer import logger as trainer_logger
-
- trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
- del trainer_logger
-except:
- pass
-
-# `use_cache=True` is incompatible with gradient checkpointing
-try:
- from transformers.utils.generic import logger as trainer_logger
-
- trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
- del trainer_logger
-except:
- pass
-
-# We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')
-try:
- from transformers.modeling_utils import logger as modeling_utils_logger
-
- modeling_utils_logger.addFilter(HideLoggingMessage("anti-pattern"))
- del modeling_utils_logger
-except:
- pass
-
-# Errors out on
-# Some weights of Gemma3nForConditionalGeneration were not initialized from the model checkpoint
-from transformers.modeling_utils import logger as transformers_logger
-
-
-class _RaiseUninitialized(logging.Handler):
- def __init__(self):
- super().__init__()
-
- def emit(self, record):
- record_lower = str(record).lower()
- if (
- ("some weights of" in record_lower)
- and ("score.weight" not in record_lower)
- and ("classifier.weight" not in record_lower)
- and ("cls.predictions" not in record_lower)
- and ("predictions.decoder" not in record_lower)
- and (os.environ.get("UNSLOTH_WARN_UNINITIALIZED", "1") == "1")
- ):
- raise Exception(
- f"Unsloth: Critical error since some weights are not initialized.\n"
- f"Please try updating Unsloth, transformers and timm via:\n"
- f"`pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo transformers timm`\n"
- f"{str(record)}"
- )
-
-
-class RaiseUninitialized:
- def __init__(self):
- self.error_handler = _RaiseUninitialized()
- transformers_logger.addHandler(self.error_handler)
-
- def remove(self):
- transformers_logger.removeHandler(self.error_handler)
-
-
-try:
- from transformers.trainer import logger as transformers_trainer_logger
-
- transformers_trainer_logger.addFilter(
- HideLoggingMessage("The model is already on multiple devices.")
- )
-except:
- pass
-
-# Hide HF Hub unauthenticated request warnings
-try:
- from huggingface_hub.utils._http import logger as hf_http_logger
-
- hf_http_logger.addFilter(
- HideLoggingMessage("You are sending unauthenticated requests")
- )
- del hf_http_logger
-except:
- pass
-
-# Replace PEFT target_parameters warning with Unsloth branded message for MoE models
-ReplaceWarningMessage.add_rule(
- match_text = "target_parameters",
- replacement = (
- "Unsloth: PEFT set target_parameters but found no matching parameters.\n"
- "This is expected for MoE models - Unsloth handles MoE expert LoRA targeting separately."
- ),
- category = RuntimeWarning,
-)
-
# Patch get_model_param_count to record correct 4bit / 8bit
from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled
-
-
-def extract_quant_model_param_count(model):
- """
- Calculate quant model param count based on difference in param class. Returns int for param count.
- """
- count: int = 0
- for name, p in model.named_parameters():
- if p.__class__.__name__ == "Params4bit":
- count += 2 * p.numel()
- else:
- count += p.numel()
- return count
-
-
def get_model_param_count(model, trainable_only = False):
"""
Calculate model's total param count. If trainable_only is True then count only those requiring grads
"""
if is_deepspeed_zero3_enabled():
-
def numel(p):
return p.ds_numel if hasattr(p, "ds_numel") else p.numel()
else:
-
def numel(p):
return p.numel()
-
- s = sum(
- numel(p) for p in model.parameters() if not trainable_only or p.requires_grad
- )
- if (
- (not trainable_only)
- and hasattr(model, "config")
- and hasattr(model.config, "quantization_config")
- ):
- approx = extract_quant_model_param_count(model)
- if approx is not None:
- s = approx
+ s = sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
+ if (not trainable_only) and \
+ hasattr(model, "config") and \
+ hasattr(model.config, "quantization_config"):
+
+ billions = re.findall(r"([0-9]{1,})(?:b|B)", model.config.name_or_path)
+ if len(billions) != 0:
+ billions = int(billions[0])
+ s = 1_000_000_000 * billions
+ pass
return s
-
-
+pass
import transformers.trainer_pt_utils
-
transformers.trainer_pt_utils.get_model_param_count = get_model_param_count
import transformers.trainer
-
transformers.trainer.get_model_param_count = get_model_param_count
# =============================================
# =============================================
# Edits all Config files to enable RoPE Scaling for all models
-
# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
def patch_mistral_nemo_config(config):
if "head_dim (" not in config:
- add_head_dim = (
- "If it is not specified, will default to `8`.\n"
- " head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"
+ add_head_dim = "If it is not specified, will default to `8`.\n"\
+ " head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"\
" The attention head dimension."
- )
- config = config.replace(
- "If it is not specified, will default to `8`.", add_head_dim
- )
+ config = config.replace("If it is not specified, will default to `8`.", add_head_dim)
add_head_dim = "num_key_value_heads=8,\n head_dim=None,"
config = config.replace("num_key_value_heads=8,", add_head_dim)
add_head_dim = "self.sliding_window = sliding_window\n self.head_dim = head_dim or hidden_size // num_attention_heads\n"
config = config.replace("self.sliding_window = sliding_window", add_head_dim)
- return config
-
-
-try:
- # Some Config files use layer_type_validation
- # for eg Gemma-2, so we must import it to stop errors.
- from transformers.configuration_utils import layer_type_validation
-except:
pass
+ return config
+pass
-try:
- # Transformers 5.0+ uses RotaryEmbeddingConfigMixin as a base class for configs
- from transformers.modeling_rope_utils import RotaryEmbeddingConfigMixin
-except:
- pass
from transformers import __version__ as transformers_version
-
-try:
- from transformers import PreTrainedConfig
-except:
- from transformers import PretrainedConfig
-
-model_architectures = [
- "llama",
- "mistral",
- "gemma",
- "gemma2",
- "qwen2",
- "granite",
- "qwen3",
- "qwen3_moe",
- "falcon_h1",
-]
+from transformers import PretrainedConfig
+model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2", "granite"]
for model_name in model_architectures:
config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
- config_filename = f"{model_name.title().replace('_','')}Config" # qwen3 arch folder is qwen3_moe but config is Qwen3Config. Need to remove underscore(_) for now
- try:
- exec(f"from {config_filepath} import {config_filename}", globals())
- except:
- continue
+ config_filename = f"{model_name.title()}Config"
+ exec(f"from {config_filepath} import {config_filename}", globals())
try:
config = inspect.getsource(eval(config_filename))
except:
continue
- if "RopeParameters" in config:
- try:
- exec(f"from {config_filepath} import RopeParameters", globals())
- except:
- continue
-
- if "rope_scaling" in config:
- continue
+ if "rope_scaling" in config: continue
config = re.sub(
r"(\*\*kwargs)[\s]{0,}\,[\s]{0,}\)[\s]{0,}\:",
- r"rope_scaling=None,"
- r"\n **kwargs):\n"
+ r"rope_scaling=None,"\
+ r"\n **kwargs):\n"\
r"\n self.rope_scaling = rope_scaling\n",
config,
)
@@ -790,28 +259,24 @@ def patch_mistral_nemo_config(config):
if model_name == "mistral":
if Version(transformers_version) <= Version("4.42.4"):
config = patch_mistral_nemo_config(config)
+ pass
exec(config, globals())
exec(f"import {config_filepath}", globals())
exec(f"{config_filepath}.{config_filename} = {config_filename}", globals())
+pass
# =============================================
# =============================================
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
torch_version = torch.__version__
-if DEVICE_TYPE in ("cuda", "hip"):
- if Version(torch_version) < Version("2.4.0"):
- torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
- torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
- else:
- torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
- torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
-elif DEVICE_TYPE == "xpu":
- if Version(torch_version) < Version("2.6.0"):
- raise RuntimeError("torch.xpu currently only supports torch.version >= 2.6.0")
- else:
- torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
- torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
+if Version(torch_version) < Version("2.4.0"):
+ torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
+ torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
+else:
+ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
+ torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
+pass
# =============================================
# =============================================
@@ -839,84 +304,29 @@ def patch_mistral_nemo_config(config):
# =============================================
# Weird Databricks errors
from transformers.utils import is_openai_available
-
if is_openai_available():
try:
from openai import OpenAI
except:
print("Unsloth: OpenAI failed to import - ignoring for now.")
import transformers.utils
-
- def _is_openai_available():
- return False
-
+ def _is_openai_available(): return False
transformers.utils.is_openai_available = _is_openai_available
+ pass
+pass
# =============================================
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
import bitsandbytes as bnb
-
from transformers import AutoTokenizer
from transformers.utils.import_utils import _is_package_available
+major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = False
HAS_FLASH_ATTENTION = False
HAS_FLASH_ATTENTION_SOFTCAPPING = False
-if DEVICE_TYPE == "cuda":
- major_version, minor_version = torch.cuda.get_device_capability()
- torch.cuda.get_device_capability = functools.cache(torch.cuda.get_device_capability)
-
- if major_version >= 8:
- SUPPORTS_BFLOAT16 = True
- if _is_package_available("flash_attn"):
- # Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
- try:
- try:
- # See https://github.com/unslothai/unsloth/issues/1437
- from flash_attn.flash_attn_interface import flash_attn_gpu
- except:
- from flash_attn.flash_attn_interface import flash_attn_cuda
- HAS_FLASH_ATTENTION = True
-
- # Also check for softcapping
- from flash_attn import __version__ as flash_attn_version
-
- HAS_FLASH_ATTENTION_SOFTCAPPING = Version(
- flash_attn_version
- ) >= Version("2.6.3")
- if not HAS_FLASH_ATTENTION_SOFTCAPPING:
- print(
- "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"
- "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"
- "To update flash-attn, do the below:\n"
- '\npip install --no-deps --no-build-isolation --upgrade "flash-attn>=2.6.3"'
- )
- except:
- print(
- "Unsloth: Your Flash Attention 2 installation seems to be broken. "
- "Using Xformers instead. No performance changes will be seen."
- )
-
- # Stop Flash Attention from importing!
- import transformers.utils.import_utils
-
- transformers.utils.import_utils.is_flash_attn_2_available = (
- lambda *args, **kwargs: False
- )
- import transformers.utils
-
- transformers.utils.is_flash_attn_2_available = (
- lambda *args, **kwargs: False
- )
-
- HAS_FLASH_ATTENTION = False
- else:
- HAS_FLASH_ATTENTION = False
- else:
- # Tri Dao's benchmark shows xformers is faster for now.
- HAS_FLASH_ATTENTION = False
-elif DEVICE_TYPE == "hip":
+if major_version >= 8:
SUPPORTS_BFLOAT16 = True
if _is_package_available("flash_attn"):
# Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
@@ -930,146 +340,114 @@ def _is_openai_available():
# Also check for softcapping
from flash_attn import __version__ as flash_attn_version
-
- HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version(
- "2.6.3"
- )
+ HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3")
if not HAS_FLASH_ATTENTION_SOFTCAPPING:
print(
- "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"
- "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"
- "To update flash-attn, do the below:\n"
- '\npip install --no-deps --no-build-isolation --upgrade "flash-attn>=2.6.3"'
+ "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
+ "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
+ "To update flash-attn, do the below:\n"\
+ '\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
)
except:
print(
- "Unsloth: Your Flash Attention 2 installation seems to be broken. "
- "Using Xformers instead. No performance changes will be seen."
+ "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
+ "A possible explanation is you have a new CUDA version which isn't\n"\
+ "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\
+ "We shall now use Xformers instead, which does not have any performance hits!\n"\
+ "We found this negligible impact by benchmarking on 1x A100."
)
# Stop Flash Attention from importing!
import transformers.utils.import_utils
-
- transformers.utils.import_utils.is_flash_attn_2_available = (
- lambda *args, **kwargs: False
- )
+ transformers.utils.import_utils.is_flash_attn_2_available = lambda *args, **kwargs: False
import transformers.utils
-
transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False
HAS_FLASH_ATTENTION = False
-elif DEVICE_TYPE == "xpu":
- SUPPORTS_BFLOAT16 = True
+ pass
+ else:
+ HAS_FLASH_ATTENTION = False
+else:
+ # Tri Dao's benchmark shows xformers is faster for now.
+ HAS_FLASH_ATTENTION = False
+pass
+
+from transformers.models.llama.modeling_llama import logger
# =============================================
# Get Xformers
-# Silence xformers CUDA mismatch warnings before import
-try:
- _xformers_logger = logging.getLogger("xformers")
- _xformers_logger.setLevel(logging.ERROR)
- del _xformers_logger
-except:
- pass
try:
from xformers import __version__ as xformers_version
-
- # [TODO] Xformers does NOT work on RTX 50x (12), B200 (10), Jetson (11)
- # See https://github.com/facebookresearch/xformers/issues/1329
- # CUDA error (/workspace/xfrm2/third_party/flash-attention/hopper/flash_fwd_launch_template.h:188)
- major_version, minor_version = torch.cuda.get_device_capability()
- if (f"{major_version}.{minor_version}" in ("10.0", "11.0", "12.0")) and (
- Version(xformers_version) in (Version("0.0.32.post2"),)
- ):
- raise NotImplementedError(
- "Unsloth: Xformers does not work in RTX 50X, Blackwell GPUs as of yet. Please build from source via\n"
- "```\n"
- "pip install ninja\n"
- "pip install -v --no-build-isolation -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers\n"
- "```\n"
- )
# Temporarily disable 0.0.27 and higher - inference issues
- if False: # Version(xformers_version) >= Version("0.0.27"):
+ if False: #Version(xformers_version) >= Version("0.0.27"):
raise ImportError(
- "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "
- "then press Disconnect Runtime and then Restart it.\n"
- "\n"
+ "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
+ "then press Disconnect Runtime and then Restart it.\n"\
+ "\n"\
"%%capture\n"
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
- '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'
- "\n"
- f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"
+ '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\
+ '\n'\
+ f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"\
'Please downgrade xformers via `pip install --force-reinstall "xformers<=0.0.27"'
)
+ pass
- if Version(torch_version) < Version("2.2.0") and Version(
- xformers_version
- ) >= Version("0.0.24"):
+ if Version(torch_version) < Version("2.2.0") and Version(xformers_version) >= Version("0.0.24"):
raise ImportError(
- f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
+ f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
f"Please install xformers < 0.0.24 for torch = {torch_version}."
)
- elif Version(torch_version) < Version("2.3.0") and Version(
- xformers_version
- ) >= Version("0.0.26"):
+ elif Version(torch_version) < Version("2.3.0") and Version(xformers_version) >= Version("0.0.26"):
raise ImportError(
- f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
+ f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
f"Please install xformers < 0.0.26 for torch = {torch_version}."
)
- elif Version(torch_version) < Version("2.4.0") and Version(
- xformers_version
- ) > Version("0.0.27"):
+ elif Version(torch_version) < Version("2.4.0") and Version(xformers_version) > Version("0.0.27"):
raise ImportError(
- f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
+ f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
f"Please install xformers <= 0.0.27 for torch = {torch_version}."
)
+ pass
from xformers._cpp_lib import _register_extensions
-
try:
- _register_extensions() # Check if C++ modules are loaded correctly
+ _register_extensions() # Check if C++ modules are loaded correctly
except Exception as error:
raise ImportError(
- "Unsloth: Xformers was not installed correctly.\n"
- "Please install xformers separately first.\n"
- "Then confirm if it's correctly installed by running:\n"
+ "Unsloth: Xformers was not installed correctly.\n"\
+ "Please install xformers separately first.\n"\
+ "Then confirm if it's correctly installed by running:\n"\
"python -m xformers.info\n\n"
"Longer error message:\n" + str(error)
)
+ pass
import xformers.ops.fmha as xformers
-
xformers_attention = xformers.memory_efficient_attention
-except ModuleNotFoundError:
- xformers = None
- xformers_attention = None
- xformers_version = None
-except Exception as e:
- if UNSLOTH_ENABLE_LOGGING:
- print(
- "========\nSwitching to PyTorch attention since your Xformers is broken.\n========\n"
- )
- print(str(e))
+except:
xformers = None
xformers_attention = None
xformers_version = None
+pass
# Check TRL version
from trl import __version__ as trl_version
-
# Unsloth now supports all TRL versions!
-if False: # Version(trl_version) >= Version("0.9.0"):
+if False:#Version(trl_version) >= Version("0.9.0"):
raise ImportError(
- "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "
- "then press Disconnect Runtime and then Restart it.\n"
- "\n"
+ "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
+ "then press Disconnect Runtime and then Restart it.\n"\
+ "\n"\
"%%capture\n"
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
- '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'
- "\n"
- f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"
- "Please downgrade TRL via `pip install --force-reinstall trl"
+ '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\
+ '\n'\
+ f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"\
+ 'Please downgrade TRL via `pip install --force-reinstall trl'
)
+pass
# =============================================
# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout'
@@ -1095,53 +473,28 @@ def _is_openai_available():
# Transformers 4.46 breaks dynamic caching. This is a hack
import transformers.generation.configuration_utils
-
if hasattr(transformers.generation.configuration_utils, "ALL_CACHE_IMPLEMENTATIONS"):
- if (
- type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS)
- is list
- ):
- if (
- "dynamic"
- not in transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS
- ):
- transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append(
- "dynamic"
- )
+ if type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS) is list:
+ transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("dynamic")
+ pass
+pass
# =============================================
# =============================================
# Torch compile settings
-UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1"
-UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1"
-UNSLOTH_COMPILE_IGNORE_ERRORS = (
- os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "1") == "1"
-)
+UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1"
+UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1"
+UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "1") == "1"
# Just remove max_autotune_gemm warning
-from torch._inductor.runtime.hints import DeviceProperties
-
-
+import functools
@functools.lru_cache(None)
-def is_big_gpu(index) -> bool:
- if DEVICE_TYPE == "xpu":
- prop = DeviceProperties.create(
- torch.device("xpu", index) if type(index) is int else index
- )
- min_sms = 16
- else:
- prop = DeviceProperties.create(
- torch.device("cuda", index) if type(index) is int else index
- )
- min_sms = 80
-
- avail_sms = prop.multi_processor_count
- if avail_sms < min_sms:
+def is_big_gpu(index):
+ sms = torch.cuda.get_device_properties(index).multi_processor_count
+ if sms < 80: # V100
+ # log.warning("not enough SMs to use max_autotune_gemm mode")
return False
return True
-
-
import torch._inductor.utils
-
torch._inductor.utils.is_big_gpu = is_big_gpu
patch_torch_compile(
debug = UNSLOTH_COMPILE_DEBUG,
@@ -1150,101 +503,76 @@ def is_big_gpu(index) -> bool:
)
torch_compile_options = {
- "epilogue_fusion": True,
- "max_autotune": True,
- "shape_padding": True,
- "trace.enabled": UNSLOTH_COMPILE_DEBUG,
- "triton.cudagraphs": False,
+ "epilogue_fusion" : True,
+ "max_autotune" : True,
+ "shape_padding" : True,
+ "trace.enabled" : UNSLOTH_COMPILE_DEBUG,
+ "triton.cudagraphs" : False,
}
import accelerate
-
-
def torch_compile_kwargs(*args, **kwargs):
print("Unsloth: Enabled auto compiling")
- return {
- "dynamic": True,
- "fullgraph": False,
- "options": torch_compile_options,
- }
-
+ return {"dynamic" : True, "fullgraph" : False, "options" : torch_compile_options,}
+pass
accelerate.utils.dataclasses.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
-accelerate.utils.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
-accelerate.accelerator.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
+accelerate.utils.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
+accelerate.accelerator.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
del accelerate
-
def patch_regional_compilation():
# Regional torch 2.5 Recompilation - weirdly very slow??
- if torch.nn.ModuleList.__name__ == "UnslothModuleList":
- return
+ if torch.nn.ModuleList.__name__ == "UnslothModuleList": return
# Only works for torch 2.5
- if Version(torch.__version__) < Version("2.5.0"):
- return
+ if Version(torch.__version__) < Version("2.5.0"): return
old_module_list = torch.nn.ModuleList
os.environ["UNSLOTH_PATCHED"] = "1"
def UnslothModuleList(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and type(args[0]) is list:
- args = [
- old_module_list(
- [
- torch.compile(
- x,
- dynamic = True,
- options = torch_compile_options,
- fullgraph = False,
- )
- for x in args[0]
- ]
- )
- ]
+ args = [old_module_list([torch.compile(x, dynamic = True, options = torch_compile_options, fullgraph = False) for x in args[0]])]
return old_module_list(*args, **kwargs)
-
+ pass
UnslothModuleList.__doc__ = old_module_list.__doc__
torch.nn.ModuleList = UnslothModuleList
return
-
+pass
# =============================================
-
def prepare_model_for_kbit_training(
- model: Any,
- use_gradient_checkpointing: Optional = True,
- use_reentrant: Optional[bool] = True,
+ model : Any,
+ use_gradient_checkpointing : Optional = True,
+ use_reentrant : Optional[bool] = True,
) -> Any:
return prepare_model_for_training(
- model = model,
+ model = model,
use_gradient_checkpointing = use_gradient_checkpointing,
- use_reentrant = use_reentrant,
- full_finetuning = False,
- train_layernorms = False,
- train_embedding = False,
- train_lm_head = False,
- float32_mixed_precision = True,
+ use_reentrant = use_reentrant,
+ full_finetuning = False,
+ train_layernorms = False,
+ train_embedding = False,
+ train_lm_head = False,
+ float32_mixed_precision = True,
)
-
+pass
# =============================================
# Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
# For mixed precision, we need it to be in float32 not float16.
from peft import __version__ as peft_version
-from peft.utils.integrations import dequantize_module_weight
-
if Version(peft_version) < Version("0.12.0"):
from peft.tuners.lora.layer import LoraLayer
-
try:
source = inspect.getsource(LoraLayer.update_layer)
text = "if weight is not None:\n"
start = source.find(text) + len(text)
end = source.find("self.to(weight.device)", start)
spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
- source = source.replace(source[start:end], spaces)
+ source = source.replace(source[start : end], spaces)
spaces = len(re.match(r"[\s]{1,}", source).group(0))
lines = source.split("\n")
source = "\n".join(x[spaces:] for x in lines)
@@ -1254,219 +582,114 @@ def prepare_model_for_kbit_training(
# Fix up incorrect downcasting of LoRA weights
from peft.tuners.lora.layer import LoraLayer
-
LoraLayer.update_layer = LoraLayer_update_layer
from peft.tuners.lora import LoraLayer
-
LoraLayer.update_layer = LoraLayer_update_layer
except:
logger.warning_once(
- "Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"
+ "Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"\
"Luckily, your training run will still work in the meantime!"
)
+ pass
+pass
# =============================================
-import importlib
-
-global USE_MODELSCOPE
-USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
-if USE_MODELSCOPE:
- if importlib.util.find_spec("modelscope") is None:
- raise ImportError(
- f"You are using the modelscope hub, please install modelscope by `pip install modelscope -U`"
- )
-
-import socket
-
-
-@functools.lru_cache(1)
-def has_internet(host = "8.8.8.8", port = 53, timeout = 3):
- if os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1":
- return False
- try:
- socket.setdefaulttimeout(timeout)
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- try:
- sock.connect((host, port))
- return True
- finally:
- sock.close()
- except socket.error as ex:
- return False
-
import psutil
-
-
def _get_statistics(statistics = None, force_download = True):
# We log some basic stats about which environment is being used.
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by commenting the below out
- n_cpus = psutil.cpu_count(logical = False)
- keynames = "\n" + "\n".join(os.environ.keys())
- # Check modelscope for down detection
- global USE_MODELSCOPE
- USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
-
- if statistics is None:
- # Prefer filesystem markers (harder to misidentify) before env-key matching
- try:
- from pathlib import Path
-
- if Path("/kaggle/working").exists():
- statistics = "kaggle"
- elif Path("/content").exists() and Path("/opt/colab").exists():
- statistics = "colab" if n_cpus == 1 else "colabpro"
- elif Path("/runpod-volume").exists():
- statistics = "runpod"
- except Exception:
- pass
-
- # Fallback to env-key detection
- if statistics is None:
- if "\nKAGGLE_" in keynames:
- statistics = "kaggle"
- elif "\nCOLAB_" in keynames and n_cpus == 1:
- statistics = "colab"
- elif "\nCOLAB_" in keynames:
- statistics = "colabpro"
- elif "\nRUNPOD_" in keynames:
- statistics = "runpod"
- elif "\nAWS_" in keynames:
- statistics = "aws"
- elif "\nAZURE_" in keynames:
- statistics = "azure"
- # elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp"
- elif "\nINVOCATION_ID" in keynames:
- statistics = "lambda"
- # else: statistics = "other"
- else:
-
- def try_vllm_check():
- vendor_files = (
- "/sys/class/dmi/id/product_version",
- "/sys/class/dmi/id/bios_vendor",
- "/sys/class/dmi/id/product_name",
- "/sys/class/dmi/id/chassis_asset_tag",
- "/sys/class/dmi/id/sys_vendor",
- )
-
- for vendor_file in vendor_files:
- path = Path(vendor_file)
- if path.is_file():
- file_content = path.read_text().lower()
- if "amazon" in file_content:
- return "aws"
- elif "microsoft corporation" in file_content:
- return "azure"
- elif "google" in file_content:
- return "gcp"
- return "other"
-
- try:
- statistics = try_vllm_check()
- except Exception:
- statistics = "other"
-
- if statistics is not None:
- import tempfile
- from huggingface_hub import snapshot_download
- from unsloth_zoo.rl_environments import execute_with_time_limit
-
- if has_internet():
-
- def stats_check():
- with tempfile.TemporaryDirectory(ignore_cleanup_errors = True) as f:
- snapshot_download(
- f"unslothai/{statistics}",
- force_download = True,
- cache_dir = f,
- local_dir = f,
- )
-
- time_limited_stats_check = execute_with_time_limit(120)(stats_check)
- try:
- time_limited_stats_check()
- except TimeoutError:
- raise TimeoutError(
- "Unsloth: HuggingFace seems to be down after trying for 120 seconds :(\n"
- "Check https://status.huggingface.co/ for more details.\n"
- "As a temporary measure, use modelscope with the same model name ie:\n"
- "```\n"
- "pip install modelscope\n"
- "import os; os.environ['UNSLOTH_USE_MODELSCOPE'] = '1'\n"
- "from unsloth import FastLanguageModel\n"
- "model = FastLanguageModel.from_pretrained('unsloth/gpt-oss-20b')\n"
- "```"
+ try:
+ n_cpus = psutil.cpu_count(logical = False)
+ keynames = "\n" + "\n".join(os.environ.keys())
+ if statistics is not None: pass
+ elif "\nCOLAB_" in keynames and n_cpus == 1: statistics = "colab"
+ elif "\nCOLAB_" in keynames: statistics = "colabpro"
+ elif "\nKAGGLE_" in keynames: statistics = "kaggle"
+ elif "\nRUNPOD_" in keynames: statistics = "runpod"
+ elif "\nAWS_" in keynames: statistics = "aws"
+ elif "\nAZURE_" in keynames: statistics = "azure"
+ # elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp"
+ elif "\nINVOCATION_ID" in keynames: statistics = "lambda"
+ # else: statistics = "other"
+ else:
+ def try_vllm_check():
+ vendor_files = (
+ "/sys/class/dmi/id/product_version",
+ "/sys/class/dmi/id/bios_vendor",
+ "/sys/class/dmi/id/product_name",
+ "/sys/class/dmi/id/chassis_asset_tag",
+ "/sys/class/dmi/id/sys_vendor",
)
- except Exception:
- # Try no time limit check
- stats_check()
+ from pathlib import Path
+ for vendor_file in vendor_files:
+ path = Path(vendor_file)
+ if path.is_file():
+ file_content = path.read_text().lower()
+ if "amazon" in file_content: return "aws"
+ elif "microsoft corporation" in file_content: return "azure"
+ elif "google" in file_content: return "gcp"
+ return "other"
+ pass
+ try: statistics = try_vllm_check()
+ except: statistics = "other"
+ pass
+ if statistics is not None:
+ from transformers import AutoModelForCausalLM
+ stats_model = AutoModelForCausalLM.from_pretrained(
+ f"unslothai/{statistics}",
+ force_download = force_download,
+ )
+ del stats_model
+ pass
+ except:
+ pass
+pass
-def get_statistics(local_files_only = False):
+def get_statistics():
# We log some basic stats about which environment is being used.
- # This is also to check if HuggingFace is down or not!
# We simply download a README.md file from HF - all data is made public.
# This is simply so we can check if some envs are broken or not.
# You can disable this by setting UNSLOTH_DISABLE_STATISTICS
import os
-
- if (
- "UNSLOTH_DISABLE_STATISTICS" in os.environ
- or os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
- ):
- return
- if local_files_only:
- return
- from huggingface_hub.utils import (
- disable_progress_bars,
- enable_progress_bars,
- are_progress_bars_disabled,
- )
-
+ if "UNSLOTH_DISABLE_STATISTICS" in os.environ: return
+ from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
disabled = False
if not are_progress_bars_disabled():
disable_progress_bars()
disabled = True
+ pass
_get_statistics(None)
_get_statistics("repeat", force_download = False)
- total_memory = (
- torch.xpu.get_device_properties(0).total_memory
- if DEVICE_TYPE == "xpu"
- else torch.cuda.get_device_properties(0).total_memory
- )
- vram = total_memory / 1024 / 1024 / 1024
- if vram <= 8:
- vram = 8
- elif vram <= 16:
- vram = 16
- elif vram <= 20:
- vram = 20
- elif vram <= 24:
- vram = 24
- elif vram <= 40:
- vram = 40
- elif vram <= 48:
- vram = 48
- elif vram <= 80:
- vram = 80
- else:
- vram = 96
- _get_statistics(f"vram-{vram}")
- _get_statistics(f"{DEVICE_COUNT if DEVICE_COUNT <= 8 else 9}")
- if disabled:
- enable_progress_bars()
+ try:
+ vram = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024
+ if vram <= 8 : vram = 8
+ elif vram <= 16: vram = 16
+ elif vram <= 20: vram = 20
+ elif vram <= 24: vram = 24
+ elif vram <= 40: vram = 40
+ elif vram <= 48: vram = 48
+ elif vram <= 80: vram = 80
+ else: vram = 96
+ _get_statistics(f"vram-{vram}")
+ except:
+ pass
+ pass
+ try:
+ devices = torch.cuda.device_count()
+ _get_statistics(f"{devices if devices <= 8 else 9}")
+ except:
+ pass
+ if disabled: enable_progress_bars()
+pass
# =============================================
# Fixes Bitsandbytes to remove missing warnings
-from transformers.utils.quantization_config import (
- BitsAndBytesConfig,
- QuantizationMethod,
-)
-
+from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
BitsAndBytesConfig__init__ = inspect.getsource(BitsAndBytesConfig.__init__)
BitsAndBytesConfig__init__ = re.sub(
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
@@ -1476,132 +699,86 @@ def get_statistics(local_files_only = False):
)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
-BitsAndBytesConfig__init__ = "\n".join(
- x[length_spaces:] for x in BitsAndBytesConfig__init__
-)
+BitsAndBytesConfig__init__ = "\n".join(x[length_spaces:] for x in BitsAndBytesConfig__init__)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace(
"__init__",
"_BitsAndBytesConfig__init__",
)
exec(BitsAndBytesConfig__init__, globals())
-if DEVICE_COUNT == 1 and int(os.environ.get("WORLD_SIZE", "1")) <= 1:
+if torch.cuda.device_count() == 1:
from accelerate.utils.dataclasses import DistributedType
-
- def _prepare_backend(self, *args, **kwargs):
+ def _prepare_backend(
+ self, cpu = False, sagemaker_dp = False, backend: str = None,
+ ) -> tuple[str, DistributedType]:
return None, DistributedType.NO
-
+ pass
import accelerate.state
-
accelerate.state.PartialState._prepare_backend = _prepare_backend
- accelerate.accelerator.Accelerator.distributed_type = (
- lambda *args, **kwargs: DistributedType.NO
- )
-
-
-# to move multiple tensors to the same device
-def move_to_device(target_device, *tensors):
- """
- Move multiple tensors to target device if they're not already there.
-
- Args:
- target_device: The target device to move tensors to
- *tensors: Variable number of tensors to potentially move
-
- Returns:
- tuple: The tensors on the target device (same objects if already on device, new if moved)
- """
- if isinstance(target_device, int):
- target_device = torch.device(target_device)
- elif isinstance(target_device, str):
- # if string we expect it to be a device name like "cuda:0"
- target_device = torch.device(target_device)
- elif isinstance(target_device, torch.device):
- pass
- else:
- raise ValueError(f"Invalid target device: {target_device}")
- moved_tensors = []
- for tensor in tensors:
- if tensor.device != target_device:
- moved_tensors.append(tensor.to(target_device))
- else:
- moved_tensors.append(tensor)
- return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0]
+ import accelerate.accelerator
+ prepare = inspect.getsource(accelerate.accelerator.Accelerator.prepare)
+ prepare = prepare.split("\n")
+ spaces = prepare[0].find("def")
+ prepare = "\n".join(x[spaces:] for x in prepare)
+ x = "for obj in args:"
+ s = " "*spaces
+ prepare = prepare.replace(x, f'self.state.distributed_type = DistributedType.NO\n{s}{x}', 1)
+ exec(prepare, globals())
+ accelerate.accelerator.Accelerator.prepare = prepare
+pass
import transformers.utils.quantization_config
-
-transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = (
- _BitsAndBytesConfig__init__
-)
+transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
# =============================================
# Offloading to disk for modules (lm_head, embed_tokens)
import pickle
-
-def offload_to_disk(
- W, model, name, temporary_location: str = "_unsloth_temporary_saved_buffers"
-):
+def offload_to_disk(W, model, name, temporary_location : str = "_unsloth_temporary_saved_buffers"):
file_location = os.path.join(temporary_location, model.config._name_or_path)
if not os.path.exists(file_location):
os.makedirs(file_location)
+ pass
filename = os.path.join(file_location, f"{name}.pt")
W = W.weight if hasattr(W, "weight") else W
- torch.save(
- W,
- filename,
- pickle_module = pickle,
- pickle_protocol = pickle.HIGHEST_PROTOCOL,
- )
+ torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
# We must use weights_only = False due to pickling
- offloaded_W = torch.load(
- filename, map_location = "cpu", mmap = True, weights_only = False
- )
+ offloaded_W = torch.load(filename, map_location = "cpu", mmap = True, weights_only = False)
offloaded_W._offloaded_file_location = filename
return offloaded_W
+pass
-def offload_input_embeddings(
- model, temporary_location: str = "_unsloth_temporary_saved_buffers"
-):
- offloaded_W = offload_to_disk(
- model.get_input_embeddings(), model, "input_embeddings", temporary_location
- )
+def offload_input_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
+ offloaded_W = offload_to_disk(model.get_input_embeddings(), model, "input_embeddings", temporary_location)
new_input_embeddings = torch.nn.Embedding.from_pretrained(offloaded_W)
new_input_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
model.set_input_embeddings(new_input_embeddings)
return
+pass
-def offload_output_embeddings(
- model, temporary_location: str = "_unsloth_temporary_saved_buffers"
-):
- offloaded_W = offload_to_disk(
- model.get_output_embeddings(), model, "output_embeddings", temporary_location
- )
+def offload_output_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
+ offloaded_W = offload_to_disk(model.get_output_embeddings(), model, "output_embeddings", temporary_location)
new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
del new_output_embeddings.weight
new_output_embeddings.weight = offloaded_W
- new_output_embeddings.in_features = offloaded_W.shape[1]
+ new_output_embeddings.in_features = offloaded_W.shape[1]
new_output_embeddings.out_features = offloaded_W.shape[0]
- new_output_embeddings._offloaded_file_location = (
- offloaded_W._offloaded_file_location
- )
+ new_output_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
model.set_output_embeddings(new_output_embeddings)
return
+pass
# Fixes a weird Torch 2.3 bug which says T4s have bfloat16
def is_bfloat16_supported():
return SUPPORTS_BFLOAT16
-
-
-def is_vLLM_available():
- return _is_package_available("vllm")
+pass
# Patches models to add RoPE Scaling
@@ -1611,18 +788,17 @@ def patch_linear_scaling(
scaled_rope_module = None,
attention_module = None,
):
- assert rope_module is not None and scaled_rope_module is not None
- assert attention_module is not None
+ assert(rope_module is not None and scaled_rope_module is not None)
+ assert(attention_module is not None)
rope_name = rope_module.__name__
scaled_rope_name = scaled_rope_module.__name__
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
- exec_code = (
- f"import torch.nn as nn\n"
- f"from typing import Union, Optional, List, Any, Callable, Tuple\n"
- f"from {model_filepath} import logger, "
+ exec_code = \
+ f"import torch.nn as nn\n"\
+ f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
+ f"from {model_filepath} import logger, "\
f"{model_name.title()}Attention, {model_name.title()}Config"
- )
try:
function = inspect.getsource(attention_module.__init__)
@@ -1660,12 +836,11 @@ def patch_linear_scaling(
pass
"""
fix_rope_function = fix_rope_function.format(
- rope_function = rope_module.__name__,
+ rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
)
rotary_emb = re.findall(
- r"self\.rotary\_emb \= .+?\)",
- function,
+ r"self\.rotary\_emb \= .+?\)", function,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0:
@@ -1675,6 +850,7 @@ def patch_linear_scaling(
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
return init_name, function
+pass
# Patches for Llama-3 LlamaExtendedRotaryEmbedding
@@ -1686,22 +862,21 @@ def patch_llama_rope_scaling(
attention_module = None,
longrope_module = None,
):
- assert (
- rope_module is not None
- and scaled_rope_module is not None
- and extended_rope_module is not None
+ assert(\
+ rope_module is not None and \
+ scaled_rope_module is not None and \
+ extended_rope_module is not None
)
- assert attention_module is not None
+ assert(attention_module is not None)
rope_name = rope_module.__name__
scaled_rope_name = scaled_rope_module.__name__
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
- exec_code = (
- f"import torch.nn as nn\n"
- f"from typing import Union, Optional, List, Any, Callable, Tuple\n"
- f"from {model_filepath} import logger, "
+ exec_code = \
+ f"import torch.nn as nn\n"\
+ f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
+ f"from {model_filepath} import logger, "\
f"{model_name.title()}Attention, {model_name.title()}Config"
- )
try:
function = inspect.getsource(attention_module.__init__)
@@ -1758,24 +933,22 @@ def patch_llama_rope_scaling(
"""
fix_rope_function = fix_rope_function.format(
- rope_function = rope_module.__name__,
- scaled_rope_function = scaled_rope_module.__name__,
+ rope_function = rope_module.__name__,
+ scaled_rope_function = scaled_rope_module.__name__,
extended_rope_function = extended_rope_module.__name__,
- longrope_rope_function = (
- longrope_module if longrope_module is not None else rope_module
- ).__name__,
+ longrope_rope_function = \
+ (longrope_module if longrope_module is not None else rope_module).__name__
)
rotary_emb = re.findall(
- r"self\.rotary\_emb \= .+?\)",
- function,
+ r"self\.rotary\_emb \= .+?\)", function,
flags = re.DOTALL | re.MULTILINE,
)
- if len(rotary_emb) == 0:
- return None, function
+ if len(rotary_emb) == 0: return None, function
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
return init_name, function
+pass
def create_boolean_mask(n = 4096, sliding_window = 2048):
@@ -1783,52 +956,36 @@ def create_boolean_mask(n = 4096, sliding_window = 2048):
mask = torch.ones(n, n, dtype = torch.bool)
if sliding_window == 0:
return torch.triu(mask, diagonal = 1, out = mask)
+ pass
torch.triu(mask, diagonal = 0, out = mask)
torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
mask = mask.T
torch.logical_not(mask, out = mask)
return mask
+pass
def test_mask_creation():
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
-
for n in range(2, 23):
for s in range(1, 23):
- correct_mask = (
- AttentionMaskConverter(
- is_causal = True,
- sliding_window = s,
- )
- .to_causal_4d(
- 1,
- n,
- n,
- dtype = torch.float16,
- )
- .squeeze(0)
- .squeeze(0)
- )
- correct_mask = correct_mask == correct_mask.min()
- our_mask = create_boolean_mask(n = n, sliding_window = s)
- assert torch.all(correct_mask == our_mask)
- correct_mask = (
- AttentionMaskConverter(
+ correct_mask = AttentionMaskConverter(
is_causal = True,
- sliding_window = None,
- )
- .to_causal_4d(
- 1,
- n,
- n,
- dtype = torch.float16,
- )
- .squeeze(0)
- .squeeze(0)
- )
- correct_mask = correct_mask == correct_mask.min()
+ sliding_window = s,
+ ).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
+ correct_mask = (correct_mask == correct_mask.min())
+ our_mask = create_boolean_mask(n = n, sliding_window = s)
+ assert(torch.all(correct_mask == our_mask))
+ pass
+ correct_mask = AttentionMaskConverter(
+ is_causal = True,
+ sliding_window = None,
+ ).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
+ correct_mask = (correct_mask == correct_mask.min())
our_mask = create_boolean_mask(n = n, sliding_window = 0)
- assert torch.all(correct_mask == our_mask)
+ assert(torch.all(correct_mask == our_mask))
+ pass
+pass
def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
@@ -1841,982 +998,247 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
kwargs.pop("num_items_in_batch")
elif "num_items_in_batch" not in inputs:
inputs["num_items_in_batch"] = num_items_in_batch
+ pass
+ pass
# Get gradient accumulation steps if possible
- if (
- num_items_in_batch is None
- and getattr(getattr(self, "args", self), "gradient_accumulation_steps", 1) != 1
- ):
+ if num_items_in_batch is None and \
+ getattr(getattr(self, "args", self), "gradient_accumulation_steps", 1) != 1:
+
inner_model = model
- if hasattr(inner_model, "base_model"):
- inner_model = inner_model.base_model
- if hasattr(inner_model, "model"):
- inner_model = inner_model.model
+ if hasattr(inner_model, "base_model"): inner_model = inner_model. base_model
+ if hasattr(inner_model, "model"): inner_model = inner_model.model
name = inner_model.__class__.__name__
logger.warning_once(
- f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"
- "Using gradient accumulation will be very slightly less accurate.\n"
+ f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"\
+ "Using gradient accumulation will be very slightly less accurate.\n"\
"Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient"
)
- # Gemma3 multimodal models in transformers 5.x require token_type_ids during training.
- # For text-only SFT, token_type_ids should be all zeros (no image tokens).
- if "token_type_ids" not in inputs and "input_ids" in inputs:
- _inner = model
- for _attr in ("base_model", "model", "model"):
- _inner = getattr(_inner, _attr, _inner)
- if getattr(getattr(_inner, "config", None), "model_type", "") in ("gemma3",):
- import sys as _sys
-
- _mod = _sys.modules.get(type(_inner).__module__)
- _has_ccm = _mod is not None and hasattr(_mod, "create_causal_mask_mapping")
- if _has_ccm and _inner.training:
- inputs["token_type_ids"] = torch.zeros_like(inputs["input_ids"])
-
- outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
+ pass
+
+ if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0":
+ autocaster = contextlib.nullcontext()
+ else:
+ autocaster = torch.autocast(device_type = "cuda", dtype = torch.float32)
+ with autocaster:
+ outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
return outputs
+pass
def patch_gradient_accumulation_fix(Trainer):
- # Fixes gradient accumulation
- # Fixes Output 0 of UnslothFusedLossBackward is a view and is being modified inplace.
+ # Fixes gradient accumulation
import inspect
-
if hasattr(Trainer, "get_batch_samples"):
- if Trainer.get_batch_samples.__name__ == "_unsloth_get_batch_samples":
- return
- if (
- not inspect.getsource(Trainer.get_batch_samples)
- .strip()
- .endswith("return batch_samples, num_items_in_batch")
- ):
- raise NotImplementedError(
- "Unsloth: Please make a Github issue immediately!!"
- )
+ if Trainer.get_batch_samples.__name__ == "_unsloth_get_batch_samples": return
+ if \
+ not inspect.getsource(Trainer.get_batch_samples).strip()\
+ .endswith("return batch_samples, num_items_in_batch"):
+
+ raise NotImplementedError("Unsloth: Please make a Github issue immediately!!")
else:
if Trainer.get_batch_samples.__name__ != "_unsloth_get_batch_samples":
Trainer.get_batch_samples = _unsloth_get_batch_samples
+ pass
# Also fix passing in num_items_in_batch
if not hasattr(Trainer, "_old_compute_loss"):
- # Fix transformers 4.57.0 causing `Output 0 of UnslothFusedLossBackward is a view and is being modified inplace.`
- function = inspect.getsource(Trainer.compute_loss)
- if "loss *=" in function or "loss*=" in function:
- where = function.find("def")
- function = function.split("\n")
- function = "\n".join(x[where:] for x in function)
-
- # Import all variables that need importing
- import transformers.trainer
-
- items_in_trainer = dir(transformers.trainer)
- good_items = []
- for item in items_in_trainer:
- if item in function:
- good_items.append(item)
- exec(
- "from transformers.trainer import ("
- + ", ".join(x for x in good_items)
- + ")",
- globals(),
- )
-
- # Replace loss*= with loss = loss *
- function = re.sub(
- r"loss[\s]{0,}\*\=",
- "loss = loss *",
- function,
- )
- exec(function, globals())
- Trainer.compute_loss = compute_loss
Trainer._old_compute_loss = Trainer.compute_loss
Trainer.compute_loss = _unsloth_pre_compute_loss
+ pass
+ pass
else:
logger.warning_once(
- "Unsloth: We fixed a gradient accumulation bug, "
- "but it seems like you don't have the latest transformers version!\n"
- "Please update transformers, TRL and unsloth via:\n"
- "`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`"
+ "Unsloth: We fixed a gradient accumulation bug, "\
+ "but it seems like you don't have the latest transformers version!\n"\
+ "Please update transformers, TRL and unsloth via:\n"\
+ '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`'
)
+ pass
# Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps
- if not (
- Trainer.training_step.__name__ == "_unsloth_training_step"
- or "num_items_in_batch"
- not in inspect.signature(Trainer.training_step).parameters
- ):
- function = inspect.getsource(Trainer.training_step)
- where = function.find("def")
- function = function.split("\n")
- function = "\n".join(x[where:] for x in function)
-
- # Import all variables that need importing
- import transformers.trainer
-
- items_in_trainer = dir(transformers.trainer)
- good_items = []
- for item in items_in_trainer:
- if item in function:
- good_items.append(item)
- exec(
- "from transformers.trainer import ("
- + ", ".join(x for x in good_items)
- + ")",
- globals(),
- )
+ if Trainer.training_step.__name__ == "_unsloth_training_step": return
+ if "num_items_in_batch" not in inspect.signature(Trainer.training_step).parameters: return
- # Accelerate does / self.args.gradient_accumulation_steps internally, so if we already
- # summed it up and did the division before hand, we have to negate it.
- function = function.replace(
- "loss *= self.args.gradient_accumulation_steps",
- "if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps",
- )
- function = function.replace(
- "def training_step", "def _unsloth_training_step", 1
- )
-
- # Fix 4.47.0 issue where num_items_in_batch was removed
- # See https://github.com/huggingface/transformers/pull/35121
- function = function.replace(
- "if self.model_accepts_loss_kwargs:",
- "if False:",
- )
-
- # Fix when num_items_in_batch is nothing
- # https://github.com/huggingface/transformers/pull/35207
- function = re.sub(
- r"else:\n"
- r"([\s]{4,})self\.accelerator\.backward\(loss, \*\*kwargs\)\n"
- r"(.+?)if num_items_in_batch is None\:\n"
- r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps",
- "else:\n"
- "\2if num_items_in_batch is None:\n"
- "\3loss = loss / self.args.gradient_accumulation_steps\n"
- "\1self.accelerator.backward(loss, **kwargs)",
- function,
- )
+ function = inspect.getsource(Trainer.training_step)
+ where = function.find("def")
+ function = function.split("\n")
+ function = "\n".join(x[where:] for x in function)
- exec(function, globals())
- Trainer.training_step = _unsloth_training_step
+ # Import all variables that need importing
+ import transformers.trainer
+ items_in_trainer = dir(transformers.trainer)
+ good_items = []
+ for item in items_in_trainer:
+ if item in function: good_items.append(item)
+ pass
+ exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
- # Prevent double scaling gradient accumulation
- # https://github.com/huggingface/transformers/pull/37208
- # Patch model_accepts_loss_kwargs detection in Trainer.__init__
- if Trainer.__init__.__name__ != "_unsloth___init__":
- try:
- init_function = inspect.getsource(Trainer.__init__)
- except Exception:
- init_function = ""
- if init_function is not None:
- init_function = textwrap.dedent(init_function)
-
- # Import all variables that need importing
- import transformers.trainer
-
- items_in_trainer = dir(transformers.trainer)
- good_items = []
- for item in items_in_trainer:
- if item in init_function:
- good_items.append(item)
- exec(
- "from transformers.trainer import ("
- + ", ".join(x for x in good_items)
- + ")",
- globals(),
- )
+ # Accelerate does / self.args.gradient_accumulation_steps internally, so if we already
+ # summed it up and did the division before hand, we have to negate it.
+ function = function.replace(
+ "loss *= self.args.gradient_accumulation_steps",
+ "if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps",
+ )
+ function = function.replace("def training_step", "def _unsloth_training_step", 1)
- init_function = init_function.replace(
- "def __init__", "def _unsloth___init__", 1
- )
+ # Fix 4.47.0 issue where num_items_in_batch was removed
+ # See https://github.com/huggingface/transformers/pull/35121
+ function = function.replace(
+ "if self.model_accepts_loss_kwargs:",
+ "if False:",
+ )
- # Force else branch
- init_function = re.sub(
- r'if[\s]+hasattr\(\s*unwrapped_model\s*,\s*"accepts_loss_kwargs"\s*\)\s*:',
- 'if hasattr(unwrapped_model, "accepts_loss_kwargs") and False:',
- init_function,
- )
- exec(init_function, globals())
- Trainer.__init__ = _unsloth___init__
+ # Fix when num_items_in_batch is nothing
+ # https://github.com/huggingface/transformers/pull/35207
+ function = re.sub(
+ r"else:\n"\
+ r"([\s]{4,})self\.accelerator\.backward\(loss, \*\*kwargs\)\n"\
+ r"(.+?)if num_items_in_batch is None\:\n"\
+ r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps",
+
+ "else:\n"\
+ "\2if num_items_in_batch is None:\n"\
+ "\3loss = loss / self.args.gradient_accumulation_steps\n"\
+ "\1self.accelerator.backward(loss, **kwargs)",
+
+ function,
+ )
+
+ exec(function, globals())
+ Trainer.training_step = _unsloth_training_step
+pass
def patch_tokenizer(model, tokenizer):
model, tokenizer = _patch_tokenizer(model, tokenizer)
if model is not None:
- model.config.update({"unsloth_version": __version__})
+ model.config.update({"unsloth_version" : __version__})
return model, tokenizer
+pass
def patch_fast_lora():
import peft.tuners.lora.bnb
-
peft.tuners.lora.bnb.Linear4bit.forward = fast_lora_forward
+pass
def unsloth_compile_transformers(
- dtype,
model_name,
- model_types,
- token = None,
- revision = None,
- trust_remote_code = False,
- sdpa_dynamic_mask = True,
- sdpa_bool_masks = True,
- sdpa_gqa_replace = True,
- sdpa_dynamic_compile = True,
- compile_attention = True,
- disable_causal_masks = True,
- compile_torch_modules = True,
- compile_custom_modules = True,
- compile_function_calls = True,
- fuse_lm_head = True,
- gradient_checkpointing = True,
- manual_replacements = True,
- fast_lora_forwards = True,
- fast_residual_stream = True,
- accurate_accumulation = True,
- epilogue_fusion = True,
- max_autotune = False,
- shape_padding = True,
- cudagraphs = False,
- debug = False,
- fullgraph = True,
- import_from_cache = False,
- disable = False,
- return_logits = False,
- unsloth_force_compile = False,
+ token = None,
+ revision = None,
+ trust_remote_code = False,
+ sdpa_dynamic_mask = True,
+ sdpa_bool_masks = True,
+ sdpa_gqa_replace = True,
+ sdpa_dynamic_compile = True,
+ compile_attention = True,
+ disable_causal_masks = True,
+ compile_torch_modules = True,
+ compile_custom_modules = True,
+ compile_function_calls = True,
+ fuse_lm_head = True,
+ gradient_checkpointing = True,
+ manual_replacements = True,
+ fast_lora_forwards = True,
+ fast_residual_stream = True,
+ accurate_accumulation = True,
+ epilogue_fusion = True,
+ max_autotune = False,
+ shape_padding = True,
+ cudagraphs = False,
+ debug = False,
+ fullgraph = True,
+ import_from_cache = False,
+ disable = False,
+ return_logits = False,
):
if Version(torch_version) < Version("2.4.0"):
print(
- "="
- * 30
- + "Unsloth: Unfortunately Unsloth vision and other newer optimized models need Torch 2.4 or later.\n"
- f"You have Torch version {torch_version}. Please upgrade your Torch version by visiting https://pytorch.org/\n"
+ "="*30 + \
+ "Unsloth: Unfortunately Unsloth vision and other newer optimized models need Torch 2.4 or later.\n"\
+ f"You have Torch version {torch_version}. Please upgrade your Torch version by visiting https://pytorch.org/\n"\
"For now your models will not get optimized, but will still work for now!"
)
return
- if trust_remote_code and unsloth_force_compile == False:
- print(
- "Unsloth: We can't trace models if `trust_remote_code = True`, "
- "so turning off some optimizations!"
- )
- return model_types, False
- model_types = list(dict().fromkeys(model_types).keys())
- if disable:
- return model_types, False
+ pass
- supports_sdpa = [True]
+ model_types = get_transformers_model_type(
+ model_name = model_name,
+ token = token,
+ revision = revision,
+ trust_remote_code = trust_remote_code,
+ )
+ model_types = ["siglip"] + model_types
- # Run patches BEFORE compiler so class replacements (e.g. GptOssTopKRouter,
- # GptOssExperts) are in place before the compiler caches references to them.
- _run_temporary_patches("pre_compile")
+ if disable: return
for model_type in model_types:
_unsloth_compile_transformers(
model_type,
- sdpa_dynamic_mask = sdpa_dynamic_mask,
- sdpa_bool_masks = sdpa_bool_masks,
- sdpa_gqa_replace = sdpa_gqa_replace,
- sdpa_dynamic_compile = sdpa_dynamic_compile,
- compile_attention = compile_attention,
- disable_causal_masks = disable_causal_masks,
- compile_torch_modules = compile_torch_modules,
+ sdpa_dynamic_mask = sdpa_dynamic_mask,
+ sdpa_bool_masks = sdpa_bool_masks,
+ sdpa_gqa_replace = sdpa_gqa_replace,
+ sdpa_dynamic_compile = sdpa_dynamic_compile,
+ compile_attention = compile_attention,
+ disable_causal_masks = disable_causal_masks,
+ compile_torch_modules = compile_torch_modules,
compile_custom_modules = compile_custom_modules,
compile_function_calls = compile_function_calls,
- fuse_lm_head = fuse_lm_head,
+ fuse_lm_head = fuse_lm_head,
gradient_checkpointing = gradient_checkpointing,
- manual_replacements = manual_replacements,
- fast_lora_forwards = fast_lora_forwards,
- fast_residual_stream = fast_residual_stream,
- accurate_accumulation = accurate_accumulation,
- epilogue_fusion = epilogue_fusion,
- max_autotune = max_autotune,
- shape_padding = shape_padding,
- cudagraphs = cudagraphs,
- debug = debug,
- fullgraph = fullgraph,
- import_from_cache = import_from_cache,
- disable = disable,
- return_logits = return_logits,
- supports_sdpa = supports_sdpa,
+ manual_replacements = manual_replacements,
+ fast_lora_forwards = fast_lora_forwards,
+ fast_residual_stream = fast_residual_stream,
+ accurate_accumulation = accurate_accumulation,
+ epilogue_fusion = epilogue_fusion,
+ max_autotune = max_autotune,
+ shape_padding = shape_padding,
+ cudagraphs = cudagraphs,
+ debug = debug,
+ fullgraph = fullgraph,
+ import_from_cache = import_from_cache,
+ disable = disable,
+ return_logits = return_logits,
)
- # Redo patches which override compiler
- _run_temporary_patches("post_compile")
- return model_types, supports_sdpa[0]
-
+ pass
+ return model_types
+pass
# We need an empty logits flag to warn people logits will not be returned anymore unless asked ie
# os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
-LOGITS_ERROR_STRING = (
- "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "
- 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'
- "```\nimport os\n"
- "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"
- "trainer.train()\n```\n"
+LOGITS_ERROR_STRING = \
+ "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\
+ 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\
+ "```\nimport os\n"\
+ "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\
+ "trainer.train()\n```\n"\
"No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!"
-)
-
-
-def raise_logits_error(*args, **kwargs):
- raise NotImplementedError(LOGITS_ERROR_STRING)
-
-
-def return_none(*args, **kwargs):
- return None
-
+def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING)
+def return_none(*args, **kwargs): return None
class EmptyLogits:
- def __init__(self):
- return
-
- def raise_getattr_error(self, attr):
- return return_none if attr == "to" else raise_logits_error
-
+ def __init__(self): return
+ def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error
__getitem__ = raise_logits_error
__getattr__ = raise_getattr_error
-
- def __repr__(self):
- return LOGITS_ERROR_STRING
-
- def __str__(self):
- return LOGITS_ERROR_STRING
-
-
+ def __repr__(self): return LOGITS_ERROR_STRING
+ def __str__ (self): return LOGITS_ERROR_STRING
+pass
EMPTY_LOGITS = EmptyLogits()
functions = dir(torch.Tensor)
for j, function in enumerate(functions):
if function.startswith("__") and function.endswith("__"):
- exec(
- f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals()
- )
- try:
- exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
- except:
- continue
-
-
-def validate_loftq_config(loftq_config, lora_dropout, bias, init_lora_weights, model):
- from peft import LoraConfig
-
- if loftq_config is None:
- loftq_config = {}
-
- signature = str(inspect.signature(LoraConfig))
- SUPPORTS_LOFTQ = "loftq_config" in signature
-
- if lora_dropout != 0:
- logger.warning_once(
- f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"
- f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
- )
-
- if bias != "none":
- logger.warning_once(
- f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"
- f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
- )
-
- if not (
- type(init_lora_weights) is bool
- or init_lora_weights == "gaussian"
- or init_lora_weights == "loftq"
- or init_lora_weights == "corda"
- ):
- raise ValueError(
- 'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq", "corda"].'
- )
-
- if init_lora_weights == "loftq":
- if not SUPPORTS_LOFTQ:
- import peft
-
- raise RuntimeError(
- f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"
- "Please install PEFT 0.7.2 or higher.\n"
- "You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
- )
-
- if loftq_config == {}:
- from peft import LoftQConfig
-
- logger.warning_once(
- "Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"
- "We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
- )
- loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
-
- if hasattr(model.config, "quantization_config"):
- raise ValueError(
- "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"
- "Reload your model without any quantization by setting `load_in_4bit = False`."
- )
-
- return loftq_config
-
-
-def fast_inference_setup(model_name, model_config):
- fast_inference = True
- if not is_vLLM_available():
- logger.warning_once(
- "Unsloth: vLLM is not installed! Will use Unsloth inference!"
- )
- fast_inference = False
- from unsloth_zoo.vllm_utils import (
- patch_vllm,
- vllm_dynamic_quant_supported,
- )
-
- patch_vllm()
- if model_name.endswith("unsloth-bnb-4bit"):
- if not vllm_dynamic_quant_supported(model_name, model_config):
- # Instead use -bnb-4bit variant
- logger.warning_once(
- f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"
- f"we do not yet support fast inference for {model_name}"
- )
- model_name = model_name[: -len("unsloth-bnb-4bit")] + "bnb-4bit"
- return fast_inference, model_name
-
-
-def patch_peft_fast_inference(model):
- vllm_engine = getattr(model.model, "vllm_engine", None)
- if vllm_engine is not None:
- model.vllm_engine = model.model.vllm_engine
- model.fast_generate = model.model.fast_generate
- model.fast_generate_batches = model.model.fast_generate_batches
-
- # Also saving and loading LoRA
- from unsloth_zoo.vllm_utils import save_lora, load_lora
-
- model.save_lora = functools.partial(save_lora, model)
- model.load_lora = functools.partial(load_lora, model)
-
+ exec(f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals())
+ try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
+ except: continue
+pass
-def error_out_no_vllm(*args, **kwargs):
- raise NotImplementedError(
- "Unsloth: vLLM is not yet supported for fast inference for this model! Please use `.generate` instead"
- )
-
-
-try:
- from torchao.core.config import AOBaseConfig
-
- try:
- from torchao.quantization import Int4WeightOnlyConfig
- except:
- print("Unsloth: TorchAO changed `torchao.quantization.Int4WeightOnlyConfig`")
- Int4WeightOnlyConfig = None
-except:
- AOBaseConfig = None
- Int4WeightOnlyConfig = None
-
-
-@dataclass
-class TorchAOConfig:
- qat_scheme: Optional[str] = "int4"
-
- # Each (config, filter_fn) pair defines a quantization rule
- base_config_and_filter_fns: List[
- Tuple["AOBaseConfig", Optional[Callable[[torch.nn.Module, str], bool]]]
- ] = field(
- default_factory = lambda: [
- (
- Int4WeightOnlyConfig(group_size = 128),
- lambda m, _: isinstance(m, torch.nn.Linear)
- and getattr(m, "in_features", 0) >= 128,
- ),
- ]
- )
-
- # Optional transformation to apply before quantization setup
- prequantization_transform: Optional[Callable[[torch.nn.Module], None]] = None
-
-
-def _untie_input_output_embeddings(model: torch.nn.Module) -> None:
- """
- Utility to untie input/output embeddings in a HuggingFace model.
- This is useful if we want to quantize the input/ouput embeddings differently.
- Model is modified in-place.
- """
-
- # 1) Persist setting in config
- if hasattr(model.config, "tie_word_embeddings"):
- model.config.tie_word_embeddings = False
-
- # 2) Find input and output embeddings
- in_emb = model.get_input_embeddings()
- out_proj = model.get_output_embeddings() or getattr(model, "lm_head", None)
- if out_proj is None:
- raise AttributeError("Couldn't locate output projection (lm_head).")
-
- # (Optional) sanity: shapes should match [vocab, hidden]
- assert (
- out_proj.weight.shape == in_emb.weight.shape
- ), f"Shape mismatch: out_proj {out_proj.weight.shape} vs in_emb {in_emb.weight.shape}"
-
- # 3) Only clone if they are actually tied (shared storage)
- if out_proj.weight.data_ptr() == in_emb.weight.data_ptr():
- with torch.no_grad():
- W = in_emb.weight.detach().clone()
- out_proj.weight = torch.nn.Parameter(W) # new storage, keeps dtype/device
-
- # 4) Prevent future automatic re-tying
- def _no_tie(self):
- return
-
- model.tie_weights = _no_tie.__get__(model, model.__class__)
-
- # 5) Verify no shared storage
- assert (
- out_proj.weight.data_ptr() != in_emb.weight.data_ptr()
- ), "Embeddings still tied!"
-
-
-def _filter_fn_to_fqns(
- model: torch.nn.Module,
- filter_fn: Callable[[torch.nn.Module, str], bool],
-) -> Iterator[str]:
- """
- Given a model and a filter function (m, fqn) -> bool,
- yield fully qualified names (FQNs) of modules that match.
- """
- for fqn, module in model.named_modules():
- if filter_fn(module, fqn):
- yield fqn
-
-
-def _convert_torchao_model(model):
- from transformers import TorchAoConfig
- from torchao.quantization import quantize_, ModuleFqnToConfig
- from torchao.quantization.qat import QATConfig
- from torchao.utils import TorchAOBaseTensor
-
- module_to_fqn_dict = {}
- for base_config, filter_fn in model._torchao_config.base_config_and_filter_fns:
- quantize_(model, QATConfig(base_config, step = "convert"), filter_fn = filter_fn)
-
- # Default filter function used for quantize_
- if filter_fn is None:
- if "_default" in module_to_fqn_dict:
- raise ValueError("Cannot use multiple default quantization configs")
- module_to_fqn_dict["_default"] = base_config
- else:
- for fqn in _filter_fn_to_fqns(model, filter_fn):
- if fqn in module_to_fqn_dict:
- raise ValueError(f"Found multiple quantization configs for {fqn}")
- module_to_fqn_dict[fqn] = base_config
-
- in_emb = model.get_input_embeddings()
- out_proj = model.get_output_embeddings() or getattr(model, "lm_head", None)
- kwargs = {}
- if isinstance(in_emb.weight, TorchAOBaseTensor) or (
- out_proj is not None and isinstance(out_proj.weight, TorchAOBaseTensor)
- ):
- kwargs["include_input_output_embeddings"] = True
- kwargs["modules_to_not_convert"] = []
-
- quant_config = ModuleFqnToConfig(module_to_fqn_dict)
- quantization_config = TorchAoConfig(quant_type = quant_config, **kwargs)
- model.config.quantization_config = quantization_config
-
-
-def _prepare_model_for_qat(
- model: torch.nn.Module, qat_scheme: Union[str, TorchAOConfig]
-) -> torch.nn.Module:
- """
- Transform a model for Quantization-Aware Training (QAT) during fine-tuning.
-
- On a high level, this means fake quantizing the base (frozen) model during training.
- Fake quantization refers to simulating quantization numerics in high precision (e.g. bf16).
- This helps mitigate quantization degradations when the model is quantized after training.
-
- QAT can be optionally combined with LoRA fine-tuning to for additional throughput improvement.
- For more details: https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700
- """
- try:
- from torchao.quantization import PerRow, quantize_
- from torchao.quantization.granularity import PerGroup, PerAxis
- from torchao.quantization.qat import QATConfig
- except ImportError:
- raise ImportError(TORCHAO_MSG)
-
- # Gemma3 models have issues with int8 embedding quantization due to their
- # large vocabulary size (262144). Auto-switch to int4 weight-only instead.
- if qat_scheme == "int8-int4":
- model_types = get_transformers_model_type(model.config)
- is_gemma3 = any("gemma3" in mt or "gemma_3" in mt for mt in model_types)
- if is_gemma3:
- print(
- "Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. "
- "Switching to int4 weight-only QAT for training stability."
- )
- qat_scheme = "int4"
-
- if not isinstance(qat_scheme, TorchAOConfig):
- torchao_config: Optional[TorchAOConfig] = None
- if qat_scheme == "fp8-int4":
- try:
- from torchao.quantization import Float8DynamicActivationInt4WeightConfig
- except ImportError:
- raise ImportError(TORCHAO_MSG)
- group_size = 128
- base_config = Float8DynamicActivationInt4WeightConfig()
- filter_fn = (
- lambda m, _: isinstance(m, torch.nn.Linear)
- and m.in_features >= group_size
- )
- torchao_config = TorchAOConfig(
- qat_scheme = qat_scheme,
- base_config_and_filter_fns = [(base_config, filter_fn)],
- )
- elif qat_scheme == "fp8-fp8":
- try:
- from torchao.quantization import (
- Float8DynamicActivationFloat8WeightConfig,
- )
- except ImportError:
- raise ImportError(TORCHAO_MSG)
- base_config = Float8DynamicActivationFloat8WeightConfig(
- granularity = PerRow()
- )
- torchao_config = TorchAOConfig(
- qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, None)]
- )
- elif qat_scheme == "int8-int4":
- try:
- from torchao.quantization import (
- Int8DynamicActivationIntxWeightConfig,
- IntxWeightOnlyConfig,
- )
- except ImportError:
- raise ImportError(TORCHAO_MSG)
- torchao_config = TorchAOConfig(
- qat_scheme = qat_scheme,
- base_config_and_filter_fns = [
- (
- IntxWeightOnlyConfig(
- weight_dtype = torch.int8, granularity = PerAxis(0)
- ),
- lambda m, fqn: isinstance(m, torch.nn.Embedding),
- ),
- (
- Int8DynamicActivationIntxWeightConfig(
- weight_dtype = torch.int4, weight_granularity = PerGroup(32)
- ),
- None,
- ),
- ],
- prequantization_transform = _untie_input_output_embeddings,
- )
- elif qat_scheme == "int4":
- try:
- from torchao.quantization import Int4WeightOnlyConfig
- except ImportError:
- raise ImportError(TORCHAO_MSG)
- group_size = 128
- base_config = Int4WeightOnlyConfig(group_size = group_size)
- filter_fn = (
- lambda m, _: isinstance(m, torch.nn.Linear)
- and m.in_features >= group_size
- )
- torchao_config = TorchAOConfig(
- qat_scheme = qat_scheme,
- base_config_and_filter_fns = [(base_config, filter_fn)],
- )
- elif qat_scheme == "int8":
- try:
- from torchao.quantization import IntxWeightOnlyConfig
- from torchao.quantization.granularity import PerAxis
- except ImportError:
- raise ImportError(TORCHAO_MSG)
-
- base_config = IntxWeightOnlyConfig(
- weight_dtype = torch.int8,
- granularity = PerAxis(0),
- )
- filter_fn = lambda m, _: isinstance(m, torch.nn.Linear)
- torchao_config = TorchAOConfig(
- qat_scheme = qat_scheme,
- base_config_and_filter_fns = [(base_config, filter_fn)],
- )
- else:
- raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
- assert torchao_config is not None, f"TorchAOConfig was not set for {qat_scheme}"
- else:
- torchao_config = qat_scheme
-
- # Save Torchao metadata everywhere
- inner_model = model
- while hasattr(inner_model, "model"):
- inner_model._torchao_config = torchao_config
- inner_model = inner_model.model
- inner_model._torchao_config = torchao_config
-
- if torchao_config.prequantization_transform is not None:
- torchao_config.prequantization_transform(model)
- for base_config, filter_fn in torchao_config.base_config_and_filter_fns:
- quantize_(model, QATConfig(base_config, step = "prepare"), filter_fn = filter_fn)
-
- return model
-
-
-def patch_hf_quantizer():
- # To tell hf trainer that the quantized model is trainable
- def make_trainable(self):
- return True
-
- try:
- from transformers.quantizers.quantizer_finegrained_fp8 import (
- FineGrainedFP8HfQuantizer,
- )
-
- FineGrainedFP8HfQuantizer.is_trainable = property(make_trainable)
- FineGrainedFP8HfQuantizer.is_qat_trainable = property(make_trainable)
- except Exception as e:
- logger.warning(f"Failed to patch FineGrainedFP8HfQuantizer. Error {e}")
-
- try:
- from transformers.quantizers.quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
-
- FbgemmFp8HfQuantizer.is_trainable = property(make_trainable)
- FbgemmFp8HfQuantizer.is_qat_trainable = property(make_trainable)
- except Exception as e:
- logger.warning(f"Failed to patch FbgemmFp8HfQuantizer. Error {e}")
-
- try:
- from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer
-
- TorchAoHfQuantizer.is_trainable = property(make_trainable)
- TorchAoHfQuantizer.is_qat_trainable = property(make_trainable)
- except Exception as e:
- logger.warning(f"Failed to patch TorchAoHfQuantizer. Error {e}")
-
-
-patch_hf_quantizer()
-
-
-def verify_fp8_support_if_applicable(model_config):
- quant_method = get_quant_type(model_config)
- if quant_method in ["fbgemm_fp8", "fp8"] and DEVICE_TYPE != "cuda":
- raise ValueError(
- f"Unsloth: FP8 quantization is only supported on CUDA GPUs. You are using {DEVICE_TYPE}."
- )
-
- # [TODO] Need to add FP8 support for Intel XPUs
- if DEVICE_TYPE == "cuda":
- major_version, minor_version = torch.cuda.get_device_capability()
- if quant_method == "fbgemm_fp8" and major_version < 9:
- # While L4 does support FP8 as data type, it doesn't have fbgemm (package) support yet. So we restrict it.
- raise ValueError(
- f"Unsloth: FBGEMM FP8 quantization is only supported on H100 and higher GPUs. L4 is not supported. You are using {torch.cuda.get_device_name()}. Refer to https://developer.nvidia.com/cuda-gpus for more details."
- )
- if quant_method == "fp8" and major_version * 10 + minor_version < 89:
- # In case of block quantized, we allow L4 because we fall back to torchao kernels.
- raise ValueError(
- f"Unsloth: FP8 quantization is only supported on L4 and higher GPUs with compute capability 8.9 or higher. You are using {torch.cuda.get_device_name()}. Refer to https://developer.nvidia.com/cuda-gpus for more details."
- )
-
-
-def _get_inference_mode_context_manager(model: torch.nn.Module):
- """
- If the state dict was quantized using torchao, we will run into
- the following error when calling ops like aten.t() in inference mode.
- This is a bug in PyTorch that affects all tensor subclasses.
-
- Cannot set version_counter for inference tensor
-
- For now, we work around this issue by using `torch.no_grad()` in this case.
- See https://github.com/pytorch/pytorch/issues/164872 for more details.
- Otherwise, just return `torch.inference_mode()`.
- """
- torchao_config = getattr(model, "torchao_config", None)
- if torchao_config is not None and torchao_config.qat_scheme is None:
- return torch.no_grad()
- else:
- return torch.inference_mode()
-
-
-def hf_login(token: Optional[str] = None) -> Optional[str]:
- if token is None:
- try:
- from huggingface_hub import get_token
-
- token = get_token()
- if token is None:
- return None
- except:
- return None
- try:
- from huggingface_hub import login
-
- login(token = token)
- return token
- except Exception as e:
- logger.info(f"Failed to login to huggingface using token with error: {e}")
- return token
-
-
-# =============================================
-# MoE (Mixture of Experts) Detection and LoRA Utilities
-
-
-def is_moe_model(model) -> bool:
- """
- Detect if a model is a Mixture of Experts (MoE) model.
-
- Args:
- model: The model to check (can be HF model or config)
-
- Returns:
- True if the model is an MoE model, False otherwise
- """
- config = getattr(model, "config", model)
-
- # Different MoE models use different config attribute names:
- # - Qwen3-MoE: num_experts
- # - GLM4-MoE: n_routed_experts, num_local_experts
- # - Mixtral: num_local_experts
- num_experts = None
- for attr in ("num_experts", "n_routed_experts", "num_local_experts"):
- num_experts = getattr(config, attr, None)
- if num_experts is not None:
- break
-
- # Check text_config for VL models
- if num_experts is None and hasattr(config, "text_config"):
- for attr in ("num_experts", "n_routed_experts", "num_local_experts"):
- num_experts = getattr(config.text_config, attr, None)
- if num_experts is not None:
- break
-
- return num_experts is not None and num_experts > 0
-
-
-def get_moe_target_parameters(model, target_modules = None) -> Optional[List[str]]:
- """
- Get the target_parameters for MoE expert layers if applicable.
-
- For MoE models, returns the parameter paths for expert weights
- (gate_up_proj, down_proj) that should be targeted by PEFT's
- target_parameters for LoRA on nn.Parameter.
-
- Only includes MoE parameters that match what's in target_modules:
- - If "down_proj" is in target_modules -> includes "mlp.experts.down_proj"
- - If "gate_proj" or "up_proj" is in target_modules -> includes "mlp.experts.gate_up_proj"
-
- Args:
- model: The model to get target parameters for
- target_modules: List/tuple of target module names to match against
-
- Returns:
- List of parameter paths for MoE experts, or None if not an MoE model
- """
- if not is_moe_model(model):
- return None
-
- config = getattr(model, "config", model)
- # Get num_experts from various possible config attributes
- num_experts = None
- for attr in ("num_experts", "n_routed_experts", "num_local_experts"):
- num_experts = getattr(config, attr, None)
- if num_experts is not None:
- break
- if num_experts is None and hasattr(config, "text_config"):
- for attr in ("num_experts", "n_routed_experts", "num_local_experts"):
- num_experts = getattr(config.text_config, attr, None)
- if num_experts is not None:
- break
- if num_experts is None:
- num_experts = 0
-
- # Determine which MoE parameters to include based on target_modules
- moe_params = []
-
- # Normalize target_modules to a set for efficient lookup
- if target_modules is None:
- # If no target_modules specified, include all MoE params
- target_set = {"gate_proj", "up_proj", "down_proj", "gate_up_proj"}
- elif isinstance(target_modules, str):
- target_set = {target_modules}
- # Heuristic for regex matching MLPs
- if "proj" in target_modules and (
- "mlp" in target_modules or "ffn" in target_modules
- ):
- target_set.update({"gate_proj", "up_proj", "down_proj", "gate_up_proj"})
- else:
- target_set = set(target_modules) if target_modules else set()
-
- # gate_up_proj combines both gate_proj and up_proj in MoE
- # Also match "gate_up_proj" directly since users may specify the fused name
- if (
- "gate_proj" in target_set
- or "up_proj" in target_set
- or "gate_up_proj" in target_set
- ):
- moe_params.append("mlp.experts.gate_up_proj")
-
- if "down_proj" in target_set:
- moe_params.append("mlp.experts.down_proj")
-
- if moe_params:
- print(
- f"Unsloth: Detected MoE model with {num_experts = } and {target_modules = }. Enabling LoRA on MoE parameters: {moe_params}"
- )
- return moe_params
-
- return None
-
-
-def make_fast_generate_wrapper(original_generate):
- """
- Creates a wrapper around model.generate that checks for incorrect
- vLLM-style usage when fast_inference=False.
- """
-
- @functools.wraps(original_generate)
- def _fast_generate_wrapper(*args, **kwargs):
- # Check for vLLM-specific arguments
- if "sampling_params" in kwargs:
- raise ValueError(
- "Unsloth: `sampling_params` is only supported when `fast_inference=True` (vLLM). "
- "Since `fast_inference=False`, use HuggingFace generate arguments instead:\n"
- " model.fast_generate(**tokens.to('cuda'), max_new_tokens=64, temperature=1.0, top_p=0.95)"
- )
-
- if "lora_request" in kwargs:
- raise ValueError(
- "Unsloth: `lora_request` is only supported when `fast_inference=True` (vLLM). "
- "Since `fast_inference=False`, LoRA weights are already merged into the model."
- )
-
- # Check if first positional argument is a string or list of strings
- if len(args) > 0:
- first_arg = args[0]
- is_string_input = False
-
- if isinstance(first_arg, str):
- is_string_input = True
- elif isinstance(first_arg, (list, tuple)) and len(first_arg) > 0:
- if isinstance(first_arg[0], str):
- is_string_input = True
-
- if is_string_input:
- raise ValueError(
- "Unsloth: Passing text strings to `fast_generate` is only supported "
- "when `fast_inference=True` (vLLM). Since `fast_inference=False`, you must "
- "tokenize the input first:\n\n"
- " messages = tokenizer.apply_chat_template(\n"
- ' [{"role": "user", "content": "Your prompt here"}],\n'
- " tokenize=True, add_generation_prompt=True,\n"
- ' return_tensors="pt", return_dict=True\n'
- " )\n"
- " output = model.fast_generate(\n"
- " **messages.to('cuda'),\n"
- " max_new_tokens=64,\n"
- " temperature=1.0,\n"
- " )"
- )
-
- # Call original generate
- return original_generate(*args, **kwargs)
-
- return _fast_generate_wrapper
+USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
+if USE_MODELSCOPE:
+ if importlib.util.find_spec("modelscope") is None:
+ raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`')
+ pass
+pass
diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py
index 4251f3acd9..0c36abf681 100644
--- a/unsloth/models/cohere.py
+++ b/unsloth/models/cohere.py
@@ -14,16 +14,6 @@
from .llama import *
from ._utils import __version__
-from unsloth_zoo.hf_utils import dtype_from_config
-from unsloth_zoo.utils import _get_dtype, Version
-from ..utils.packing import get_packed_info_from_kwargs
-from ..utils.attention_dispatch import (
- AttentionConfig,
- AttentionContext,
- run_attention,
- select_attention_backend,
-)
-
try:
from transformers.models.cohere.modeling_cohere import (
CohereAttention,
@@ -35,19 +25,21 @@
repeat_kv,
)
except:
+ from packaging.version import Version
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.42"):
raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support Cohere.\n"
- f"The minimum required version is 4.42.3.\n"
- f'Try `pip install --upgrade "transformers>=4.42.3"`\n'
- f"to obtain the latest transformers build, then restart this session."
+ f"Unsloth: Your transformers version of {transformers_version} does not support Cohere.\n"\
+ f"The minimum required version is 4.42.3.\n"\
+ f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
+ f"to obtain the latest transformers build, then restart this session."\
)
+ pass
+pass
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
-
# For Pytorch 2.1.1
try:
from transformers.models.cohere.modeling_cohere import (
@@ -55,8 +47,9 @@
CohereFlashAttention2,
)
except:
- CohereSdpaAttention = CohereAttention
+ CohereSdpaAttention = CohereAttention
CohereFlashAttention2 = CohereAttention
+pass
def fast_layernorm_inference(self, X, out_weight = None):
@@ -68,23 +61,24 @@ def fast_layernorm_inference(self, X, out_weight = None):
out_weight[:] = self.weight
XX *= out_weight
return XX.to(X.dtype)
+pass
# QK norm in Cohere
def CohereAttention_fast_forward(
self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- padding_mask: Optional[torch.LongTensor] = None,
+ hidden_states: torch.Tensor,
+ causal_mask: Optional[BlockDiagonalCausalMask] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- *args,
- **kwargs,
+ *args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
@@ -96,118 +90,124 @@ def CohereAttention_fast_forward(
del self.attention
del self.q_norm_out_weight
del self.k_norm_out_weight
+ pass
bsz, q_len, _ = hidden_states.size()
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
+ n_heads = self.config.num_attention_heads
+ n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
- assert n_kv_heads * n_groups == n_heads
+ head_dim = self.head_dim
+ assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
- Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
+ Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
- seq_info = get_packed_info_from_kwargs(kwargs, Q.device)
if self.use_qk_norm:
Q = fast_layernorm_compiled(self.q_norm, Q)
K = fast_layernorm_compiled(self.k_norm, K)
+ pass
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
- # Extend RoPE dynamically to fit in VRAM
- if position_embeddings:
- cos, sin = position_embeddings
+ cos, sin = position_embeddings
+ if position_ids is None:
+ Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
- cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index)
-
- rope_position_ids = (
- position_ids if position_ids is not None else kwargs.get("position_ids")
- )
- # Useful for LongRoPE
- Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
+ cos, sin = cos[position_ids], sin[position_ids]
+ Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
+ pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
+ pass
past_key_value = (K, V) if use_cache else None
# Attention module
- use_varlen = seq_info is not None and past_key_value is None
- backend = select_attention_backend(use_varlen)
- attention_config = AttentionConfig(
- backend = backend,
- n_kv_heads = n_kv_heads,
- n_groups = n_groups,
- flash_dense_kwargs = {"causal": True},
- flash_varlen_kwargs = {
- "dropout_p": 0.0,
- "causal": True,
- "softmax_scale": getattr(self, "softmax_scale", None),
- },
- )
- context = AttentionContext(
- bsz = bsz,
- q_len = q_len,
- kv_seq_len = kv_seq_len,
- n_heads = n_heads,
- head_dim = head_dim,
- requires_grad = hidden_states.requires_grad,
- seq_info = seq_info,
- attention_mask = attention_mask,
- causal_mask = causal_mask,
- )
-
- A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
-
- attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
+ if (not HAS_FLASH_ATTENTION and attention_mask is None):
+ # Xformers memory efficient attention
+ # Also has Flash Attention v2 dispatching
+ Q = Q.transpose(1, 2)
+ K = K.transpose(1, 2)
+ V = V.transpose(1, 2)
+
+ # Group query attention
+ if n_groups != 1:
+ K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+ V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+ K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
+ V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
+ if hidden_states.requires_grad:
+ K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
+ V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
+ else:
+ Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
+ pass
+ A = xformers_attention(Q, K, V, attn_bias = causal_mask)
+ A = A.view(bsz, q_len, n_heads, head_dim)
+
+ elif HAS_FLASH_ATTENTION and attention_mask is None:
+ Q = Q.transpose(1, 2)
+ K = K.transpose(1, 2)
+ V = V.transpose(1, 2)
+ A = flash_attn_func(Q, K, V, causal = True)
+ else:
+ # Grouped query attention
+ if n_groups != 1:
+ K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+ V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+ K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
+ V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
+ pass
+ # Must be contiguous or else results are False!
+ # https://github.com/pytorch/pytorch/issues/112577
+ Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
+ # Needs (batch_size, n_heads, seq_len, head_dim)
+ # is_casual and attention_mask must not be both set!
+ A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
+ # Go back to (batch_size, seq_len, n_heads, head_dim)
+ A = A.transpose(1, 2).contiguous()
+ pass
+ attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
+pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def CohereDecoderLayer_fast_forward(
self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- padding_mask: Optional[torch.LongTensor] = None,
+ hidden_states: torch.Tensor,
+ causal_mask: Optional[BlockDiagonalCausalMask] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- *args,
- **kwargs,
+ *args, **kwargs,
):
- if use_cache and hasattr(
- self, "_flag_for_generation"
- ): # past_key_value is not None:
- out_weight = torch.empty(
- self.input_layernorm.weight.shape,
- dtype = torch.float32,
- device = f"{DEVICE_TYPE_TORCH}:0",
- )
+ if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
+ out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
# Self Attention
residual = hidden_states
- hidden_states = fast_layernorm_inference(
- self.input_layernorm, hidden_states, out_weight
- )
+ hidden_states = fast_layernorm_inference(self.input_layernorm, hidden_states, out_weight)
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
- **kwargs,
+ hidden_states=hidden_states,
+ causal_mask=causal_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ padding_mask=padding_mask,
)
# Fully Connected
@@ -219,164 +219,126 @@ def CohereDecoderLayer_fast_forward(
residual = hidden_states
hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states)
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
- **kwargs,
+ hidden_states=hidden_states,
+ causal_mask=causal_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ padding_mask=padding_mask,
)
# Fully Connected
hidden_states_mlp = self.mlp(hidden_states)
hidden_states = residual + hidden_states_attention + hidden_states_mlp
+ pass
outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if use_cache:
- outputs += (present_key_value,)
+ if output_attentions: outputs += (self_attn_weights,)
+ if use_cache: outputs += (present_key_value,)
return outputs
+pass
from math import sqrt as math_sqrt
-
-KV_CACHE_INCREMENT = 256 # KV Cache update size
+KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
torch_matmul = torch.matmul
-
def CohereAttention_fast_forward_inference(
self,
- hidden_states: torch.Tensor,
+ hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
- **kwargs,
):
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
+ n_heads = self.config.num_attention_heads
+ n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
+ head_dim = self.head_dim
# assert(n_kv_heads * n_groups == n_heads)
hidden_size = self.config.hidden_size
- attention_size = n_heads * head_dim
+ attention_size = n_heads*head_dim
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
# Prefill phase
# if not hasattr(self, "paged_attention"):
if do_prefill:
- self.paged_attention = torch.empty(
- (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
- dtype = dtype,
- device = f"{DEVICE_TYPE_TORCH}:0",
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
+ self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
+ self.paged_attention_K = self.paged_attention[:,0]
+ self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
- self.temp_QA = torch.empty(
- (2, bsz, 1, attention_size), dtype = dtype, device = f"{DEVICE_TYPE_TORCH}:0"
- )
- self.temp_KV = torch.empty(
- (2, bsz, 1, n_kv_heads * head_dim),
- dtype = dtype,
- device = f"{DEVICE_TYPE_TORCH}:0",
- )
- self.RH_Q = torch.empty(
- (bsz, n_heads, 1, head_dim), dtype = dtype, device = f"{DEVICE_TYPE_TORCH}:0"
- )
-
+ self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
+ self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
+ self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+
# Mistral Nemo 12b has weird dimensions
if attention_size != hidden_size:
- self.temp_O = torch.empty(
- (bsz, 1, hidden_size), dtype = dtype, device = f"{DEVICE_TYPE_TORCH}:0"
- )
+ self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0")
else:
- self.temp_O = self.temp_QA[1][:, :, :hidden_size]
-
- self.attention = torch.empty(
- (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len),
- dtype = dtype,
- device = f"{DEVICE_TYPE_TORCH}:0",
- )
+ self.temp_O = self.temp_QA[1][:,:,:hidden_size]
+ pass
+
+ self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
# Cohere has QK layernorms
if self.use_qk_norm:
- self.q_norm_out_weight = torch.empty(
- self.q_norm.weight.shape,
- dtype = torch.float32,
- device = f"{DEVICE_TYPE_TORCH}:0",
- )
- self.k_norm_out_weight = torch.empty(
- self.k_norm.weight.shape,
- dtype = torch.float32,
- device = f"{DEVICE_TYPE_TORCH}:0",
- )
+ self.q_norm_out_weight = torch.empty(self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0")
+ self.k_norm_out_weight = torch.empty(self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0")
else:
self.q_norm_out_weight = None
self.k_norm_out_weight = None
+ pass
elif kv_seq_len >= self.paged_attention.shape[0]:
- self.paged_attention.resize_(
- (
- self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
- 2,
- bsz,
- n_kv_heads,
- head_dim,
- )
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
- self.attention.resize_(
- (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
- )
+ self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
+ self.paged_attention_K = self.paged_attention[:,0]
+ self.paged_attention_V = self.paged_attention[:,1]
+ self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
+ pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
- Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
+ Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
if self.use_qk_norm:
- Qn = fast_layernorm_inference(self.q_norm, Qn, self.q_norm_out_weight)
- Kn = fast_layernorm_inference(self.k_norm, Kn, self.k_norm_out_weight)
+ Q = fast_layernorm_inference(self.q_norm, Q, self.q_norm_out_weight)
+ K = fast_layernorm_inference(self.k_norm, K, self.k_norm_out_weight)
+ pass
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
- cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)
+ cos, sin = self.rotary_emb.get_cached(kv_seq_len)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
h = self.half_head_dim
RH_Q = self.RH_Q
- RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
- RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
- RH_Q[:, :, :, :h].neg_()
+ RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
+ RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
+ torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
- RH_K = RH_Q[
- :, :n_kv_heads, :, :
- ] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
- RH_K[:, :, :, :h] = Kn[:, :, :, h:]
- RH_K[:, :, :, h:] = Kn[:, :, :, :h]
- RH_K[:, :, :, :h].neg_()
+ RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+ RH_K[:,:,:,:h] = Kn[:,:,:,h:]
+ RH_K[:,:,:,h:] = Kn[:,:,:,:h]
+ torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
-
+
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
@@ -388,45 +350,42 @@ def CohereAttention_fast_forward_inference(
# Handle sliding windows
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is not None and kv_seq_len > sliding_window:
- start = kv_seq_len - sliding_window
- Knn = Kn[:, :, start:, :] # .contiguous()
- Vnn = Vn[:, :, start:, :] # .contiguous()
- if attention_mask is not None:
- attention_mask = attention_mask[..., start:]
+ # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
+ slicing_tokens = 1 - sliding_window
+ Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
+ Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
+ pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
if n_groups != 1:
- Knn = Knn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
- Vnn = Vnn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
+ Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+ Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
+ pass
+ # else:
+ # Knn, Vnn = Knn, Vnn
+ # pass
# Attention
if bsz == 1:
- Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
+ Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
- A = torch_matmul(
- Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
- )
- A[:] = torch_nn_functional_softmax(
- A, dim = -1, dtype = torch.float32
- ) # .to(A.dtype)
+ A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
+ # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
+ A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch_matmul(A, Vnn, out = Qn)
else:
- A = scaled_dot_product_attention(
- Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False
- )
+ A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
+ pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
+pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
@@ -438,17 +397,10 @@ def CohereModel_fast_forward_inference(
position_ids,
attention_mask = None,
):
- out_weights = tuple(
- torch.empty_like(
- self.model.layers[0].input_layernorm.weight,
- dtype = torch.float32,
- device = torch.device(x),
- )
- for x in range(DEVICE_COUNT)
- )
- input_ids = input_ids[:, : self.max_seq_length]
+ out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
+ input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
- hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
+ hidden_states = hidden_states.to(self.config.torch_dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
@@ -459,42 +411,31 @@ def CohereModel_fast_forward_inference(
seq_len,
sliding_window = getattr(self.config, "sliding_window", None),
)
- # Pre-convert to bool once for all layers (avoids per-layer .eq(0))
- if attention_mask is not None and attention_mask.dtype != torch.bool:
- attention_mask = attention_mask.eq(0)
else:
attention_mask = None
+ pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
- device_index = getattr(decoder_layer, "_per_layer_device_index", 0)
- hidden_states, position_ids = move_to_device(
- device_index, hidden_states, position_ids
- )
residual = hidden_states
- hidden_states = fast_layernorm_inference(
- decoder_layer.input_layernorm, hidden_states, out_weights[device_index]
- )
- hidden_states_attention, present_key_value = (
- CohereAttention_fast_forward_inference(
- decoder_layer.self_attn,
- hidden_states = hidden_states,
- past_key_value = past_key_values[idx],
- position_ids = position_ids,
- attention_mask = attention_mask,
- do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
- )
+ hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight)
+ hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference(
+ decoder_layer.self_attn,
+ hidden_states = hidden_states,
+ past_key_value = past_key_values[idx],
+ position_ids = position_ids,
+ attention_mask = attention_mask,
+ do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
- hidden_states_mlp = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
+ hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states)
residual += hidden_states_attention
residual += hidden_states_mlp
hidden_states = residual
next_decoder_cache.append(present_key_value)
- hidden_states = fast_layernorm_inference(
- self.model.norm, hidden_states, out_weights[device_index]
- )
+ pass
+ hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
@@ -502,34 +443,34 @@ def CohereModel_fast_forward_inference(
hidden_states = [],
attentions = [],
)
+pass
class FastCohereModel(FastLlamaModel):
+
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
- model_name = "cohere",
- rope_module = LlamaRotaryEmbedding,
+ model_name = "cohere",
+ rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
- attention_module = CohereAttention,
+ attention_module = CohereAttention,
)
if init_name is not None:
exec(function, globals())
- CohereAttention.__init__ = eval(init_name)
- CohereAttention.forward = CohereAttention_fast_forward
- CohereSdpaAttention.forward = CohereAttention_fast_forward
+ CohereAttention.__init__ = eval(init_name)
+ pass
+ CohereAttention .forward = CohereAttention_fast_forward
+ CohereSdpaAttention .forward = CohereAttention_fast_forward
CohereFlashAttention2.forward = CohereAttention_fast_forward
- CohereDecoderLayer.forward = CohereDecoderLayer_fast_forward
- CohereModel.forward = LlamaModel_fast_forward
- CohereForCausalLM.forward = CausalLM_fast_forward(
- CohereModel_fast_forward_inference
- )
- PeftModelForCausalLM.forward = PeftModel_fast_forward
+ CohereDecoderLayer .forward = CohereDecoderLayer_fast_forward
+ CohereModel .forward = LlamaModel_fast_forward
+ CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference)
+ PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(CohereForCausalLM)
-
+
import transformers.models.cohere.modeling_cohere
-
- transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = (
- LlamaRotaryEmbedding
- )
+ transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding
return
+ pass
+pass
diff --git a/unsloth/models/dpo.py b/unsloth/models/dpo.py
index def047e394..7fc02fcf7f 100644
--- a/unsloth/models/dpo.py
+++ b/unsloth/models/dpo.py
@@ -17,10 +17,6 @@
"PatchKTOTrainer",
]
+def PatchDPOTrainer(): return
-def PatchDPOTrainer():
- return
-
-
-def PatchKTOTrainer():
- return
+def PatchKTOTrainer(): return
diff --git a/unsloth/models/falcon_h1.py b/unsloth/models/falcon_h1.py
deleted file mode 100644
index 6e3b16b21b..0000000000
--- a/unsloth/models/falcon_h1.py
+++ /dev/null
@@ -1,770 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .llama import *
-import os
-from ._utils import __version__
-from unsloth_zoo.utils import Version, _get_dtype
-from unsloth_zoo.hf_utils import dtype_from_config
-from ..utils.packing import get_packed_info_from_kwargs
-from ..utils.attention_dispatch import (
- AttentionConfig,
- AttentionContext,
- run_attention,
- select_attention_backend,
- SDPA,
-)
-from .llama import (
- LlamaRotaryEmbedding,
- LlamaLinearScalingRotaryEmbedding,
- _LlamaModel_fast_forward_inference,
-)
-
-try:
- from transformers.models.falcon_h1.modeling_falcon_h1 import (
- FalconH1Attention,
- FalconH1DecoderLayer,
- FalconH1Model,
- FalconH1ForCausalLM,
- FalconHybridMambaAttentionDynamicCache,
- )
-except:
- from transformers import __version__ as transformers_version
-
- transformers_version = Version(transformers_version)
- if not transformers_version >= Version(
- "4.53.0"
- ): # TODO: Update when transformers is updated
- raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support FalconH1.\n"
- f"The minimum required version is 4.53.0.\n"
- f'Try `pip install --upgrade "transformers>=4.53.0"`\n'
- f"to obtain the latest transformers build, then restart this session."
- )
-from transformers.modeling_attn_mask_utils import (
- _prepare_4d_causal_attention_mask_for_sdpa,
-)
-from transformers.utils import (
- is_torchdynamo_compiling,
-)
-
-# For Pytorch 2.1.1
-try:
- from transformers.models.falcon_h1.modeling_falcon_h1 import (
- FalconH1Attention,
- )
-except ModuleNotFoundError:
- # if we are on an old version of transformers technically it should fail in the try except above
- # but if somehow we make it here, we need to raise an error since FalconH1Attention is not available
- # or renamed
- raise ImportError(
- "Unsloth: Could not import FalconH1Attention from transformers.models.falcon_h1.modeling_falcon_h1."
- )
-
-
-def FalconH1Attention_fast_forward(
- self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- padding_mask: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- *args,
- **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # Clear inference
- if hasattr(self, "paged_attention"):
- del self.paged_attention_K
- del self.paged_attention_V
- del self.paged_attention
- del self.temp_QA
- del self.temp_KV
- del self.RH_Q
- del self.attention
-
- bsz, q_len, _ = hidden_states.size()
-
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
- n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
- assert n_kv_heads * n_groups == n_heads
-
- Q, K, V = self.apply_qkv(self, hidden_states)
- Q = Q.view(bsz, q_len, n_heads, head_dim)
- K = K.view(bsz, q_len, n_kv_heads, head_dim)
- V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
- seq_info = get_packed_info_from_kwargs(kwargs, hidden_states.device)
-
- # Falcon H1 multiplies key states by a multiplier
- K = K * self.config.key_multiplier
-
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
-
- kv_seq_len = K.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- # Extend RoPE dynamically to fit in VRAM
- if position_embeddings and kv_seq_len <= position_embeddings[0].shape[0]:
- cos, sin = position_embeddings
- else:
- rotary_emb = self.rotary_emb
- rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
- cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index)
-
- rope_position_ids = (
- position_ids if position_ids is not None else kwargs.get("position_ids")
- )
- # Useful for LongRoPE
- Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
-
- if past_key_value is not None:
- K = torch.cat([past_key_value[0], K], dim = 2)
- V = torch.cat([past_key_value[1], V], dim = 2)
- past_key_value = (K, V) if use_cache else None
-
- # Attention module
- window = (-1, -1)
- use_varlen = (
- attention_mask is None
- and seq_info is not None
- and past_key_value is None
- and window == (-1, -1)
- )
-
- backend = (
- SDPA if attention_mask is not None else select_attention_backend(use_varlen)
- )
- attention_config = AttentionConfig(
- backend = backend,
- n_kv_heads = n_kv_heads,
- n_groups = n_groups,
- flash_dense_kwargs = {
- "causal": True,
- "window_size": (kv_seq_len, kv_seq_len),
- },
- flash_varlen_kwargs = {
- "dropout_p": 0.0,
- "softmax_scale": None,
- "causal": True,
- },
- sdpa_kwargs = {} if attention_mask is None else {"attn_mask": attention_mask},
- )
- context = AttentionContext(
- bsz = bsz,
- q_len = q_len,
- kv_seq_len = kv_seq_len,
- n_heads = n_heads,
- head_dim = head_dim,
- requires_grad = hidden_states.requires_grad,
- seq_info = seq_info,
- attention_mask = attention_mask,
- causal_mask = causal_mask,
- )
-
- A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
-
- attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
- attn_output = self.apply_o(self, attn_output)
- attn_weights = None
- return attn_output, attn_weights, past_key_value
-
-
-torch_matmul = torch.matmul
-
-
-def FalconH1Attention_fast_forward_inference(
- self,
- hidden_states: torch.Tensor,
- past_key_value: Optional[Tuple[torch.Tensor]],
- position_ids,
- do_prefill = False,
- attention_mask = None,
- **kwargs,
-):
- """
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
- Fast inference using KV cache.
- QK^T can be computed in 4 chunks
-
- [Q, q] @ [K, k].T where q, k are the new tokens.
- [QK^T, Qk^T]
- [qK^T, qk^T]
-
- Since the attention mask wipes Qk^T, we just get
- [QK^T, 0]
- [qK^T, qk^T]
-
- Since softmax is row-wise, we get
- softmax([QK^T, 0])
- softmax([qK^T, qk^T])
-
- We then multiply by [V]
- [v]
- softmax([QK^T, 0]) [softmax(QK^T)V] *
- softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
-
- But notice * [softmax(QK^T)V] is just the last attention.
- We just need to compute the last final row.
-
- This means we can pass in a row of Q, but we need to
- remember K and V, which are called the KV cache.
- """
- Xn = hidden_states
- bsz, _, hd = hidden_states.size()
- K1, V1 = past_key_value
- dtype = Xn.dtype
-
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
- n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
- # assert(n_kv_heads * n_groups == n_heads)
-
- hidden_size = self.config.hidden_size
- attention_size = n_heads * head_dim
- seq_len = K1.shape[-2]
- kv_seq_len = seq_len + 1
-
- # Prefill phase
- # if not hasattr(self, "paged_attention"):
- device = hidden_states.device
- if do_prefill:
- self.paged_attention = torch.empty(
- (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
- dtype = dtype,
- device = device,
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
- self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
- self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
- self.temp_QA = torch.empty(
- (2, bsz, 1, attention_size), dtype = dtype, device = device
- )
- self.temp_KV = torch.empty(
- (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
- )
- self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
-
- # Mistral Nemo 12b has weird dimensions
- if attention_size != hidden_size:
- self.temp_O = torch.empty((bsz, 1, hidden_size), dtype = dtype, device = device)
- else:
- self.temp_O = self.temp_QA[1][:, :, :hidden_size]
-
- self.attention = torch.empty(
- (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
- )
- self.scalar = 1.0 / math_sqrt(self.head_dim)
- self.half_head_dim = head_dim // 2
- elif kv_seq_len >= self.paged_attention.shape[0]:
- self.paged_attention.resize_(
- (
- self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
- 2,
- bsz,
- n_kv_heads,
- head_dim,
- )
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
- self.attention.resize_(
- (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
- )
-
- Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
- Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
- Kn.mul_(self.config.key_multiplier)
- Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
- Qn = Qn.view(
- bsz, 1, n_heads, head_dim
- ) # .transpose(1, 2) # we will transpose after normalisation
- Kn = Kn.view(
- bsz, 1, n_kv_heads, head_dim
- ) # .transpose(1, 2) # we will transpose after normalisation
- Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
-
- Qn = Qn.transpose(1, 2)
- Kn = Kn.transpose(1, 2)
-
- # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
- # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
-
- # Need to do it prior 2 steps before hitting full on short KV cache
- # or else error
- self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2)
- cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)
- cos = cos[position_ids].unsqueeze(1)
- sin = sin[position_ids].unsqueeze(1)
- h = self.half_head_dim
-
- RH_Q = self.RH_Q
- RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
- RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
- RH_Q[:, :, :, :h].neg_() # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
- Qn *= cos
- Qn.addcmul_(RH_Q, sin)
-
- RH_K = RH_Q[
- :, :n_kv_heads, :, :
- ] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
- RH_K[:, :, :, :h] = Kn[:, :, :, h:]
- RH_K[:, :, :, h:] = Kn[:, :, :, :h]
- RH_K[:, :, :, :h].neg_() # torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
- Kn *= cos
- Kn.addcmul_(RH_K, sin)
-
- # New KV cache
- # Kn = torch.cat([K1, Kn], dim = 2)
- # Vn = torch.cat([V1, Vn], dim = 2)
- self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
- self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
- Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
- Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
-
- # Handle sliding windows
- sliding_window = getattr(self.config, "sliding_window", None)
- if sliding_window is not None and kv_seq_len > sliding_window:
- start = kv_seq_len - sliding_window
- Knn = Kn[:, :, start:, :] # .contiguous()
- Vnn = Vn[:, :, start:, :] # .contiguous()
- if attention_mask is not None:
- attention_mask = attention_mask[..., start:]
- else:
- Knn, Vnn = Kn, Vn
-
- # Grouped query attention
- _, _, cached_len, _ = Knn.shape
- if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1:
- Knn = Knn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
- Vnn = Vnn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
- Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
- Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
-
- # Attention
- if bsz == 1:
- Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
- # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
- A = torch_matmul(
- Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
- )
- A[:] = torch_nn_functional_softmax(
- A, dim = -1, dtype = torch.float32
- ) # .to(A.dtype)
- A = torch_matmul(A, Vnn, out = Qn)
- else:
- if SDPA_HAS_GQA:
- A = scaled_dot_product_attention(
- Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True
- )
- else:
- A = scaled_dot_product_attention(
- Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False
- )
- A = A.transpose(1, 2)
- A = A.reshape(bsz, 1, attention_size)
- A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
- return A, (Kn, Vn)
-
-
-# https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon_h1/modeling_falcon_h1.py
-def FalconH1DecoderLayer_fast_forward(
- self,
- hidden_states: torch.Tensor,
- causal_mask = None,
- attention_mask: Optional[torch.Tensor] = None,
- mamba_attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- cache_position: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- padding_mask: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- *args,
- **kwargs,
-) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- """
- if use_cache and hasattr(self, "_flag_for_generation"):
- residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- self.input_layernorm, hidden_states
- )
- attention_hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
- position_embeddings = position_embeddings,
- **kwargs,
- )
- attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
-
- mamba_hidden_states = self.mamba(
- hidden_states = hidden_states,
- cache_params = past_key_value,
- cache_position = cache_position,
- attention_mask = mamba_attention_mask,
- )
- mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
-
- hidden_states = mamba_hidden_states + attention_hidden_states
-
- hidden_states += residual
-
- # Fully Connected
- residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- self.pre_ff_layernorm, hidden_states
- )
- hidden_states = fast_swiglu_inference(self.feed_forward, hidden_states)
- hidden_states += residual
- else:
- residual = hidden_states
- hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
-
- mamba_hidden_states = self.mamba(
- hidden_states = hidden_states,
- cache_params = past_key_value,
- cache_position = cache_position,
- attention_mask = mamba_attention_mask,
- )
- mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
-
- attention_hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
- position_embeddings = position_embeddings,
- **kwargs,
- )
- attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
-
- hidden_states = mamba_hidden_states + attention_hidden_states
-
- # residual connection after attention + Mamba
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = fast_rms_layernorm(self.pre_ff_layernorm, hidden_states)
- hidden_states = self.feed_forward(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if use_cache:
- outputs += (present_key_value,)
- return outputs
-
-
-def _FalconH1_fast_forward_inference(
- attention_fast_forward_inference = FalconH1Attention_fast_forward_inference,
- mlp_fast_forward_inference = fast_swiglu_inference,
-):
- # This makes the attention and MLP customisable.
- # Now for models like qwen3 or cohere which use custom attention operations, we can use this function
- def FalconH1Model_fast_forward_inference_custom(
- self,
- input_ids,
- past_key_values,
- position_ids,
- cache_position = None,
- attention_mask = None,
- mamba_attention_mask = None,
- ):
- input_ids = input_ids[:, : self.max_seq_length]
- bsz, q_len = input_ids.shape
- hd = self.config.hidden_size
- mlp_size = self.config.intermediate_size
- gate_multiplier, down_multiplier = self.config.mlp_multipliers
-
- X = self.model.embed_tokens(input_ids)
- X = X * self.config.embedding_multiplier
-
- X = X.to(_get_dtype(dtype_from_config(self.config)))
- bsz, q_len, hd = X.shape
- assert q_len == 1
- # Get saved buffers to reduce memory movement
- residual = torch.empty(
- (bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
- )
- _XX = torch.empty(
- (2, bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
- )
- XX, XX2 = _XX[0], _XX[1]
- variance = torch.empty(
- (bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
- )
- temp_mlp = torch.empty(
- (2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE_TORCH}:0"
- )
- temp_gate, temp_up = temp_mlp[0], temp_mlp[1]
- seq_len = past_key_values[0][0].shape[-2]
- if bsz != 1:
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
- attention_mask,
- (bsz, q_len),
- X,
- seq_len,
- sliding_window = getattr(self.config, "sliding_window", None),
- )
- else:
- attention_mask = None
-
- next_decoder_cache = []
-
- for idx, decoder_layer in enumerate(self.model.layers):
- residual.copy_(X) # residual = X
- X = fast_rms_layernorm_inference(
- decoder_layer.input_layernorm,
- X,
- XX = XX,
- XX2 = XX2,
- variance = variance,
- )
- attention_hidden_states, present_key_value = (
- attention_fast_forward_inference(
- decoder_layer.self_attn,
- hidden_states = X * decoder_layer.attention_in_multiplier,
- past_key_value = past_key_values[idx],
- position_ids = position_ids,
- attention_mask = attention_mask,
- do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
- )
- )
- attention_hidden_states = (
- attention_hidden_states * decoder_layer.attn_out_multiplier
- )
- mamba_hidden_states = decoder_layer.mamba(
- hidden_states = X,
- cache_params = present_key_value,
- cache_position = cache_position,
- attention_mask = mamba_attention_mask,
- )
- mamba_hidden_states = mamba_hidden_states * decoder_layer.ssm_out_multiplier
- X = mamba_hidden_states + attention_hidden_states
-
- X += residual
-
- residual.copy_(X) # residual = X
- X = fast_rms_layernorm_inference(
- decoder_layer.pre_ff_layernorm,
- X,
- XX = XX,
- XX2 = XX2,
- variance = variance,
- )
- X = mlp_fast_forward_inference(
- decoder_layer.feed_forward,
- X,
- temp_gate = temp_gate,
- temp_up = temp_up,
- gate_multiplier = gate_multiplier,
- down_multiplier = down_multiplier,
- )
- X += residual
-
- next_decoder_cache.append(present_key_value)
- X = fast_rms_layernorm_inference(
- self.model.final_layernorm,
- X,
- XX = XX,
- XX2 = XX2,
- variance = variance,
- )
-
- return BaseModelOutputWithPast(
- last_hidden_state = X,
- past_key_values = next_decoder_cache,
- hidden_states = [],
- attentions = [],
- )
-
- return FalconH1Model_fast_forward_inference_custom
-
-
-# Separate prepare_inputs_for_generation for Hybrid FalconH1
-def _fast_prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values = None,
- attention_mask = None,
- inputs_embeds = None,
- cache_position = None,
- position_ids = None,
- use_cache = True,
- **kwargs,
-):
- # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
- empty_past_kv = past_key_values is None
-
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
- # Exception 1: when passing input_embeds, input_ids may be missing entries
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
- # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
- # (we can't check exception 3 while compiling)
- if not empty_past_kv:
- if (
- inputs_embeds is not None # Exception 1
- or (
- is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]
- ) # Exception 3
- ):
- input_ids = input_ids[:, -cache_position.shape[0] :]
- elif (
- input_ids.shape[1] != cache_position.shape[0]
- ): # Default case (the "else", a no op, is Exception 2)
- input_ids = input_ids[:, cache_position]
- # TODO: Wire up Cache to work for inference.
- # else:
- # past_key_values = FalconHybridMambaAttentionDynamicCache(
- # self.config,
- # input_ids.shape[0],
- # self.dtype,
- # devices=[
- # self.model.layers[i].mamba.conv1d.weight.device for i in range(self.config.num_hidden_layers)
- # ],
- # )
-
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if not empty_past_kv:
- position_ids = position_ids[:, -input_ids.shape[1] :]
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and empty_past_kv:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {
- "input_ids": input_ids.contiguous()
- } # `contiguous()` needed for compilation use cases
-
- model_inputs.update(
- {
- "position_ids": position_ids,
- "past_key_values": past_key_values,
- "use_cache": use_cache,
- "attention_mask": attention_mask,
- "logits_to_keep": self.config.num_logits_to_keep,
- "cache_position": cache_position,
- }
- )
- return model_inputs
-
-
-def fix_prepare_inputs_for_generation(module):
- # Fix prepare_inputs_for_generation
- if hasattr(module, "prepare_inputs_for_generation"):
- module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation
-
-
-class FastFalconH1Model(FastLlamaModel):
- @staticmethod
- def pre_patch():
- init_name, function = patch_linear_scaling(
- model_name = "FalconH1",
- rope_module = LlamaRotaryEmbedding,
- scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
- attention_module = FalconH1Attention,
- )
- if init_name is not None:
- exec(function, globals())
- FalconH1Attention.__init__ = eval(init_name)
- FalconH1Attention.forward = FalconH1Attention_fast_forward
- FalconH1DecoderLayer.forward = FalconH1DecoderLayer_fast_forward
- FalconH1Model.forward = LlamaModel_fast_forward
- FalconH1ForCausalLM.forward = CausalLM_fast_forward(
- _FalconH1_fast_forward_inference(FalconH1Attention_fast_forward_inference)
- )
- PeftModelForCausalLM.forward = PeftModel_fast_forward
- fix_prepare_inputs_for_generation(FalconH1ForCausalLM)
-
- # Solves https://github.com/unslothai/unsloth/issues/168
- # Static KV Cache was introduced in 4.38.0, causing training to be much slower.
- # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.
- # https://github.com/huggingface/transformers/pull/27931
- # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
- import transformers.models.falcon_h1.modeling_falcon_h1
-
- transformers.models.falcon_h1.modeling_falcon_h1.FalconH1RotaryEmbedding = (
- LlamaRotaryEmbedding
- )
- return
-
- @staticmethod
- def from_pretrained( # TODO: Change after release
- model_name = "Qwen/FalconH1-7B",
- max_seq_length = 4096,
- dtype = None,
- load_in_4bit = True,
- token = None,
- device_map = "sequential",
- rope_scaling = None,
- fix_tokenizer = True,
- model_patcher = None,
- tokenizer_name = None,
- trust_remote_code = False,
- **kwargs,
- ):
- return FastLlamaModel.from_pretrained(
- model_name = model_name,
- max_seq_length = max_seq_length,
- dtype = dtype,
- load_in_4bit = load_in_4bit,
- token = token,
- device_map = device_map,
- rope_scaling = rope_scaling,
- fix_tokenizer = fix_tokenizer,
- model_patcher = FastFalconH1Model,
- tokenizer_name = tokenizer_name,
- trust_remote_code = trust_remote_code,
- **kwargs,
- )
diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py
index cf543ae094..873bdcf2eb 100644
--- a/unsloth/models/gemma.py
+++ b/unsloth/models/gemma.py
@@ -13,15 +13,7 @@
# limitations under the License.
from .llama import *
-from .llama import _get_rope_theta
from ._utils import __version__
-from unsloth_zoo.utils import _get_dtype, Version
-from unsloth_zoo.hf_utils import dtype_from_config
-from ..utils.packing import (
- build_sdpa_packed_attention_mask,
- build_xformers_block_causal_mask,
- get_packed_info_from_kwargs,
-)
import math
try:
@@ -35,19 +27,21 @@
repeat_kv,
)
except:
+ from packaging.version import Version
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.38"):
raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"
- f"The minimum required version is 4.38.\n"
- f'Try `pip install --upgrade "transformers>=4.38"`\n'
- f"to obtain the latest transformers build, then restart this session."
+ f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
+ f"The minimum required version is 4.38.\n"\
+ f'Try `pip install --upgrade "transformers>=4.38"`\n'\
+ f"to obtain the latest transformers build, then restart this session."\
)
+ pass
+pass
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
-
# For Pytorch 2.1.1
try:
from transformers.models.gemma.modeling_gemma import (
@@ -55,13 +49,12 @@
GemmaFlashAttention2,
)
except:
- GemmaSdpaAttention = GemmaAttention
+ GemmaSdpaAttention = GemmaAttention
GemmaFlashAttention2 = GemmaAttention
+pass
torch_nn_functional_gelu = torch.nn.functional.gelu
-
-
def fast_geglu_inference(self, X):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
@@ -69,101 +62,84 @@ def fast_geglu_inference(self, X):
# mlp_size = self.config.intermediate_size
# temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
- gate = fast_linear_forward(self.gate_proj, X) # , out = temp[0])
- up = fast_linear_forward(self.up_proj, X) # , out = temp[1])
+ gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
+ up = fast_linear_forward(self. up_proj, X)#, out = temp[1])
gate = torch_nn_functional_gelu(gate, approximate = "tanh")
gate *= up
# X = self.down_proj(gate)
- down = fast_linear_forward(self.down_proj, gate, out = up[:, :, :hd])
+ down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
return down
+pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def GemmaDecoderLayer_fast_forward(
self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- padding_mask: Optional[torch.LongTensor] = None,
- *args,
- **kwargs,
+ hidden_states: torch.Tensor,
+ causal_mask: Optional[BlockDiagonalCausalMask] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ padding_mask: Optional[torch.LongTensor] = None,
+ *args, **kwargs,
):
- if use_cache and hasattr(
- self, "_flag_for_generation"
- ): # past_key_value is not None:
- out_weight = torch.empty(
- self.input_layernorm.weight.shape,
- dtype = torch.float32,
- device = f"{DEVICE_TYPE_TORCH}:0",
- )
+ if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
+ out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
# Self Attention
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference_gemma(
- self.input_layernorm, hidden_states, out_weight
- )
+ hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
- **kwargs,
+ hidden_states=hidden_states,
+ causal_mask=causal_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ padding_mask=padding_mask,
)
hidden_states += residual
# Fully Connected
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference_gemma(
- self.post_attention_layernorm, hidden_states, out_weight
- )
+ hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
hidden_states += residual
else:
residual = hidden_states
- hidden_states = fast_rms_layernorm(
- self.input_layernorm, hidden_states, gemma = True
- )
+ hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
- **kwargs,
+ hidden_states=hidden_states,
+ causal_mask=causal_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
- hidden_states = fast_rms_layernorm(
- self.post_attention_layernorm, hidden_states, gemma = True
- )
+ hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
+ pass
outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if use_cache:
- outputs += (present_key_value,)
+ if output_attentions: outputs += (self_attn_weights,)
+ if use_cache: outputs += (present_key_value,)
return outputs
+pass
from math import sqrt as math_sqrt
-
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
# @torch.inference_mode
def GemmaModel_fast_forward_inference(
@@ -172,28 +148,17 @@ def GemmaModel_fast_forward_inference(
past_key_values,
position_ids,
attention_mask = None,
- **kwargs,
):
- out_weights = tuple(
- torch.empty_like(
- self.model.layers[0].input_layernorm.weight,
- dtype = torch.float32,
- device = torch.device(x),
- )
- for x in range(DEVICE_COUNT)
- )
- input_ids = input_ids[:, : self.max_seq_length]
+ out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
+ input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
- hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
+ hidden_states = hidden_states.to(self.config.torch_dtype)
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
- hidden_states *= torch.tensor(
- math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype
- )
+ hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
- kv_seq_len = seq_len + 1
if bsz != 1:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
@@ -201,24 +166,12 @@ def GemmaModel_fast_forward_inference(
hidden_states,
seq_len,
)
- # Pre-convert to bool once for all layers (avoids per-layer .eq(0))
- if attention_mask is not None and attention_mask.dtype != torch.bool:
- attention_mask = attention_mask.eq(0)
-
- # Compute rotary_seq_len once to avoid per-layer GPU-CPU sync from .item()
- rotary_seq_len = max(kv_seq_len, int(position_ids.max().item()) + 1)
+ pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
- device_index = getattr(decoder_layer, "_per_layer_device_index", 0)
- hidden_states, position_ids = move_to_device(
- device_index, hidden_states, position_ids
- )
-
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference_gemma(
- decoder_layer.input_layernorm, hidden_states, out_weights[device_index]
- )
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
@@ -226,23 +179,17 @@ def GemmaModel_fast_forward_inference(
position_ids = position_ids,
attention_mask = attention_mask,
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
- rotary_seq_len = rotary_seq_len,
)
hidden_states += residual
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference_gemma(
- decoder_layer.post_attention_layernorm,
- hidden_states,
- out_weights[device_index],
- )
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
hidden_states += residual
next_decoder_cache.append(present_key_value)
- hidden_states = fast_rms_layernorm_inference_gemma(
- self.model.norm, hidden_states, out_weights[device_index]
- )
+ pass
+ hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
@@ -250,6 +197,7 @@ def GemmaModel_fast_forward_inference(
hidden_states = [],
attentions = [],
)
+pass
# Follows line by line https://github.com/google-deepmind/gemma/blob/main/gemma/positional_embeddings.py#L45
@@ -258,205 +206,150 @@ class GemmaFixedRotaryEmbedding(torch.nn.Module):
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
- def __init__(
- self,
- dim = None,
- max_position_embeddings = 2048,
- base = 10000,
- device = None,
- config = None, # [TODO] Hack to pass in config - need to remove later
+ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
+ config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
- # In transformers 5.0+, RotaryEmbedding(config) passes config as first positional arg (dim)
- if (
- config is None
- and dim is not None
- and hasattr(dim, "max_position_embeddings")
- ):
- config = dim
- dim = None
if config is not None:
# [TODO] Hack to pass in config - need to remove later
- base = _get_rope_theta(config, default = base)
- partial_rotary_factor = (
- config.partial_rotary_factor
- if hasattr(config, "partial_rotary_factor")
- else 1.0
- )
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = getattr(config, "head_dim", None)
- if dim is None:
- dim = int((config.hidden_size // config.num_attention_heads))
+ if dim is None: dim = int((config.hidden_size // config.num_attention_heads))
device = "cuda"
max_position_embeddings = config.max_position_embeddings
+ pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
- self.multi_gpu_cos_cached = [None] * DEVICE_COUNT
- self.multi_gpu_sin_cached = [None] * DEVICE_COUNT
# Build here to make `torch.jit.trace` work.
- for device in range(DEVICE_COUNT):
- self._set_cos_sin_cache(
- seq_len = self.current_rope_size,
- device = torch.device(device),
- dtype = torch.get_default_dtype(),
- )
-
- # dummy so that patch_utils doesn't fail for now
- self.cos_cached = torch.empty(
- 1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype()
- )
- self.sin_cached = torch.empty(
- 1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype()
- )
+ self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
+ pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
- # The difference is we do division explicitly instead of t * (1/x) ie we do t/x.
+ # The difference is we do division explicity instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
- positions = torch.arange(
- self.current_rope_size, device = "cpu", dtype = torch.int64
- ).float()
+ positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)
emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
- cos = emb.cos().to(device = device, non_blocking = True) # , dtype = dtype)
- sin = emb.sin().to(device = device, non_blocking = True) # , dtype = dtype)
- self.multi_gpu_cos_cached[device.index] = cos
- self.multi_gpu_sin_cached[device.index] = sin
- return cos, sin
+ cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype)
+ sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype)
+ self.register_buffer("cos_cached", cos, persistent = False)
+ self.register_buffer("sin_cached", sin, persistent = False)
+ pass
- def forward(self, x, position_ids = None, seq_len = None):
+ def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
- if seq_len is not None and seq_len > self.current_rope_size:
- self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
-
- device_index = x.device.index
+ if seq_len > self.current_rope_size:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
- self.multi_gpu_cos_cached[device_index][:seq_len],
- self.multi_gpu_sin_cached[device_index][:seq_len],
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
)
+ pass
- def get_cached(self, seq_len = None, device_index = None):
- if device_index is None:
- device_index = torch.cuda.current_device()
- return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[
- device_index
- ]
+ def get_cached(self, seq_len = None):
+ return self.cos_cached, self.sin_cached
+ pass
def extend_rope_embedding(self, x, seq_len):
- if seq_len <= self.current_rope_size:
- return
+ if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
- for device in range(DEVICE_COUNT):
- self._set_cos_sin_cache(
- self.current_rope_size, device = torch.device(device), dtype = x.dtype
- )
+ self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype)
+ pass
+pass
class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
-
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
- def __init__(
- self,
- dim = None,
- max_position_embeddings = 2048,
- base = 10000,
- device = None,
- scaling_factor = 1.0,
- config = None, # [TODO] Hack to pass in config - need to remove later
+ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
+ config = None, # [TODO] Hack to pass in config - need to remove later
):
self.scaling_factor = scaling_factor
- super().__init__(
- dim = dim,
- max_position_embeddings = max_position_embeddings,
- base = base,
- device = device,
- config = config,
- )
+ super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
+ pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
- # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
+# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
- # The difference is we do division explicitly instead of t * (1/x) ie we do t/x.
+ # The difference is we do division explicity instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
- positions = torch.arange(
- self.current_rope_size, device = "cpu", dtype = torch.int64
- ).float()
- positions = positions / self.scaling_factor
+ positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
+ positions = positions / self.scaling_factor
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)
emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
- cos = emb.cos().to(device = device, non_blocking = True) # , dtype = dtype)
- sin = emb.sin().to(device = device, non_blocking = True) # , dtype = dtype)
- self.multi_gpu_cos_cached[device.index] = cos
- self.multi_gpu_sin_cached[device.index] = sin
- return cos, sin
+ cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype)
+ sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype)
+ self.register_buffer("cos_cached", cos, persistent = False)
+ self.register_buffer("sin_cached", sin, persistent = False)
+ pass
+pass
class FastGemmaModel(FastLlamaModel):
+
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
- model_name = "gemma",
- rope_module = GemmaFixedRotaryEmbedding,
+ model_name = "gemma",
+ rope_module = GemmaFixedRotaryEmbedding,
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
- attention_module = GemmaAttention,
+ attention_module = GemmaAttention,
)
if init_name is not None:
exec(function, globals())
- GemmaAttention.__init__ = eval(init_name)
- GemmaAttention.forward = LlamaAttention_fast_forward
- GemmaSdpaAttention.forward = LlamaAttention_fast_forward
+ GemmaAttention.__init__ = eval(init_name)
+ pass
+ GemmaAttention .forward = LlamaAttention_fast_forward
+ GemmaSdpaAttention .forward = LlamaAttention_fast_forward
GemmaFlashAttention2.forward = LlamaAttention_fast_forward
- GemmaDecoderLayer.forward = GemmaDecoderLayer_fast_forward
- GemmaModel.forward = LlamaModel_fast_forward
- GemmaForCausalLM.forward = CausalLM_fast_forward(
- GemmaModel_fast_forward_inference
- )
- PeftModelForCausalLM.forward = PeftModel_fast_forward
+ GemmaDecoderLayer .forward = GemmaDecoderLayer_fast_forward
+ GemmaModel .forward = LlamaModel_fast_forward
+ GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference)
+ PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(GemmaForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
- # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.
+ # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.gemma.modeling_gemma
-
- transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = (
- GemmaFixedRotaryEmbedding
- )
+ transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = GemmaFixedRotaryEmbedding
return
+ pass
+
@staticmethod
- def post_patch(model, tokenizer, correct_dtype = None):
+ def post_patch(model, tokenizer):
# Gemma does not downcast RoPE
- model, tokenizer = patch_model_and_tokenizer(
- model, tokenizer, downcast_rope = False, correct_dtype = correct_dtype
- )
+ model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = False)
# Add 1 to weight
# return output * (1 + self.weight)
@@ -470,6 +363,7 @@ def post_patch(model, tokenizer, correct_dtype = None):
param.requires_grad_(True)
else:
param.requires_grad_(False)
+ pass
# Patch RMS Layernorm
for name, module in model.named_modules():
@@ -480,14 +374,14 @@ def post_patch(model, tokenizer, correct_dtype = None):
# Leave + 1 to Triton kernel itself
# module.weight += 1.0 # return output * (1 + self.weight)
if not hasattr(module, "variance_epsilon"):
- module.variance_epsilon = (
- module.eps
- ) # Gemma doesn't use variance_epsilon
+ module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
+ pass
# Clear deleted GPU items
import gc
-
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model, tokenizer
+ pass
+pass
diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py
index e59b8d5ebd..316b4e8f0e 100644
--- a/unsloth/models/gemma2.py
+++ b/unsloth/models/gemma2.py
@@ -14,22 +14,11 @@
from .llama import *
from ._utils import __version__
-from unsloth_zoo.utils import _get_dtype, Version
-from unsloth_zoo.hf_utils import dtype_from_config
-from ..utils.packing import get_packed_info_from_kwargs
-from ..utils.attention_dispatch import (
- AttentionConfig,
- AttentionContext,
- run_attention,
- select_attention_backend,
- SDPA,
-)
from .gemma import (
GemmaFixedRotaryEmbedding,
GemmaFixedLinearScalingRotaryEmbedding,
fast_geglu_inference,
)
-
try:
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
@@ -41,19 +30,21 @@
repeat_kv,
)
except:
+ from packaging.version import Version
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.42"):
raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"
- f"The minimum required version is 4.42.3.\n"
- f'Try `pip install --upgrade "transformers>=4.42.3"`\n'
- f"to obtain the latest transformers build, then restart this session."
+ f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
+ f"The minimum required version is 4.42.3.\n"\
+ f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
+ f"to obtain the latest transformers build, then restart this session."\
)
+ pass
+pass
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
)
-
# For Pytorch 2.1.1
try:
from transformers.models.gemma2.modeling_gemma2 import (
@@ -61,27 +52,39 @@
Gemma2FlashAttention2,
)
except:
- Gemma2SdpaAttention = Gemma2Attention
+ Gemma2SdpaAttention = Gemma2Attention
Gemma2FlashAttention2 = Gemma2Attention
+pass
if HAS_FLASH_ATTENTION_SOFTCAPPING:
from flash_attn import flash_attn_func
+# [TODO] We must randomnly use torch.compile?
+# Gemma 2 uses double RMS Layernorms, so the backward passes should not overwrite the gradients!
+@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
+def fast_rms_layernorm_gemma2_compiled(layernorm, X, gemma = True):
+ old_dtype = X.dtype
+ X = X.float()
+ X = X * torch.rsqrt(X.square().mean(-1, keepdim = True) + layernorm.eps) * \
+ (1.0 + layernorm.weight.float())
+ return X.to(old_dtype)
+pass
+
# Logit softcapping
def Gemma2Attention_fast_forward(
self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- padding_mask: Optional[torch.LongTensor] = None,
- *args,
- **kwargs,
+ hidden_states: torch.Tensor,
+ causal_mask: Optional[BlockDiagonalCausalMask] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.LongTensor] = None,
+ *args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
@@ -91,245 +94,175 @@ def Gemma2Attention_fast_forward(
del self.temp_KV
del self.RH_Q
del self.attention
+ pass
bsz, q_len, _ = hidden_states.size()
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
+ n_heads = self.config.num_attention_heads
+ n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
- assert n_kv_heads * n_groups == n_heads
+ head_dim = self.head_dim
+ assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
- Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
+ Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
- seq_info = get_packed_info_from_kwargs(kwargs, Q.device)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
- device_index = Q.device.index
- cos = self.rotary_emb.multi_gpu_cos_cached[device_index]
- sin = self.rotary_emb.multi_gpu_sin_cached[device_index]
-
- rope_position_ids = (
- position_ids if position_ids is not None else kwargs.get("position_ids")
- )
- if rope_position_ids is not None:
- # Useful for LongRoPE
- cos_var, sin_var = self.rotary_emb.get_cached(kv_seq_len, device_index)
- Q, K = fast_rope_embedding(Q, K, cos_var, sin_var, rope_position_ids)
- else:
+ if position_ids is None:
+ cos = self.rotary_emb.cos_cached
+ sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
+ else:
+ cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
+ Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
+ pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
+ pass
past_key_value = (K, V) if use_cache else None
# Only enable if the attention_mask is True
- use_sliding_window = kwargs.get("use_sliding_window")
- has_sliding_window = (
- use_sliding_window
- if use_sliding_window is not None
- else isinstance(causal_mask, bool) and causal_mask is True
- )
-
- use_flash = HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None
-
- if use_flash:
+ has_sliding_window = type(causal_mask) is bool and causal_mask is True
+ if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
window = (-1, -1)
- sliding_window = getattr(self.config, "sliding_window", None)
if has_sliding_window:
- sliding_window = (
- sliding_window if sliding_window is not None else kv_seq_len
- )
- window = (
- (-1, -1)
- if kv_seq_len <= sliding_window
- else (sliding_window, sliding_window)
- )
+ sw = getattr(self.config, "sliding_window", None)
+ sw = kv_seq_len if (sw is None or sw == "null") else sw
+ window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
+ pass
+ # FA uses 1 / sqrt for softmax_scale!
if not hasattr(self, "_flash_attention_softmax_scale"):
- self._flash_attention_softmax_scale = 1.0 / (
- self.config.query_pre_attn_scalar**0.5
- )
-
- use_varlen = seq_info is not None and past_key_value is None
-
- attention_config = AttentionConfig(
- backend = select_attention_backend(use_varlen),
- n_kv_heads = n_kv_heads,
- n_groups = n_groups,
- flash_dense_kwargs = {
- "causal": True,
- "softcap": self.config.attn_logit_softcapping,
- "softmax_scale": self._flash_attention_softmax_scale,
- "window_size": window,
- },
- flash_varlen_kwargs = {
- "dropout_p": 0.0,
- "softmax_scale": self._flash_attention_softmax_scale,
- "causal": True,
- "softcap": self.config.attn_logit_softcapping,
- "window_size": window,
- },
- )
-
- context = AttentionContext(
- bsz = bsz,
- q_len = q_len,
- kv_seq_len = kv_seq_len,
- n_heads = n_heads,
- head_dim = head_dim,
- requires_grad = hidden_states.requires_grad,
- seq_info = seq_info,
- attention_mask = attention_mask,
- causal_mask = causal_mask,
- sliding_window = sliding_window,
- )
-
- A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
- A = A.reshape(bsz, q_len, n_heads * head_dim)
+ self._flash_attention_softmax_scale = 1.0 / (self.config.query_pre_attn_scalar**0.5)
+ pass
+
+ Q = Q.transpose(1, 2)
+ K = K.transpose(1, 2)
+ V = V.transpose(1, 2)
+ A = flash_attn_func(
+ Q, K, V,
+ causal = True,
+ softcap = self.config.attn_logit_softcapping,
+ softmax_scale = self._flash_attention_softmax_scale,
+ window_size = window,
+ )
+ A = A.reshape(bsz, q_len, n_heads*head_dim)
else:
- fx = (
- slow_inference_attention_softcapping
- if "_flag_for_generation" in kwargs
- else slow_attention_softcapping
- )
+ fx = slow_inference_attention_softcapping \
+ if "_flag_for_generation" in kwargs else \
+ slow_attention_softcapping
A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len)
+ pass
A = self.apply_o(self, A)
return A, None, past_key_value
+pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def Gemma2DecoderLayer_fast_forward(
self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- padding_mask: Optional[torch.LongTensor] = None,
- *args,
- **kwargs,
+ hidden_states: torch.Tensor,
+ causal_mask: Optional[BlockDiagonalCausalMask] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ padding_mask: Optional[torch.LongTensor] = None,
+ *args, **kwargs,
):
- if use_cache and hasattr(
- self, "_flag_for_generation"
- ): # past_key_value is not None:
- out_weight = torch.empty(
- self.input_layernorm.weight.shape,
- dtype = torch.float32,
- device = f"{DEVICE_TYPE_TORCH}:0",
- )
+ if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
+ out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
# Self Attention
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference_gemma(
- self.input_layernorm, hidden_states, out_weight
- )
+ hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
- _flag_for_generation = self._flag_for_generation,
- **kwargs,
- )
- hidden_states = fast_rms_layernorm_inference_gemma(
- self.post_attention_layernorm, hidden_states, out_weight
- )
+ hidden_states=hidden_states,
+ causal_mask=causal_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ padding_mask=padding_mask,
+ _flag_for_generation=self._flag_for_generation,
+ )
+ hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
hidden_states += residual
# Fully Connected
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference_gemma(
- self.pre_feedforward_layernorm, hidden_states, out_weight
- )
+ hidden_states = fast_rms_layernorm_inference_gemma(self. pre_feedforward_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
- hidden_states = fast_rms_layernorm_inference_gemma(
- self.post_feedforward_layernorm, hidden_states, out_weight
- )
+ hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight)
hidden_states += residual
else:
residual = hidden_states
- hidden_states = fast_rms_layernorm(
- self.input_layernorm, hidden_states, gemma = True
- )
+ hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
- **kwargs,
- )
- hidden_states = fast_rms_layernorm(
- self.post_attention_layernorm, hidden_states, gemma = True
- )
+ hidden_states=hidden_states,
+ causal_mask=causal_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ padding_mask=padding_mask,
+ )
+ hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
- hidden_states = fast_rms_layernorm(
- self.pre_feedforward_layernorm, hidden_states, gemma = True
- )
+ hidden_states = fast_rms_layernorm(self. pre_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = self.mlp(hidden_states)
- hidden_states = fast_rms_layernorm(
- self.post_feedforward_layernorm, hidden_states, gemma = True
- )
+ hidden_states = fast_rms_layernorm(self.post_feedforward_layernorm, hidden_states, gemma = True)
hidden_states = residual + hidden_states
+ pass
outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if use_cache:
- outputs += (present_key_value,)
+ if output_attentions: outputs += (self_attn_weights,)
+ if use_cache: outputs += (present_key_value,)
return outputs
+pass
from math import sqrt as math_sqrt
-
-KV_CACHE_INCREMENT = 256 # KV Cache update size
+KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
torch_matmul = torch.matmul
-torch_tanh = torch.tanh
-
+torch_tanh = torch.tanh
def Gemma2Attention_fast_forward_inference(
self,
- hidden_states: torch.Tensor,
+ hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
use_sliding_window = False,
- **kwargs,
):
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
+ n_heads = self.config.num_attention_heads
+ n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
+ head_dim = self.head_dim
# assert(n_kv_heads * n_groups == n_heads)
hidden_size = self.config.hidden_size
- attention_size = n_heads * head_dim
+ attention_size = n_heads*head_dim
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
device = hidden_states.device
@@ -337,28 +270,18 @@ def Gemma2Attention_fast_forward_inference(
# Prefill phase
# if not hasattr(self, "paged_attention"):
if do_prefill:
- self.paged_attention = torch.empty(
- (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
- dtype = dtype,
- device = device,
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
+ self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = device)
+ self.paged_attention_K = self.paged_attention[:,0]
+ self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
- self.temp_QA = torch.empty(
- (2, bsz, 1, attention_size), dtype = dtype, device = device
- )
- self.temp_KV = torch.empty(
- (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
- )
+ self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device)
+ self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
# Only for Gemma2
- self.temp_O = torch.empty((bsz, 1, hidden_size), dtype = dtype, device = device)
- self.attention = torch.empty(
- (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
- )
-
+ self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
+ self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device)
+
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
# We default to using the config file itself
@@ -366,54 +289,42 @@ def Gemma2Attention_fast_forward_inference(
self.scalar = 1.0 / math_sqrt(self.config.query_pre_attn_scalar)
# self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads)
self.half_head_dim = head_dim // 2
- self.t = self.config.attn_logit_softcapping
+ self. t = self.config.attn_logit_softcapping
self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping
elif kv_seq_len >= self.paged_attention.shape[0]:
- self.paged_attention.resize_(
- (
- self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
- 2,
- bsz,
- n_kv_heads,
- head_dim,
- )
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
- self.attention.resize_(
- (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
- )
+ self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
+ self.paged_attention_K = self.paged_attention[:,0]
+ self.paged_attention_V = self.paged_attention[:,1]
+ self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
+ pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
- Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
+ Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
- cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)
- cos = cos[position_ids].unsqueeze(1)
- sin = sin[position_ids].unsqueeze(1)
+ cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1)
+ sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1)
h = self.half_head_dim
RH_Q = self.RH_Q
- RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
- RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
- RH_Q[:, :, :, :h].neg_()
+ RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
+ RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
+ torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
- RH_K = RH_Q[
- :, :n_kv_heads, :, :
- ] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
- RH_K[:, :, :, :h] = Kn[:, :, :, h:]
- RH_K[:, :, :, h:] = Kn[:, :, :, :h]
- RH_K[:, :, :, :h].neg_()
+ RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+ RH_K[:,:,:,:h] = Kn[:,:,:,h:]
+ RH_K[:,:,:,h:] = Kn[:,:,:,:h]
+ torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
-
+
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
@@ -425,53 +336,45 @@ def Gemma2Attention_fast_forward_inference(
# Handle sliding windows
sliding_window = self.config.sliding_window
if use_sliding_window and kv_seq_len > sliding_window:
- start = kv_seq_len - sliding_window
- Knn = Kn[:, :, start:, :] # .contiguous()
- Vnn = Vn[:, :, start:, :] # .contiguous()
+ # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
+ slicing_tokens = 1 - sliding_window
+ Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
+ Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
+ pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
if n_groups != 1:
- Knn = Knn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
- Vnn = Vnn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
+ Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+ Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
+ pass
+ # else:
+ # Knn, Vnn = Knn, Vnn
+ # pass
# Attention
- # [TODO] Gemma2 uses manual matmul for all batch sizes because SDPA does
- # not support softcapping (tanh logit scaling). If a future PyTorch adds
- # a softcap param to scaled_dot_product_attention, consider using SDPA
- # for bsz > 1 to match the llama/qwen3 pattern.
- Qn *= (
- self.scalar
- ) # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
+ # if bsz == 1:
+ Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
- A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len])
+ A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
+ # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
- # Softcapping must happen BEFORE the mask is applied.
- # Reference: google-deepmind/gemma _modules.py and transformers gemma2 eager_attention_forward
- A *= self.reciprocal_t
- A.tanh_()
- A *= self.t # Logit softcapping
+ A *= self.reciprocal_t; torch_tanh(A, out = A); A *= self.t; # Logit softcapping
- if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
- # Slice mask to match K/V when sliding window is active
- if attention_mask.shape[-1] != A.shape[-1]:
- attention_mask = attention_mask[:, :, :, -A.shape[-1] :]
- A += attention_mask
-
- A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32) # .to(A.dtype)
+ A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch_matmul(A, Vnn, out = Qn)
+ # else:
+ # A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
+ # pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
+pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
@@ -482,31 +385,21 @@ def Gemma2Model_fast_forward_inference(
past_key_values,
position_ids,
attention_mask = None,
- **kwargs,
):
- out_weights = tuple(
- torch.empty_like(
- self.model.layers[0].input_layernorm.weight,
- dtype = torch.float32,
- device = torch.device(x),
- )
- for x in range(DEVICE_COUNT)
- )
- input_ids = input_ids[:, : self.max_seq_length]
+ out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
+ input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
- hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
+ hidden_states = hidden_states.to(self.config.torch_dtype)
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
- hidden_states *= torch.tensor(
- math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype
- )
+ hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
if HAS_FLASH_ATTENTION_SOFTCAPPING:
SWA = True
- GA = False
+ GA = False
else:
SWA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
@@ -521,24 +414,18 @@ def Gemma2Model_fast_forward_inference(
hidden_states,
seq_len,
)
+ pass
else:
SWA = attention_mask
- GA = attention_mask
+ GA = attention_mask
+ pass
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
- # For pipeline parallelism, we need to move all tensors to the same device
- # note that this movement is once per GPU in PP
- device_index = getattr(decoder_layer, "_per_layer_device_index", 0)
- hidden_states, position_ids = move_to_device(
- device_index, hidden_states, position_ids
- )
use_sliding_window = idx % 2 == 0
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference_gemma(
- decoder_layer.input_layernorm, hidden_states, out_weights[device_index]
- )
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
@@ -548,31 +435,18 @@ def Gemma2Model_fast_forward_inference(
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
use_sliding_window = use_sliding_window,
)
- hidden_states = fast_rms_layernorm_inference_gemma(
- decoder_layer.post_attention_layernorm,
- hidden_states,
- out_weights[device_index],
- )
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
hidden_states += residual
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference_gemma(
- decoder_layer.pre_feedforward_layernorm,
- hidden_states,
- out_weights[device_index],
- )
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight)
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
- hidden_states = fast_rms_layernorm_inference_gemma(
- decoder_layer.post_feedforward_layernorm,
- hidden_states,
- out_weights[device_index],
- )
+ hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight)
hidden_states += residual
next_decoder_cache.append(present_key_value)
- hidden_states = fast_rms_layernorm_inference_gemma(
- self.model.norm, hidden_states, out_weights[device_index]
- )
+ pass
+ hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
return BaseModelOutputWithPast(
last_hidden_state = hidden_states,
@@ -580,49 +454,47 @@ def Gemma2Model_fast_forward_inference(
hidden_states = [],
attentions = [],
)
+pass
class FastGemma2Model(FastLlamaModel):
+
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
- model_name = "gemma2",
- rope_module = GemmaFixedRotaryEmbedding,
+ model_name = "gemma2",
+ rope_module = GemmaFixedRotaryEmbedding,
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
- attention_module = Gemma2Attention,
+ attention_module = Gemma2Attention,
)
if init_name is not None:
exec(function, globals())
- Gemma2Attention.__init__ = eval(init_name)
- Gemma2Attention.forward = Gemma2Attention_fast_forward
- Gemma2SdpaAttention.forward = Gemma2Attention_fast_forward
+ Gemma2Attention.__init__ = eval(init_name)
+ pass
+ Gemma2Attention .forward = Gemma2Attention_fast_forward
+ Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward
Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward
- Gemma2DecoderLayer.forward = Gemma2DecoderLayer_fast_forward
- Gemma2Model.forward = LlamaModel_fast_forward
- Gemma2ForCausalLM.forward = CausalLM_fast_forward(
- Gemma2Model_fast_forward_inference
- )
- PeftModelForCausalLM.forward = PeftModel_fast_forward
+ Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward
+ Gemma2Model .forward = LlamaModel_fast_forward
+ Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference)
+ PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(Gemma2ForCausalLM)
-
+
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
- # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.
+ # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.gemma2.modeling_gemma2
-
- transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = (
- GemmaFixedRotaryEmbedding
- )
+ transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding
return
+ pass
+
@staticmethod
- def post_patch(model, tokenizer, correct_dtype = None):
+ def post_patch(model, tokenizer):
# Gemma does not downcast RoPE
- model, tokenizer = patch_model_and_tokenizer(
- model, tokenizer, downcast_rope = False, correct_dtype = correct_dtype
- )
+ model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = False)
# Add 1 to weight
# return output * (1 + self.weight)
@@ -636,6 +508,7 @@ def post_patch(model, tokenizer, correct_dtype = None):
param.requires_grad_(True)
else:
param.requires_grad_(False)
+ pass
# Patch RMS Layernorm
for name, module in model.named_modules():
@@ -646,14 +519,14 @@ def post_patch(model, tokenizer, correct_dtype = None):
# Leave + 1 to Triton kernel itself
# module.weight += 1.0 # return output * (1 + self.weight)
if not hasattr(module, "variance_epsilon"):
- module.variance_epsilon = (
- module.eps
- ) # Gemma doesn't use variance_epsilon
+ module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
+ pass
# Clear deleted GPU items
import gc
-
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model, tokenizer
+ pass
+pass
diff --git a/unsloth/models/glm4_moe.py b/unsloth/models/glm4_moe.py
deleted file mode 100644
index 5d04b2f1d0..0000000000
--- a/unsloth/models/glm4_moe.py
+++ /dev/null
@@ -1,450 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-GLM-4.7 Flash (GLM4 MoE Lite) optimized implementation using grouped GEMM.
-
-Key architecture differences from Qwen3 MoE:
-- Router uses sigmoid activation (not softmax)
-- Has routed_scaling_factor of 1.8
-- Has 1 shared expert that processes all tokens
-- Uses group-based selection before topk
-- Uses MLA (Multi-head Latent Attention)
-"""
-
-from .llama import *
-import os
-from ._utils import __version__
-from .llama import (
- LlamaRotaryEmbedding,
- LlamaLinearScalingRotaryEmbedding,
- fix_prepare_inputs_for_generation,
- fast_rms_layernorm_inference,
- fast_swiglu_inference,
- LlamaModel_fast_forward,
- LlamaModel_fast_forward_inference,
- CausalLM_fast_forward,
- PeftModel_fast_forward,
-)
-import torch
-import torch.nn.functional as F
-from typing import Optional, Tuple
-from ..kernels import fast_rms_layernorm
-
-# Import the grouped gemm utilities from unsloth kernels
-# The grouped_gemm module expects its parent directory to be in sys.path
-HAS_GROUPED_GEMM = False
-try:
- import sys
- import os
-
- # Add the moe directory (parent of grouped_gemm) to sys.path
- _moe_path = os.path.join(
- os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "kernels", "moe"
- )
- if _moe_path not in sys.path:
- sys.path.insert(0, _moe_path)
-
- # Import grouped_gemm package first to apply TMA compatibility shim
- # This patches triton.language to support both old and new TMA API names
- import grouped_gemm # noqa: F401 - triggers TMA compatibility shim
-
- from grouped_gemm.interface import grouped_gemm
- from grouped_gemm.reference.moe_ops import (
- get_routing_indices,
- permute,
- unpermute,
- )
-
- HAS_GROUPED_GEMM = True
-except ImportError as e:
- import warnings
-
- warnings.warn(
- f"Grouped GEMM not available: {e}. MoE will use fallback implementation."
- )
-
-
-# Import transformers GLM4 MoE Lite classes
-try:
- from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
- Glm4MoeLiteAttention,
- Glm4MoeLiteMoE,
- Glm4MoeLiteMLP,
- Glm4MoeLiteNaiveMoe,
- Glm4MoeLiteTopkRouter,
- Glm4MoeLiteDecoderLayer,
- Glm4MoeLiteModel,
- Glm4MoeLiteForCausalLM,
- Glm4MoeLiteRMSNorm,
- )
-
- HAS_GLM4_MOE = True
-except ImportError:
- HAS_GLM4_MOE = False
-
- # Create dummy classes for type checking
- class Glm4MoeLiteAttention:
- pass
-
- class Glm4MoeLiteMoE:
- pass
-
- class Glm4MoeLiteMLP:
- pass
-
- class Glm4MoeLiteNaiveMoe:
- pass
-
- class Glm4MoeLiteTopkRouter:
- pass
-
- class Glm4MoeLiteDecoderLayer:
- pass
-
- class Glm4MoeLiteModel:
- pass
-
- class Glm4MoeLiteForCausalLM:
- pass
-
-
-torch_nn_functional_silu = torch.nn.functional.silu
-
-
-def Glm4MoeLiteMoE_fast_forward(self, hidden_states):
- """
- Optimized MoE forward pass using grouped GEMM.
-
- GLM4 MoE specifics:
- - Uses sigmoid router activation (not softmax)
- - Has routed_scaling_factor of 1.8
- - Has 1 shared expert that always processes all tokens
- - Uses group-based selection with topk_group
- """
- residuals = hidden_states
- orig_shape = hidden_states.shape
- batch_size, seq_len, hidden_dim = orig_shape
- num_tokens = batch_size * seq_len
-
- # Flatten hidden states for routing
- hidden_states = hidden_states.view(-1, hidden_dim)
-
- # Router computation
- router_logits = self.gate(hidden_states) # [num_tokens, n_routed_experts]
- topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
- # Cast routing weights to match hidden_states dtype (Qwen3 pattern)
- # Sigmoid router returns fp32, but hidden_states may be bf16
- topk_weights = topk_weights.to(hidden_states.dtype)
-
- # Get routing indices for grouped GEMM
- with torch.no_grad():
- token_counts_by_expert, gather_indices = get_routing_indices(
- topk_indices, self.n_routed_experts
- )
-
- # Use grouped GEMM for expert computation
- if HAS_GROUPED_GEMM:
- # Cast hidden_states to match expert weights dtype
- # Under autocast, hidden_states may be fp32 while weights are bf16
- hidden_states = hidden_states.to(self.experts.gate_up_proj.dtype)
-
- # First grouped GEMM: gate_up_proj with permute_x
- # Input: [num_tokens, hidden_dim] -> Output: [total_tokens, 2*intermediate_dim]
- intermediate = grouped_gemm(
- X = hidden_states,
- W = self.experts.gate_up_proj,
- m_sizes = token_counts_by_expert.int(),
- topk = self.top_k,
- gather_indices = gather_indices,
- permute_x = True,
- permute_y = False,
- autotune = True,
- is_first_gemm = True,
- )
-
- # Activation: SiLU(gate) * up
- gate, up = intermediate.chunk(2, dim = -1)
- intermediate = torch_nn_functional_silu(gate) * up
-
- # Second grouped GEMM: down_proj with permute_y
- # Input: [total_tokens, intermediate_dim] -> Output: [total_tokens, hidden_dim]
- expert_output = grouped_gemm(
- X = intermediate,
- W = self.experts.down_proj,
- m_sizes = token_counts_by_expert.int(),
- topk = self.top_k,
- gather_indices = gather_indices,
- permute_x = False,
- permute_y = True,
- autotune = True,
- is_first_gemm = False,
- )
-
- # Merge topk weights: [num_tokens, top_k, hidden_dim] -> [num_tokens, hidden_dim]
- hidden_states = (
- expert_output.view(num_tokens, self.top_k, hidden_dim)
- * topk_weights.unsqueeze(-1)
- ).sum(dim = 1)
- else:
- # Fallback to naive implementation
- hidden_states = self.experts(hidden_states, topk_indices, topk_weights)
-
- # Add shared expert output
- hidden_states = hidden_states + self.shared_experts(residuals.view(-1, hidden_dim))
-
- return hidden_states.view(*orig_shape)
-
-
-def Glm4MoeLiteNaiveMoe_fast_forward(
- self,
- hidden_states: torch.Tensor,
- top_k_index: torch.Tensor,
- top_k_weights: torch.Tensor,
-) -> torch.Tensor:
- """
- Optimized expert forward using grouped GEMM.
-
- Args:
- hidden_states: [num_tokens, hidden_dim]
- top_k_index: [num_tokens, top_k] indices of selected experts
- top_k_weights: [num_tokens, top_k] weights for selected experts
-
- Returns:
- [num_tokens, hidden_dim] output after weighted sum of expert outputs
- """
- num_tokens, hidden_dim = hidden_states.shape
- top_k = top_k_index.shape[1]
- # Cast routing weights to match hidden_states dtype (Qwen3 pattern)
- top_k_weights = top_k_weights.to(hidden_states.dtype)
-
- if not HAS_GROUPED_GEMM:
- # Fallback to original naive implementation
- final_hidden_states = torch.zeros_like(hidden_states)
- with torch.no_grad():
- expert_mask = torch.nn.functional.one_hot(
- top_k_index, num_classes = self.num_experts
- )
- expert_mask = expert_mask.permute(2, 1, 0)
- expert_hit = torch.greater(expert_mask.sum(dim = (-1, -2)), 0).nonzero()
-
- for expert_idx in expert_hit:
- expert_idx = expert_idx[0]
- if expert_idx == self.num_experts:
- continue
- top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
- current_state = hidden_states[token_idx]
- gate, up = torch.nn.functional.linear(
- current_state, self.gate_up_proj[expert_idx]
- ).chunk(2, dim = -1)
- current_hidden_states = self.act_fn(gate) * up
- current_hidden_states = torch.nn.functional.linear(
- current_hidden_states, self.down_proj[expert_idx]
- )
- current_hidden_states = (
- current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
- )
- final_hidden_states.index_add_(
- 0, token_idx, current_hidden_states.to(final_hidden_states.dtype)
- )
-
- return final_hidden_states
-
- # Get routing indices for grouped GEMM
- with torch.no_grad():
- token_counts_by_expert, gather_indices = get_routing_indices(
- top_k_index, self.num_experts
- )
-
- # Cast hidden_states to match expert weights dtype
- # Under autocast, hidden_states may be fp32 while weights are bf16
- hidden_states = hidden_states.to(self.gate_up_proj.dtype)
-
- # First grouped GEMM: gate_up_proj
- intermediate = grouped_gemm(
- X = hidden_states,
- W = self.gate_up_proj,
- m_sizes = token_counts_by_expert.int(),
- topk = top_k,
- gather_indices = gather_indices,
- permute_x = True,
- permute_y = False,
- autotune = True,
- is_first_gemm = True,
- )
-
- # Activation: SiLU(gate) * up
- gate, up = intermediate.chunk(2, dim = -1)
- intermediate = self.act_fn(gate) * up
-
- # Second grouped GEMM: down_proj
- expert_output = grouped_gemm(
- X = intermediate,
- W = self.down_proj,
- m_sizes = token_counts_by_expert.int(),
- topk = top_k,
- gather_indices = gather_indices,
- permute_x = False,
- permute_y = True,
- autotune = True,
- is_first_gemm = False,
- )
-
- # Merge topk weights
- final_hidden_states = (
- expert_output.view(num_tokens, top_k, hidden_dim) * top_k_weights.unsqueeze(-1)
- ).sum(dim = 1)
-
- return final_hidden_states
-
-
-def Glm4MoeLiteDecoderLayer_fast_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values = None,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- **kwargs,
-) -> torch.Tensor:
- """
- Optimized decoder layer forward with fast RMS layernorm.
- """
- # Check if we're in inference mode
- is_inference = use_cache and hasattr(self, "_flag_for_generation")
-
- if is_inference:
- # Self-attention with fast inference path
- residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- self.input_layernorm, hidden_states
- )
- hidden_states, _ = self.self_attn(
- hidden_states = hidden_states,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_values = past_key_values,
- use_cache = use_cache,
- cache_position = cache_position,
- position_embeddings = position_embeddings,
- **kwargs,
- )
- hidden_states = residual + hidden_states
-
- # MLP/MoE
- residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- self.post_attention_layernorm, hidden_states
- )
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- else:
- # Training path
- residual = hidden_states
- hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states = hidden_states,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_values = past_key_values,
- use_cache = use_cache,
- cache_position = cache_position,
- position_embeddings = position_embeddings,
- **kwargs,
- )
- hidden_states = residual + hidden_states
-
- # MLP/MoE
- residual = hidden_states
- hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
-
- return hidden_states
-
-
-def Glm4MoeLiteMLP_fast_forward(self, x):
- """
- Optimized MLP forward using fused SwiGLU.
- """
- return fast_swiglu_inference(self, x)
-
-
-class FastGLM47Model(FastLlamaModel):
- """
- Fast GLM-4.7 Flash (GLM4 MoE Lite) model with grouped GEMM optimization.
-
- This provides 2-3x throughput improvement for MoE layers by:
- - Replacing sequential expert loops with grouped GEMM operations
- - Fusing permutation operations into the GEMM kernels
- - Using optimized RMS LayerNorm and SwiGLU implementations
- """
-
- @staticmethod
- def pre_patch():
- if not HAS_GLM4_MOE:
- raise ImportError(
- "Unsloth: GLM4 MoE Lite support requires transformers >= 5.0.0. "
- "Please upgrade with: pip install --upgrade transformers"
- )
-
- # Patch MoE forward with grouped GEMM optimization
- # TMA compatibility is handled by grouped_gemm/__init__.py which patches
- # triton.language to support both old (_experimental_make_tensor_descriptor)
- # and new (make_tensor_descriptor) API names
- if HAS_GROUPED_GEMM:
- Glm4MoeLiteNaiveMoe.forward = Glm4MoeLiteNaiveMoe_fast_forward
- Glm4MoeLiteMoE.forward = Glm4MoeLiteMoE_fast_forward
-
- # Note: We don't patch the following for GLM4 MoE because:
- # - GLM4 uses MLA (Multi-head Latent Attention) which has different projection names
- # - Glm4MoeLiteRotaryEmbedding doesn't have extend_rope_embedding method
- # - The decoder layer and model forward functions assume Llama-compatible infrastructure
-
- return
-
- @staticmethod
- def from_pretrained(
- model_name = "unsloth/GLM-4.7-Flash",
- max_seq_length = 4096,
- dtype = None,
- load_in_4bit = True,
- token = None,
- device_map = "sequential",
- rope_scaling = None,
- fix_tokenizer = True,
- model_patcher = None,
- tokenizer_name = None,
- trust_remote_code = False,
- **kwargs,
- ):
- # Pop kwargs that are used by loader but not passed to model
- kwargs.pop("unsloth_force_compile", None)
-
- return FastLlamaModel.from_pretrained(
- model_name = model_name,
- max_seq_length = max_seq_length,
- dtype = dtype,
- load_in_4bit = load_in_4bit,
- token = token,
- device_map = device_map,
- rope_scaling = rope_scaling,
- fix_tokenizer = fix_tokenizer,
- model_patcher = FastGLM47Model,
- tokenizer_name = tokenizer_name,
- trust_remote_code = trust_remote_code,
- **kwargs,
- )
diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py
index 79ac41c43f..df498d18ba 100644
--- a/unsloth/models/granite.py
+++ b/unsloth/models/granite.py
@@ -15,16 +15,6 @@
from .llama import *
import os
from ._utils import __version__
-from unsloth_zoo.utils import _get_dtype, Version
-from unsloth_zoo.hf_utils import dtype_from_config
-from ..utils.packing import get_packed_info_from_kwargs
-from ..utils.attention_dispatch import (
- AttentionConfig,
- AttentionContext,
- run_attention,
- select_attention_backend,
- SDPA,
-)
from .llama import (
LlamaRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
@@ -32,7 +22,6 @@
from .mistral import *
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
-
try:
from transformers.models.granite.modeling_granite import (
GraniteAttention,
@@ -41,14 +30,18 @@
GraniteForCausalLM,
)
except:
+ from packaging.version import Version
+
transformers_version = Version(transformers_version)
if not transformers_version >= Version("4.45.0"):
raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support Granite.\n"
- f"The minimum required version is 4.45.0.\n"
- f'Try `pip install --upgrade "transformers>=4.45.0"`\n'
- f"to obtain the latest transformers build, then restart this session."
+ f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
+ f"The minimum required version is 4.42.3.\n"\
+ f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
+ f"to obtain the latest transformers build, then restart this session."\
)
+ pass
+pass
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
@@ -61,24 +54,24 @@
GraniteFlashAttention2,
)
except:
- GraniteSdpaAttention = GraniteAttention
+ GraniteSdpaAttention = GraniteAttention
GraniteFlashAttention2 = GraniteAttention
-
+pass
def GraniteAttention_fast_forward(
self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- padding_mask: Optional[torch.LongTensor] = None,
+ hidden_states: torch.Tensor,
+ causal_mask: Optional[BlockDiagonalCausalMask] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- *args,
- **kwargs,
+ *args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
@@ -88,21 +81,21 @@ def GraniteAttention_fast_forward(
del self.temp_KV
del self.RH_Q
del self.attention
+ pass
bsz, q_len, _ = hidden_states.size()
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
+ n_heads = self.config.num_attention_heads
+ n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
- dropout_p = self.config.attention_dropout if self.training else 0
- assert n_kv_heads * n_groups == n_heads
+ head_dim = self.head_dim
+ dropout_p = self.config.attention_dropout if self.training else 0
+ assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
- Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
+ Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
- seq_info = get_packed_info_from_kwargs(kwargs, Q.device)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
@@ -110,144 +103,126 @@ def GraniteAttention_fast_forward(
assert position_embeddings is not None
cos, sin = position_embeddings
- rope_position_ids = (
- position_ids if position_ids is not None else kwargs.get("position_ids")
- )
- if rope_position_ids is not None:
- # Useful for LongRoPE
- Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
- else:
+ if position_ids is None:
Q, K = fast_rope_embedding(Q, K, cos, sin)
+ else:
+ Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
+ pass
past_key_value = (K, V) if use_cache else None
# Attention module
- use_varlen = (
- attention_mask is None and seq_info is not None and past_key_value is None
- )
-
- backend = (
- SDPA if attention_mask is not None else select_attention_backend(use_varlen)
- )
-
- window = (kv_seq_len, kv_seq_len)
- softmax_scale = getattr(self, "scaling", None)
- attention_config = AttentionConfig(
- backend = backend,
- n_kv_heads = n_kv_heads,
- n_groups = n_groups,
- flash_dense_kwargs = {
- "causal": True,
- "softmax_scale": softmax_scale,
- "dropout_p": dropout_p,
- "window_size": window,
- },
- flash_varlen_kwargs = {
- "dropout_p": 0.0,
- "softmax_scale": softmax_scale,
- "causal": True,
- },
- sdpa_kwargs = {
- k: v
- for k, v in {
- "attn_mask": attention_mask,
- "scale": softmax_scale,
- "dropout_p": dropout_p,
- }.items()
- if v is not None
- },
- xformers_kwargs = {
- "scale": softmax_scale,
- "p": dropout_p,
- },
- )
-
- context = AttentionContext(
- bsz = bsz,
- q_len = q_len,
- kv_seq_len = kv_seq_len,
- n_heads = n_heads,
- head_dim = head_dim,
- requires_grad = hidden_states.requires_grad,
- seq_info = seq_info,
- attention_mask = attention_mask,
- causal_mask = causal_mask,
- )
-
- A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
-
- attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
+ if (not HAS_FLASH_ATTENTION and attention_mask is None):
+ # Xformers memory efficient attention
+ Q = Q.transpose(1, 2)
+ K = K.transpose(1, 2)
+ V = V.transpose(1, 2)
+ K_M = V_M = bsz * kv_seq_len
+ Q_M = bsz * q_len
+
+ # Group query attention
+ K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+ V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+ K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
+ V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
+ if hidden_states.requires_grad:
+ K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
+ V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
+ else:
+ # Xformers does support the forward pass though
+ Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
+ pass
+
+ A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling, p=dropout_p)
+ A = A.view(bsz, q_len, n_heads, head_dim)
+
+ elif HAS_FLASH_ATTENTION and attention_mask is None:
+ Q = Q.transpose(1, 2)
+ K = K.transpose(1, 2)
+ V = V.transpose(1, 2)
+ window = (kv_seq_len, kv_seq_len)
+ A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling, dropout_p=dropout_p)
+ else:
+ # Grouped query attention
+ # if n_groups != 1:
+ K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+ V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+ K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
+ V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
+ # pass
+ # Must be contiguous or else results are False!
+ # https://github.com/pytorch/pytorch/issues/112577
+ Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
+ # Needs (batch_size, n_heads, seq_len, head_dim)
+ # is_casual and attention_mask must not be both set!
+ A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False, dropout_p=dropout_p)
+ # Go back to (batch_size, seq_len, n_heads, head_dim)
+ A = A.transpose(1, 2).contiguous()
+ pass
+
+ attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
+pass
def GraniteDecoderLayer_fast_forward(
self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- padding_mask: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- *args,
- **kwargs,
+ hidden_states: torch.Tensor,
+ causal_mask: Optional[BlockDiagonalCausalMask] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ padding_mask: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ *args, **kwargs,
):
- residual_multiplier = (
- self.residual_multiplier
- if hasattr(self, "residual_multiplier")
- else self.config.residual_multiplier
- )
+ residual_multiplier = \
+ self.residual_multiplier \
+ if hasattr(self, "residual_multiplier") else \
+ self.config.residual_multiplier
- if use_cache and hasattr(
- self, "_flag_for_generation"
- ): # past_key_value is not None:
+ if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- self.input_layernorm, hidden_states
- )
+ hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
+ hidden_states=hidden_states,
+ causal_mask=causal_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ padding_mask=padding_mask,
position_embeddings = position_embeddings,
- _flag_for_generation = self._flag_for_generation,
- **kwargs,
+ _flag_for_generation=self._flag_for_generation,
)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
# Fully Connected
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- self.post_attention_layernorm, hidden_states
- )
+ hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
+ hidden_states=hidden_states,
+ causal_mask=causal_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ padding_mask=padding_mask,
position_embeddings = position_embeddings,
- **kwargs,
)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
@@ -256,50 +231,47 @@ def GraniteDecoderLayer_fast_forward(
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
+ pass
outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if use_cache:
- outputs += (present_key_value,)
+ if output_attentions: outputs += (self_attn_weights,)
+ if use_cache: outputs += (present_key_value,)
return outputs
+pass
from math import sqrt as math_sqrt
-
-KV_CACHE_INCREMENT = 256 # KV Cache update size
+KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
torch_matmul = torch.matmul
-torch_tanh = torch.tanh
-
+torch_tanh = torch.tanh
def GraniteAttention_fast_forward_inference(
self,
- hidden_states: torch.Tensor,
+ hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
use_sliding_window = False,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ position_embeddings : Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
- assert (
- position_embeddings is not None
- ), f"Granite model requires position embeddings to be specified"
+
+ assert position_embeddings is not None, f"Granite model requires position embeddings to be specified"
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
+ n_heads = self.config.num_attention_heads
+ n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
+ head_dim = self.head_dim
# assert(n_kv_heads * n_groups == n_heads)
hidden_size = self.config.hidden_size
- attention_size = n_heads * head_dim
+ attention_size = n_heads*head_dim
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
device = hidden_states.device
@@ -307,48 +279,31 @@ def GraniteAttention_fast_forward_inference(
# Prefill phase
# if not hasattr(self, "paged_attention"):
if do_prefill:
- self.paged_attention = torch.empty(
- (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
- dtype = dtype,
- device = device,
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
+ self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = device)
+ self.paged_attention_K = self.paged_attention[:,0]
+ self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
- self.temp_QA = torch.empty(
- (2, bsz, 1, attention_size), dtype = dtype, device = device
- )
- self.temp_KV = torch.empty(
- (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
- )
+ self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device)
+ self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
- self.temp_O = torch.empty((bsz, 1, hidden_size), dtype = dtype, device = device)
- self.attention = torch.empty(
- (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
- )
+ # Only for Gemma2
+ self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
+ self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device)
+
self.half_head_dim = head_dim // 2
elif kv_seq_len >= self.paged_attention.shape[0]:
- self.paged_attention.resize_(
- (
- self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
- 2,
- bsz,
- n_kv_heads,
- head_dim,
- )
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
- self.attention.resize_(
- (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
- )
+ self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
+ self.paged_attention_K = self.paged_attention[:,0]
+ self.paged_attention_V = self.paged_attention[:,1]
+ self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
+ pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
- Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
+ Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
@@ -359,21 +314,19 @@ def GraniteAttention_fast_forward_inference(
h = self.half_head_dim
RH_Q = self.RH_Q
- RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
- RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
- RH_Q[:, :, :, :h].neg_()
+ RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
+ RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
+ torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
- RH_K = RH_Q[
- :, :n_kv_heads, :, :
- ] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
- RH_K[:, :, :, :h] = Kn[:, :, :, h:]
- RH_K[:, :, :, h:] = Kn[:, :, :, :h]
- RH_K[:, :, :, :h].neg_()
+ RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+ RH_K[:,:,:,:h] = Kn[:,:,:,h:]
+ RH_K[:,:,:,h:] = Kn[:,:,:,:h]
+ torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
-
+
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
@@ -384,52 +337,31 @@ def GraniteAttention_fast_forward_inference(
# Grouped query attention
_, _, cached_len, _ = Kn.shape
- if bsz == 1 or ((not SDPA_HAS_GQA) and n_groups != 1):
- Kn = Kn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
- Vn = Vn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
+ if n_groups != 1:
+ Kn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+ Vn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Kn = Kn.reshape(bsz, n_heads, cached_len, head_dim)
Vn = Vn.reshape(bsz, n_heads, cached_len, head_dim)
-
- # Attention
- if bsz == 1:
- Qn *= self.scaling
- A = torch_matmul(
- Qn, Kn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
- )
- A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)
- A = torch_matmul(A, Vn, out = Qn)
- else:
- if (
- attention_mask is not None
- and attention_mask.dim() == 4
- and attention_mask.dtype != torch.bool
- ):
- attention_mask = attention_mask.eq(0)
- if SDPA_HAS_GQA:
- A = scaled_dot_product_attention(
- Qn,
- Kn,
- Vn,
- attn_mask = attention_mask,
- scale = self.scaling,
- enable_gqa = True,
- )
- else:
- A = scaled_dot_product_attention(
- Qn,
- Kn,
- Vn,
- attn_mask = attention_mask,
- scale = self.scaling,
- )
+ pass
+ # else:
+ # Kn, Vn = Kn, Vn
+ # pass
+
+ Qn *= self.scaling
+ A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
+
+ # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
+
+ A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
+ A = torch_matmul(A, Vn, out = Qn)
+ # else:
+ # A = scaled_dot_product_attention(Qn, Kn, Vn, attn_mask = attention_mask, is_causal = False)
+ # pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
+pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
@@ -441,15 +373,14 @@ def GraniteModel_fast_forward_inference(
position_ids,
attention_mask = None,
):
- input_ids = input_ids[:, : self.max_seq_length]
+ input_ids = input_ids[:,:self.max_seq_length]
hidden_states = self.model.embed_tokens(input_ids)
- hidden_states = hidden_states.to(_get_dtype(dtype_from_config(self.config)))
+ hidden_states = hidden_states.to(self.config.torch_dtype)
hidden_states *= self.model.embedding_multiplier
- residual_multiplier = (
- self.residual_multiplier
- if hasattr(self, "residual_multiplier")
- else self.config.residual_multiplier
- )
+ residual_multiplier = \
+ self.residual_multiplier \
+ if hasattr(self, "residual_multiplier") else \
+ self.config.residual_multiplier
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
@@ -460,27 +391,17 @@ def GraniteModel_fast_forward_inference(
hidden_states,
seq_len,
)
- # Pre-convert to bool once for all layers (avoids per-layer .eq(0))
- if attention_mask is not None and attention_mask.dtype != torch.bool:
- attention_mask = attention_mask.eq(0)
else:
attention_mask = None
+ pass
- position_embeddings = self.model.rotary_emb.get_cached(
- self.max_seq_length, hidden_states.device.index
- )
+ position_embeddings = self.model.rotary_emb(hidden_states, position_ids, self.max_seq_length)
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
- device_index = getattr(decoder_layer, "_per_layer_device_index", 0)
- hidden_states, position_ids = move_to_device(
- device_index, hidden_states, position_ids
- )
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- decoder_layer.input_layernorm, hidden_states
- )
+ hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states)
hidden_states, present_key_value = GraniteAttention_fast_forward_inference(
decoder_layer.self_attn,
hidden_states = hidden_states,
@@ -494,13 +415,12 @@ def GraniteModel_fast_forward_inference(
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- decoder_layer.post_attention_layernorm, hidden_states
- )
+ hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
next_decoder_cache.append(present_key_value)
+ pass
hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states)
return BaseModelOutputWithPast(
@@ -509,13 +429,12 @@ def GraniteModel_fast_forward_inference(
hidden_states = [],
attentions = [],
)
-
+pass
class GraniteRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(self, config):
super().__init__(config = config)
-
def patched_init(original_init):
def new_init(self, *args, **kwargs):
# we can use self.residual_multiplier arg in GraniteDecoderLayer_fast_forward as mentioned here
@@ -526,70 +445,64 @@ def new_init(self, *args, **kwargs):
if config is not None:
self.config = config
original_init(self, *args, **kwargs)
-
return new_init
-
class FastGraniteModel(FastLlamaModel):
+
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
- model_name = "granite",
- rope_module = GraniteRotaryEmbedding,
+ model_name = "granite",
+ rope_module = GraniteRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
- attention_module = GraniteAttention,
+ attention_module = GraniteAttention,
)
if init_name is not None:
exec(function, globals())
- GraniteAttention.__init__ = eval(init_name)
- GraniteAttention.forward = GraniteAttention_fast_forward
- GraniteSdpaAttention.forward = GraniteAttention_fast_forward
- GraniteFlashAttention2.forward = GraniteAttention_fast_forward
- GraniteDecoderLayer.forward = GraniteDecoderLayer_fast_forward
- GraniteModel.forward = LlamaModel_fast_forward
- GraniteForCausalLM.forward = CausalLM_fast_forward(
- GraniteModel_fast_forward_inference
- )
- GraniteForCausalLM.__init__ = patched_init(GraniteForCausalLM.__init__)
- PeftModelForCausalLM.forward = PeftModel_fast_forward
+ GraniteAttention.__init__ = eval(init_name)
+ pass
+ GraniteAttention .forward = GraniteAttention_fast_forward
+ GraniteSdpaAttention .forward = GraniteAttention_fast_forward
+ GraniteFlashAttention2.forward = GraniteAttention_fast_forward
+ GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward
+ GraniteModel .forward = LlamaModel_fast_forward
+ GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference)
+ GraniteForCausalLM .__init__ = patched_init(GraniteForCausalLM.__init__)
+ PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(GraniteForCausalLM)
import transformers.models.granite.modeling_granite
-
- transformers.models.granite.modeling_granite.GraniteRotaryEmbedding = (
- GraniteRotaryEmbedding
- )
+ transformers.models.granite.modeling_granite.GraniteRotaryEmbedding = GraniteRotaryEmbedding
return
+ pass
+
@staticmethod
- def post_patch(model, tokenizer, correct_dtype = None):
+ def post_patch(model, tokenizer):
+
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.2
- model.model.embed_tokens = torch.nn.Embedding.from_pretrained(
- model.model.embed_tokens.weight
- )
- model.config.update({"unsloth_version": __version__})
+ model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
+ model.config.update({"unsloth_version" : __version__})
# We also do this for the lm_head
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.lm_head.weight
- lm_head.in_features = lm_head.weight.shape[1]
+ lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
# Granite has tied weights! This means lm_head == embed_tokens
- if (
- model.model.embed_tokens.weight.data_ptr()
- != model.lm_head.weight.data_ptr()
- ):
+ if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
lm_head = torch.nn.Linear(1, 1, bias = None)
del lm_head.weight
lm_head.weight = model.model.embed_tokens.weight
- lm_head.in_features = lm_head.weight.shape[1]
+ lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]
model.lm_head = lm_head
+ pass
# Also patch all dtypes - BnB seems to not allocate the correct type?
# BnB default dtype seems to be float16!
@@ -602,30 +515,36 @@ def post_patch(model, tokenizer, correct_dtype = None):
if type(quant_state) is list:
# BnB seems to have float16 as default!
- module.weight.quant_state[2] = (
- correct_dtype # Cast to correct dtype
- )
+ module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
else:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
quant_state.dtype = correct_dtype
+ pass
+ pass
# Downcast RoPE embedding to correct data type
- if name.endswith("rotary_emb") or hasattr(module, "cos_cached"):
- if hasattr(module, "cos_cached") and (
- module.cos_cached.dtype != correct_dtype
- ):
+ if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")):
+
+ if hasattr(module, "cos_cached") and \
+ (module.cos_cached.dtype != correct_dtype):
+
module.cos_cached = module.cos_cached.to(correct_dtype)
module.sin_cached = module.sin_cached.to(correct_dtype)
- elif hasattr(module, "short_cos_cached") and (
- module.short_cos_cached.dtype != correct_dtype
- ):
+ elif hasattr(module, "short_cos_cached") and \
+ (module.short_cos_cached.dtype != correct_dtype):
+
module.short_cos_cached = module.short_cos_cached.to(correct_dtype)
module.short_sin_cached = module.short_sin_cached.to(correct_dtype)
+ pass
+ pass
+ pass
# Clear deleted GPU items
import gc
-
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return model, tokenizer
+ pass
+pass
+
diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py
index 93d93e26d6..893a09dd14 100644
--- a/unsloth/models/llama.py
+++ b/unsloth/models/llama.py
@@ -17,55 +17,15 @@
import math
import functools
from typing import Optional, Tuple, List, Union
-
from ._utils import *
-from ._utils import apply_unsloth_gradient_checkpointing
-from ._utils import __version__, importlib_version
-from ._utils import move_to_device
-from ._utils import (
- _get_inference_mode_context_manager,
- _prepare_model_for_qat,
- is_bfloat16_supported,
- get_quant_type,
-)
-from .loader_utils import _get_fp8_mode_and_check_settings
-from ..utils.packing import (
- get_packed_info_from_kwargs,
- mask_packed_sequence_boundaries,
-)
-from ..utils.attention_dispatch import (
- AttentionConfig,
- AttentionContext,
- run_attention,
- SDPA,
- select_attention_backend,
-)
+from ._utils import patch_unsloth_smart_gradient_checkpointing
+from ._utils import __version__
from torch.nn.functional import scaled_dot_product_attention
from transformers import __version__ as transformers_version
from unsloth_zoo.utils import Version, _get_dtype
-from unsloth_zoo.hf_utils import (
- dtype_from_config,
- add_dtype_kwargs,
- fix_lora_auto_mapping,
-)
-from unsloth_zoo.peft_utils import SKIP_QUANTIZATION_MODULES
-from ..device_type import (
- is_hip,
- get_device_type,
- DEVICE_TYPE,
- DEVICE_TYPE_TORCH,
- DEVICE_COUNT,
- ALLOW_PREQUANTIZED_MODELS,
-)
-
transformers_version = Version(transformers_version)
# Transformers moved rotary embeddings out of all attention layers
IS_ATTENTION_REFACTOR = transformers_version > Version("4.47.1")
-try:
- from transformers.modeling_layers import GradientCheckpointingLayer
-except:
- GradientCheckpointingLayer = type(None)
-
from transformers.models.llama.modeling_llama import (
logger,
BaseModelOutputWithPast,
@@ -76,6 +36,8 @@
)
from ..kernels import *
from ..tokenizer_utils import *
+if HAS_FLASH_ATTENTION:
+ from flash_attn import flash_attn_func
from .vision import FastBaseModel
# Final patching code
@@ -93,42 +55,27 @@
LlamaFlashAttention2,
)
except:
- LlamaSdpaAttention = LlamaAttention
+ LlamaSdpaAttention = LlamaAttention
LlamaFlashAttention2 = LlamaAttention
+pass
-from transformers import (
- AutoTokenizer,
- AutoModelForCausalLM,
- AutoModelForSequenceClassification,
- BitsAndBytesConfig,
- AutoConfig,
-)
+from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
from transformers import set_seed as transformers_set_seed
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
-from peft import PeftModelForCausalLM, PeftModelForSequenceClassification
+from peft import PeftModelForCausalLM
from ..save import patch_saving_functions
import re, os, inspect, math, sys
import types
-
try:
from huggingface_hub.utils import get_token
except:
# Old HF Hub versions <= 0.0.25
from huggingface_hub.utils._token import get_token
+pass
from triton import __version__ as triton_version
-
HAS_XFORMERS = xformers is not None
-BlockDiagonalCausalMask = (
- xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None
-)
-
-if DEVICE_TYPE == "xpu":
- clean_gpu_cache = torch.xpu.empty_cache
- get_current_device = torch.xpu.current_device
-else:
- clean_gpu_cache = torch.cuda.empty_cache
- get_current_device = torch.cuda.current_device
+BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None
def original_apply_qkv(self, X):
@@ -136,300 +83,88 @@ def original_apply_qkv(self, X):
K = self.k_proj(X)
V = self.v_proj(X)
return Q, K, V
+pass
def original_apply_o(self, X):
O = self.o_proj(X)
return O
-
+pass
from math import sqrt as math_sqrt
-
-KV_CACHE_INCREMENT = 512 # KV Cache update size
+KV_CACHE_INCREMENT = 512 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
# SDPA has GQA internally
SDPA_HAS_GQA = "enable_gqa" in scaled_dot_product_attention.__doc__
-from peft.utils.other import ModulesToSaveWrapper
-
-
-def _offload_frozen_module_for_training(
- module: ModulesToSaveWrapper,
- device_type: str,
- offload_device: Optional[str] = "cpu",
-) -> None:
- """
- Offload frozen module to CPU and configure trainable copy for mixed precision training.
-
- This function optimizes memory usage by:
- 1. Moving the trainable copy to the target device with appropriate precision
- 2. Optionally offloading the original frozen module to CPU/disk to free VRAM
- 3. Converting float16 to float32 for compatibility with certain GPUs (e.g., Tesla T4)
-
- Args:
- module: The module to configure. Must be a ModulesToSaveWrapper with a
- `modules_to_save` attribute containing trainable and original modules.
- device_type: Target device string for training (e.g., "cuda:0", "xpu:0")
- offload_device: Device to offload frozen parameters (default: "cpu").
- If None, the original frozen module remains on its current device.
- Note: Currently only "cpu" is supported; disk offloading is planned.
-
- Returns:
- None (modifies module in-place)
-
- Note:
- - Float16 weights are automatically promoted to float32 for GPU compatibility
- - When offload_device is specified, frozen parameters are moved to free VRAM
- - Future versions will support disk-based offloading for even larger models
-
- See Also:
- - https://github.com/unslothai/unsloth/pull/1200 (Tesla T4 float32 requirement)
- """
- # Early return with explicit None if module doesn't support mixed precision training
- if not hasattr(module, "modules_to_save"):
- return None
-
- new_dtype = module.modules_to_save.default.weight.dtype
- if new_dtype == torch.float16:
- # See https://github.com/unslothai/unsloth/pull/1200
- # Tesla T4 must use float32 and not float16
- new_dtype = torch.float32
-
- module.modules_to_save.default.to(
- device = device_type, dtype = new_dtype, non_blocking = True
- )
- module.modules_to_save.default.requires_grad_(True)
-
- # [TODO] Move old module to CPU - should be disk!
- if offload_device is not None:
- module.original_module.to(device = offload_device, non_blocking = True)
- module.original_module.requires_grad_(False)
-
-
# Fix new HF's inference code
-def _fast_prepare_inputs_for_generation(
- self,
- input_ids,
- attention_mask = None,
- inputs_embeds = None,
- **kwargs,
-):
- past_key_values = kwargs.get("past_key_values", None)
- original_attention_mask = attention_mask
-
- # Handle inputs_embeds - only use on FIRST generation step (no cache)
- # This fixes GitHub issue #3798: inputs_embeds was ignored
- use_inputs_embeds = inputs_embeds is not None and past_key_values is None
-
- if input_ids is not None and input_ids.numel() > 0:
- bs, seq_length = input_ids.shape
- device = input_ids.device
- elif inputs_embeds is not None:
- bs, seq_length, _ = inputs_embeds.shape
- device = inputs_embeds.device
- else:
- bs, seq_length = 1, 0
- device = "cuda" if torch.cuda.is_available() else "cpu"
-
- if past_key_values is not None:
- # Check for uninitialized DynamicCache
- if len(past_key_values) == 0:
- past_key_values = None
- kwargs["past_key_values"] = None
- use_inputs_embeds = inputs_embeds is not None
- # New since 4.56
- elif (
- hasattr(past_key_values, "get_seq_length")
- and past_key_values.get_seq_length() == 0
- ):
- past_key_values = None
- kwargs["past_key_values"] = None
- use_inputs_embeds = inputs_embeds is not None
- else:
- if input_ids is not None and input_ids.numel() > 0:
- bs = input_ids.shape[0]
- input_ids = input_ids[:, [-1]]
- device = input_ids.device
- seq_length = 1
- elif inputs_embeds is not None:
- bs, seq_length, _ = inputs_embeds.shape
- device = inputs_embeds.device
- else:
- bs, seq_length = 1, 0
- device = "cuda" if torch.cuda.is_available() else "cpu"
-
- if hasattr(past_key_values, "get_seq_length"):
- past_len = int(past_key_values.get_seq_length())
- else:
- # legacy tuple cache: (layer, (K,V))
- past_len = int(past_key_values[0][0].shape[-2])
-
- max_cache_len = None
- if hasattr(past_key_values, "get_max_cache_shape"):
- m = past_key_values.get_max_cache_shape()
- max_cache_len = int(m) if m is not None and m > 0 else None
- elif hasattr(past_key_values, "get_max_length"):
- m = past_key_values.get_max_length()
- max_cache_len = int(m) if m is not None else None
-
- # ensure cache_position
- cache_position = kwargs.get("cache_position", None)
- if cache_position is None:
- kwargs["cache_position"] = torch.arange(
- past_len,
- past_len + seq_length,
- device = device,
- dtype = torch.long,
- )
- else:
- if (
- hasattr(cache_position, "device")
- and cache_position.device != device
- ):
- kwargs["cache_position"] = cache_position.to(device)
-
- # Get to the base model
- base_model = self
- if hasattr(base_model, "base_model_prefix"):
- base_model = getattr(base_model, base_model.base_model_prefix)
-
- if hasattr(
- base_model, "_prepare_4d_causal_attention_mask_with_cache_position"
- ):
- if not hasattr(base_model, "_unsloth_mask_needs_device"):
-
- def _check_needs_device(fn) -> bool:
- try:
- sig = inspect.signature(inspect.unwrap(fn))
- return "device" in sig.parameters
- except:
- # transformers <= 4.51.3 includes device arg but > 4.51.3 does not
- return transformers_version < Version("4.52.0")
-
- base_model._unsloth_mask_needs_device = _check_needs_device(
- base_model._prepare_4d_causal_attention_mask_with_cache_position
- )
-
- if max_cache_len is not None:
- target_length = max_cache_len
- elif (
- original_attention_mask is not None
- and original_attention_mask.dim() == 2
- ):
- target_length = original_attention_mask.shape[-1]
- else:
- target_length = past_len + seq_length
-
- mask_kwargs = {
- "sequence_length": seq_length,
- "target_length": target_length,
- "dtype": self.dtype,
- "cache_position": kwargs["cache_position"],
- "batch_size": bs,
- "config": self.config,
- "past_key_values": past_key_values,
- }
- if base_model._unsloth_mask_needs_device:
- mask_kwargs["device"] = device
-
- attention_mask = (
- base_model._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- **mask_kwargs,
- )
- )
- else:
- if transformers_version <= Version("4.52.4"):
- logger.warning_once(
- f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
- "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
- "writing code, see Llama for an example implementation. If you're a user, please report this "
- "issue on GitHub."
- )
-
- if kwargs.get("position_ids", None) is None:
- if original_attention_mask is not None and original_attention_mask.dim() == 2:
- position_ids = original_attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(original_attention_mask == 0, 1)
- position_ids = position_ids[:, -seq_length:]
- kwargs["position_ids"] = position_ids
- elif kwargs.get("cache_position", None) is not None:
- cp = kwargs["cache_position"]
- if cp.dim() == 1:
- cp = cp.unsqueeze(0).expand(bs, -1)
- kwargs["position_ids"] = cp
-
- result = {
- "attention_mask": attention_mask,
- **kwargs,
- }
- if use_inputs_embeds:
- result["inputs_embeds"] = inputs_embeds
- result["input_ids"] = None
- else:
- result["input_ids"] = input_ids
- return result
+def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,):
+ if "past_key_values" in kwargs:
+ input_ids = input_ids[:,[-1]]
+ kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]]
+ if "cache_position" in kwargs:
+ kwargs["position_ids"] = kwargs["cache_position"]
+ return { "input_ids" : input_ids, **kwargs, }
+pass
def fix_prepare_inputs_for_generation(module):
# Fix prepare_inputs_for_generation
if hasattr(module, "prepare_inputs_for_generation"):
module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation
-
+ pass
+pass
torch_matmul = torch.matmul
-
-
def LlamaAttention_fast_forward_inference(
self,
- hidden_states: torch.Tensor,
+ hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]],
position_ids,
do_prefill = False,
attention_mask = None,
- rotary_seq_len = None,
):
"""
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
- Fast inference using KV cache.
- QK^T can be computed in 4 chunks
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
+ Fast inference using KV cache.
+ QK^T can be computed in 4 chunks
- [Q, q] @ [K, k].T where q, k are the new tokens.
- [QK^T, Qk^T]
- [qK^T, qk^T]
+ [Q, q] @ [K, k].T where q, k are the new tokens.
+ [QK^T, Qk^T]
+ [qK^T, qk^T]
- Since the attention mask wipes Qk^T, we just get
- [QK^T, 0]
- [qK^T, qk^T]
+ Since the attention mask wipes Qk^T, we just get
+ [QK^T, 0]
+ [qK^T, qk^T]
- Since softmax is row-wise, we get
- softmax([QK^T, 0])
- softmax([qK^T, qk^T])
+ Since softmax is row-wise, we get
+ softmax([QK^T, 0])
+ softmax([qK^T, qk^T])
- We then multiply by [V]
- [v]
- softmax([QK^T, 0]) [softmax(QK^T)V] *
- softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
+ We then multiply by [V]
+ [v]
+ softmax([QK^T, 0]) [softmax(QK^T)V] *
+ softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
- But notice * [softmax(QK^T)V] is just the last attention.
- We just need to compute the last final row.
+ But notice * [softmax(QK^T)V] is just the last attention.
+ We just need to compute the last final row.
- This means we can pass in a row of Q, but we need to
- remember K and V, which are called the KV cache.
+ This means we can pass in a row of Q, but we need to
+ remember K and V, which are called the KV cache.
"""
Xn = hidden_states
bsz, _, hd = hidden_states.size()
K1, V1 = past_key_value
dtype = Xn.dtype
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
+ n_heads = self.config.num_attention_heads
+ n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
+ head_dim = self.head_dim
# assert(n_kv_heads * n_groups == n_heads)
hidden_size = self.config.hidden_size
- attention_size = n_heads * head_dim
+ attention_size = n_heads*head_dim
seq_len = K1.shape[-2]
kv_seq_len = seq_len + 1
@@ -437,54 +172,36 @@ def LlamaAttention_fast_forward_inference(
# if not hasattr(self, "paged_attention"):
device = hidden_states.device
if do_prefill:
- self.paged_attention = torch.empty(
- (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
- dtype = dtype,
- device = device,
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
+ self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = device)
+ self.paged_attention_K = self.paged_attention[:,0]
+ self.paged_attention_V = self.paged_attention[:,1]
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
- self.temp_QA = torch.empty(
- (2, bsz, 1, attention_size), dtype = dtype, device = device
- )
- self.temp_KV = torch.empty(
- (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
- )
+ self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device)
+ self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device)
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
-
+
# Mistral Nemo 12b has weird dimensions
if attention_size != hidden_size:
- self.temp_O = torch.empty((bsz, 1, hidden_size), dtype = dtype, device = device)
+ self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device)
else:
- self.temp_O = self.temp_QA[1][:, :, :hidden_size]
-
- self.attention = torch.empty(
- (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
- )
+ self.temp_O = self.temp_QA[1][:,:,:hidden_size]
+ pass
+
+ self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device)
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
elif kv_seq_len >= self.paged_attention.shape[0]:
- self.paged_attention.resize_(
- (
- self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
- 2,
- bsz,
- n_kv_heads,
- head_dim,
- )
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
- self.attention.resize_(
- (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
- )
+ self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
+ self.paged_attention_K = self.paged_attention[:,0]
+ self.paged_attention_V = self.paged_attention[:,1]
+ self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
+ pass
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
- Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
+ Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
@@ -493,37 +210,26 @@ def LlamaAttention_fast_forward_inference(
# Need to do it prior 2 steps before hitting full on short KV cache
# or else error
- # ensure correct shape
- if position_ids.dim() == 1:
- position_ids = position_ids[:, None]
- position_ids = position_ids.to(Qn.device)
-
- if rotary_seq_len is None:
- rotary_seq_len = max(kv_seq_len, int(position_ids.max().item()) + 1)
- self.rotary_emb.extend_rope_embedding(Vn, rotary_seq_len + 1) # +1 slack
- cos, sin = self.rotary_emb.get_cached(rotary_seq_len, Qn.device.index or 0)
-
- cos = cos[position_ids].unsqueeze(1).to(device = Qn.device, dtype = Qn.dtype)
- sin = sin[position_ids].unsqueeze(1).to(device = Qn.device, dtype = Qn.dtype)
-
+ self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2)
+ cos, sin = self.rotary_emb.get_cached(kv_seq_len)
+ cos = cos[position_ids].unsqueeze(1)
+ sin = sin[position_ids].unsqueeze(1)
h = self.half_head_dim
RH_Q = self.RH_Q
- RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
- RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
- RH_Q[:, :, :, :h].neg_() # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
+ RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
+ RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
+ RH_Q[:,:,:,:h].neg_() # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
Qn *= cos
Qn.addcmul_(RH_Q, sin)
- RH_K = RH_Q[
- :, :n_kv_heads, :, :
- ] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
- RH_K[:, :, :, :h] = Kn[:, :, :, h:]
- RH_K[:, :, :, h:] = Kn[:, :, :, :h]
- RH_K[:, :, :, :h].neg_() # torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
+ RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
+ RH_K[:,:,:,:h] = Kn[:,:,:,h:]
+ RH_K[:,:,:,h:] = Kn[:,:,:,:h]
+ RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
Kn *= cos
Kn.addcmul_(RH_K, sin)
-
+
# New KV cache
# Kn = torch.cat([K1, Kn], dim = 2)
# Vn = torch.cat([V1, Vn], dim = 2)
@@ -536,82 +242,48 @@ def LlamaAttention_fast_forward_inference(
sliding_window = getattr(self.config, "sliding_window", None)
if sliding_window is not None and kv_seq_len > sliding_window:
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
- start = kv_seq_len - sliding_window
- Knn = Kn[:, :, start:, :] # .contiguous()
- Vnn = Vn[:, :, start:, :] # .contiguous()
- if attention_mask is not None:
- attention_mask = attention_mask[..., start:]
+ slicing_tokens = 1 - sliding_window
+ Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
+ Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
else:
Knn, Vnn = Kn, Vn
+ pass
# Grouped query attention
_, _, cached_len, _ = Knn.shape
- if bsz == 1 or ((not SDPA_HAS_GQA) and n_groups != 1):
- Knn = Knn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
- Vnn = Vnn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
+ if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1:
+ Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
+ Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
+ pass
+ # else:
+ # Knn, Vnn = Knn, Vnn
+ # pass
- # when qlen==vlen and attn_mask is None, we should use causal attention
- Q_len = Qn.shape[-2]
- K_len = Knn.shape[-2]
- if attention_mask is None and Q_len == K_len:
- is_causal = True
- else:
- is_causal = False
# Attention
if bsz == 1:
- Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
+ Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
- A = torch_matmul(
- Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
- )
- A[:] = torch_nn_functional_softmax(
- A, dim = -1, dtype = torch.float32
- ) # .to(A.dtype)
+ A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
+ # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
+ A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch_matmul(A, Vnn, out = Qn)
- # --- attention_mask fixup for SDPA if user passes 2D padding mask
else:
- if attention_mask is not None and attention_mask.dim() == 2:
- attention_mask = attention_mask[:, None, None, :].to(torch.bool)
- # is it more appropriate to use _prepare_4d_causal_attention_mask_for_sdpa?
- elif (
- attention_mask is not None
- and attention_mask.dim() == 4
- and attention_mask.dtype != torch.bool
- ):
- # Decode is more stable with boolean keep masks than additive bf16 masks.
- attention_mask = attention_mask.eq(0)
-
if SDPA_HAS_GQA:
- A = scaled_dot_product_attention(
- Qn,
- Knn,
- Vnn,
- attn_mask = attention_mask,
- is_causal = is_causal,
- enable_gqa = True,
- )
+ A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True)
else:
- A = scaled_dot_product_attention(
- Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal
- )
+ A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
+ pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
+pass
torch_nn_functional_silu = torch.nn.functional.silu
-
-
-def fast_swiglu_inference(
- self, X, temp_gate = None, temp_up = None, gate_multiplier = None, down_multiplier = None
-):
+def fast_swiglu_inference(self, X, temp_gate = None, temp_up = None):
# gate = self.gate_proj(X)
# up = self.up_proj(X)
bsz, _, hd = X.shape
@@ -619,28 +291,17 @@ def fast_swiglu_inference(
# temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
gate = fast_linear_forward(self.gate_proj, X, out = temp_gate)
-
- if gate_multiplier is not None:
- gate *= gate_multiplier
-
- up = fast_linear_forward(self.up_proj, X, out = temp_up)
-
+ up = fast_linear_forward(self. up_proj, X, out = temp_up)
gate = torch_nn_functional_silu(gate, inplace = True)
gate *= up
# X = self.down_proj(gate)
- down = fast_linear_forward(self.down_proj, gate, out = up[:, :, :hd])
-
- if down_multiplier is not None:
- down *= down_multiplier
-
+ down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
return down
-
+pass
torch_square = torch.square
-torch_mean = torch.mean
-
-
+torch_mean = torch.mean
def fast_rms_layernorm_inference(self, X, XX = None, XX2 = None, variance = None):
old_dtype = X.dtype
if XX is None:
@@ -649,16 +310,16 @@ def fast_rms_layernorm_inference(self, X, XX = None, XX2 = None, variance = None
else:
XX.copy_(X)
torch_mean(torch_square(XX, out = XX2), -1, keepdim = True, out = variance)
+ pass
variance += self.variance_epsilon
XX *= variance.rsqrt_()
- if XX is None:
- X = XX.to(old_dtype)
- else:
- X.copy_(XX)
+ if XX is None: X = XX.to(old_dtype)
+ else: X.copy_(XX)
X *= self.weight
return X
+pass
def fast_rms_layernorm_inference_gemma(self, X, out_weight = None):
@@ -672,9 +333,11 @@ def fast_rms_layernorm_inference_gemma(self, X, out_weight = None):
else:
out_weight[:] = self.weight
out_weight += 1.0
+ pass
XX *= out_weight
return XX.to(X.dtype)
+pass
# Normal layernorm with mean removal
@@ -684,29 +347,28 @@ def fast_layernorm_compiled(layernorm, X):
X = X.float()
mean = X.mean(-1, keepdim = True)
Xbar = X - mean
- X = (
- Xbar
- * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + layernorm.variance_epsilon)
- * layernorm.weight.float()
- )
+ X = Xbar * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + \
+ layernorm.variance_epsilon) * \
+ layernorm.weight.float()
return X.to(old_dtype)
+pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320
def LlamaAttention_fast_forward(
self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- padding_mask: Optional[torch.LongTensor] = None,
+ hidden_states: torch.Tensor,
+ causal_mask: Optional[BlockDiagonalCausalMask] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- *args,
- **kwargs,
+ *args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
@@ -716,96 +378,124 @@ def LlamaAttention_fast_forward(
del self.temp_KV
del self.RH_Q
del self.attention
+ pass
+
bsz, q_len, _ = hidden_states.size()
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
+ n_heads = self.config.num_attention_heads
+ n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
- assert n_kv_heads * n_groups == n_heads
+ head_dim = self.head_dim
+ assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
- Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
+ Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
- seq_info = get_packed_info_from_kwargs(kwargs, Q.device)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
- if position_embeddings and kv_seq_len <= position_embeddings[0].shape[0]:
+ if position_embeddings:
cos, sin = position_embeddings
else:
+ # Extend RoPE dynamically to fit in VRA
rotary_emb = self.rotary_emb
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
- cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index)
- cos = cos.to(device = Q.device, dtype = Q.dtype)
- sin = sin.to(device = Q.device, dtype = Q.dtype)
- rope_position_ids = position_ids
- if rope_position_ids is None and seq_info is not None:
- rope_position_ids = kwargs.get("position_ids")
+ if position_ids is None:
+ # Useful for LongRoPE
+ cos, sin = rotary_emb.get_cached(kv_seq_len)
+ else:
+ cos, sin = rotary_emb(V, seq_len = kv_seq_len)
# Q, K = (
# fast_rope_embedding(Q, K, cos, sin)
- # if rope_position_ids is None
- # else inplace_rope_embedding(Q, K, cos, sin, rope_position_ids)
+ # if position_ids is None
+ # else inplace_rope_embedding(Q, K, cos, sin, position_ids)
# )
- Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
+ Q, K = fast_rope_embedding(Q, K, cos, sin)
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
+ pass
past_key_value = (K, V) if use_cache else None
# Attention module
- use_varlen = seq_info is not None and past_key_value is None
- backend = (
- SDPA if attention_mask is not None else select_attention_backend(use_varlen)
- )
-
- # should dropout be hardcoded to 0.0?
- config = AttentionConfig(
- backend = backend,
- n_kv_heads = n_kv_heads,
- n_groups = n_groups,
- flash_dense_kwargs = {"causal": True},
- flash_varlen_kwargs = {"dropout_p": 0.0, "causal": True},
- )
- context = AttentionContext(
- bsz = bsz,
- q_len = q_len,
- kv_seq_len = kv_seq_len,
- n_heads = n_heads,
- head_dim = head_dim,
- requires_grad = hidden_states.requires_grad,
- seq_info = seq_info,
- attention_mask = attention_mask,
- causal_mask = causal_mask,
- )
-
- A = run_attention(config = config, context = context, Q = Q, K = K, V = V)
- attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
+ if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None):
+ # Xformers memory efficient attention
+ # Also has Flash Attention v2 dispatching
+ Q = Q.transpose(1, 2)
+ K = K.transpose(1, 2)
+ V = V.transpose(1, 2)
+
+ # Group query attention
+ if n_groups != 1:
+ K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+ V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+ K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
+ V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
+ if hidden_states.requires_grad:
+ K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
+ V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
+ else:
+ Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
+ pass
+ A = xformers_attention(Q, K, V, attn_bias = causal_mask)
+ A = A.view(bsz, q_len, n_heads, head_dim)
+
+ elif HAS_FLASH_ATTENTION and attention_mask is None:
+ Q = Q.transpose(1, 2)
+ K = K.transpose(1, 2)
+ V = V.transpose(1, 2)
+ A = flash_attn_func(Q, K, V, causal = True)
+ else:
+ # Grouped query attention
+ if SDPA_HAS_GQA:
+ # Needs (batch_size, n_heads, seq_len, head_dim)
+ # is_casual and attention_mask must not be both set!
+ A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1)
+ # Go back to (batch_size, seq_len, n_heads, head_dim)
+ A = A.transpose(1, 2)#.contiguous()
+ else:
+ if n_groups != 1:
+ K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+ V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+ K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
+ V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
+ pass
+ # Must be contiguous or else results are False!
+ # https://github.com/pytorch/pytorch/issues/112577
+ Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
+ # Needs (batch_size, n_heads, seq_len, head_dim)
+ # is_casual and attention_mask must not be both set!
+ A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
+ # Go back to (batch_size, seq_len, n_heads, head_dim)
+ A = A.transpose(1, 2).contiguous()
+ pass
+ pass
+ attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
+pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
def LlamaDecoderLayer_fast_forward(
self,
- hidden_states: torch.Tensor,
- causal_mask = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- padding_mask: Optional[torch.LongTensor] = None,
+ hidden_states: torch.Tensor,
+ causal_mask = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- *args,
- **kwargs,
+ *args, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@@ -822,44 +512,38 @@ def LlamaDecoderLayer_fast_forward(
"""
if use_cache and hasattr(self, "_flag_for_generation"):
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- self.input_layernorm, hidden_states
- )
+ hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
+ hidden_states = hidden_states,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
position_embeddings = position_embeddings,
- **kwargs,
)
hidden_states += residual
# Fully Connected
residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- self.post_attention_layernorm, hidden_states
- )
+ hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
hidden_states += residual
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
+ hidden_states = hidden_states,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
position_embeddings = position_embeddings,
- **kwargs,
)
hidden_states = residual + hidden_states
@@ -868,13 +552,13 @@ def LlamaDecoderLayer_fast_forward(
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
+ pass
outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if use_cache:
- outputs += (present_key_value,)
+ if output_attentions: outputs += (self_attn_weights,)
+ if use_cache: outputs += (present_key_value,)
return outputs
+pass
# https://github.com/unslothai/unsloth/issues/404#issuecomment-2323473452
@@ -887,113 +571,100 @@ def LlamaDecoderLayer_fast_forward(
torch.bfloat16: torch.bfloat16,
}
-
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
def LlamaModel_fast_forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
+ input_ids: torch.LongTensor,
+ causal_mask: Optional[BlockDiagonalCausalMask] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- *args,
- **kwargs,
+ return_dict: Optional[bool] = None,
+ *args, **kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- assert output_attentions is False
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ assert(output_attentions is False)
output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
- raise ValueError(
- "Unsloth: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
- )
+ raise ValueError("Unsloth: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
- raise ValueError(
- "Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds"
- )
+ raise ValueError("Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
- # Fix out of bounds tokenization unless we were given packed metadata
- allow_overlength = getattr(self, "_unsloth_allow_packed_overlength", False) or (
- "packed_seq_lengths" in kwargs
- )
- if hasattr(self, "max_seq_length") and not allow_overlength:
+ # Fix out of bounds tokenization
+ if hasattr(self, "max_seq_length"):
if seq_length > self.max_seq_length:
- shape = input_ids.shape if input_ids is not None else inputs_embeds.shape
logger.warning_once(
- f"Unsloth: Input IDs of shape {shape} with length {seq_length} > the model's max sequence length of {self.max_seq_length}.\n"
+ f"Unsloth: Input IDs of length {seq_length} > the model's max sequence length of {self.max_seq_length}.\n"\
"We shall truncate it ourselves. It's imperative if you correct this issue first."
)
if input_ids is not None:
- input_ids = input_ids[:, : self.max_seq_length]
+ input_ids = input_ids[:,:self.max_seq_length]
elif inputs_embeds is not None:
- inputs_embeds = inputs_embeds[:, : self.max_seq_length, :]
- if (
- attention_mask is not None
- and attention_mask.shape[-1] > self.max_seq_length
- ):
- attention_mask = attention_mask[:, : self.max_seq_length]
-
+ inputs_embeds = inputs_embeds[:,:self.max_seq_length,:]
+ pass
+ pass
+
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
+ pass
# We already handle KV cache position_ids ourselves.
- if False: # (past_key_values_length != 0):
+ if False:#(past_key_values_length != 0):
position_ids = torch.arange(
- past_key_values_length,
- seq_length + past_key_values_length,
- dtype = torch.int32,
- device = f"{DEVICE_TYPE_TORCH}:0",
+ past_key_values_length, seq_length + past_key_values_length,
+ dtype = torch.int32,
+ device = "cuda:0",
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
elif position_ids is not None:
- position_ids = position_ids.view(-1, seq_length).to(torch.int32) # .long()
+ position_ids = position_ids.view(-1, seq_length).to(torch.int32)#.long()
else:
position_ids = None
+ pass
if position_ids is not None:
if position_ids.shape[0] != batch_size:
position_ids = position_ids.repeat((batch_size, 1))
+ pass
# Embed positions
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- inputs_embeds = inputs_embeds.to(_get_dtype(dtype_from_config(self.config)))
+ # inputs_embeds = inputs_embeds.to(self.config.torch_dtype)
+ torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None)
+ if torch_dtype is not None:
+ inputs_embeds = inputs_embeds.to(torch_dtype)
+ else:
+ raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!")
+ pass
# Normalized from Gemma
- IS_GEMMA = self.config.model_type.startswith("gemma")
- IS_GEMMA2 = self.config.model_type.startswith("gemma2")
- IS_COHERE = self.config.model_type.startswith("cohere")
+ IS_GEMMA = self.config.model_type.startswith("gemma")
+ IS_GEMMA2 = self.config.model_type.startswith("gemma2")
+ IS_COHERE = self.config.model_type.startswith("cohere")
IS_GRANITE = self.config.model_type.startswith("granite")
- IS_FALCON_H1 = self.config.model_type.startswith("falcon_h1")
train_embed_tokens = self.embed_tokens.weight.requires_grad
@@ -1002,9 +673,7 @@ def LlamaModel_fast_forward(
# inputs_embeds *= math_sqrt(self.config.hidden_size)
# Ie 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
# & 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
- normalizer = torch.tensor(
- math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype
- )
+ normalizer = torch.tensor(math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype)
if train_embed_tokens:
# Careful we must not do an inplace op!
@@ -1016,20 +685,17 @@ def LlamaModel_fast_forward(
inputs_requires_grad = True
elif inputs_requires_grad:
inputs_embeds.requires_grad_(False)
+ pass
inputs_embeds *= normalizer
# inputs_embeds *= math_sqrt(self.config.hidden_size)
- if inputs_requires_grad:
- inputs_embeds.requires_grad_(True)
+ if inputs_requires_grad: inputs_embeds.requires_grad_(True)
+ pass
+ pass
# Fix up attention mask by setting elements to 0
# Specifically for DPO
- if (
- getattr(self, "_has_no_labels", False) is True
- and (attention_mask is not None)
- and (past_key_values is None)
- and (not train_embed_tokens)
- and self.training
- ):
+ if getattr(self, "_has_no_labels", False) is True and (attention_mask is not None) and (past_key_values is None) and \
+ (not train_embed_tokens):
# Careful for inference the attention_mask is size (1, kv_seq_len)
# Whilst the input_embeds is size (1, 1, 4096)
inputs_requires_grad = inputs_embeds.requires_grad
@@ -1038,15 +704,17 @@ def LlamaModel_fast_forward(
inputs_requires_grad = True
elif inputs_requires_grad:
inputs_embeds.requires_grad_(False)
- attention_mask = attention_mask[:, : self.max_seq_length] # Must resize!
+ pass
+ attention_mask = attention_mask[:,:self.max_seq_length] # Must resize!
inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)
- if inputs_requires_grad:
- inputs_embeds.requires_grad_(True)
+ if inputs_requires_grad: inputs_embeds.requires_grad_(True)
+ pass
# Ignore attention_mask
if attention_mask is None:
padding_mask = None
elif self.training:
+ # elif attention_mask is None:
attention_mask = None
padding_mask = None
else:
@@ -1065,10 +733,11 @@ def LlamaModel_fast_forward(
# Must NOT convert to bool - weirdly this causes stuff to error out!
# if attention_mask is not None:
# attention_mask = attention_mask.to(torch.bool)
+ pass
hidden_states = inputs_embeds
- if IS_GRANITE or IS_FALCON_H1: # granite has embedding multiplier
- hidden_states = self.config.embedding_multiplier * hidden_states
+ if IS_GRANITE: #granite has embedding multiplier
+ hidden_states = self.embedding_multiplier * hidden_states
if past_key_values is None and self.training:
use_cache = False
@@ -1077,6 +746,7 @@ def LlamaModel_fast_forward(
# "Unsloth: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`"
# )
# use_cache = False
+ pass
# decoder layers
all_hidden_states = () if output_hidden_states else None
@@ -1088,144 +758,103 @@ def LlamaModel_fast_forward(
boundaries = self._gradient_checkpointing_boundaries
else:
boundaries = None
+ pass
# Check checkpointing method
gradient_checkpointing = False
- if self.gradient_checkpointing and self.training and not use_cache:
+ if (self.gradient_checkpointing and self.training and not use_cache):
gradient_checkpointing = True
+ pass
# Gemma2 has alternating SWA and global attn
- use_static_mask = True
+ use_static_mask = True
dynamic_SWA_mask = None
- dynamic_GA_mask = None
+ dynamic_GA_mask = None
if IS_GEMMA2:
if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
self.SWA_mask = True
- self.GA_mask = False
+ self.GA_mask = False
elif attention_mask is not None:
# Fixes https://github.com/unslothai/unsloth/issues/853
# Unsloth needs a 2D mask, not a [2, 1, n, n] mask!
# https://github.com/pytorch/pytorch/issues/103749
# Need to convert to float and not using bool
- # attention_mask = (1.0 - attention_mask.float()) * torch.finfo(inputs_embeds.dtype).min
+ attention_mask = (1.0 - attention_mask.float()) * torch.finfo(inputs_embeds.dtype).min
dynamic_SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window = self.config.sliding_window,
- )
+ )[0][0]
dynamic_GA_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window = None,
- )
+ )[0][0]
use_static_mask = False
elif not hasattr(self, "SWA_mask"):
if HAS_FLEX_ATTENTION:
# Use Flex Attention instead!
- self.SWA_mask = create_flex_attention_sliding_window_mask(
- self.max_seq_length, self.config.sliding_window
- )
- self.GA_mask = create_flex_attention_causal_mask(self.max_seq_length)
+ self.SWA_mask = create_flex_attention_sliding_window_mask(self.max_seq_length, self.config.sliding_window)
+ self.GA_mask = create_flex_attention_causal_mask(self.max_seq_length)
else:
- n = self.max_seq_length # self.config.max_position_embeddings
+ n = self.max_seq_length # self.config.max_position_embeddings
# masked_fill is making stuff slower!
# self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)
# self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
-
- self.SWA_mask = (
- AttentionMaskConverter(
- is_causal = True,
- sliding_window = self.config.sliding_window,
- )
- .to_causal_4d(
- 1,
- n,
- n,
- dtype = inputs_embeds.dtype,
- device = DEVICE_TYPE_TORCH,
- )
- .squeeze(0)
- .squeeze(0)
- )
-
- self.GA_mask = (
- AttentionMaskConverter(
- is_causal = True,
- )
- .to_causal_4d(
- 1,
- n,
- n,
- dtype = inputs_embeds.dtype,
- device = DEVICE_TYPE_TORCH,
- )
- .squeeze(0)
- .squeeze(0)
- )
+ self.SWA_mask = AttentionMaskConverter(
+ is_causal = True,
+ sliding_window = self.config.sliding_window,
+ )\
+ .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda",)\
+ .squeeze(0).squeeze(0)
+
+ self.GA_mask = AttentionMaskConverter(
+ is_causal = True,
+ )\
+ .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda",)\
+ .squeeze(0).squeeze(0)
pass
+ pass
+ pass
- if (
- IS_ATTENTION_REFACTOR
- and (
- hasattr(self, "rotary_emb")
- or not hasattr(self.layers[0].self_attn, "rotary_emb")
- )
- ) or IS_GRANITE:
+ if (IS_ATTENTION_REFACTOR and (hasattr(self, "rotary_emb") or not hasattr(self.layers[0].self_attn, "rotary_emb"))) or IS_GRANITE:
# Transformers main has made it mandatory to pass position_embeddings
# https://github.com/huggingface/transformers/pull/34858
# Also, transformers 4.45.0 supports granite but with the attention refactor (it always had the refactor)
# unsloth's check for granite too has "version >= 4.45.0 (rightly so)".
# so let granite always use the attention refactor implementation.
-
- self.rotary_emb.extend_rope_embedding(
- hidden_states, self.config.max_position_embeddings
- )
- position_embeddings = self.rotary_emb.get_cached(
- self.config.max_position_embeddings, hidden_states.device.index
- )
+ position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings)
else:
position_embeddings = None
# Go through every layer!
for idx, decoder_layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
+
+ if output_hidden_states: all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
mask = causal_mask
if IS_GEMMA2:
- use_sliding_window = idx % 2 == 0
- if use_sliding_window:
+ if (idx % 2 == 0):
mask = self.SWA_mask if use_static_mask else dynamic_SWA_mask
else:
- mask = self.GA_mask if use_static_mask else dynamic_GA_mask
- kwargs["use_sliding_window"] = use_sliding_window
-
- if gradient_checkpointing and not isinstance(
- decoder_layer, GradientCheckpointingLayer
- ):
+ mask = self. GA_mask if use_static_mask else dynamic_GA_mask
+ pass
+ if gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
- return module(
- *inputs,
- past_key_value,
- output_attentions,
- padding_mask = padding_mask,
- position_embeddings = position_embeddings,
- **kwargs,
- )
-
+ return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings)
return custom_forward
-
+ pass
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
@@ -1240,187 +869,139 @@ def custom_forward(*inputs):
else:
layer_outputs = decoder_layer(
hidden_states,
- causal_mask = mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
+ causal_mask=mask,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_value = past_key_value,
+ output_attentions = output_attentions,
+ use_cache = use_cache,
+ padding_mask = padding_mask,
position_embeddings = position_embeddings,
- **kwargs,
)
hidden_states = layer_outputs[0]
+ pass
- if use_cache:
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
+ if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+ if output_attentions: all_self_attns += (layer_outputs[1],)
+ pass
# Final layernorm
if use_cache:
- if IS_FALCON_H1:
- hidden_states = fast_rms_layernorm_inference(
- self.final_layernorm, hidden_states
- )
- else:
- hidden_states = (
- fast_rms_layernorm_inference_gemma
- if IS_GEMMA
- else fast_rms_layernorm_inference
- )(self.norm, hidden_states)
+ hidden_states = \
+ (fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\
+ (self.norm, hidden_states)
elif IS_COHERE:
hidden_states = self.norm(hidden_states)
- elif IS_FALCON_H1:
- hidden_states = fast_rms_layernorm(
- self.final_layernorm, hidden_states, gemma = IS_GEMMA
- )
else:
hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA)
+ pass
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
+ if output_hidden_states: all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
- return tuple(
- v
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
- if v is not None
- )
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
- last_hidden_state = hidden_states,
- past_key_values = next_cache,
- hidden_states = all_hidden_states,
- attentions = all_self_attns,
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
)
+pass
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
-def _LlamaModel_fast_forward_inference(
- attention_fast_forward_inference = LlamaAttention_fast_forward_inference,
- mlp_fast_forward_inference = fast_swiglu_inference,
+def LlamaModel_fast_forward_inference(
+ self,
+ input_ids,
+ past_key_values,
+ position_ids,
+ attention_mask = None,
):
- # This makes the attention and MLP customisable.
- # Now for models like qwen3 or cohere which use custom attention operations, we can use this function
- def LlamaModel_fast_forward_inference_custom(
- self,
- input_ids,
- past_key_values,
- position_ids,
- attention_mask = None,
- **kwargs,
- ):
- input_ids = input_ids[:, : self.max_seq_length]
- bsz, q_len = input_ids.shape
- hd = self.config.hidden_size
- mlp_size = self.config.intermediate_size
-
- X = self.model.embed_tokens(input_ids)
- X = X.to(_get_dtype(dtype_from_config(self.config)))
- bsz, q_len, hd = X.shape
- assert q_len == 1
- # Get saved buffers to reduce memory movement
- residual = torch.empty(
- (bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
- )
- _XX = torch.empty(
- (2, bsz, q_len, hd), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
- )
- XX, XX2 = _XX[0], _XX[1]
- variance = torch.empty(
- (bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE_TORCH}:0"
- )
- temp_mlp = torch.empty(
- (2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE_TORCH}:0"
- )
- temp_gates, temp_ups = (
- tuple(temp_mlp[0].to(torch.device(x)) for x in range(DEVICE_COUNT)),
- tuple(temp_mlp[1].to(torch.device(x)) for x in range(DEVICE_COUNT)),
+ input_ids = input_ids[:,:self.max_seq_length]
+ bsz, q_len = input_ids.shape
+ hd = self.config.hidden_size
+ mlp_size = self.config.intermediate_size
+
+ X = self.model.embed_tokens(input_ids)
+ X = X.to(self.config.torch_dtype)
+ bsz, q_len, hd = X.shape
+ assert(q_len == 1)
+ # Get saved buffers to reduce memory movement
+ residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
+ _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0")
+ XX, XX2 = _XX[0], _XX[1]
+ variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0")
+ temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
+ temp_gate, temp_up = temp_mlp[0], temp_mlp[1]
+
+ seq_len = past_key_values[0][0].shape[-2]
+ if bsz != 1:
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ (bsz, q_len),
+ X,
+ seq_len,
+ sliding_window = getattr(self.config, "sliding_window", None),
)
+ else:
+ attention_mask = None
+ pass
- seq_len = past_key_values[0][0].shape[-2]
- kv_seq_len = seq_len + 1
- if attention_mask is not None:
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
- attention_mask,
- (bsz, q_len),
- X,
- seq_len,
- sliding_window = getattr(self.config, "sliding_window", None),
- )
- # Pre-convert to bool once for all layers (avoids per-layer .eq(0))
- if attention_mask is not None and attention_mask.dtype != torch.bool:
- attention_mask = attention_mask.eq(0)
- else:
- attention_mask = None
-
- # Compute rotary_seq_len once to avoid per-layer GPU-CPU sync from .item()
- rotary_seq_len = max(kv_seq_len, int(position_ids.max().item()) + 1)
-
- next_decoder_cache = []
-
- for idx, decoder_layer in enumerate(self.model.layers):
- device_index = getattr(decoder_layer, "_per_layer_device_index", 0)
- X, residual, position_ids = move_to_device(
- device_index, X, residual, position_ids
- )
- residual.copy_(X) # residual = X
- X = fast_rms_layernorm_inference(
- decoder_layer.input_layernorm,
- X,
- XX = XX,
- XX2 = XX2,
- variance = variance,
- )
- X, present_key_value = attention_fast_forward_inference(
- decoder_layer.self_attn,
- hidden_states = X,
- past_key_value = past_key_values[idx],
- position_ids = position_ids,
- attention_mask = attention_mask,
- do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
- rotary_seq_len = rotary_seq_len,
- )
- X += residual
-
- residual.copy_(X) # residual = X
- X = fast_rms_layernorm_inference(
- decoder_layer.post_attention_layernorm,
- X,
- XX = XX,
- XX2 = XX2,
- variance = variance,
- )
- X = mlp_fast_forward_inference(
- decoder_layer.mlp,
- X,
- temp_gate = temp_gates[device_index],
- temp_up = temp_ups[device_index],
- )
- X += residual
+ next_decoder_cache = []
- next_decoder_cache.append(present_key_value)
+ for idx, decoder_layer in enumerate(self.model.layers):
+ residual.copy_(X) # residual = X
X = fast_rms_layernorm_inference(
- self.model.norm,
+ decoder_layer.input_layernorm,
X,
XX = XX,
XX2 = XX2,
variance = variance,
)
-
- return BaseModelOutputWithPast(
- last_hidden_state = X,
- past_key_values = next_decoder_cache,
- hidden_states = [],
- attentions = [],
+ X, present_key_value = LlamaAttention_fast_forward_inference(
+ decoder_layer.self_attn,
+ hidden_states = X,
+ past_key_value = past_key_values[idx],
+ position_ids = position_ids,
+ attention_mask = attention_mask,
+ do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
+ X += residual
- return LlamaModel_fast_forward_inference_custom
-
+ residual.copy_(X) # residual = X
+ X = fast_rms_layernorm_inference(
+ decoder_layer.post_attention_layernorm,
+ X,
+ XX = XX,
+ XX2 = XX2,
+ variance = variance,
+ )
+ X = fast_swiglu_inference(
+ decoder_layer.mlp,
+ X,
+ temp_gate = temp_gate,
+ temp_up = temp_up,
+ )
+ X += residual
+
+ next_decoder_cache.append(present_key_value)
+ pass
+ X = fast_rms_layernorm_inference(
+ self.model.norm,
+ X,
+ XX = XX,
+ XX2 = XX2,
+ variance = variance,
+ )
-# For ensuring backwards compatibility, we create LlamaModel_fast_forward_inference that is consumed by other models
-LlamaModel_fast_forward_inference = _LlamaModel_fast_forward_inference()
+ return BaseModelOutputWithPast(
+ last_hidden_state = X,
+ past_key_values = next_decoder_cache,
+ hidden_states = [],
+ attentions = [],
+ )
+pass
def CausalLM_fast_forward(fast_forward_inference):
@@ -1439,8 +1020,7 @@ def _CausalLM_fast_forward(
return_dict: Optional[bool] = None,
num_logits_to_keep: Optional[int] = 0,
logits_to_keep: Optional[int] = 0,
- *args,
- **kwargs,
+ *args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if past_key_values is not None:
outputs = fast_forward_inference(
@@ -1449,26 +1029,16 @@ def _CausalLM_fast_forward(
past_key_values,
position_ids = position_ids,
attention_mask = attention_mask,
- **kwargs,
)
else:
- causal_mask = (
- xformers.attn_bias.LowerTriangularMask() if HAS_XFORMERS else None
- )
+ causal_mask = xformers.attn_bias.LowerTriangularMask() if HAS_XFORMERS else None
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
outputs = self.model(
@@ -1482,8 +1052,8 @@ def _CausalLM_fast_forward(
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
return_dict = return_dict,
- **kwargs,
)
+ pass
hidden_states = outputs[0]
bsz, q_len, hd = hidden_states.shape
@@ -1491,17 +1061,16 @@ def _CausalLM_fast_forward(
lm_head_device = lm_head.device
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
- logit_scaling = getattr(self.config, "logit_scale", 0)
+ logit_scaling = getattr(self.config, "logit_scale", 0)
dtype = lm_head.dtype
num_logits_to_keep = max(num_logits_to_keep, logits_to_keep)
-
+
# Move items to same device as lm_head
hidden_states = hidden_states.to(lm_head_device)
- if labels is not None:
- labels = labels.to(lm_head_device)
+ if labels is not None: labels = labels.to(lm_head_device)
# Output last hidden states without logits if asked
- if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1":
+ if self.training and os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1":
if num_logits_to_keep != 0:
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
return CausalLMOutputWithPast(
@@ -1509,8 +1078,9 @@ def _CausalLM_fast_forward(
logits = hidden_states,
past_key_values = outputs.past_key_values,
hidden_states = outputs.hidden_states,
- attentions = outputs.attentions,
+ attentions= outputs.attentions,
)
+ pass
if bsz == 1 and q_len == 1:
logits = torch.mv(lm_head, hidden_states.ravel().to(dtype))
@@ -1520,67 +1090,50 @@ def _CausalLM_fast_forward(
else:
RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
# < 1024 Normal Unsloth uses less VRAM!
- if bsz * q_len <= 1024 and not RETURN_LOGITS:
- # Use unsloth_fused_ce_loss which actually calculates the best chunk size to reduce VRAM usage
- RETURN_LOGITS = False
-
- if not RETURN_LOGITS and labels is not None:
- n_items = kwargs.get("num_items_in_batch", None)
- if n_items is None:
- n_items = kwargs.get("n_items", None)
-
- if self.config.model_type == "falcon_h1":
- hidden_states = hidden_states * self.config.lm_head_multiplier
-
- ### DISABLED since T4 breaks
- # OutOfResources: out of resource: shared memory, Required: 98304, Hardware limit: 65536. Reducing block sizes or `num_stages` may help.
- # loss = fused_linear_cross_entropy(
- # hidden_states = hidden_states,
- # lm_weight = lm_head,
- # labels = labels,
- # num_items_in_batch = n_items,
- # logit_softcapping = logit_softcapping,
- # )
- loss = unsloth_fused_ce_loss(
- trainer = None,
- hidden_states = hidden_states,
- lm_head_weight = lm_head,
- lm_head_bias = None,
- labels = labels,
- mask = None,
- n_items = n_items,
- scaling = getattr(self, "accelerator_scaler", None),
- target_gb = None,
- torch_compile = True,
- logit_softcapping = logit_softcapping,
+ if bsz*q_len <= 1024: RETURN_LOGITS = True
+
+ if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None:
+
+ n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)
+ loss = fused_linear_cross_entropy(
+ hidden_states = hidden_states,
+ lm_weight = lm_head,
+ labels = labels,
+ num_items_in_batch = n_items,
+ logit_softcapping = logit_softcapping,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
output = CausalLMOutputWithPast(
- loss = loss,
- logits = EMPTY_LOGITS,
- past_key_values = outputs.past_key_values,
- hidden_states = outputs.hidden_states,
- attentions = outputs.attentions,
+ loss=loss,
+ logits=EMPTY_LOGITS,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
)
return output
pass
logits = self.lm_head(hidden_states.to(dtype))
+ pass
+
+ torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None)
+ if torch_dtype is not None:
+ logits = logits.to(torch_dtype)
+ else:
+ raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!")
+ pass
- logits = logits.to(_get_dtype(dtype_from_config(self.config)))
loss = None
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
- logit_scaling = getattr(self.config, "logit_scale", 0)
+ logit_scaling = getattr(self.config, "logit_scale", 0)
if self.config.model_type == "granite":
# granite uses logit_scaling as key and they divide by the scale unlike cohere
# notice that for granite, logits_scale is 16 and for cohere it is 0.125 (aka 1/8) in their respective configs
# granite: https://github.com/huggingface/transformers/blob/4d1d0f29a493098e6bc6b904b82e29cb331827f5/src/transformers/models/granite/modeling_granite.py#L1103
# cohere: https://github.com/huggingface/transformers/blob/4d1d0f29a493098e6bc6b904b82e29cb331827f5/src/transformers/models/cohere/modeling_cohere.py#L1176
logit_scaling = 1 / getattr(self.config, "logits_scaling", 1)
- elif self.config.model_type == "falcon_h1":
- logit_scaling = self.config.lm_head_multiplier
if labels is not None:
shift_logits = logits
@@ -1591,20 +1144,13 @@ def _CausalLM_fast_forward(
shift_labels = torch.empty_like(labels)
shift_labels[..., :-1] = labels[..., 1:]
shift_labels[..., -1] = -100
- mask_packed_sequence_boundaries(
- shift_labels,
- kwargs.get("packed_seq_lengths"),
- )
# shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
- n_items = kwargs.get("num_items_in_batch", None)
- if n_items is None:
- n_items = kwargs.get("n_items", None)
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
logit_softcapping = logit_softcapping,
- logit_scaling = logit_scaling,
- n_items = n_items,
+ logit_scaling = logit_scaling,
+ n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None),
)
else:
if logit_scaling != 0:
@@ -1612,32 +1158,39 @@ def _CausalLM_fast_forward(
logits = logit_scaling * logits
else:
logits *= logit_scaling
+ pass
+ pass
if logit_softcapping != 0:
if logits.requires_grad:
logits = (1.0 / logit_softcapping) * logits
logits = torch.tanh(logits)
logits = logit_softcapping * logits
else:
- logits *= 1.0 / logit_softcapping
- logits.tanh_()
+ logits *= (1.0 / logit_softcapping)
+ torch.tanh(logits, out = logits)
logits *= logit_softcapping
+ pass
+ pass
+ pass
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
+
return CausalLMOutputWithPast(
loss = loss,
logits = logits,
past_key_values = outputs.past_key_values,
hidden_states = outputs.hidden_states,
- attentions = outputs.attentions,
+ attentions= outputs.attentions,
)
-
+ pass
return _CausalLM_fast_forward
+pass
@torch._disable_dynamo
-def PeftModel_fast_forward(
+def PeftModelForCausalLM_fast_forward(
self,
input_ids = None,
causal_mask = None,
@@ -1652,218 +1205,173 @@ def PeftModel_fast_forward(
logits_to_keep = 0,
**kwargs,
):
- is_classification = "Classification" in str(type(self.base_model.model))
- if is_classification:
- return self.base_model(
- input_ids = input_ids,
- attention_mask = attention_mask,
- inputs_embeds = inputs_embeds,
- labels = labels,
- output_attentions = output_attentions,
- output_hidden_states = output_hidden_states,
- return_dict = return_dict,
- **kwargs,
- )
- else:
- return self.base_model(
- input_ids = input_ids,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- inputs_embeds = inputs_embeds,
- labels = labels,
- output_attentions = output_attentions,
- output_hidden_states = output_hidden_states,
- return_dict = return_dict,
- num_logits_to_keep = num_logits_to_keep,
- logits_to_keep = logits_to_keep,
- **kwargs,
- )
-
-
-def _get_rope_theta(config, default = 10000.0):
- """Get rope_theta from config, handling both transformers 4.x and 5.x."""
- try:
- return config.rope_theta
- except (AttributeError, KeyError):
- pass
- rp = getattr(config, "rope_parameters", None)
- if isinstance(rp, dict):
- return rp.get("rope_theta", default)
- return default
+ return self.base_model(
+ input_ids = input_ids,
+ causal_mask = causal_mask,
+ attention_mask = attention_mask,
+ inputs_embeds = inputs_embeds,
+ labels = labels,
+ output_attentions = output_attentions,
+ output_hidden_states = output_hidden_states,
+ return_dict = return_dict,
+ num_logits_to_keep = num_logits_to_keep,
+ logits_to_keep = logits_to_keep,
+ **kwargs,
+ )
+pass
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
-# Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.
+# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
class LlamaRotaryEmbedding(torch.nn.Module):
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
- def __init__(
- self,
- dim = None,
- max_position_embeddings = 2048,
- base = 10000,
- device = None,
- config = None, # [TODO] Hack to pass in config - need to remove later
+ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
+ config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
if config is not None:
# [TODO] Hack to pass in config - need to remove later
- base = _get_rope_theta(config, default = base)
- partial_rotary_factor = (
- config.partial_rotary_factor
- if hasattr(config, "partial_rotary_factor")
- else 1.0
- )
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = getattr(config, "head_dim", None)
- if dim is None:
- dim = int((config.hidden_size // config.num_attention_heads))
- device = DEVICE_TYPE_TORCH
+ if dim is None: dim = int((config.hidden_size // config.num_attention_heads))
+ device = "cuda"
max_position_embeddings = config.max_position_embeddings
+ pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
- self.multi_gpu_cos_cached = [None] * DEVICE_COUNT
- self.multi_gpu_sin_cached = [None] * DEVICE_COUNT
-
- # Normal Llama-3 RoPE
- inv_freq = 1.0 / (
- self.base
- ** (
- torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float()
- / self.dim
- )
- )
- inv_freq = self._apply_inv_freq_scaling(inv_freq)
- self.register_buffer("inv_freq", inv_freq, persistent = False)
# Build here to make `torch.jit.trace` work.
- for device_idx in range(DEVICE_COUNT):
- self._set_cos_sin_cache(
- seq_len = self.current_rope_size,
- device = torch.device(device_idx),
- dtype = torch.get_default_dtype(),
- )
-
- # dummy so that patch_utils doesn't fail for now
- self.cos_cached = torch.empty(
- 1, device = get_current_device(), dtype = torch.get_default_dtype()
- )
- self.sin_cached = torch.empty(
- 1, device = get_current_device(), dtype = torch.get_default_dtype()
- )
-
- def _apply_inv_freq_scaling(self, inv_freq):
- """Override to apply custom inv_freq scaling (e.g., extended RoPE)."""
- return inv_freq
-
- def _apply_time_scaling(self, t):
- """Override to apply custom time scaling (e.g., linear scaling)."""
- return t
+ self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
+ pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
- t = torch.arange(
- self.current_rope_size, device = self.inv_freq.device, dtype = torch.int64
- ).float()
- t = self._apply_time_scaling(t)
+ inv_freq = 1.0 / (
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
+ )
+ t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
- freqs = torch.outer(t, self.inv_freq)
+ freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim = -1)
- cos = emb.cos().to(dtype = dtype, device = device, non_blocking = True)
- sin = emb.sin().to(dtype = dtype, device = device, non_blocking = True)
- self.multi_gpu_cos_cached[device.index] = cos
- self.multi_gpu_sin_cached[device.index] = sin
- return cos, sin
-
- def forward(self, x, position_ids = None, seq_len = None):
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+ pass
+
+ def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
- if seq_len is not None and seq_len > self.current_rope_size:
- self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
+ if seq_len > self.current_rope_size:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
- device_index = x.device.index
return (
- self.multi_gpu_cos_cached[device_index][:seq_len],
- self.multi_gpu_sin_cached[device_index][:seq_len],
+ self.cos_cached[:seq_len].to(dtype = x.dtype),
+ self.sin_cached[:seq_len].to(dtype = x.dtype),
)
+ pass
- def get_cached(self, seq_len = None, device_index = None):
- if device_index is None:
- device_index = get_current_device()
- return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[
- device_index
- ]
+ def get_cached(self, seq_len = None):
+ return self.cos_cached, self.sin_cached
+ pass
def extend_rope_embedding(self, x, seq_len):
- if seq_len <= self.current_rope_size:
- return
+ if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
- for device_idx in range(DEVICE_COUNT):
- self._set_cos_sin_cache(
- self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype
- )
+ self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype)
+ pass
+pass
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
-
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
- def __init__(
- self,
- dim = None,
- max_position_embeddings = 2048,
- base = 10000,
- device = None,
- scaling_factor = 1.0,
- config = None, # [TODO] Hack to pass in config - need to remove later
+ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
+ config = None, # [TODO] Hack to pass in config - need to remove later
):
self.scaling_factor = scaling_factor
- super().__init__(
- dim = dim,
- max_position_embeddings = max_position_embeddings,
- base = base,
- device = device,
- config = config,
+ super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
+ pass
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.current_rope_size = seq_len
+ inv_freq = 1.0 / (
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
+ t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
+ t = t / self.scaling_factor
- def _apply_time_scaling(self, t):
- """Apply linear scaling to time indices."""
- return t / self.scaling_factor
+ freqs = torch.outer(t, inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+ pass
+pass
# See https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py#L736
# For Llama 3.1
-class LlamaExtendedRotaryEmbedding(LlamaRotaryEmbedding):
- def __init__(
- self,
- dim = None,
- max_position_embeddings = 2048,
- base = 10000,
- device = None,
- config = None, # [TODO] Hack to pass in config - need to remove later
+class LlamaExtendedRotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
+ config = None, # [TODO] Hack to pass in config - need to remove later
):
- super().__init__(
- dim = dim,
- max_position_embeddings = max_position_embeddings,
- base = base,
- device = device,
- config = config,
+ super().__init__()
+ if config is not None:
+ # [TODO] Hack to pass in config - need to remove later
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+ dim = int((config.hidden_size // config.num_attention_heads))
+ device = "cuda"
+ max_position_embeddings = config.max_position_embeddings
+ pass
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
+ self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
+
+ # Normal Llama-3 RoPE
+ inv_freq = 1.0 / (
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
+ inv_freq = self.apply_scaling(inv_freq)
+ self.register_buffer("inv_freq", inv_freq, persistent = False)
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
+ pass
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
+ # in FP32. They are applied (multiplied) in FP32 as well.
+ self.current_rope_size = seq_len
+
+ t = torch.arange(self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64).float()
+
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False)
+ pass
# From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41
- def _apply_inv_freq_scaling(self, freqs: torch.Tensor):
+ def apply_scaling(self, freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
@@ -1885,172 +1393,146 @@ def _apply_inv_freq_scaling(self, freqs: torch.Tensor):
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
- return torch.tensor(new_freqs, dtype = freqs.dtype, device = freqs.device)
+ return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
+ pass
+
+ def forward(self, x, position_ids=None, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.current_rope_size:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+ return (
+ self.cos_cached[:seq_len].to(dtype = x.dtype),
+ self.sin_cached[:seq_len].to(dtype = x.dtype),
+ )
+ pass
+
+ def get_cached(self, seq_len = None):
+ return self.cos_cached, self.sin_cached
+ pass
+
+ def extend_rope_embedding(self, x, seq_len):
+ if seq_len <= self.current_rope_size: return
+ # Iteratively grow by increments of 8192
+ self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
+ self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype)
+ pass
+pass
class LongRopeRotaryEmbedding(torch.nn.Module):
# For Phi 3.5 128K https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/modeling_phi3.py
- def __init__(
- self,
+ def __init__(self,
dim = None,
max_position_embeddings = 131072,
original_max_position_embeddings = 4096,
base = 10000,
short_factor = None,
- long_factor = None,
+ long_factor = None,
device = None,
- config = None, # [TODO] Hack to pass in config - need to remove later
+ config = None, # [TODO] Hack to pass in config - need to remove later
):
super().__init__()
- assert short_factor is not None
- assert long_factor is not None
- assert type(original_max_position_embeddings) is int
+ assert(short_factor is not None)
+ assert(long_factor is not None)
+ assert(type(original_max_position_embeddings) is int)
if config is not None:
# [TODO] Hack to pass in config - need to remove later
- base = _get_rope_theta(config, default = base)
- partial_rotary_factor = (
- config.partial_rotary_factor
- if hasattr(config, "partial_rotary_factor")
- else 1.0
- )
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads))
- device = DEVICE_TYPE_TORCH
+ device = "cuda"
max_position_embeddings = config.max_position_embeddings
+ pass
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
- self.current_rope_size = min(
- original_max_position_embeddings, self.max_position_embeddings
- )
- self.multi_gpu_short_cos_cached = [None] * DEVICE_COUNT
- self.multi_gpu_short_sin_cached = [None] * DEVICE_COUNT
- self.multi_gpu_long_cos_cached = [None] * DEVICE_COUNT
- self.multi_gpu_long_sin_cached = [None] * DEVICE_COUNT
+ self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings)
# Long RoPE similar to RoPE except short sequences have 1 cos / sin
# and long sequences have another cos / sin
- inv_freq_shape = (
- torch.arange(0, self.dim, 2, dtype = torch.int64, device = "cpu").float()
- / self.dim
- )
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim
short_factor = torch.tensor(short_factor, device = "cpu", dtype = torch.float32)
- long_factor = torch.tensor(long_factor, device = "cpu", dtype = torch.float32)
+ long_factor = torch.tensor(long_factor, device = "cpu", dtype = torch.float32)
short_inv_freq = 1.0 / (short_factor * self.base**inv_freq_shape)
- long_inv_freq = 1.0 / (long_factor * self.base**inv_freq_shape)
+ long_inv_freq = 1.0 / (long_factor * self.base**inv_freq_shape)
# Phi-3 Scale factor
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
- scaling_factor = math.sqrt(
- 1 + math.log(scale) / math.log(self.original_max_position_embeddings)
- )
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
+ pass
self.scaling_factor = scaling_factor
# Short and long inv_freq
- self.register_buffer("short_inv_freq", short_inv_freq, persistent = False)
- self.register_buffer("long_inv_freq", long_inv_freq, persistent = False)
-
- # Build here to make `torch.jit.trace` work.
- # Initialize short sequences cache for all devices
- dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16
- t = torch.arange(
- original_max_position_embeddings,
- device = self.short_inv_freq.device,
- dtype = torch.int64,
- ).float()
- freqs = torch.outer(t, self.short_inv_freq)
- emb = torch.cat((freqs, freqs), dim = -1)
-
- for device_idx in range(DEVICE_COUNT):
- device_obj = torch.device(device_idx)
- cos_cached = (emb.cos() * self.scaling_factor).to(
- dtype = dtype, device = device_obj, non_blocking = True
- )
- sin_cached = (emb.sin() * self.scaling_factor).to(
- dtype = dtype, device = device_obj, non_blocking = True
- )
- self.multi_gpu_short_cos_cached[device_idx] = cos_cached
- self.multi_gpu_short_sin_cached[device_idx] = sin_cached
-
- # dummy so that patch_utils doesn't fail for now
- self.short_cos_cached = torch.empty(
- 1, device = get_current_device(), dtype = torch.get_default_dtype()
- )
- self.short_sin_cached = torch.empty(
- 1, device = get_current_device(), dtype = torch.get_default_dtype()
- )
- self.long_cos_cached = torch.empty(
- 1, device = get_current_device(), dtype = torch.get_default_dtype()
- )
- self.long_sin_cached = torch.empty(
- 1, device = get_current_device(), dtype = torch.get_default_dtype()
- )
+ self.register_buffer("short_inv_freq", short_inv_freq, persistent = False)
+ self.register_buffer("long_inv_freq", long_inv_freq, persistent = False)
+ # Build here to make `torch.jit.trace` work.
+ # self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
+
+ # Short sequences
+ dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16
+ t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float()
+ freqs = torch.outer(t, self.short_inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
+ sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
+ self.register_buffer("short_cos_cached", cos_cached, persistent=False)
+ self.register_buffer("short_sin_cached", sin_cached, persistent=False)
+ pass
def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len
-
- t = torch.arange(
- self.current_rope_size, device = self.long_inv_freq.device, dtype = torch.int64
- ).float()
+
+ t = torch.arange(self.current_rope_size, device=self.long_inv_freq.device, dtype=torch.int64).float()
# Long sequences
freqs = torch.outer(t, self.long_inv_freq)
- emb = torch.cat((freqs, freqs), dim = -1)
- cos_cached = (emb.cos() * self.scaling_factor).to(
- dtype = dtype, device = device, non_blocking = True
- )
- sin_cached = (emb.sin() * self.scaling_factor).to(
- dtype = dtype, device = device, non_blocking = True
- )
- self.multi_gpu_long_cos_cached[device.index] = cos_cached
- self.multi_gpu_long_sin_cached[device.index] = sin_cached
- return cos_cached, sin_cached
-
- def forward(self, x, position_ids = None, seq_len = None):
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
+ sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True)
+ self.register_buffer("long_cos_cached", cos_cached, persistent=False)
+ self.register_buffer("long_sin_cached", sin_cached, persistent=False)
+ pass
+
+ def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
- if seq_len is not None and seq_len > self.current_rope_size:
- self._set_cos_sin_cache(seq_len = seq_len, device = x.device, dtype = x.dtype)
-
- device_index = x.device.index
+ if seq_len > self.current_rope_size:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
- if seq_len is not None and seq_len < self.original_max_position_embeddings:
+ if seq_len < self.original_max_position_embeddings:
return (
- self.multi_gpu_short_cos_cached[device_index][:seq_len],
- self.multi_gpu_short_sin_cached[device_index][:seq_len],
+ self.short_cos_cached[:seq_len].to(dtype = x.dtype),
+ self.short_sin_cached[:seq_len].to(dtype = x.dtype),
)
else:
return (
- self.multi_gpu_long_cos_cached[device_index][:seq_len],
- self.multi_gpu_long_sin_cached[device_index][:seq_len],
+ self.long_cos_cached[:seq_len].to(dtype = x.dtype),
+ self.long_sin_cached[:seq_len].to(dtype = x.dtype),
)
+ pass
+ pass
- def get_cached(self, seq_len = None, device_index = None):
- if device_index is None:
- device_index = get_current_device()
- if seq_len is not None and seq_len < self.original_max_position_embeddings:
- return self.multi_gpu_short_cos_cached[
- device_index
- ], self.multi_gpu_short_sin_cached[device_index]
- return self.multi_gpu_long_cos_cached[
- device_index
- ], self.multi_gpu_long_sin_cached[device_index]
+ def get_cached(self, seq_len = None):
+ if seq_len < self.original_max_position_embeddings:
+ return self.short_cos_cached, self.short_sin_cached
+ return self.long_cos_cached, self.long_sin_cached
+ pass
def extend_rope_embedding(self, x, seq_len):
- if seq_len <= self.current_rope_size:
- return
+ if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192
- for device_idx in range(DEVICE_COUNT):
- self._set_cos_sin_cache(
- self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype
- )
+ self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype)
+ pass
+pass
def unsloth_fast_generate(
@@ -2058,43 +1540,18 @@ def unsloth_fast_generate(
*args,
**kwargs,
):
- # If the model starts out in training mode, restore training mode after generation
- restore_training_mode = self.training
-
FastLlamaModel.for_inference(self)
- # Unpack BatchEncoding passed as input_ids for backwards compatibility.
- # Old notebooks do model.generate(input_ids=tokenizer(...)) where the tokenizer
- # output is a BatchEncoding (dict-like). Transformers v5 generate() calls
- # .shape on it directly and crashes. Unpack into separate kwargs so both
- # v4 and v5 work transparently.
- _maybe_encoding = kwargs.get("input_ids", None)
- if (
- _maybe_encoding is not None
- and not isinstance(_maybe_encoding, torch.Tensor)
- and hasattr(_maybe_encoding, "items")
- ):
- batch_data = kwargs.pop("input_ids")
- for key, val in batch_data.items():
- kwargs.setdefault(key, val)
-
- dtype = _get_dtype(dtype_from_config(self.config))
+ dtype = _get_dtype(self.config.torch_dtype)
if hasattr(self, "config") and hasattr(self.config, "max_position_embeddings"):
- if (
- "input_ids" in kwargs
- and kwargs["input_ids"] is not None
- and "max_new_tokens" in kwargs
- ):
- _ids = kwargs["input_ids"]
- if hasattr(_ids, "shape") and (
- _ids.shape[-1] + kwargs["max_new_tokens"]
- > self.config.max_position_embeddings
- ):
+ if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs:
+ if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > self.config.max_position_embeddings:
raise ValueError(
- f"Unsloth: input length {_ids.shape[-1]} + max_new_tokens {kwargs['max_new_tokens']} exceeds the maximum sequence length of {self.config.max_position_embeddings}!\n"
- "You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`."
+ f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {model.config.max_position_embeddings}!\n'\
+ 'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.'
)
+ pass
# Must patch accelerate for Xformers
# if accelerate_new_send_to_device is not None:
@@ -2105,10 +1562,7 @@ def unsloth_fast_generate(
# For newer HF
kwargs["cache_implementation"] = "dynamic"
# For num_logits_to_keep
- num_logits_to_keep = kwargs.get("num_logits_to_keep", None)
- logits_to_keep = kwargs.get("logits_to_keep", None)
- if num_logits_to_keep is None and logits_to_keep is None:
- kwargs["num_logits_to_keep"] = 1
+ kwargs["num_logits_to_keep"] = 1
# Remove token_type_ids
kwargs.pop("token_type_ids", None)
@@ -2121,210 +1575,140 @@ def unsloth_fast_generate(
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
# Mixed precision autocast
- with (
- _get_inference_mode_context_manager(self),
- torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype),
- ):
+ with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype):
output = self._old_generate(*args, **kwargs)
+ pass
# Return accelerate back
# if accelerate_new_send_to_device is not None:
# accelerate.utils.operations.send_to_device = accelerate_old_send_to_device
# pass
- if restore_training_mode:
- FastLlamaModel.for_training(self)
+ FastLlamaModel.for_training(self)
return output
+pass
class FastLlamaModel:
- @staticmethod
- def _prepare_for_qat(model, qat_scheme):
- model = _prepare_model_for_qat(model, qat_scheme)
- return model
@staticmethod
def pre_patch():
init_name, function = patch_llama_rope_scaling(
- model_name = "llama",
- rope_module = LlamaRotaryEmbedding,
- scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
+ model_name = "llama",
+ rope_module = LlamaRotaryEmbedding,
+ scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
extended_rope_module = LlamaExtendedRotaryEmbedding,
- attention_module = LlamaAttention,
- longrope_module = LongRopeRotaryEmbedding,
+ attention_module = LlamaAttention,
+ longrope_module = LongRopeRotaryEmbedding,
)
if init_name is not None:
exec(function, globals())
- LlamaAttention.__init__ = eval(init_name)
- LlamaAttention.forward = LlamaAttention_fast_forward
- LlamaSdpaAttention.forward = LlamaAttention_fast_forward
+ LlamaAttention.__init__ = eval(init_name)
+ pass
+ LlamaAttention .forward = LlamaAttention_fast_forward
+ LlamaSdpaAttention .forward = LlamaAttention_fast_forward
LlamaFlashAttention2.forward = LlamaAttention_fast_forward
- LlamaDecoderLayer.forward = LlamaDecoderLayer_fast_forward
- LlamaModel.forward = LlamaModel_fast_forward
- LlamaForCausalLM.forward = CausalLM_fast_forward(
- LlamaModel_fast_forward_inference
- )
- PeftModelForCausalLM.forward = PeftModel_fast_forward
+ LlamaDecoderLayer .forward = LlamaDecoderLayer_fast_forward
+ LlamaModel .forward = LlamaModel_fast_forward
+ LlamaForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
+ PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(LlamaForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
- # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.
+ # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.llama.modeling_llama
-
- transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = (
- LlamaRotaryEmbedding
- )
- transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = (
- LlamaLinearScalingRotaryEmbedding
- )
+ transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding
+ transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = LlamaLinearScalingRotaryEmbedding
return
+ pass
+
@staticmethod
def from_pretrained(
- model_name = "unsloth/llama-3-8b-bnb-4bit",
- max_seq_length = None,
- dtype = None,
- load_in_4bit = True,
- token = None,
- device_map = "sequential",
- rope_scaling = None,
- fix_tokenizer = True,
- model_patcher = None,
- tokenizer_name = None,
+ model_name = "unsloth/llama-3-8b-bnb-4bit",
+ max_seq_length = None,
+ dtype = None,
+ load_in_4bit = True,
+ token = None,
+ device_map = "sequential",
+ rope_scaling = None,
+ fix_tokenizer = True,
+ model_patcher = None,
+ tokenizer_name = None,
trust_remote_code = False,
- revision = None,
- fast_inference = False, # uses vLLM
+
+ fast_inference = False, # uses vLLM
gpu_memory_utilization = 0.5,
- float8_kv_cache = False,
- random_state = 3407,
- max_lora_rank = 16,
+ float8_kv_cache = False,
+ random_state = 3407,
+ max_lora_rank = 16,
disable_log_stats = False,
- unsloth_vllm_standby = False,
- num_labels = None,
- qat_scheme = None,
- load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
**kwargs,
):
os.environ["UNSLOTH_USE_NEW_MODEL"] = "0"
if trust_remote_code:
if fast_inference:
- raise NotImplementedError(
- "Unsloth: Fast inference does not support `trust_remote_code` yet."
- )
+ raise NotImplementedError("Unsloth: Fast inference does not support `trust_remote_code` yet.")
print(
- "Unsloth: WARNING `trust_remote_code` is True.\n"
+ "Unsloth: WARNING `trust_remote_code` is True.\n"\
"Are you certain you want to do remote code execution?"
)
+ pass
if fast_inference:
- if not is_vLLM_available():
- print("Unsloth: vLLM is not installed! Will use Unsloth inference!")
+ import platform
+ if platform.system().lower() == 'windows':
+ print("Unsloth: vLLM does not work in Windows! Will use Unsloth inference!")
fast_inference = False
- if DEVICE_TYPE == "cuda":
- major_version, minor_version = torch.cuda.get_device_capability()
- if major_version < 7:
- print(
- "Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!"
- )
- fast_inference = False
- elif DEVICE_TYPE == "hip":
- fast_inference = True
- if (
- unsloth_vllm_standby
- and os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "0"
- ):
- raise RuntimeError(
- "Unsloth: `unsloth_vllm_standby` is True, but environment variable `UNSLOTH_VLLM_STANDBY` is not set to 1!"
- )
+ major_version, minor_version = torch.cuda.get_device_capability()
+ if major_version < 7:
+ print("Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!")
+ fast_inference = False
+ pass
- token = hf_login(token)
- if model_patcher is None:
- model_patcher = FastLlamaModel
+ if token is None: token = get_token()
+ if model_patcher is None: model_patcher = FastLlamaModel
SUPPORTS_BFLOAT16 = is_bfloat16_supported()
-
- if DEVICE_TYPE == "cuda":
- gpu_stats = torch.cuda.get_device_properties(0)
- gpu_stats_name = (
- gpu_stats.name + ". " if gpu_stats.name != "" else "NVIDIA GPU Device. "
- )
- gpu_version = torch.version.cuda
- gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}."
- try:
- vllm_version = f" vLLM: {importlib_version('vllm')}."
- except:
- vllm_version = ""
- elif DEVICE_TYPE == "hip":
- gpu_stats = torch.cuda.get_device_properties(0)
- gpu_stats_name = resolve_hip_gpu_stats_name(gpu_stats)
- gpu_version = torch.version.hip
- gpu_stats_snippet = f"ROCm Toolkit: {gpu_version}."
- try:
- vllm_version = f" vLLM: {importlib_version('vllm')}."
- except:
- vllm_version = ""
- elif DEVICE_TYPE == "xpu":
- gpu_stats = torch.xpu.get_device_properties(0)
- gpu_stats_name = (
- gpu_stats.name + ". " if gpu_stats.name != "" else "Intel XPU Device. "
- )
- gpu_version = torch.version.xpu
- gpu_stats_snippet = f"Intel Toolkit: {gpu_version}."
- try:
- vllm_version = f" vLLM: {importlib_version('vllm')}."
- except:
- vllm_version = ""
- else:
- raise ValueError(f"Unsloth: Unsupported device type: {DEVICE_TYPE}")
-
+ gpu_stats = torch.cuda.get_device_properties(0)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
- statistics = (
- f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"
- f" {chr(92)}{chr(92)} /| {gpu_stats_name}Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"
- f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"
- f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"
- f' "-____-" Free license: http://github.com/unslothai/unsloth'
- )
+ from importlib.metadata import version as importlib_version
+ try: vllm_version = f" vLLM: {importlib_version('vllm')}."
+ except: vllm_version = ""
+ statistics = \
+ f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"\
+ f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\
+ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\
+ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
+ f' "-____-" Free license: http://github.com/unslothai/unsloth'
print(statistics)
# Warn about fast transfers
- if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ:
- old_hf_transfer = os.environ["HF_HUB_ENABLE_HF_TRANSFER"]
- if old_hf_transfer in ("False", "false"):
- old_hf_transfer = "0"
- if old_hf_transfer in ("True", "true"):
- old_hf_transfer = "1"
- else:
- old_hf_transfer = "0"
- if old_hf_transfer == "1":
- print(
- "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!"
- )
- if old_hf_transfer != "0":
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+ old_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0")
+ if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1":
+ print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!")
+ pass
+ # Return old flag
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
model_patcher.pre_patch()
- # For debugging - we use a download counter to see if environments are not breaking or if HF is down
- get_statistics(kwargs.get("local_files_only", False))
+ get_statistics() # For debugging - we use a download counter to see if environments are not breaking
if dtype is None:
dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
- logger.warning_once(
- "Device does not support bfloat16. Will change to float16."
- )
+ logger.warning_once("Device does not support bfloat16. Will change to float16.")
dtype = torch.float16
# elif dtype == torch.float16 and SUPPORTS_BFLOAT16:
# logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.")
# dtype = torch.bfloat16
- assert (
- dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32
- )
+ assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
# RoPE Scaling
model_config = AutoConfig.from_pretrained(
@@ -2332,113 +1716,81 @@ def from_pretrained(
token = token,
attn_implementation = "sdpa",
)
- model_config.model_name = model_name
model_max_seq_length = model_config.max_position_embeddings
- verify_fp8_support_if_applicable(model_config)
-
# Check if RoPE Scaling is even allowed
model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__]
- IS_FALCON_H1 = model_config.model_type.startswith("falcon_h1")
-
- preferred_attn_impl = (
- prefer_flex_attn_if_supported(model_function, model_config) or "eager"
- )
-
has_rope_scaling = False
try:
- with open(inspect.getfile(model_function), "r", encoding = "utf-8") as file:
+ with open(inspect.getfile(model_function), "r") as file:
has_rope_scaling = "self.config.rope_scaling" in file.read()
- except:
- pass
+ except: pass
has_rope_scaling = True
- # If max_seq_length is not specified, use maximum from config
+ # If max_seq_length is not specified, use maximum fron config
if max_seq_length is None:
max_seq_length = model_max_seq_length
+ pass
if (rope_scaling is None) and (max_seq_length > model_max_seq_length):
+
rope_scaling = max_seq_length / model_max_seq_length
if fast_inference:
- raise NotImplementedError(
- "Unsloth: Fast inference does not yet work with RoPE Scaling."
- )
+ raise NotImplementedError("Unsloth: Fast inference does not yet work with RoPE Scaling.")
logger.warning_once(
- f"Unsloth: {model_name} can only handle sequence lengths of at most "
- f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "
- f"{round(rope_scaling, 3)}, it can be magically be extended to "
+ f"Unsloth: {model_name} can only handle sequence lengths of at most "\
+ f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\
+ f"{round(rope_scaling, 3)}, it can be magically be extended to "\
f"{max_seq_length}!"
)
# Warn RoPE scaling isn't allowed
if not has_rope_scaling:
raise RuntimeError(
- f"However, {model_name} doesn't support RoPE Scaling!\n"
+ "However, {model_name} doesn't support RoPE Scaling!\n"\
"Please file a feature request at https://github.com/unslothai/unsloth."
)
+ pass
- rope_scaling = {
- "type": "linear",
- "factor": rope_scaling,
- }
+ rope_scaling = {"type": "linear", "factor": rope_scaling,}
# Add to kwargs
kwargs["rope_scaling"] = rope_scaling
+ pass
bnb_config = None
if load_in_4bit:
- llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy()
- if IS_FALCON_H1:
- # we cannot quantize out_proj layer due to mamba kernels: https://github.com/tiiuae/Falcon-H1/issues/13#issuecomment-2918671274
- llm_int8_skip_modules.append("out_proj")
bnb_config = BitsAndBytesConfig(
- load_in_4bit = True,
+ load_in_4bit = True,
bnb_4bit_use_double_quant = True,
- bnb_4bit_quant_type = "nf4",
- bnb_4bit_compute_dtype = dtype,
- llm_int8_skip_modules = llm_int8_skip_modules,
+ bnb_4bit_quant_type = "nf4",
+ bnb_4bit_compute_dtype = dtype,
)
+ pass
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12
# RoPE Scaling's max_position_embeddings must be updated
max_position_embeddings = max(max_seq_length, model_max_seq_length)
- kwargs.pop("attn_implementation", None) # No need since we auto call it
+ kwargs.pop("attn_implementation", None); # No need since we auto call it
# Cannot be None, since HF now checks for the config
- if load_in_4bit:
- kwargs["quantization_config"] = bnb_config
-
- kwargs = add_dtype_kwargs(dtype, kwargs)
-
- raise_handler = RaiseUninitialized()
- if num_labels is not None:
- model = AutoModelForSequenceClassification.from_pretrained(
- model_name,
- device_map = device_map,
- # torch_dtype = dtype, # transformers changed torch_dtype to dtype
- num_labels = num_labels,
- # quantization_config = bnb_config,
- token = token,
- max_position_embeddings = max_position_embeddings,
- trust_remote_code = trust_remote_code,
- attn_implementation = preferred_attn_impl,
- **kwargs,
- )
- elif not fast_inference:
+ if load_in_4bit: kwargs["quantization_config"] = bnb_config
+
+ if not fast_inference:
model = AutoModelForCausalLM.from_pretrained(
model_name,
- device_map = device_map,
- # torch_dtype = dtype, # transformers changed torch_dtype to dtype
+ device_map = device_map,
+ torch_dtype = dtype,
# quantization_config = bnb_config,
- token = token,
+ token = token,
max_position_embeddings = max_position_embeddings,
- trust_remote_code = trust_remote_code,
- attn_implementation = preferred_attn_impl,
+ trust_remote_code = trust_remote_code,
+ attn_implementation = "eager",
**kwargs,
)
- model.fast_generate = make_fast_generate_wrapper(model.generate)
+ model.fast_generate = model.generate
model.fast_generate_batches = None
else:
from unsloth_zoo.vllm_utils import (
@@ -2447,28 +1799,18 @@ def from_pretrained(
convert_vllm_to_huggingface,
generate_batches,
)
-
- fp8_mode = None
- if load_in_fp8 != False:
- fp8_mode = _get_fp8_mode_and_check_settings(
- load_in_fp8,
- fast_inference,
- )
-
allowed_args = inspect.getfullargspec(load_vllm).args
load_vllm_kwargs = dict(
- model_name = model_name,
- config = model_config,
+ model_name = model_name,
+ config = model_config,
gpu_memory_utilization = gpu_memory_utilization,
- max_seq_length = max_seq_length,
- dtype = dtype,
- float8_kv_cache = float8_kv_cache,
- enable_lora = True,
- max_lora_rank = max_lora_rank,
- disable_log_stats = disable_log_stats,
- use_bitsandbytes = load_in_4bit,
- unsloth_vllm_standby = unsloth_vllm_standby,
- fp8_mode = fp8_mode,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ float8_kv_cache = float8_kv_cache,
+ enable_lora = True,
+ max_lora_rank = max_lora_rank,
+ disable_log_stats = disable_log_stats,
+ use_bitsandbytes = load_in_4bit,
)
for allowed_arg in allowed_args:
if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:
@@ -2479,47 +1821,37 @@ def from_pretrained(
llm = load_vllm(**load_vllm_kwargs)
# Convert to HF format
- _, quant_state_dict = get_vllm_state_dict(
- llm,
- config = model_config,
- load_in_fp8 = load_in_fp8,
- )
- model = convert_vllm_to_huggingface(
- quant_state_dict, model_config, dtype, bnb_config
- )
+ _, quant_state_dict = get_vllm_state_dict(llm, config = model_config)
+ model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype)
model.vllm_engine = llm
model.fast_generate = model.vllm_engine.generate
- model.fast_generate_batches = functools.partial(
- generate_batches, model.vllm_engine
- )
- raise_handler.remove()
+ model.fast_generate_batches = functools.partial(generate_batches, model.vllm_engine)
+ pass
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
# Counteract saved tokenizers
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
tokenizer = load_correct_tokenizer(
- tokenizer_name = tokenizer_name,
- model_max_length = max_position_embeddings,
- padding_side = "right",
- token = token,
+ tokenizer_name = tokenizer_name,
+ model_max_length = max_position_embeddings,
+ padding_side = "right",
+ token = token,
trust_remote_code = trust_remote_code,
- fix_tokenizer = fix_tokenizer,
+ fix_tokenizer = fix_tokenizer,
)
model, tokenizer = patch_tokenizer(model, tokenizer)
- model, tokenizer = model_patcher.post_patch(
- model, tokenizer, correct_dtype = dtype
- )
+ model, tokenizer = model_patcher.post_patch(model, tokenizer)
# Patch up QKV / O and MLP
for idx, layer in enumerate(model.model.layers):
layer.self_attn.apply_qkv = original_apply_qkv
- layer.self_attn.apply_o = original_apply_o
+ layer.self_attn.apply_o = original_apply_o
+ pass
# Patch Trainer
from transformers.trainer import Trainer
-
try:
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
inner_training_loop = inspect.getsource(Trainer._inner_training_loop)
@@ -2527,29 +1859,22 @@ def from_pretrained(
else:
inner_training_loop = Trainer._original_training_loop
except:
- raise RuntimeError("Unsloth: Unsuccessfully patched inner_training_loop")
-
+ raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop')
+ pass
+
import transformers.trainer
-
items_in_trainer = dir(transformers.trainer)
good_items = []
for item in items_in_trainer:
- if item in inner_training_loop:
- good_items.append(item)
- exec(
- "from transformers.trainer import ("
- + ", ".join(x for x in good_items)
- + ")",
- globals(),
- )
+ if item in inner_training_loop: good_items.append(item)
+ pass
+ exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals())
- start = re.search(
- r"logger\.info\([\"\'].+?Running training", inner_training_loop
- ).span(0)[0]
+ start = re.search(r'logger\.info\([\"\'].+?Running training', inner_training_loop).span(0)[0]
end = inner_training_loop.find("\n\n", start)
original_debug = inner_training_loop[start:end]
- spaces = re.search(r"\n([\s\t]{1,})", original_debug).group(0)[1:]
- front_spaces = re.match(r"([\s\t]{1,})", inner_training_loop).group(0)
+ spaces = re.search(r'\n([\s\t]{1,})', original_debug).group(0)[1:]
+ front_spaces = re.match(r'([\s\t]{1,})', inner_training_loop).group(0)
# Cannot use \\ since it will cause a SyntaxWarning in Python 3.12
# Instead use chr(92) == \\
@@ -2558,20 +1883,15 @@ def from_pretrained(
f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\
f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\
f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\
- f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,} of {get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)'
+ f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)'
logger.warning(debug_info)
import gc
for _ in range(3):
gc.collect()
- if DEVICE_TYPE == "xpu":
- torch.xpu.empty_cache()
- else:
- torch.cuda.empty_cache()"""
+ torch.cuda.empty_cache()"""
- debug_info = debug_info.split("\n")
- debug_info = "\n".join(
- [debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]
- )
+ debug_info = debug_info.split('\n')
+ debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
inner_training_loop = inner_training_loop.replace(original_debug, debug_info)
debug_info = """n_total_devices = total_train_batch_size // \\
@@ -2579,24 +1899,19 @@ def from_pretrained(
if n_total_devices > 1:
logger.warning_once('Unsloth is running with multi GPUs - the effective batch size is multiplied by ' + str(n_total_devices))
debug_info ="""
- debug_info = debug_info.split("\n")
- debug_info = "\n".join(
- [debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]
- )
+ debug_info = debug_info.split('\n')
+ debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]])
inner_training_loop = inner_training_loop.replace("debug_info =", debug_info, 1)
front_spaces = re.match(r"[\t\s]{1,}", inner_training_loop).group(0)
- inner_training_loop = re.sub(
- r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE
- )
+ inner_training_loop = re.sub(r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE)
inner_training_loop = inner_training_loop.replace(
"train_dataloader = tpu_spmd_dataloader(train_dataloader)",
- "raise RuntimeError('Unsloth: TPUs are not yet supported!')",
+ "raise RuntimeError('Unsloth: TPUs are not yet supported!')"
)
inner_training_loop = inner_training_loop.replace(
"_inner_training_loop",
- "_fast_inner_training_loop",
- 1,
+ "_fast_inner_training_loop", 1,
)
inner_training_loop = inner_training_loop.replace(
"is_torch_tpu_available()",
@@ -2611,21 +1926,20 @@ def from_pretrained(
while hasattr(m, "model"):
m.max_seq_length = max_seq_length
m = m.model
+ pass
m.max_seq_length = max_seq_length
- # Save to modules as well
- for module in model.modules():
- module.max_seq_length = max_seq_length
# We check the tokenizer first for errors
if fix_tokenizer:
tokenizer = check_tokenizer(
- model = model,
- tokenizer = tokenizer,
- model_name = model_name,
+ model = model,
+ tokenizer = tokenizer,
+ model_name = model_name,
model_max_length = max_position_embeddings,
- padding_side = "right",
- token = token,
+ padding_side = "right",
+ token = token,
)
+ pass
patch_saving_functions(tokenizer)
# Fix up config for transformers uploading PEFT
@@ -2633,11 +1947,13 @@ def from_pretrained(
if False:
name = model.config._name_or_path
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
- name = name[: len(name) - len("-bnb-4bit")]
- model.config.update({"_name_or_path": name})
+ name = name[:len(name) - len("-bnb-4bit")]
+ model.config.update({"_name_or_path" : name})
+ pass
+ pass
# Log Unsloth version for future fastpaths for inference
- model.config.update({"unsloth_version": __version__})
+ model.config.update({"unsloth_version" : __version__})
# Add save modules
patch_saving_functions(model)
@@ -2647,7 +1963,7 @@ def from_pretrained(
patch_gradient_accumulation_fix(Trainer)
# Save tokenizer for inference purposes
- tokenizer.padding_side = "left" # Force inference
+ tokenizer.padding_side = "left" # Force inference
internal_model = model
while hasattr(internal_model, "model"):
internal_model._saved_temp_tokenizer = tokenizer
@@ -2655,6 +1971,7 @@ def from_pretrained(
internal_model.is_loaded_in_8bit = True
internal_model = internal_model.model
+ pass
internal_model._saved_temp_tokenizer = tokenizer
# Also set is_loaded_in_8bit to disable incorrect DDP
internal_model.is_loaded_in_8bit = True
@@ -2664,178 +1981,114 @@ def from_pretrained(
rotary_emb = model.model.rotary_emb
for layer in model.model.layers:
layer.self_attn.rotary_emb = rotary_emb
+ pass
# Add for_inference and for_training
- model.for_training = functools.partial(FastLlamaModel.for_training, model)
+ model.for_training = functools.partial(FastLlamaModel.for_training, model)
model.for_inference = functools.partial(FastLlamaModel.for_inference, model)
- m = model
- while hasattr(m, "model"):
- m.for_training = functools.partial(FastBaseModel.for_training, m)
- m.for_inference = functools.partial(FastBaseModel.for_inference, m)
- m = m.model
# Patch generate
- is_classification = "Classification" in str(type(model))
- if not is_classification and model.generate.__name__ != "unsloth_fast_generate":
+ if model.generate.__name__ != "unsloth_fast_generate":
model._old_generate = model.generate
unsloth_fast_generate.__doc__ = model._old_generate.__doc__
model.generate = types.MethodType(unsloth_fast_generate, model)
- # Set weight[padding_idx] = 0 for embeddings that are NOT tied with the
- # lm_head. When weights are tied, zeroing the padding row also zeros
- # the corresponding lm_head row, forcing logit = 0 for the pad token.
- # This is higher than the (negative) logits for real tokens in models
- # like Gemma, causing the decoder to emit and produce gibberish.
- # Skip entirely if eos_token == pad_token to avoid zeroing EOS embedding.
- eos_token_id = (
- getattr(tokenizer, "eos_token_id", None) if tokenizer is not None else None
- )
- pad_token_id = (
- getattr(tokenizer, "pad_token_id", None) if tokenizer is not None else None
- )
- if tokenizer is not None and eos_token_id != pad_token_id:
- lm_head = getattr(model, "lm_head", None)
- lm_head_weight = (
- getattr(lm_head, "weight", None) if lm_head is not None else None
- )
- with torch.no_grad():
- for name, module in model.named_modules():
- if type(module) is torch.nn.Embedding:
- if (
- getattr(module, "weight", None) is not None
- and getattr(module, "padding_idx", None) is not None
- ):
- if module.padding_idx < module.weight.shape[0]:
- # Skip if tied to lm_head
- if (
- lm_head_weight is not None
- and module.weight.data_ptr()
- == lm_head_weight.data_ptr()
- ):
- continue
- module.weight[module.padding_idx] = 0
+ pass
return model, tokenizer
+ pass
+
@staticmethod
- def post_patch(model, tokenizer, correct_dtype = None):
- model, tokenizer = patch_model_and_tokenizer(
- model, tokenizer, downcast_rope = True, correct_dtype = correct_dtype
- )
+ def post_patch(model, tokenizer):
+ model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = True)
return model, tokenizer
+ pass
+
@staticmethod
def get_peft_model(
model,
- r = 16,
- target_modules = [
- "q_proj",
- "k_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "up_proj",
- "down_proj",
- ],
- lora_alpha = 16,
- lora_dropout = 0.0,
- bias = "none",
+ r = 16,
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
+ "gate_proj", "up_proj", "down_proj"],
+ lora_alpha = 16,
+ lora_dropout = 0,
+ bias = "none",
layers_to_transform = None,
- layers_pattern = None,
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
- max_seq_length = 2048, # not used anymore
- use_rslora = False,
- modules_to_save = None,
- init_lora_weights = True,
- loftq_config = {},
- temporary_location = "_unsloth_temporary_saved_buffers",
- qat_scheme = None,
- target_parameters = None, # For MoE expert layers (nn.Parameter)
- ensure_weight_tying = False,
+ layers_pattern = None,
+ use_gradient_checkpointing = True,
+ random_state = 3407,
+ max_seq_length = 2048, # not used anymore
+ use_rslora = False,
+ modules_to_save = None,
+ init_lora_weights = True,
+ loftq_config = {},
+ temporary_location = "_unsloth_temporary_saved_buffers",
**kwargs,
):
if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1":
- # Check for other PEFT args in kwargs
- for peft_arg, flag in (
- ("finetune_vision_layers", False),
- ("finetune_language_layers", True),
- ("finetune_attention_modules", True),
- ("finetune_mlp_modules", True),
- ):
- if peft_arg not in kwargs:
- kwargs[peft_arg] = flag
return FastBaseModel.get_peft_model(
- model = model,
- r = r,
- target_modules = target_modules,
- lora_alpha = lora_alpha,
- lora_dropout = lora_dropout,
- bias = bias,
- layers_to_transform = layers_to_transform,
- layers_pattern = layers_pattern,
+ model = model,
+ r = r,
+ target_modules = target_modules,
+ lora_alpha = lora_alpha,
+ lora_dropout = lora_dropout,
+ bias = bias,
+ finetune_vision_layers = False,
+ finetune_language_layers = True,
+ finetune_attention_modules = True,
+ finetune_mlp_modules = True,
+ layers_to_transform = layers_to_transform,
+ layers_pattern = layers_pattern,
use_gradient_checkpointing = use_gradient_checkpointing,
- random_state = random_state,
- max_seq_length = max_seq_length,
- use_rslora = use_rslora,
- modules_to_save = modules_to_save,
- init_lora_weights = init_lora_weights,
- loftq_config = loftq_config,
- temporary_location = temporary_location,
- target_parameters = target_parameters,
- ensure_weight_tying = ensure_weight_tying,
+ random_state = random_state,
+ max_seq_length = max_seq_length,
+ use_rslora = use_rslora,
+ modules_to_save = modules_to_save,
+ init_lora_weights = init_lora_weights,
+ loftq_config = loftq_config,
+ temporary_location = temporary_location,
**kwargs,
)
+ pass
if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1":
- print(
- "Unsloth: Full finetuning is enabled, so .get_peft_model has no effect"
- )
+ print("Unsloth: Full finetuning is enabled, so .get_peft_model has no effect")
return model
+ pass
transformers_set_seed(random_state)
- # Apply gradient checkpointing with smart heuristics
- max_seq = getattr(model, "max_seq_length", 512)
- dtype = model.get_input_embeddings().weight.dtype
- use_gradient_checkpointing = apply_unsloth_gradient_checkpointing(
- use_gradient_checkpointing, max_seq, dtype
- )
+ if use_gradient_checkpointing == "unsloth":
+ patch_unsloth_smart_gradient_checkpointing(dtype = model.get_input_embeddings().weight.dtype)
if type(r) is not int:
raise TypeError(f"Unsloth: Rank of {str(r)} must be an integer.")
if r <= 0:
raise TypeError(f"Unsloth: Rank of {str(r)} must be larger than 0.")
- if isinstance(model, PeftModelForCausalLM) or isinstance(
- model, PeftModelForSequenceClassification
- ):
+ if isinstance(model, PeftModelForCausalLM):
# Check if exactly the same and then pass through!
- assert hasattr(model, "peft_config")
+ assert(hasattr(model, "peft_config"))
peft_config = model.peft_config["default"].to_dict()
check_parameters = [
- "r",
- "lora_alpha",
- "lora_dropout",
- "bias",
- "layers_to_transform",
- "layers_pattern",
- "use_rslora",
- "init_lora_weights",
+ "r", "lora_alpha", "lora_dropout",
+ "bias", "layers_to_transform", "layers_pattern",
+ "use_rslora", "init_lora_weights",
]
check_all = True
for param in check_parameters:
check_all = check_all and (peft_config[param] == eval(param))
+ pass
# Check save_modules
old_target_modules = list(peft_config["target_modules"])
modules_to_save = peft_config["modules_to_save"]
- if modules_to_save is None:
- modules_to_save = {}
+ if modules_to_save is None: modules_to_save = {}
modules_to_save = list(modules_to_save)
old_target_modules += modules_to_save
# Combine all
- new_target_modules = list(target_modules) + list(
- modules_to_save if modules_to_save is not None else []
- )
+ new_target_modules = list(target_modules) + \
+ list(modules_to_save if modules_to_save is not None else [])
# Now check!
new_target_modules = set(new_target_modules)
@@ -2844,11 +2097,8 @@ def get_peft_model(
)
check_all = check_all and (
- (loftq_config == {} or loftq_config is None)
- and (
- peft_config["loftq_config"] == {}
- or peft_config["loftq_config"] is None
- )
+ (loftq_config == {} or loftq_config is None) and \
+ (peft_config["loftq_config"] == {} or peft_config["loftq_config"] is None)
)
if check_all:
@@ -2860,165 +2110,182 @@ def get_peft_model(
# Offload!
# [TODO] First offload lm_head and embed_tokens to CPU (should be disk!!)
if "embed_tokens" in new_target_modules:
- print(
- "Unsloth: Training embed_tokens in mixed precision to save VRAM"
- )
-
- _offload_frozen_module_for_training(
- model.get_input_embeddings(), DEVICE_TYPE_TORCH
- )
+ print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
+
+ new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype
+ if new_dtype == torch.float16:
+ # See https://github.com/unslothai/unsloth/pull/1200
+ # Tesla T4 must use float32 and not float16
+ new_dtype = torch.float32
+ pass
+
+ model.get_input_embeddings().modules_to_save.default\
+ .to(device = "cuda", dtype = new_dtype, non_blocking = True)
+ model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
+
+ # [TODO] Move old embed_tokens to CPU - should be disk!
+ model.get_input_embeddings().original_module\
+ .to(device = "cpu", non_blocking = True)
+ model.get_input_embeddings().original_module.requires_grad_(False)
+ pass
if "lm_head" in new_target_modules:
print("Unsloth: Training lm_head in mixed precision to save VRAM")
- _offload_frozen_module_for_training(
- model.get_output_embeddings(), DEVICE_TYPE_TORCH
- )
+ new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype
+ if new_dtype == torch.float16:
+ # See https://github.com/unslothai/unsloth/pull/1200
+ # Tesla T4 must use float32 and not float16
+ new_dtype = torch.float32
+ pass
+
+ model.get_output_embeddings().modules_to_save.default\
+ .to(device = "cuda", dtype = new_dtype, non_blocking = True)
+ model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
+
+ # [TODO] Move old lm_head to CPU - should be disk!
+ model.get_output_embeddings().original_module\
+ .to(device = "cpu", non_blocking = True)
+ model.get_output_embeddings().original_module.requires_grad_(False)
+ pass
return model
else:
raise TypeError(
"Unsloth: Your model already has LoRA adapters. Your new parameters are different."
)
+ pass
+ pass
- if loftq_config is None:
- loftq_config = {}
+ if loftq_config is None: loftq_config = {}
signature = str(inspect.signature(LoraConfig))
- SUPPORTS_LOFTQ = "loftq_config" in signature
- SUPPORTS_RSLORA = "use_rslora" in signature
+ SUPPORTS_LOFTQ = "loftq_config" in signature
+ SUPPORTS_RSLORA = "use_rslora" in signature
if lora_dropout != 0:
logger.warning_once(
- f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"
+ f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"\
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
)
+ pass
if bias != "none":
logger.warning_once(
- f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"
+ f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"\
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
)
+ pass
- if not (
- type(init_lora_weights) is bool
- or init_lora_weights == "gaussian"
- or init_lora_weights == "loftq"
- or init_lora_weights == "corda"
- ):
+ if not (type(init_lora_weights) is bool or \
+ init_lora_weights == "gaussian" or init_lora_weights == "loftq"):
raise ValueError(
- 'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq", "corda"].'
+ 'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq"].'
)
+ pass
if init_lora_weights == "loftq":
+
if not SUPPORTS_LOFTQ:
import peft
-
raise RuntimeError(
- f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"
- "Please install PEFT 0.7.2 or higher.\n"
+ f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"\
+ "Please install PEFT 0.7.2 or higher.\n"\
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
)
+ pass
if loftq_config == {}:
from peft import LoftQConfig
-
logger.warning_once(
- "Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"
+ "Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\
"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
)
loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
-
+ pass
+
if hasattr(model.config, "quantization_config"):
raise ValueError(
- "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"
+ "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\
"Reload your model without any quantization by setting `load_in_4bit = False`."
)
+ pass
+ pass
- assert type(use_rslora) is bool
+ assert(type(use_rslora) is bool)
if use_rslora:
if not SUPPORTS_RSLORA:
# We manually check for PEFT
import peft
-
raise RuntimeError(
- f"Unsloth: Your PEFT version of {peft.__version__} does not support `use_rslora`.\n"
- "Please install PEFT 0.7.2 or higher.\n"
+ f"Unsloth: Your PEFT version of {peft.__version__} does not support `use_rslora`.\n"\
+ "Please install PEFT 0.7.2 or higher.\n"\
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
)
+ pass
+ pass
- accepted_modules = frozenset(
- (
- "lm_head",
- "q_proj",
- "k_proj",
- "v_proj",
- "o_proj",
- "gate_proj",
- "up_proj",
- "down_proj",
- ),
- )
- model.config.update({"unsloth_version": __version__})
+ accepted_modules = frozenset(("q_proj", "k_proj", "v_proj", "o_proj",
+ "gate_proj", "up_proj", "down_proj",),)
+ model.config.update({"unsloth_version" : __version__})
if type(modules_to_save) is tuple:
modules_to_save = list(modules_to_save)
+ pass
train_lm_head = False
train_embed_tokens = False
final_modules = []
for module in target_modules:
- if module == "embed_tokens":
+ if module == "lm_head":
+ # logger.warning_once(
+ # "Unsloth: `lm_head` should be placed in `modules_to_save` and not `target_modules`. "\
+ # "Luckily, we shall do it for you!"
+ # )
+ train_lm_head = True
+ if modules_to_save is None: modules_to_save = ["lm_head"]
+ else: modules_to_save.append("lm_head")
+
+ elif module == "embed_tokens":
# logger.warning_once(
# "Unsloth: `embed_tokens` should be placed in `modules_to_save` and not `target_modules`. "\
# "Luckily, we shall do it for you!"
# )
train_embed_tokens = True
- if modules_to_save is None:
- modules_to_save = ["embed_tokens"]
- else:
- modules_to_save.append("embed_tokens")
+ if modules_to_save is None: modules_to_save = ["embed_tokens"]
+ else: modules_to_save.append("embed_tokens")
else:
try:
- assert module in accepted_modules
+ assert(module in accepted_modules)
final_modules.append(module)
except AssertionError as e:
final_modules.append(module)
print(
- "Unsloth: You added custom modules, but Unsloth hasn't optimized for this.\n"
+ "Unsloth: You added custom modules, but Unsloth hasn't optimized for this.\n"\
"Beware - your finetuning might be noticeably slower!"
)
pass
+ pass
+ pass
# Check if we added new tokens!
if hasattr(model, "_need_to_train_embeddings"):
- # Check if embed_tokens/lm_head are already being trained
- # (either as LoRA targets in final_modules or via modules_to_save)
- _embed_already_trained = (
- train_embed_tokens or "embed_tokens" in final_modules
- )
- _lm_head_already_trained = train_lm_head or "lm_head" in final_modules
- if not _lm_head_already_trained or not _embed_already_trained:
+ if not train_lm_head or not train_embed_tokens:
print(
- "Unsloth: You added new tokens but did not specify if you wanted to "
+ "Unsloth: You added new tokens but did not specify if you wanted to "\
"train the lm_head and embed_tokens.\nWe must turn it on for you."
)
+ train_lm_head = True
+ train_embed_tokens = True
- # Only add to modules_to_save if not already a LoRA target
- if not _embed_already_trained:
- train_embed_tokens = True
- if modules_to_save is None:
- modules_to_save = ["embed_tokens"]
- elif "embed_tokens" not in modules_to_save:
- modules_to_save.append("embed_tokens")
+ if modules_to_save is None: modules_to_save = ["embed_tokens"]
+ else: modules_to_save.append("embed_tokens")
- if not _lm_head_already_trained:
- train_lm_head = True
- if modules_to_save is None:
- modules_to_save = ["lm_head"]
- elif "lm_head" not in modules_to_save:
- modules_to_save.append("lm_head")
+ if modules_to_save is None: modules_to_save = ["lm_head"]
+ else: modules_to_save.append("lm_head")
+ pass
+ pass
# Check for Llama-3
# if hasattr(model._saved_temp_tokenizer, "_using_llama3_template"):
@@ -3042,8 +2309,11 @@ def get_peft_model(
raise TypeError(
f"Unsloth: Module = {module} is not allowed. Only 'lm_head' and 'embed_tokens' is allowed."
)
+ pass
+ pass
if isinstance(modules_to_save, (tuple, list)):
modules_to_save = list(set(modules_to_save))
+ pass
vllm_engine = None
if hasattr(model, "vllm_engine"):
@@ -3053,222 +2323,177 @@ def get_peft_model(
vllm_fast_generate_batches = model.fast_generate_batches
if modules_to_save is not None:
- raise NotImplementedError(
- "Unsloth: Currently fast inference does not work with training embeddings or lm_head."
- )
+ raise NotImplementedError("Unsloth: Currently fast inference does not work with training embeddings or lm_head.")
if bias != "none":
- raise NotImplementedError(
- "Unsloth: Currently fast inference does not work with using biases for LoRA."
- )
-
- # Does not get lora yet, so get name from model, not base model
- is_classification = "Classification" in str(type(model))
-
- # Auto-detect MoE models and populate target_parameters for expert layers
- if target_parameters is None:
- target_parameters = get_moe_target_parameters(model, target_modules)
+ raise NotImplementedError("Unsloth: Currently fast inference does not work with using biases for LoRA.")
+ pass
+ # Get LoRA
arguments = dict(
- r = r,
- lora_alpha = lora_alpha,
- target_modules = final_modules,
- lora_dropout = lora_dropout,
- bias = bias,
- task_type = TaskType.CAUSAL_LM if not is_classification else TaskType.SEQ_CLS,
+ r = r,
+ lora_alpha = lora_alpha,
+ target_modules = final_modules,
+ lora_dropout = lora_dropout,
+ bias = bias,
+ task_type = TaskType.CAUSAL_LM,
layers_to_transform = layers_to_transform,
- init_lora_weights = init_lora_weights,
- loftq_config = loftq_config,
- use_rslora = use_rslora,
- modules_to_save = modules_to_save,
- target_parameters = target_parameters,
- ensure_weight_tying = ensure_weight_tying,
+ init_lora_weights = init_lora_weights,
+ loftq_config = loftq_config,
+ use_rslora = use_rslora,
+ modules_to_save = modules_to_save,
**kwargs,
)
- if not SUPPORTS_LOFTQ:
- del arguments["loftq_config"]
- if not SUPPORTS_RSLORA:
- del arguments["use_rslora"]
+ if not SUPPORTS_LOFTQ: del arguments["loftq_config"]
+ if not SUPPORTS_RSLORA: del arguments["use_rslora"]
_saved_temp_tokenizer = model._saved_temp_tokenizer
lora_config = LoraConfig(**arguments)
+
# First offload lm_head and embed_tokens to disk
- input_embeddings_device = model.get_input_embeddings().weight.device
- if is_classification:
- output_embeddings_device = model.score.weight.device
- else:
- output_embeddings_device = model.get_output_embeddings().weight.device
+ input_embeddings_device = model. get_input_embeddings().weight.device
+ output_embeddings_device = model.get_output_embeddings().weight.device
if use_gradient_checkpointing == "unsloth":
if train_embed_tokens:
print("Unsloth: Offloading input_embeddings to disk to save VRAM")
offload_input_embeddings(model, temporary_location)
+ pass
# Remove old items to save VRAM
for _ in range(3):
gc.collect()
- clean_gpu_cache()
+ torch.cuda.empty_cache()
+ pass
if train_lm_head:
print("Unsloth: Offloading output_embeddings to disk to save VRAM")
offload_output_embeddings(model, temporary_location)
+ pass
# Remove old items to save VRAM
for _ in range(3):
gc.collect()
- clean_gpu_cache()
+ torch.cuda.empty_cache()
+ pass
+ pass
model = _get_peft_model(model, lora_config)
- # Fix LoraConfig.auto_mapping is None
- fix_lora_auto_mapping(model)
-
- # Apply QAT + LoRA if specified
- if qat_scheme is not None:
- print("Unsloth: Applying QAT to mitigate quantization degradation")
- model = FastLlamaModel._prepare_for_qat(model, qat_scheme)
model._saved_temp_tokenizer = _saved_temp_tokenizer
model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing)
- if ensure_weight_tying:
- try:
- input_embeddings = model.get_input_embeddings()
- output_embeddings = model.get_output_embeddings()
-
- if input_embeddings is not None and output_embeddings is not None:
-
- def _retie_parameter(target_module, source_module):
- if not hasattr(source_module, "weight"):
- return
- weight = source_module.weight
- # Remove existing registration to avoid "attribute already exists"
- if "weight" in getattr(target_module, "_parameters", {}):
- target_module._parameters.pop("weight")
- if hasattr(target_module, "weight"):
- try:
- delattr(target_module, "weight")
- except Exception as exc:
- logger.warning_once(
- f"Unsloth: Could not delete existing weight attr during retie on "
- f"{type(target_module).__name__}: {exc}"
- )
- target_module.register_parameter("weight", weight)
-
- # Tie trainable copies created by ModulesToSaveWrapper first (these are used in forward)
- if hasattr(input_embeddings, "modules_to_save") and hasattr(
- output_embeddings, "modules_to_save"
- ):
- if hasattr(
- input_embeddings.modules_to_save, "default"
- ) and hasattr(output_embeddings.modules_to_save, "default"):
- _retie_parameter(
- output_embeddings.modules_to_save.default,
- input_embeddings.modules_to_save.default,
- )
-
- # Tie original_module references as well if present
- if hasattr(input_embeddings, "original_module") and hasattr(
- output_embeddings, "original_module"
- ):
- _retie_parameter(
- output_embeddings.original_module,
- input_embeddings.original_module,
- )
- except Exception as e:
- logger.warning_once(
- f"Unsloth: Failed to ensure weight tying between embeddings and lm_head: {e}"
- )
-
if train_embed_tokens:
print("Unsloth: Training embed_tokens in mixed precision to save VRAM")
- assert hasattr(model.get_input_embeddings(), "modules_to_save")
+ assert(hasattr(model.get_input_embeddings(), "modules_to_save"))
- _offload_frozen_module_for_training(
- model.get_input_embeddings(), DEVICE_TYPE_TORCH, offload_device = None
- )
+ new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype
+ if new_dtype == torch.float16:
+ # See https://github.com/unslothai/unsloth/pull/1200
+ # Tesla T4 must use float32 and not float16
+ new_dtype = torch.float32
+ pass
+
+ model.get_input_embeddings().modules_to_save.default\
+ .to(device = "cuda", dtype = new_dtype, non_blocking = True)
+ model.get_input_embeddings().modules_to_save.default.requires_grad_(True)
+ pass
if train_lm_head:
print("Unsloth: Training lm_head in mixed precision to save VRAM")
- assert hasattr(model.get_output_embeddings(), "modules_to_save")
+ assert(hasattr(model.get_output_embeddings(), "modules_to_save"))
- _offload_frozen_module_for_training(
- model.get_output_embeddings(), DEVICE_TYPE_TORCH, offload_device = None
- )
+ new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype
+ if new_dtype == torch.float16:
+ # See https://github.com/unslothai/unsloth/pull/1200
+ # Tesla T4 must use float32 and not float16
+ new_dtype = torch.float32
+ pass
+
+ model.get_output_embeddings().modules_to_save.default\
+ .to(device = "cuda", dtype = new_dtype, non_blocking = True)
+ model.get_output_embeddings().modules_to_save.default.requires_grad_(True)
+ pass
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
+ pass
# Also set is_loaded_in_8bit to disable incorrect DDP
internal_model.is_loaded_in_8bit = True
internal_model = internal_model.model
+ pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
+ pass
# Also set is_loaded_in_8bit to disable incorrect DDP
internal_model.is_loaded_in_8bit = True
# Clear deleted GPU items
for _ in range(3):
gc.collect()
- clean_gpu_cache()
+ torch.cuda.empty_cache()
+ pass
+
+ # Patch for fast inference
+ if vllm_engine is not None:
+ model.vllm_engine = vllm_engine
+ model.fast_generate = vllm_fast_generate
+ model.fast_generate_batches = vllm_fast_generate_batches
- patch_peft_fast_inference(model)
+ # Also saving and loading LoRA
+ from unsloth_zoo.vllm_utils import save_lora, load_lora
+ model.save_lora = functools.partial(save_lora, model)
+ model.load_lora = functools.partial(load_lora, model)
+ pass
# Add for_inference and for_training
- model.for_training = functools.partial(FastLlamaModel.for_training, model)
+ model.for_training = functools.partial(FastLlamaModel.for_training, model)
model.for_inference = functools.partial(FastLlamaModel.for_inference, model)
- m = model
- while hasattr(m, "model"):
- m.for_training = functools.partial(FastBaseModel.for_training, m)
- m.for_inference = functools.partial(FastBaseModel.for_inference, m)
- m = m.model
+
+ # Patch generate
+ if model.generate.__name__ != "unsloth_fast_generate":
+ model._old_generate = model.generate
+ unsloth_fast_generate.__doc__ = model._old_generate.__doc__
+ model.generate = types.MethodType(unsloth_fast_generate, model)
return model
+ pass
+
@staticmethod
def patch_peft_model(
model,
- use_gradient_checkpointing = "unsloth",
+ use_gradient_checkpointing = True,
):
if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1":
return FastBaseModel.patch_peft_model(
model = model,
use_gradient_checkpointing = use_gradient_checkpointing,
)
- if not isinstance(model, PeftModelForCausalLM) and not isinstance(
- model, PeftModelForSequenceClassification
- ):
+ pass
+ if not isinstance(model, PeftModelForCausalLM):
raise TypeError(
"Unsloth: Your model needs to call `.get_peft_model` first!"
)
+ pass
# Get activation function
model_type = model.config.model_type
- if model_type == "llama":
- apply_lora_mlp = apply_lora_mlp_swiglu
- elif model_type == "mistral":
- apply_lora_mlp = apply_lora_mlp_swiglu
- elif model_type == "qwen2":
- apply_lora_mlp = apply_lora_mlp_swiglu
- elif model_type == "gemma":
- apply_lora_mlp = apply_lora_mlp_geglu_approx
- elif model_type == "gemma2":
- apply_lora_mlp = apply_lora_mlp_geglu_approx
- elif model_type == "cohere":
- apply_lora_mlp = apply_lora_mlp_swiglu
- elif model_type == "granite":
- apply_lora_mlp = apply_lora_mlp_swiglu
- elif model_type == "qwen3":
- apply_lora_mlp = apply_lora_mlp_swiglu
- elif model_type == "falcon_h1":
- apply_lora_mlp = apply_lora_mlp_swiglu
- elif model_type == "qwen3moe":
- apply_lora_mlp = apply_lora_mlp_swiglu
+ if model_type == "llama": apply_lora_mlp = apply_lora_mlp_swiglu
+ elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu
+ elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu
+ elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx
+ elif model_type == "gemma2": apply_lora_mlp = apply_lora_mlp_geglu_approx
+ elif model_type == "cohere": apply_lora_mlp = apply_lora_mlp_swiglu
+ elif model_type == "granite": apply_lora_mlp = apply_lora_mlp_swiglu
else:
raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!")
+ pass
model = prepare_model_for_kbit_training(
model,
@@ -3282,147 +2507,123 @@ def patch_peft_model(
if False:
name = model.peft_config[active_adapter].base_model_name_or_path
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
- name = name[: len(name) - len("-bnb-4bit")]
+ name = name[:len(name) - len("-bnb-4bit")]
model.peft_config[active_adapter].base_model_name_or_path = name
pass
# Add revision to enable future fast inference paths
# [TODO] Bugs out!see https://github.com/unslothai/unsloth/issues/492
# model.peft_config[active_adapter].revision = f"unsloth"
+ pass
- from transformers.trainer import Trainer
-
+ from transformers.trainer import Trainer
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
- raise RuntimeError(
- "Unsloth: Unsuccessfully patched Trainer! Please file a bug report!"
- )
+ raise RuntimeError("Unsloth: Unsuccessfully patched Trainer! Please file a bug report!")
+ pass
# Fix loftq issues
# loftq_config must not = None, but rather {}
all_configs = model.peft_config
for key, current_config in all_configs.items():
- if (
- hasattr(current_config, "loftq_config")
- and current_config.loftq_config is None
- ):
+ if hasattr(current_config, "loftq_config") and current_config.loftq_config is None:
new_args = current_config.__dict__
new_args["loftq_config"] = {}
current_config = current_config.__class__(**new_args)
all_configs[key] = current_config
+ pass
+ pass
# Do patching
n_mlp = 0
n_qkv = 0
- n_o = 0
+ n_o = 0
- active_adapter = (
- model.active_adapters[0]
- if hasattr(model, "active_adapters")
- else model.active_adapter
- )
+ active_adapter = model.active_adapters[0] if \
+ hasattr(model, "active_adapters") else model.active_adapter
# Get dropout and bias
lora_dropout = model.peft_config[active_adapter].lora_dropout
- bias = model.peft_config[active_adapter].bias
+ bias = model.peft_config[active_adapter].bias
# We also do not inplace edit QKV for Cohere!
- _apply_lora_mlp = (
- functools.partial(apply_lora_mlp, inplace = False)
- if model_type == "cohere"
- else apply_lora_mlp
- )
+ _apply_lora_mlp = \
+ functools.partial(apply_lora_mlp, inplace = False) \
+ if model_type == "cohere" else \
+ apply_lora_mlp
+ pass
if lora_dropout == 0 and bias == "none":
for idx, layer in enumerate(model.model.model.layers):
- if model_type != "falcon_h1":
- # LoRAMLP.apply doesn't have functionality for gate and down multipliers yet.
- # Don't patch falcon h1 for the time being.
-
- # MLP patching
- mlp_module = layer.mlp
- gate_proj = mlp_module.gate_proj
- up_proj = mlp_module.up_proj
- down_proj = mlp_module.down_proj
-
- if (
- hasattr(gate_proj, "lora_A")
- and hasattr(up_proj, "lora_A")
- and hasattr(down_proj, "lora_A")
- and (getattr(gate_proj, "base_layer", gate_proj).bias is None)
- and (getattr(up_proj, "base_layer", up_proj).bias is None)
- and (getattr(down_proj, "base_layer", down_proj).bias is None)
- and (
- len(getattr(gate_proj, "lora_magnitude_vector", []) or [])
- == 0
- )
- and (
- len(getattr(up_proj, "lora_magnitude_vector", []) or [])
- == 0
- )
- and (
- len(getattr(down_proj, "lora_magnitude_vector", []) or [])
- == 0
- )
- ):
- # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
- if hasattr(mlp_module, "_unsloth_forward"):
- # then we've patched the mlp to use TiledMLP
- mlp_module._unsloth_forward = types.MethodType(
- _apply_lora_mlp, mlp_module
- )
- else:
- mlp_module.forward = types.MethodType(
- _apply_lora_mlp, mlp_module
- )
- n_mlp += 1
- else:
- logger.warning_once(
- "Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters\n"
- "are not enabled or a bias term (like in Qwen) is used."
- )
+
+ # MLP patching
+ gate_proj = layer.mlp.gate_proj
+ up_proj = layer.mlp. up_proj
+ down_proj = layer.mlp.down_proj
+
+ if hasattr(gate_proj, "lora_A") and \
+ hasattr( up_proj, "lora_A") and \
+ hasattr(down_proj, "lora_A") and \
+ (getattr(gate_proj, "base_layer", gate_proj).bias is None) and \
+ (getattr( up_proj, "base_layer", up_proj).bias is None) and \
+ (getattr(down_proj, "base_layer", down_proj).bias is None) and \
+ (len(getattr(gate_proj, "lora_magnitude_vector", []) or []) == 0) and \
+ (len(getattr( up_proj, "lora_magnitude_vector", []) or []) == 0) and \
+ (len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0):
+
+ # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
+ layer.mlp.forward = types.MethodType(_apply_lora_mlp, layer.mlp)
+ n_mlp += 1
+ else:
+ logger.warning_once(
+ "Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters\n"\
+ "are not enabled or a bias term (like in Qwen) is used."
+ )
+ pass
# QKV attention patching
q_proj = layer.self_attn.q_proj
k_proj = layer.self_attn.k_proj
v_proj = layer.self_attn.v_proj
- if (
- hasattr(q_proj, "lora_A")
- and hasattr(k_proj, "lora_A")
- and hasattr(v_proj, "lora_A")
- and (getattr(q_proj, "base_layer", q_proj).bias is None)
- and (getattr(k_proj, "base_layer", k_proj).bias is None)
- and (getattr(v_proj, "base_layer", v_proj).bias is None)
- and (len(getattr(q_proj, "lora_magnitude_vector", []) or []) == 0)
- and (len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0)
- and (len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0)
- ):
+ if hasattr(q_proj, "lora_A") and \
+ hasattr(k_proj, "lora_A") and \
+ hasattr(v_proj, "lora_A") and \
+ (getattr(q_proj, "base_layer", q_proj).bias is None) and \
+ (getattr(k_proj, "base_layer", k_proj).bias is None) and \
+ (getattr(v_proj, "base_layer", v_proj).bias is None) and \
+ (len(getattr(q_proj, "lora_magnitude_vector", []) or []) == 0) and \
+ (len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0) and \
+ (len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0):
+
layer.self_attn.apply_qkv = apply_lora_qkv
n_qkv += 1
else:
- if model_type == "qwen2":
- n_qkv += 1
+ if model_type == "qwen2": n_qkv += 1
else:
logger.warning_once(
- "Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"
+ "Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
)
+ pass
+ pass
# O attention patching
o_proj = layer.self_attn.o_proj
- if (
- hasattr(o_proj, "lora_A")
- and (getattr(o_proj, "base_layer", o_proj).bias is None)
- and (len(getattr(o_proj, "lora_magnitude_vector", []) or []) == 0)
- ):
+ if hasattr(o_proj, "lora_A") and \
+ (getattr(o_proj, "base_layer", o_proj).bias is None) and \
+ (len(getattr(o_proj, "lora_magnitude_vector", []) or []) == 0):
+
layer.self_attn.apply_o = apply_lora_o
n_o += 1
else:
logger.warning_once(
- "Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters\n"
+ "Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
)
+ pass
+ pass
+ pass
logger.warning_once(
- f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "
+ f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "\
f"{n_qkv} QKV layers, {n_o} O layers and {n_mlp} MLP layers.",
)
patch_saving_functions(model)
@@ -3436,126 +2637,116 @@ def patch_peft_model(
while hasattr(internal_model, "model"):
internal_model.max_seq_length = max_seq_length
internal_model = internal_model.model
- internal_model.max_seq_length = max_seq_length
- # Save to modules as well
- for module in model.modules():
- module.max_seq_length = max_seq_length
+ pass
+ internal_model.max_seq_length = max_seq_length
# Patch tokenizer to pad to the right
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
+ pass
internal_model = internal_model.model
+ pass
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
+ pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
- clean_gpu_cache()
+ torch.cuda.empty_cache()
+ pass
- patch_peft_fast_inference(model)
+ # Patch for fast inference
+ vllm_engine = getattr(model.model, "vllm_engine", None)
+ if vllm_engine is not None:
+ model.vllm_engine = model.model.vllm_engine
+ model.fast_generate = model.model.fast_generate
+ model.fast_generate_batches = model.model.fast_generate_batches
+
+ # Also saving and loading LoRA
+ from unsloth_zoo.vllm_utils import save_lora, load_lora
+ model.save_lora = functools.partial(save_lora, model)
+ model.load_lora = functools.partial(load_lora, model)
+ pass
# Add for_inference and for_training
- model.for_training = functools.partial(FastLlamaModel.for_training, model)
+ model.for_training = functools.partial(FastLlamaModel.for_training, model)
model.for_inference = functools.partial(FastLlamaModel.for_inference, model)
- m = model
- while hasattr(m, "model"):
- m.for_training = functools.partial(FastBaseModel.for_training, m)
- m.for_inference = functools.partial(FastBaseModel.for_inference, m)
- m = m.model
return model
+ pass
+
@staticmethod
def for_inference(model):
if not hasattr(model, "parameters"):
- raise TypeError(
- "Unsloth: I think you're passing a tokenizer, not the model to for_inference!"
- )
+ raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_inference!")
def _for_inference(m):
- if hasattr(m, "gradient_checkpointing"):
- m.gradient_checkpointing = False
- if hasattr(m, "training"):
- m.training = False
+ if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = False
+ if hasattr(m, "training"): m.training = False
# Pad tokenizer to the left
- if hasattr(m, "_saved_temp_tokenizer"):
- m._saved_temp_tokenizer.padding_side = "left"
+ if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "left"
# Set a flag for generation!
m._flag_for_generation = True
-
+ pass
m = model
while hasattr(m, "model"):
_for_inference(m)
m = m.model
_for_inference(m)
- model.eval() # to turn off training on modules deeper in
-
- # Since transformers 4.53, must turn off explicitly
- for module in model.modules():
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = False
# Also disable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
- if hasattr(embeddings, "training"):
- embeddings.training = False
+ if hasattr(embeddings, "training"): embeddings.training = False
+ pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
- if hasattr(embeddings, "training"):
- embeddings.training = False
+ if hasattr(embeddings, "training"): embeddings.training = False
+ pass
return model
+ pass
+
@staticmethod
def for_training(model, use_gradient_checkpointing = True):
if not hasattr(model, "parameters"):
- raise TypeError(
- "Unsloth: I think you're passing a tokenizer, not the model to for_training!"
- )
+ raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_training!")
# Delete all fast inference loras
for param in model.parameters():
if hasattr(param, "_fast_lora"):
del param._fast_lora
+ pass
def _for_training(m):
- if hasattr(m, "gradient_checkpointing"):
- m.gradient_checkpointing = use_gradient_checkpointing
- if hasattr(m, "training"):
- m.training = True
+ if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = use_gradient_checkpointing
+ if hasattr(m, "training"): m.training = True
# Pad tokenizer to the left
- if hasattr(m, "_saved_temp_tokenizer"):
- m._saved_temp_tokenizer.padding_side = "right"
+ if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "right"
# Set a flag for generation!
- if hasattr(m, "_flag_for_generation"):
- del m._flag_for_generation
-
+ if hasattr(m, "_flag_for_generation"): del m._flag_for_generation
+ pass
m = model
while hasattr(m, "model"):
_for_training(m)
m = m.model
_for_training(m)
- model.train() # to turn on training on modules deeper in
-
- # Since transformers 4.53, must turn on explicitly
- for module in model.modules():
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = use_gradient_checkpointing
# Also re-enable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
- if hasattr(embeddings, "training"):
- embeddings.training = True
+ if hasattr(embeddings, "training"): embeddings.training = True
+ pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
- if hasattr(embeddings, "training"):
- embeddings.training = True
+ if hasattr(embeddings, "training"): embeddings.training = True
+ pass
return model
-
+ pass
+pass
from .rl import PatchFastRL
-
PatchFastRL(FastLanguageModel = FastLlamaModel)
diff --git a/unsloth/models/llama4.py b/unsloth/models/llama4.py
deleted file mode 100644
index f429922268..0000000000
--- a/unsloth/models/llama4.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# from unsloth_studio.models import patch_llama4
-# patch_llama4()
diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py
index bd15ed5281..44475780af 100644
--- a/unsloth/models/loader.py
+++ b/unsloth/models/loader.py
@@ -13,34 +13,21 @@
# limitations under the License.
from ._utils import (
- _prepare_model_for_qat,
is_bfloat16_supported,
- is_vLLM_available,
HAS_FLASH_ATTENTION,
HAS_FLASH_ATTENTION_SOFTCAPPING,
USE_MODELSCOPE,
- get_transformers_model_type,
- hf_login,
)
from .granite import FastGraniteModel
-from .llama import FastLlamaModel, logger
+from .llama import FastLlamaModel, logger
from .mistral import FastMistralModel
-from .qwen2 import FastQwen2Model
-from .qwen3 import FastQwen3Model
-from .qwen3_moe import FastQwen3MoeModel
-from .cohere import FastCohereModel
+from .qwen2 import FastQwen2Model
+from .cohere import FastCohereModel
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from peft import PeftConfig, PeftModel
-from .loader_utils import (
- _get_fp8_mode_and_check_settings,
- _offline_quantize_to_fp8,
- _tag_model_with_fp8_torchao_config,
- get_model_name,
- prepare_device_map,
-)
+from .loader_utils import get_model_name
import os, contextlib, sys
-
try:
from huggingface_hub import get_token
except:
@@ -49,386 +36,117 @@
except:
# For older versions of huggingface_hub
from huggingface_hub.utils._token import get_token
+ pass
+pass
from huggingface_hub import HfFileSystem
import importlib.util
-from ..device_type import (
- is_hip,
- get_device_type,
- DEVICE_TYPE,
- DEVICE_TYPE_TORCH,
- DEVICE_COUNT,
- ALLOW_PREQUANTIZED_MODELS,
- ALLOW_BITSANDBYTES,
-)
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from unsloth_zoo.utils import Version, _get_dtype
-from unsloth_zoo.hf_utils import dtype_from_config
-from unsloth_zoo.tiled_mlp import patch_tiled_mlp
-
transformers_version = Version(transformers_version)
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
-SUPPORTS_GEMMA = transformers_version >= Version("4.38")
-SUPPORTS_GEMMA2 = transformers_version >= Version("4.42")
+SUPPORTS_GEMMA = transformers_version >= Version("4.38")
+SUPPORTS_GEMMA2 = transformers_version >= Version("4.42")
SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.2")
-SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0")
+SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0")
SUPPORTS_GRANITE = transformers_version >= Version("4.46.0")
-SUPPORTS_QWEN3 = transformers_version >= Version("4.50.3")
-SUPPORTS_QWEN3_MOE = transformers_version >= Version("4.50.3")
-SUPPORTS_FALCON_H1 = transformers_version >= Version("4.53.0")
-SUPPORTS_GEMMA3N = transformers_version >= Version("4.53.0")
-SUPPORTS_GPTOSS = transformers_version >= Version("4.55.0")
-# Transformers v5 meta-device loading corrupts non-persistent buffers (inv_freq).
-# See _fix_rope_inv_freq() below for details.
-_NEEDS_ROPE_FIX = transformers_version >= Version("5.0.0")
if SUPPORTS_GEMMA:
- from .gemma import FastGemmaModel
+ from .gemma import FastGemmaModel
if SUPPORTS_GEMMA2:
from .gemma2 import FastGemma2Model
-if SUPPORTS_FALCON_H1:
- from .falcon_h1 import FastFalconH1Model
+pass
import torch
from ._utils import (
patch_compiling_bitsandbytes,
patch_model_and_tokenizer,
prepare_model_for_kbit_training,
- apply_unsloth_gradient_checkpointing,
+ patch_unsloth_smart_gradient_checkpointing,
patch_compiled_autograd,
process_vision_info,
unsloth_compile_transformers,
- fast_inference_setup,
)
-global FORCE_FLOAT32
-# Forces float32 precision since float16 goes to infinity
-FORCE_FLOAT32 = [
- "gemma3,", # Add comma bc gemma3 will match gemma3n
- "gemma3text", # Gemma3TextModel (EmbeddingGemma, standalone text-only Gemma3)
- "gemma3n",
- "gpt_oss",
- "qwen3_5", # Qwen3.5 RMSNorm uses (1+w) pattern like Gemma3, overflows float16
-]
-
-global DISABLE_COMPILE_MODEL_NAMES
-# Must be alphabetically sorted for each entry
-DISABLE_COMPILE_MODEL_NAMES = [
- "aya_vision",
- "modernbert",
- "granite,llava_next", # Granite-vision 3
-]
-
-global DISABLE_SDPA_MODEL_NAMES
-# Disables some SDPA modules since it's wrong
-DISABLE_SDPA_MODEL_NAMES = [
- "gemma3,", # Add comma bc gemma3 will match gemma3n
- "gemma3_text", # Gemma3TextModel (EmbeddingGemma) - substring match, keep underscore
-]
-
-
-def _fix_rope_inv_freq(model):
- """Fix inv_freq corruption caused by transformers v5 meta-device loading.
-
- Transformers v5 initializes models on the meta device, then
- _move_missing_keys_from_meta_to_device() (modeling_utils.py) replaces ALL
- non-persistent buffers with torch.empty_like() -- uninitialized memory.
-
- Vanilla transformers restores inv_freq via _init_weights() which checks for
- hasattr(module, "original_inv_freq"). Unsloth's LlamaRotaryEmbedding and
- subclasses do not have this attribute, so inv_freq stays corrupted. This
- produces wrong positional encodings and causes 5-11x higher training loss.
-
- This function recomputes inv_freq from the stored base and dim, applies
- any model-specific scaling, and rebuilds the cos/sin caches.
-
- Only runs on transformers >= 5.0.0. No-op on v4.
- """
- if not _NEEDS_ROPE_FIX:
- return model
-
- for name, module in model.named_modules():
- # Unsloth's LlamaRotaryEmbedding and subclasses (Extended, LinearScaling,
- # Granite). Native v5 rotary classes (Gemma3, etc.) have original_inv_freq
- # which v5's _init_weights() uses to restore inv_freq, so they are fine.
- if (
- hasattr(module, "inv_freq")
- and hasattr(module, "base")
- and hasattr(module, "dim")
- and hasattr(module, "_apply_inv_freq_scaling")
- and hasattr(module, "multi_gpu_cos_cached")
- ):
- inv_freq = 1.0 / (
- module.base
- ** (
- torch.arange(
- 0, module.dim, 2, dtype = torch.int64, device = "cpu"
- ).float()
- / module.dim
- )
- )
- inv_freq = module._apply_inv_freq_scaling(inv_freq)
- module.inv_freq = inv_freq
- for device_idx in range(len(module.multi_gpu_cos_cached)):
- if module.multi_gpu_cos_cached[device_idx] is not None:
- module._set_cos_sin_cache(
- seq_len = module.current_rope_size,
- device = torch.device(device_idx),
- dtype = torch.get_default_dtype(),
- )
-
- # LongRopeRotaryEmbedding (Phi-3.5 style with short_inv_freq + long_inv_freq)
- elif (
- hasattr(module, "short_inv_freq")
- and hasattr(module, "long_inv_freq")
- and hasattr(module, "base")
- and hasattr(module, "dim")
- ):
- config = getattr(model, "config", None)
- rope_scaling = getattr(config, "rope_scaling", None) if config else None
- if rope_scaling is not None:
- short_factor = rope_scaling.get("short_factor", None)
- long_factor = rope_scaling.get("long_factor", None)
- if short_factor is not None and long_factor is not None:
- inv_freq_shape = (
- torch.arange(
- 0, module.dim, 2, dtype = torch.int64, device = "cpu"
- ).float()
- / module.dim
- )
- sf = torch.tensor(short_factor, device = "cpu", dtype = torch.float32)
- lf = torch.tensor(long_factor, device = "cpu", dtype = torch.float32)
- module.short_inv_freq = 1.0 / (sf * module.base**inv_freq_shape)
- module.long_inv_freq = 1.0 / (lf * module.base**inv_freq_shape)
-
- dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16
- t = torch.arange(
- module.original_max_position_embeddings,
- device = module.short_inv_freq.device,
- dtype = torch.int64,
- ).float()
- freqs = torch.outer(t, module.short_inv_freq)
- emb = torch.cat((freqs, freqs), dim = -1)
- for device_idx in range(len(module.multi_gpu_short_cos_cached)):
- if module.multi_gpu_short_cos_cached[device_idx] is not None:
- device_obj = torch.device(device_idx)
- module.multi_gpu_short_cos_cached[device_idx] = (
- emb.cos() * module.scaling_factor
- ).to(dtype = dtype, device = device_obj, non_blocking = True)
- module.multi_gpu_short_sin_cached[device_idx] = (
- emb.sin() * module.scaling_factor
- ).to(dtype = dtype, device = device_obj, non_blocking = True)
- return model
-
-
class FastLanguageModel(FastLlamaModel):
@staticmethod
def from_pretrained(
- model_name = "unsloth/Llama-3.2-1B-Instruct",
- max_seq_length = 2048,
- dtype = None,
- load_in_4bit = True, # 4bit QLoRA
- load_in_8bit = False, # 8bit LoRA
- load_in_16bit = False, # 16bit LoRA
- full_finetuning = False,
- token = None,
- device_map = "sequential",
- rope_scaling = None,
- fix_tokenizer = True,
- trust_remote_code = False,
+ model_name = "unsloth/Llama-3.2-1B-Instruct",
+ max_seq_length = 2048,
+ dtype = None,
+ load_in_4bit = True,
+ load_in_8bit = False,
+ full_finetuning = False,
+ token = None,
+ device_map = "sequential",
+ rope_scaling = None,
+ fix_tokenizer = True,
+ trust_remote_code = False,
use_gradient_checkpointing = "unsloth",
- resize_model_vocab = None,
- revision = None,
- use_exact_model_name = False,
- offload_embedding = False,
- float32_mixed_precision = None, # Forces float32 mixed precision
- fast_inference = False, # uses vLLM
- gpu_memory_utilization = 0.5,
- float8_kv_cache = False,
- random_state = 3407,
- max_lora_rank = 64,
- disable_log_stats = True,
- qat_scheme = None,
- load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
- unsloth_tiled_mlp = False,
- *args,
- **kwargs,
+ resize_model_vocab = None,
+ revision = None,
+ use_exact_model_name = False,
+
+ fast_inference = False, # uses vLLM
+ gpu_memory_utilization = 0.5,
+ float8_kv_cache = False,
+ random_state = 3407,
+ max_lora_rank = 64,
+ disable_log_stats = True,
+ *args, **kwargs,
):
- # Respect user-provided quantization_config (e.g. BitsAndBytesConfig)
- quantization_config = kwargs.get("quantization_config", None)
- if quantization_config is not None:
- if isinstance(quantization_config, dict):
- q_load_in_4bit = quantization_config.get("load_in_4bit", False)
- q_load_in_8bit = quantization_config.get("load_in_8bit", False)
- else:
- q_load_in_4bit = getattr(quantization_config, "load_in_4bit", False)
- q_load_in_8bit = getattr(quantization_config, "load_in_8bit", False)
- if q_load_in_4bit:
- load_in_4bit = True
- load_in_8bit = False
- if q_load_in_8bit:
- load_in_8bit = True
- load_in_4bit = False
-
- # Login to allow private models
- token = hf_login(token)
- # Align dtype with bnb_4bit_compute_dtype if provided and dtype is unset.
- if dtype is None and quantization_config is not None:
- bnb_compute_dtype = None
- if isinstance(quantization_config, dict):
- if quantization_config.get("load_in_4bit", False):
- bnb_compute_dtype = quantization_config.get(
- "bnb_4bit_compute_dtype", None
- )
- else:
- if getattr(quantization_config, "load_in_4bit", False):
- bnb_compute_dtype = getattr(
- quantization_config, "bnb_4bit_compute_dtype", None
- )
- if isinstance(bnb_compute_dtype, str):
- bnb_compute_dtype = getattr(torch, bnb_compute_dtype, None)
- if isinstance(bnb_compute_dtype, torch.dtype):
- dtype = bnb_compute_dtype
-
- # Distributed-safe device placement for quantized models.
- # In multi-GPU (torchrun), each rank must load the model on its own device
- # to avoid Accelerate device relocation errors with quantized weights.
- is_quantized = load_in_4bit or load_in_8bit or load_in_fp8
- if is_quantized and isinstance(device_map, str):
- distributed_device_map, is_dist = prepare_device_map()
- if is_dist:
- device_map = distributed_device_map
-
- if load_in_8bit or full_finetuning or qat_scheme is not None:
+ if load_in_8bit or full_finetuning:
return FastModel.from_pretrained(
- model_name = model_name,
- max_seq_length = max_seq_length,
- dtype = dtype,
- load_in_4bit = load_in_4bit,
- load_in_8bit = load_in_8bit,
- load_in_16bit = load_in_16bit,
- full_finetuning = full_finetuning,
- token = token,
- device_map = device_map,
- rope_scaling = rope_scaling, # [TODO] No effect
- fix_tokenizer = fix_tokenizer, # [TODO] No effect
- trust_remote_code = trust_remote_code,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ load_in_4bit = load_in_4bit,
+ load_in_8bit = load_in_8bit,
+ full_finetuning = full_finetuning,
+ token = token,
+ device_map = device_map,
+ rope_scaling = rope_scaling, # [TODO] No effect
+ fix_tokenizer = fix_tokenizer, # [TODO] No effect
+ trust_remote_code = trust_remote_code,
use_gradient_checkpointing = use_gradient_checkpointing,
- resize_model_vocab = resize_model_vocab, # [TODO] No effect
- revision = revision,
- return_logits = False, # Return logits
- fullgraph = True, # No graph breaks
- use_exact_model_name = use_exact_model_name,
- offload_embedding = offload_embedding,
- float32_mixed_precision = float32_mixed_precision,
- # Pass vLLM/inference parameters
- fast_inference = fast_inference,
- gpu_memory_utilization = gpu_memory_utilization,
- float8_kv_cache = float8_kv_cache,
- random_state = random_state,
- max_lora_rank = max_lora_rank,
- disable_log_stats = disable_log_stats,
- qat_scheme = qat_scheme,
- load_in_fp8 = load_in_fp8,
- unsloth_tiled_mlp = unsloth_tiled_mlp,
- *args,
- **kwargs,
+ resize_model_vocab = resize_model_vocab, # [TODO] No effect
+ revision = revision,
+ return_logits = False, # Return logits
+ fullgraph = True, # No graph breaks
+ use_exact_model_name = use_exact_model_name,
+ *args, **kwargs,
)
+ pass
- if isinstance(dtype, str) and dtype in ["float16", "bfloat16"]:
- dtype = getattr(torch, dtype)
- assert (
- dtype is None
- or dtype == torch.float16
- or dtype == torch.bfloat16
- or dtype == torch.float32
- )
+ if token is None: token = get_token()
+ assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16)
+
+ if use_gradient_checkpointing == "unsloth":
+ patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
if fast_inference:
if importlib.util.find_spec("vllm") is None:
raise ImportError(
- "Unsloth: Please install vLLM before enabling `fast_inference`!\n"
+ "Unsloth: Please install vLLM before enabling `fast_inference`!\n"\
"You can do this in a terminal via `pip install vllm`"
)
- if DEVICE_TYPE_TORCH == "cuda":
- for i in range(DEVICE_COUNT):
- # [TODO] DGX Spark vLLM breaks
- if "NVIDIA GB10" in str(torch.cuda.get_device_name(i)).upper():
- print(
- "Unsloth: DGX Spark detected - `fast_inference=True` is currently broken as of January 2026.\n"
- "Defaulting to native Unsloth inference."
- )
- fast_inference = False
- break
-
- # Check if 4bit is allowed specifically for AMD
- if not ALLOW_BITSANDBYTES and not use_exact_model_name:
- if load_in_4bit or load_in_8bit or model_name.lower().endswith("-bnb-4bit"):
- print(
- "Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now."
- )
- load_in_4bit = False
-
- # Find FP8, BnB 4bit, other mapped names
+ pass
+ pass
+
old_model_name = model_name
- fp8_mode = None
if not use_exact_model_name:
- new_model_name = get_model_name(
- model_name,
- load_in_4bit = load_in_4bit,
- load_in_fp8 = load_in_fp8,
- token = token,
- trust_remote_code = trust_remote_code,
- )
- if new_model_name is None and load_in_fp8 != False:
- fp8_mode = _get_fp8_mode_and_check_settings(
- load_in_fp8,
- fast_inference,
- full_finetuning,
- load_in_4bit,
- load_in_8bit,
- load_in_16bit,
- )
- model_name = _offline_quantize_to_fp8(model_name, fp8_mode)
- else:
- assert new_model_name is not None
- model_name = new_model_name
- # If mapper resolved to a pre-quantized FP8 model, disable
- # on-the-fly quantization to avoid double quantization
- if load_in_fp8 != False and new_model_name != old_model_name:
- load_in_fp8 = False
-
- # Check if pre-quantized models are allowed
- # AMD Instinct GPUs need blocksize = 128 on bitsandbytes < 0.49.2 (our pre-quants use blocksize = 64)
- if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
- ("-unsloth-bnb-4bit", "-bnb-4bit")
- ):
- model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
- model_name = model_name.lower().removesuffix("-bnb-4bit")
- # Change -BF16 to all False for 4bit, 8bit etc
- if model_name.lower().endswith("-bf16"):
- load_in_4bit = False
- load_in_8bit = False
- load_in_fp8 = False
- load_in_16bit = True
+ model_name = get_model_name(model_name, load_in_4bit)
if USE_MODELSCOPE and not os.path.exists(model_name):
from modelscope import snapshot_download
-
model_name = snapshot_download(model_name)
+ pass
# First check if it's a normal model via AutoConfig
- from huggingface_hub.utils import (
- disable_progress_bars,
- enable_progress_bars,
- are_progress_bars_disabled,
- )
-
+ from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
was_disabled = are_progress_bars_disabled()
disable_progress_bars()
autoconfig_error = None
peft_error = None
- model_config = None
- peft_config = None
try:
model_config = AutoConfig.from_pretrained(
model_name,
@@ -437,15 +155,8 @@ def from_pretrained(
trust_remote_code = trust_remote_code,
)
is_model = True
- except ImportError:
- raise
except Exception as error:
autoconfig_error = str(error)
- if "architecture" in autoconfig_error:
- raise ValueError(
- f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
- f"Please update transformers via `pip install --upgrade transformers` and try again."
- )
is_model = False
try:
peft_config = PeftConfig.from_pretrained(
@@ -455,186 +166,126 @@ def from_pretrained(
trust_remote_code = trust_remote_code,
)
is_peft = True
- except ImportError:
- raise
except Exception as error:
peft_error = str(error)
- if "architecture" in peft_error:
- raise ValueError(
- f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
- f"Please update transformers via `pip install --upgrade transformers` and try again."
- )
is_peft = False
+ pass
+
+ # Both config.json and adapter_config.json should not exist!
# Old transformers versions check
both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32
- # Error out if both LoRA and normal model config exists.
- if both_exist:
- raise RuntimeError(
- "Unsloth: Your repo has a LoRA adapter and a base model.\n"
- "You have 2 files `config.json` and `adapter_config.json`.\n"
- "We must only allow one config file.\n"
- "Please separate the LoRA and base models to 2 repos."
- )
- model_types = get_transformers_model_type(
- peft_config if peft_config is not None else model_config,
- trust_remote_code = trust_remote_code,
- )
- if len(model_types) == 1:
- model_type = model_types[0]
- else:
- # Leave as tuple if more than one arch
- model_type = model_types
-
# New transformers need to check manually.
if SUPPORTS_LLAMA32:
# Check if folder exists locally
if os.path.isdir(model_name):
- exist_adapter_config = os.path.exists(
- os.path.join(model_name, "adapter_config.json")
- )
- exist_config = os.path.exists(os.path.join(model_name, "config.json"))
+ exist_adapter_config = os.path.exists(os.path.join(model_name, "adapter_config.json"))
+ exist_config = os.path.exists(os.path.join(model_name, "config.json"))
both_exist = exist_adapter_config and exist_config
else:
# Because HfFileSystem assumes linux paths, we need to set the path with forward slashes, even on Windows.
files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
- files = list(os.path.split(x)[-1] for x in files)
- if (
- sum(x == "adapter_config.json" or x == "config.json" for x in files)
- >= 2
- ):
+ files = (os.path.split(x)[-1] for x in files)
+ if sum(x == "adapter_config.json" or x == "config.json" for x in files) >= 2:
both_exist = True
+ pass
+ pass
+ pass
+
+ # Error out if both LoRA and normal model config exists.
+ if both_exist:
+ raise RuntimeError(
+ "Unsloth: Your repo has a LoRA adapter and a base model.\n"\
+ "You have 2 files `config.json` and `adapter_config.json`.\n"\
+ "We must only allow one config file.\n"\
+ "Please separate the LoRA and base models to 2 repos."
+ )
- if not is_model and not is_peft:
- error = autoconfig_error if autoconfig_error is not None else peft_error
+ elif not is_model and not is_peft:
+ error = autoconfig_error or peft_error
# Old transformers version
if "rope_scaling" in error.lower() and not SUPPORTS_LLAMA31:
raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\n"
- f"This includes Llama 3.1. The minimum required version is 4.43.2\n"
- f'Try `pip install --upgrade "transformers>=4.43.2"`\n'
- f"to obtain the latest transformers build, then restart this session."
- )
- # Create a combined error message showing both failures
- combined_error = (
- "Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\n\n"
- f"AutoConfig error: {autoconfig_error}\n\n"
- f"PeftConfig error: {peft_error}\n\n"
- )
- raise RuntimeError(combined_error)
+ f"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\n"\
+ f"This includes Llama 3.1. The minimum required version is 4.43.2\n"\
+ f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\
+ f"to obtain the latest transformers build, then restart this session."\
+ )
+ raise RuntimeError(autoconfig_error or peft_error)
+ pass
# Get base model for PEFT:
if is_peft:
# Check base model again for PEFT
model_name = peft_config.base_model_name_or_path
if not use_exact_model_name:
- model_name = get_model_name(
- model_name,
- load_in_4bit = load_in_4bit,
- load_in_fp8 = load_in_fp8,
- token = token,
- trust_remote_code = trust_remote_code,
- )
- # Check if pre-quantized models are allowed
- # AMD Instinct GPUs need blocksize = 128 on bitsandbytes < 0.49.2 (our pre-quants use blocksize = 64)
- if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
- ("-unsloth-bnb-4bit", "-bnb-4bit")
- ):
- model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
- model_name = model_name.lower().removesuffix("-bnb-4bit")
- # Change -BF16 to all False for 4bit, 8bit etc
- if model_name.lower().endswith("-bf16"):
- load_in_4bit = False
- load_in_8bit = False
- load_in_fp8 = False
- load_in_16bit = True
-
+ model_name = get_model_name(model_name, load_in_4bit)
model_config = AutoConfig.from_pretrained(
model_name,
token = token,
trust_remote_code = trust_remote_code,
)
+ pass
- if not was_disabled:
- enable_progress_bars()
+ if not was_disabled: enable_progress_bars()
+
+ model_type = model_config.model_type
if model_type == "llama":
scaling_type = None
if getattr(model_config, "rope_scaling", None) is not None:
scaling_type1 = model_config.rope_scaling.get("type", None)
scaling_type2 = model_config.rope_scaling.get("rope_type", None)
- scaling_type = (
- scaling_type1 if scaling_type1 is not None else scaling_type2
- )
+ scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
+ pass
if scaling_type == "llama3" and not SUPPORTS_LLAMA31:
raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support Llama 3.1.\n"
- f"The minimum required version is 4.43.2\n"
- f'Try `pip install --upgrade "transformers>=4.43.2"`\n'
- f"to obtain the latest transformers build, then restart this session."
+ f"Unsloth: Your transformers version of {transformers_version} does not support Llama 3.1.\n"\
+ f"The minimum required version is 4.43.2\n"\
+ f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\
+ f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastLlamaModel
- elif model_type == "mistral":
- dispatch_model = FastMistralModel
+ elif model_type == "mistral": dispatch_model = FastMistralModel
elif model_type == "gemma":
if not SUPPORTS_GEMMA:
raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"
- f"The minimum required version is 4.38.\n"
- f'Try `pip install --upgrade "transformers>=4.38"`\n'
- f"to obtain the latest transformers build, then restart this session."
+ f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
+ f"The minimum required version is 4.38.\n"\
+ f'Try `pip install --upgrade "transformers>=4.38"`\n'\
+ f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastGemmaModel
elif model_type == "gemma2":
if not SUPPORTS_GEMMA2:
raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"
- f"The minimum required version is 4.42.3.\n"
- f'Try `pip install --upgrade "transformers>=4.42.3"`\n'
- f"to obtain the latest transformers build, then restart this session."
+ f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
+ f"The minimum required version is 4.42.3.\n"\
+ f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
+ f"to obtain the latest transformers build, then restart this session."\
)
# Also check for softcapping support in flash-attn which is faster!
if is_bfloat16_supported() and not HAS_FLASH_ATTENTION:
print(
- "Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!\n"
- "To install flash-attn, do the below:\n"
+ "Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!\n"\
+ "To install flash-attn, do the below:\n"\
'\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
)
elif HAS_FLASH_ATTENTION and not HAS_FLASH_ATTENTION_SOFTCAPPING:
print(
- "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"
- "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"
- "To update flash-attn, do the below:\n"
+ "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
+ "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
+ "To update flash-attn, do the below:\n"\
'\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
)
-
+
dispatch_model = FastGemma2Model
elif model_type == "qwen2":
dispatch_model = FastQwen2Model
- elif model_type == "qwen3": # or model_type == "qwen3_moe":
- if not SUPPORTS_QWEN3 or not SUPPORTS_QWEN3_MOE:
- raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3.\n"
- f"The minimum required version is 4.50.3.\n"
- f'Try `pip install --upgrade "transformers>=4.50.3"`\n'
- f"to obtain the latest transformers build, then restart this session."
- )
- dispatch_model = (
- FastQwen3Model if model_type == "qwen3" else FastQwen3MoeModel
- )
- # elif model_type == "falcon_h1":
- # dispatch_model = FastFalconH1Model
- # if not SUPPORTS_FALCON_H1:
- # raise ImportError(
- # f"Unsloth: Your transformers version of {transformers_version} does not support FalconH1.\n"\
- # f"The minimum required version is 4.50.3.\n"\
- # f'Try `pip install --upgrade "transformers>=4.50.3"`\n'\
- # f"to obtain the latest transformers build, then restart this session."\
- # )
# Temporary disable optimized Cohere until errors match
# elif model_type == "cohere":
# dispatch_model = FastCohereModel
@@ -643,133 +294,113 @@ def from_pretrained(
# dispatch_model = FastGraniteModel
else:
return FastModel.from_pretrained(
- model_name = old_model_name,
- max_seq_length = max_seq_length,
- dtype = dtype,
- load_in_4bit = load_in_4bit,
- load_in_8bit = load_in_8bit,
- load_in_16bit = load_in_16bit,
- full_finetuning = full_finetuning,
- token = token,
- device_map = device_map,
- rope_scaling = rope_scaling, # [TODO] No effect
- fix_tokenizer = fix_tokenizer, # [TODO] No effect
- trust_remote_code = trust_remote_code,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ load_in_4bit = load_in_4bit,
+ load_in_8bit = load_in_8bit,
+ full_finetuning = full_finetuning,
+ token = token,
+ device_map = device_map,
+ rope_scaling = rope_scaling, # [TODO] No effect
+ fix_tokenizer = fix_tokenizer, # [TODO] No effect
+ trust_remote_code = trust_remote_code,
use_gradient_checkpointing = use_gradient_checkpointing,
- resize_model_vocab = resize_model_vocab, # [TODO] No effect
- revision = revision,
- return_logits = False, # Return logits
- fullgraph = True, # No graph breaks
- use_exact_model_name = use_exact_model_name,
- offload_embedding = offload_embedding,
- float32_mixed_precision = float32_mixed_precision,
- # Pass vLLM/inference parameters
- fast_inference = fast_inference,
- gpu_memory_utilization = gpu_memory_utilization,
- float8_kv_cache = float8_kv_cache,
- random_state = random_state,
- max_lora_rank = max_lora_rank,
- disable_log_stats = disable_log_stats,
- qat_scheme = qat_scheme,
- load_in_fp8 = load_in_fp8,
- unsloth_tiled_mlp = unsloth_tiled_mlp,
- *args,
- **kwargs,
+ resize_model_vocab = resize_model_vocab, # [TODO] No effect
+ revision = revision,
+ return_logits = False, # Return logits
+ fullgraph = True, # No graph breaks
+ use_exact_model_name = use_exact_model_name,
+ *args, **kwargs,
)
-
- # Apply gradient checkpointing with smart heuristics
- use_gradient_checkpointing = apply_unsloth_gradient_checkpointing(
- use_gradient_checkpointing, max_seq_length, dtype
- )
+ pass
# Check if this is local model since the tokenizer gets overwritten
- if (
- os.path.exists(os.path.join(old_model_name, "tokenizer_config.json"))
- and os.path.exists(os.path.join(old_model_name, "tokenizer.json"))
- and os.path.exists(os.path.join(old_model_name, "special_tokens_map.json"))
- ):
+ if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \
+ os.path.exists(os.path.join(old_model_name, "tokenizer.json")) and \
+ os.path.exists(os.path.join(old_model_name, "special_tokens_map.json")):
+
tokenizer_name = old_model_name
else:
- tokenizer_name = kwargs.pop("tokenizer_name", None)
+ tokenizer_name = None
+ pass
if fast_inference:
- fast_inference, model_name = fast_inference_setup(model_name, model_config)
-
- load_in_4bit_kwargs = load_in_4bit
- load_in_8bit_kwargs = load_in_8bit
- if quantization_config is not None and not fast_inference:
- load_in_4bit_kwargs = False
- load_in_8bit_kwargs = False
+ import platform
+ if platform.system().lower() == 'windows':
+ print("Unsloth: vLLM does not work in Windows! Will use Unsloth inference!")
+ fast_inference = False
+ pass
+ from unsloth_zoo.vllm_utils import (
+ patch_vllm,
+ vllm_dynamic_quant_supported,
+ )
+ patch_vllm()
+ if model_name.endswith("unsloth-bnb-4bit"):
+ if not vllm_dynamic_quant_supported(model_name, model_config):
+ # Instead use -bnb-4bit variant
+ print(
+ f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"\
+ f"we do not yet support fast inference for {model_name}"
+ )
+ model_name = model_name[:-len("unsloth-bnb-4bit")] + "bnb-4bit"
+ pass
+ pass
+ pass
model, tokenizer = dispatch_model.from_pretrained(
- model_name = model_name,
- max_seq_length = max_seq_length,
- dtype = _get_dtype(dtype),
- load_in_4bit = load_in_4bit_kwargs,
- token = token,
- device_map = device_map,
- rope_scaling = rope_scaling,
- fix_tokenizer = fix_tokenizer,
- model_patcher = dispatch_model,
- tokenizer_name = tokenizer_name,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = _get_dtype(dtype),
+ load_in_4bit = load_in_4bit,
+ token = token,
+ device_map = device_map,
+ rope_scaling = rope_scaling,
+ fix_tokenizer = fix_tokenizer,
+ model_patcher = dispatch_model,
+ tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
- revision = revision if not is_peft else None,
- fast_inference = fast_inference,
+ revision = revision if not is_peft else None,
+
+ fast_inference = fast_inference,
gpu_memory_utilization = gpu_memory_utilization,
- float8_kv_cache = float8_kv_cache,
- random_state = random_state,
- max_lora_rank = max_lora_rank,
+ float8_kv_cache = float8_kv_cache,
+ random_state = random_state,
+ max_lora_rank = max_lora_rank,
disable_log_stats = disable_log_stats,
- load_in_fp8 = load_in_fp8,
- *args,
- **kwargs,
+ *args, **kwargs,
)
-
+
if resize_model_vocab is not None:
model.resize_token_embeddings(resize_model_vocab)
+ pass
# In case the model supports tagging, add the unsloth tag.
if hasattr(model, "add_model_tags"):
- model.add_model_tags(
- [
- "unsloth",
- ]
- )
+ model.add_model_tags(["unsloth",])
+ pass
if hasattr(tokenizer, "add_model_tags"):
- tokenizer.add_model_tags(
- [
- "unsloth",
- ]
- )
+ tokenizer.add_model_tags(["unsloth",])
+ pass
if load_in_4bit:
- # Fix up bitsandbytes config, but respect user-provided quantization_config
- if quantization_config is None:
- compute_dtype = dtype_from_config(model.config)
- quantization_config = {
- # Sometimes compute_dtype is not a string!!
- "bnb_4bit_compute_dtype": compute_dtype,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_use_double_quant": True,
- "llm_int8_enable_fp32_cpu_offload": False,
- "llm_int8_has_fp16_weight": False,
- "llm_int8_skip_modules": None,
- "llm_int8_threshold": 6.0,
- "load_in_4bit": True,
- "load_in_8bit": False,
- "quant_method": "bitsandbytes",
- }
- model.config.update({"quantization_config": quantization_config})
- else:
- if hasattr(quantization_config, "to_dict"):
- model.config.update(
- {"quantization_config": quantization_config.to_dict()}
- )
- elif isinstance(quantization_config, dict):
- model.config.update({"quantization_config": quantization_config})
-
- if load_in_fp8 != False:
- _tag_model_with_fp8_torchao_config(model, fp8_mode)
+ # Fix up bitsandbytes config
+ quantization_config = \
+ {
+ # Sometimes torch_dtype is not a string!!
+ "bnb_4bit_compute_dtype" : model.config.to_dict()["torch_dtype"],
+ "bnb_4bit_quant_type" : "nf4",
+ "bnb_4bit_use_double_quant" : True,
+ "llm_int8_enable_fp32_cpu_offload" : False,
+ "llm_int8_has_fp16_weight" : False,
+ "llm_int8_skip_modules" : None,
+ "llm_int8_threshold" : 6.0,
+ "load_in_4bit" : True,
+ "load_in_8bit" : False,
+ "quant_method" : "bitsandbytes",
+ }
+ model.config.update({"quantization_config" : quantization_config})
+ pass
if is_peft:
# From https://github.com/huggingface/peft/issues/184
@@ -784,17 +415,10 @@ def from_pretrained(
)
# Patch it as well!
model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)
-
- # Patch Tiled MLP
- # to turn on set UNSLOTH_TILED_MLP to "arctic", "target", or "target:{GB}""
- patch_tiled_mlp_choice = os.environ.get(
- "UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0"
- )
- if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp:
- patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
-
- model = _fix_rope_inv_freq(model)
+ pass
return model, tokenizer
+ pass
+pass
from ..kernels import (
@@ -805,239 +429,104 @@ def from_pretrained(
from transformers import (
AutoModelForCausalLM,
)
-
try:
from transformers import AutoModelForImageTextToText
-
AutoModelForVision2Seq = AutoModelForImageTextToText
except:
from transformers import AutoModelForVision2Seq
+pass
class FastModel(FastBaseModel):
- @staticmethod
- def _prepare_for_qat(model, qat_scheme):
- model = _prepare_model_for_qat(model, qat_scheme)
- return model
-
@staticmethod
def from_pretrained(
- model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
- max_seq_length = 2048,
- dtype = None,
- load_in_4bit = True, # 4bit QLoRA
- load_in_8bit = False, # 8bit LoRA
- load_in_16bit = False, # 16bit LoRA
- full_finetuning = False,
- token = None,
- device_map = "sequential",
- rope_scaling = None, # [TODO] No effect
- fix_tokenizer = True, # [TODO] No effect
- trust_remote_code = False,
+ model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
+ max_seq_length = 2048,
+ dtype = None,
+ load_in_4bit = True,
+ load_in_8bit = False,
+ full_finetuning = False,
+ token = None,
+ device_map = "sequential",
+ rope_scaling = None, # [TODO] No effect
+ fix_tokenizer = True, # [TODO] No effect
+ trust_remote_code = False,
use_gradient_checkpointing = "unsloth",
- resize_model_vocab = None, # [TODO] No effect
- revision = None,
- return_logits = False, # Return logits
- fullgraph = True, # No graph breaks
- use_exact_model_name = False,
- auto_model = None,
- whisper_language = None,
- whisper_task = None,
- unsloth_force_compile = False,
- offload_embedding = False,
- float32_mixed_precision = None, # Forces float32 mixed precision
- # Add the missing vLLM/inference parameters
- fast_inference = False, # uses vLLM
- gpu_memory_utilization = 0.5,
- float8_kv_cache = False,
- random_state = 3407,
- max_lora_rank = 64,
- disable_log_stats = True,
- qat_scheme = None,
- load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
- unsloth_tiled_mlp = False,
- target_parameters = None, # For MoE expert parameters
- *args,
- **kwargs,
+ resize_model_vocab = None, # [TODO] No effect
+ revision = None,
+ return_logits = False, # Return logits
+ fullgraph = True, # No graph breaks
+ use_exact_model_name = False,
+ *args, **kwargs,
):
- # Respect user-provided quantization_config (e.g. BitsAndBytesConfig)
- quantization_config = kwargs.get("quantization_config", None)
- if quantization_config is not None:
- if isinstance(quantization_config, dict):
- q_load_in_4bit = quantization_config.get("load_in_4bit", False)
- q_load_in_8bit = quantization_config.get("load_in_8bit", False)
- else:
- q_load_in_4bit = getattr(quantization_config, "load_in_4bit", False)
- q_load_in_8bit = getattr(quantization_config, "load_in_8bit", False)
- if q_load_in_4bit:
- load_in_4bit = True
- load_in_8bit = False
- if q_load_in_8bit:
- load_in_8bit = True
- load_in_4bit = False
-
- # Login to allow private models
- token = hf_login(token)
- if whisper_language is not None:
- assert type(whisper_language) is str
- if whisper_task is not None:
- assert type(whisper_task) is str
- # Align dtype with bnb_4bit_compute_dtype if provided and dtype is unset.
- if dtype is None and quantization_config is not None:
- bnb_compute_dtype = None
- if isinstance(quantization_config, dict):
- if quantization_config.get("load_in_4bit", False):
- bnb_compute_dtype = quantization_config.get(
- "bnb_4bit_compute_dtype", None
- )
- else:
- if getattr(quantization_config, "load_in_4bit", False):
- bnb_compute_dtype = getattr(
- quantization_config, "bnb_4bit_compute_dtype", None
- )
- if isinstance(bnb_compute_dtype, str):
- bnb_compute_dtype = getattr(torch, bnb_compute_dtype, None)
- if isinstance(bnb_compute_dtype, torch.dtype):
- dtype = bnb_compute_dtype
- SUPPORTS_BFLOAT16 = is_bfloat16_supported()
- if dtype is None:
- dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
- elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
- logger.warning_once(
- "Device does not support bfloat16. Will change to float16."
- )
- dtype = torch.float16
- assert dtype in (torch.float16, torch.bfloat16, torch.float32)
- assert load_in_fp8 in (True, False, "block")
+ if token is None: token = get_token()
+ assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16)
patch_compiled_autograd()
patch_compiling_bitsandbytes()
+ if use_gradient_checkpointing == "unsloth":
+ patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
if full_finetuning and (load_in_4bit or load_in_8bit):
- print(
- "Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA."
- )
+ print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.")
load_in_4bit = False
load_in_8bit = False
- load_in_fp8 = False
- load_in_16bit = False
+ pass
- if (
- int(load_in_4bit)
- + int(load_in_8bit)
- + int(load_in_16bit)
- + int(load_in_fp8 != False)
- >= 2
- ):
+ if load_in_4bit and load_in_8bit:
raise RuntimeError(
- "Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!\n"
- "Also, we by default set `load_in_4bit = True`.\n"
- "If you want 8bit finetuning, set both `load_in_4bit = False` and `load_in_8bit = True`\n"
- "If you want 16bit LoRA finetuning, set `load_in_16bit = True`"
- )
-
- if qat_scheme is not None and not full_finetuning:
- raise ValueError(
- "Specifying `qat_scheme` in `FastLanguageModel.from_pretrained(...)` is only "
- "compatible with `full_finetuning=True`. If you wish to use QAT with LoRA, "
- "please pass in `qat_scheme` in `FastLanguageModel.get_peft_model(...)` instead."
+ "Unsloth: Can only load in 4bit or 8bit, not both!\n"\
+ "Also, we by default set `load_in_4bit = True`.\n"\
+ "If you want 8bit finetuning, set both `load_in_4bit = False` and `load_in_8bit = True`"
)
- if qat_scheme == "phone-deployment":
- qat_scheme = "int8-int4"
-
- # Distributed-safe device placement for quantized models.
- # In multi-GPU (torchrun), each rank must load the model on its own device
- # to avoid Accelerate device relocation errors with quantized weights.
- is_quantized = load_in_4bit or load_in_8bit or load_in_fp8
- if is_quantized and isinstance(device_map, str):
- distributed_device_map, is_dist = prepare_device_map()
- if is_dist:
- device_map = distributed_device_map
-
- # Check if 4bit is allowed specifically for AMD
- if not ALLOW_BITSANDBYTES and not use_exact_model_name:
- if load_in_4bit or load_in_8bit or model_name.lower().endswith("-bnb-4bit"):
- print(
- "Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now."
- )
- load_in_4bit = False
+ if load_in_4bit: pass
+ elif load_in_8bit: pass
+ elif not load_in_4bit and not load_in_8bit and not full_finetuning:
+ print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.")
+ load_in_4bit = True
+ pass
- if fast_inference:
- if importlib.util.find_spec("vllm") is None:
- raise ImportError(
- "Unsloth: Please install vLLM before enabling `fast_inference`!\n"
- "You can do this in a terminal via `pip install vllm`"
- )
- if DEVICE_TYPE_TORCH == "cuda":
- for i in range(DEVICE_COUNT):
- # [TODO] DGX Spark vLLM breaks
- if "NVIDIA GB10" in str(torch.cuda.get_device_name(i)).upper():
- print(
- "Unsloth: DGX Spark detected - `fast_inference=True` is currently broken as of January 2026.\n"
- "Defaulting to native Unsloth inference."
- )
- fast_inference = False
- break
-
- # Find FP8, BnB 4bit, other mapped names
old_model_name = model_name
- fp8_mode = None
if not use_exact_model_name:
- new_model_name = get_model_name(
- model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
- )
- if new_model_name is None and load_in_fp8 != False:
- fp8_mode = _get_fp8_mode_and_check_settings(
- load_in_fp8,
- fast_inference,
- full_finetuning,
- load_in_4bit,
- load_in_8bit,
- load_in_16bit,
- )
- model_name = _offline_quantize_to_fp8(model_name, fp8_mode)
- else:
- assert new_model_name is not None
- model_name = new_model_name
- # If mapper resolved to a pre-quantized FP8 model, disable
- # on-the-fly quantization to avoid double quantization
- if load_in_fp8 != False and new_model_name != old_model_name:
- load_in_fp8 = False
+ model_name = get_model_name(model_name, load_in_4bit)
- # Check if pre-quantized models are allowed
- # AMD Instinct GPUs need blocksize = 128 on bitsandbytes < 0.49.2 (our pre-quants use blocksize = 64)
- if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
- ("-unsloth-bnb-4bit", "-bnb-4bit")
- ):
- model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
- model_name = model_name.lower().removesuffix("-bnb-4bit")
- # Change -BF16 to all False for 4bit, 8bit etc
- if model_name.lower().endswith("-bf16"):
- load_in_4bit = False
- load_in_8bit = False
- load_in_fp8 = False
- load_in_16bit = True
+ # Check versions
+ LATEST = '\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`'
+ NIGHTLY = '\nPlease use nightly transformers via pip install --upgrade "transformers>=4.49.0"`'
+ if "pixtral" in model_name.lower() and transformers_version < Version("4.49.0"):
+ raise RuntimeError("Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST)
+ elif "qwen2.5" in model_name.lower() and transformers_version < Version("4.49.0"):
+ raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST)
+ elif "aya-vision" in model_name.lower():
+ # Disable compiling for now - errors out!
+ os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
+ if transformers_version < Version("4.50.0.dev0"):
+ raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY)
+ elif "gemma-3" in model_name.lower() and transformers_version < Version("4.50.0.dev0"):
+ raise RuntimeError("Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY)
+ elif "c4ai-command-a-03-2025" in model_name.lower() and transformers_version < Version("4.50.0.dev0"):
+ raise RuntimeError("Unsloth: Cohere's Command model only works on transformers >= 4.50.0." + NIGHTLY)
+ elif "granite-vision" in model_name.lower():
+ # Disable compiling for now - errors out!
+ os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
+ if transformers_version < Version("4.50.0.dev0"):
+ raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY)
+ elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"):
+ raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY)
+ pass
- # Check modelscope
if USE_MODELSCOPE and not os.path.exists(model_name):
from modelscope import snapshot_download
-
model_name = snapshot_download(model_name)
+ pass
# First check if it's a normal model via AutoConfig
- from huggingface_hub.utils import (
- disable_progress_bars,
- enable_progress_bars,
- are_progress_bars_disabled,
- )
-
+ from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
was_disabled = are_progress_bars_disabled()
disable_progress_bars()
autoconfig_error = None
peft_error = None
- model_config = None
- peft_config = None
try:
model_config = AutoConfig.from_pretrained(
model_name,
@@ -1046,15 +535,8 @@ def from_pretrained(
trust_remote_code = trust_remote_code,
)
is_model = True
- except ImportError:
- raise
except Exception as error:
autoconfig_error = str(error)
- if "architecture" in autoconfig_error:
- raise ValueError(
- f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
- f"Please update transformers via `pip install --upgrade transformers` and try again."
- )
is_model = False
try:
peft_config = PeftConfig.from_pretrained(
@@ -1064,222 +546,53 @@ def from_pretrained(
trust_remote_code = trust_remote_code,
)
is_peft = True
- except ImportError:
- raise
except Exception as error:
peft_error = str(error)
- if "architecture" in peft_error:
- raise ValueError(
- f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
- f"Please update transformers via `pip install --upgrade transformers` and try again."
- )
is_peft = False
- # Old transformers versions check
- both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32
- # Error out if both LoRA and normal model config exists.
- if both_exist:
- raise RuntimeError(
- "Unsloth: Your repo has a LoRA adapter and a base model.\n"
- "You have 2 files `config.json` and `adapter_config.json`.\n"
- "We must only allow one config file.\n"
- "Please separate the LoRA and base models to 2 repos."
- )
- model_types = get_transformers_model_type(
- peft_config if peft_config is not None else model_config,
- trust_remote_code = trust_remote_code,
- )
- model_types_all = ",".join(model_types) + ","
-
- # Save model types and loading method
- lowered_model_name = model_name.lower()
- string = os.environ.get("UNSLOTH_MODEL_NAME", "") + model_types_all
- if load_in_4bit:
- string += "_load_in_4bit_"
- if load_in_8bit:
- string += "_load_in_8bit_"
- if load_in_16bit:
- string += "_load_in_16bit_"
- if load_in_fp8:
- string += "load_in_fp8"
- os.environ["UNSLOTH_MODEL_NAME"] = string
-
- # Check versions
- LATEST = "\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`"
- NIGHTLY = '\nPlease use nightly transformers via pip install --upgrade "transformers>=4.49.0"`'
- # Pixtral
- if "pixtral" in model_types_all and transformers_version < Version("4.49.0"):
- raise RuntimeError(
- "Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST
- )
- # Qwen 2.5
- elif "qwen2_5" in model_types_all and transformers_version < Version("4.49.0"):
- raise RuntimeError(
- "Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST
- )
- # Gemma 3N must be before Gemma 3
- elif "gemma3n" in model_types_all:
- if transformers_version < Version("4.53.0"):
- raise RuntimeError(
- "Unsloth: Gemma 3N only works on transformers >= 4.53.0" + LATEST
- )
- os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
- os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
- "float16;torch.float16;torch.float16;"
- "if name.endswith('norm'): "
- "module._pre_set_compute_dtype = torch.float32\n"
- ";"
- "from unsloth_zoo.temporary_patches.gemma3n import patch_Gemma3nConv_Embed_forwards; patch_Gemma3nConv_Embed_forwards()"
- )
- # Set norms to float32 since anyways they get upcasted to float32
- # common in both gemma-3 and gemma-3n
- os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
- # Gemma 3
- elif "gemma3" in model_types_all:
- if transformers_version < Version("4.50.0.dev0"):
- raise RuntimeError(
- "Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY
- )
- # Set norms to float32 since anyways they get upcasted to float32
- # common in both gemma-3 and gemma-3n
- os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
- # ROCm/HIP: Gemma3 compiled forward produces NaN on RDNA GPUs
- # (gfx1100, gfx1101, gfx1102, gfx1150, gfx1151, etc.).
- # Disable torch.compile for model forward; loss compilation is fine.
- # See https://github.com/unslothai/unsloth/issues/3385
- from unsloth.kernels.utils import is_rdna
+ pass
- if is_rdna():
- os.environ["UNSLOTH_COMPILE_DISABLE"] = "partial"
- # Cohere
- elif "cohere2" in model_types_all and transformers_version < Version(
- "4.50.0.dev0"
- ):
- raise RuntimeError(
- "Unsloth: Cohere's Command model only works on transformers >= 4.50.0."
- + NIGHTLY
- )
- # Sesame
- elif "csm" in model_types_all:
- os.environ["UNSLOTH_COMPILE_DISABLE"] = "partial" # Inference is too slow
- os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Sesame fails
- os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
- "all;torch.float32;torch.float16;"
- "if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)"
- ";"
- )
- # Granite 4
- elif "granitemoehybrid" in model_types_all:
- # Granite-4 rms norms are stored as 16 bit, but we upcast
- os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
- os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
- # Olmo 2
- elif "olmo2" in model_types_all and transformers_version < Version(
- "4.50.0.dev0"
- ):
- raise RuntimeError(
- "Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY
- )
- elif "falcon_h1" in model_types_all:
- # Falcon must use float32 Triton ie TRITON_F32_DEFAULT = 'ieee'
- # since Mamba kernels error out on using lower precision
- os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
- "float16;torch.float32;torch.float16;"
- "if name.endswith(('q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'head')): module.to(torch.float16)"
- ";"
- "os.environ['TRITON_F32_DEFAULT'] = 'ieee'"
- )
- elif "nemotron_h" in model_types_all:
- # NemotronH (hybrid Mamba-2 + Transformer) uses same Mamba kernels as Falcon-H1
- # Mamba kernels need float32 Triton precision
- os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
- "float16;torch.float32;torch.float16;"
- "if name.endswith(('q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'head')): module.to(torch.float16)"
- ";"
- "os.environ['TRITON_F32_DEFAULT'] = 'ieee'"
- )
- elif "gpt_oss" in model_types_all:
- os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
- if not load_in_4bit:
- # Only upcast MoE biases for MXFP4, not BnB
- # Set norms to float32 since anyways they get upcasted to float32
- os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
- "all;None;None;"
- "x = 'gate_up_proj_bias'\n"
- "if hasattr(module, x): "
- "setattr(module, x, torch.nn.Parameter(getattr(module, x).to(torch.float32)) if isinstance(getattr(module, x), torch.nn.Parameter) else getattr(module, x).to(torch.float32))\n"
- ""
- "x = 'down_proj_bias'\n"
- "if hasattr(module, x): "
- "setattr(module, x, torch.nn.Parameter(getattr(module, x).to(torch.float32)) if isinstance(getattr(module, x), torch.nn.Parameter) else getattr(module, x).to(torch.float32))\n"
- ""
- ";"
- )
- else:
- # Set down projection compute dtype to be float32 for float16 machines
- # Set norms to float32 since anyways they get upcasted to float32
- os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
- "torch.float16;torch.bfloat16;torch.float16;"
- "if ('down_projs' in name) and hasattr(module, 'weight') and "
- "torch.amax(dequantize_module_weight(module)) >= 0:"
- "module._pre_set_compute_dtype = torch.float32\n"
- ""
- "if ('mlp.router' in name) and hasattr(module, 'weight'):"
- "module._pre_set_compute_dtype = torch.float32\n"
- ";"
- )
- # Set norms to float32 since anyways they get upcasted to float32
- os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
- else:
- for check_model_name in DISABLE_COMPILE_MODEL_NAMES:
- if check_model_name in lowered_model_name:
- os.environ["UNSLOTH_COMPILE_DISABLE"] = "partial"
- os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
- if transformers_version < Version("4.50.0.dev0"):
- raise RuntimeError(
- f"Unsloth: {check_model_name} only works on transformers >= 4.50.0."
- + NIGHTLY
- )
- break
+ # Both config.json and adapter_config.json should not exist!
- if auto_model is not None:
- # All other models need to disable static cache
- os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
+ # Old transformers versions check
+ both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32
# New transformers need to check manually.
if SUPPORTS_LLAMA32:
# Check if folder exists locally
if os.path.isdir(model_name):
- exist_adapter_config = os.path.exists(
- os.path.join(model_name, "adapter_config.json")
- )
- exist_config = os.path.exists(os.path.join(model_name, "config.json"))
+ exist_adapter_config = os.path.exists(os.path.join(model_name, "adapter_config.json"))
+ exist_config = os.path.exists(os.path.join(model_name, "config.json"))
both_exist = exist_adapter_config and exist_config
else:
files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
- files = list(os.path.split(x)[-1] for x in files)
- if (
- sum(x == "adapter_config.json" or x == "config.json" for x in files)
- >= 2
- ):
+ files = (os.path.split(x)[-1] for x in files)
+ if sum(x == "adapter_config.json" or x == "config.json" for x in files) >= 2:
both_exist = True
+ pass
+ pass
+ pass
+
+ # Error out if both LoRA and normal model config exists.
+ if both_exist:
+ raise RuntimeError(
+ "Unsloth: Your repo has a LoRA adapter and a base model.\n"\
+ "You have 2 files `config.json` and `adapter_config.json`.\n"\
+ "We must only allow one config file.\n"\
+ "Please separate the LoRA and base models to 2 repos."
+ )
- if not is_model and not is_peft:
- error = autoconfig_error if autoconfig_error is not None else peft_error
+ elif not is_model and not is_peft:
+ error = autoconfig_error or peft_error
# Old transformers version
if "rope_scaling" in error.lower() and not SUPPORTS_LLAMA31:
raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\n"
- f"This includes Llama 3.1. The minimum required version is 4.43.2\n"
- f'Try `pip install --upgrade "transformers>=4.43.2"`\n'
- f"to obtain the latest transformers build, then restart this session."
- )
- # Create a combined error message showing both failures
- combined_error = (
- "Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\n\n"
- f"AutoConfig error: {autoconfig_error}\n\n"
- f"PeftConfig error: {peft_error}\n\n"
- )
- raise RuntimeError(combined_error)
+ f"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\n"\
+ f"This includes Llama 3.1. The minimum required version is 4.43.2\n"\
+ f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\
+ f"to obtain the latest transformers build, then restart this session."\
+ )
+ raise RuntimeError(autoconfig_error or peft_error)
+ pass
# Get base model for PEFT:
if is_peft:
@@ -1287,28 +600,15 @@ def from_pretrained(
model_name = peft_config.base_model_name_or_path
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)
- # Check if pre-quantized models are allowed
- # AMD Instinct GPUs need blocksize = 128 on bitsandbytes < 0.49.2 (our pre-quants use blocksize = 64)
- if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
- ("-unsloth-bnb-4bit", "-bnb-4bit")
- ):
- model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
- model_name = model_name.lower().removesuffix("-bnb-4bit")
- # Change -BF16 to all False for 4bit, 8bit etc
- if model_name.lower().endswith("-bf16"):
- load_in_4bit = False
- load_in_8bit = False
- load_in_fp8 = False
- load_in_16bit = True
-
+
model_config = AutoConfig.from_pretrained(
model_name,
token = token,
trust_remote_code = trust_remote_code,
)
+ pass
- if not was_disabled:
- enable_progress_bars()
+ if not was_disabled: enable_progress_bars()
do_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
if do_logging:
@@ -1316,185 +616,100 @@ def from_pretrained(
else:
redirector = contextlib.redirect_stdout(open(os.devnull, "w"))
- model_types = ["siglip"] + model_types
- # Set forced float32 env flag
- os.environ["UNSLOTH_FORCE_FLOAT32"] = "0"
- do_forced_float32 = False
- for model_type_arch in model_types:
- if model_type_arch != "siglip":
- break
- global FORCE_FLOAT32
- for disable_name in FORCE_FLOAT32:
- # add comma to model_types_all matching in case of exact match for end
- if (
- disable_name.lower()
- == model_type_arch.lower().replace("-", "").replace("_", "")
- or disable_name.lower() in model_types_all
- ) and ((dtype == torch.float16) or not SUPPORTS_BFLOAT16):
- os.environ["UNSLOTH_FORCE_FLOAT32"] = "1"
- dtype = torch.bfloat16 # Change to bfloat16 loading
- break
- # Apply gradient checkpointing with smart heuristics
- use_gradient_checkpointing = apply_unsloth_gradient_checkpointing(
- use_gradient_checkpointing, max_seq_length, dtype
- )
with redirector:
patch_loss_functions(torch_compile = False)
- model_types, supports_sdpa = unsloth_compile_transformers(
- dtype = dtype,
- model_name = model_name,
- model_types = model_types,
- token = token,
- sdpa_dynamic_mask = True,
- sdpa_bool_masks = True,
- sdpa_gqa_replace = True,
- sdpa_dynamic_compile = True,
- compile_attention = True,
- disable_causal_masks = True,
- compile_torch_modules = True,
- compile_custom_modules = True,
- compile_function_calls = True,
- fuse_lm_head = True,
- gradient_checkpointing = True,
- manual_replacements = True,
- fast_lora_forwards = True,
- fast_residual_stream = False,
- accurate_accumulation = True,
- epilogue_fusion = True,
- max_autotune = False,
- shape_padding = True,
- cudagraphs = False,
- debug = False,
- fullgraph = fullgraph,
- import_from_cache = False,
- disable = False,
- return_logits = return_logits,
- trust_remote_code = trust_remote_code,
- unsloth_force_compile = unsloth_force_compile,
- )
- # Fix SDPA issues
- for model_type in DISABLE_SDPA_MODEL_NAMES:
- if model_type in model_types_all:
- supports_sdpa = False
+ model_types = unsloth_compile_transformers(
+ model_name = model_name,
+ sdpa_dynamic_mask = True,
+ sdpa_bool_masks = True,
+ sdpa_gqa_replace = True,
+ sdpa_dynamic_compile = True,
+ compile_attention = True,
+ disable_causal_masks = True,
+ compile_torch_modules = True,
+ compile_custom_modules = True,
+ compile_function_calls = True,
+ fuse_lm_head = True,
+ gradient_checkpointing = True,
+ manual_replacements = True,
+ fast_lora_forwards = True,
+ fast_residual_stream = False,
+ accurate_accumulation = True,
+ epilogue_fusion = True,
+ max_autotune = False,
+ shape_padding = True,
+ cudagraphs = False,
+ debug = False,
+ fullgraph = fullgraph,
+ import_from_cache = False,
+ disable = False,
+ return_logits = return_logits,
+ )
+ pass
# Check if this is local model since the tokenizer gets overwritten
- if (
- os.path.exists(os.path.join(old_model_name, "tokenizer_config.json"))
- and os.path.exists(os.path.join(old_model_name, "tokenizer.json"))
- and os.path.exists(os.path.join(old_model_name, "special_tokens_map.json"))
- ):
+ if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \
+ os.path.exists(os.path.join(old_model_name, "tokenizer.json")) and \
+ os.path.exists(os.path.join(old_model_name, "special_tokens_map.json")):
+
tokenizer_name = old_model_name
else:
- tokenizer_name = kwargs.pop("tokenizer_name", None)
+ tokenizer_name = None
+ pass
# Check if VLM
- architectures = getattr(model_config, "architectures", None)
- if architectures is None:
- architectures = []
- is_vlm = any(x.endswith("ForConditionalGeneration") for x in architectures)
+ is_vlm = any(x.endswith("ForConditionalGeneration") for x in model_config.architectures)
is_vlm = is_vlm or hasattr(model_config, "vision_config")
- if auto_model is None:
- if is_vlm:
- # Check if the model's auto_map supports the VLM auto class.
- # Some VL models (e.g. Nemotron-VL) only register AutoModelForCausalLM
- # in their auto_map, not AutoModelForImageTextToText/AutoModelForVision2Seq.
- _auto_map = getattr(model_config, "auto_map", {}) or {}
- _vlm_class_name = AutoModelForVision2Seq.__name__
- if (
- "AutoModelForCausalLM" in _auto_map
- and _vlm_class_name not in _auto_map
- ):
- auto_model = AutoModelForCausalLM
- else:
- auto_model = AutoModelForVision2Seq
- else:
- auto_model = AutoModelForCausalLM
-
- load_in_4bit_kwargs = load_in_4bit
- load_in_8bit_kwargs = load_in_8bit
- if quantization_config is not None and not fast_inference:
- load_in_4bit_kwargs = False
- load_in_8bit_kwargs = False
+ auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
model, tokenizer = FastBaseModel.from_pretrained(
- model_name = model_name,
- max_seq_length = max_seq_length,
- dtype = _get_dtype(dtype),
- load_in_4bit = load_in_4bit_kwargs,
- load_in_8bit = load_in_8bit_kwargs,
- load_in_16bit = load_in_16bit,
- full_finetuning = full_finetuning,
- token = token,
- device_map = device_map,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = _get_dtype(dtype),
+ load_in_4bit = load_in_4bit,
+ load_in_8bit = load_in_8bit,
+ full_finetuning = full_finetuning,
+ token = token,
+ device_map = device_map,
trust_remote_code = trust_remote_code,
- revision = revision if not is_peft else None,
- model_types = model_types,
- tokenizer_name = tokenizer_name,
- auto_model = auto_model,
+ revision = revision if not is_peft else None,
+ model_types = model_types,
+ tokenizer_name = tokenizer_name,
+ auto_model = auto_model,
use_gradient_checkpointing = use_gradient_checkpointing,
- supports_sdpa = supports_sdpa,
- whisper_language = whisper_language,
- whisper_task = whisper_task,
- auto_config = model_config,
- offload_embedding = offload_embedding,
- float32_mixed_precision = float32_mixed_precision,
- # Pass vLLM/inference parameters
- fast_inference = fast_inference,
- gpu_memory_utilization = gpu_memory_utilization,
- float8_kv_cache = float8_kv_cache,
- random_state = random_state,
- max_lora_rank = max_lora_rank,
- disable_log_stats = disable_log_stats,
- load_in_fp8 = load_in_fp8,
- *args,
- **kwargs,
+ *args, **kwargs,
)
if resize_model_vocab is not None:
model.resize_token_embeddings(resize_model_vocab)
+ pass
# In case the model supports tagging, add the unsloth tag.
if hasattr(model, "add_model_tags"):
- model.add_model_tags(
- [
- "unsloth",
- ]
- )
+ model.add_model_tags(["unsloth",])
+ pass
if hasattr(tokenizer, "add_model_tags"):
- tokenizer.add_model_tags(
- [
- "unsloth",
- ]
- )
+ tokenizer.add_model_tags(["unsloth",])
+ pass
if load_in_4bit:
- # Fix up bitsandbytes config, but respect user-provided quantization_config
- if quantization_config is None:
- compute_dtype = dtype_from_config(model.config)
- quantization_config = {
- # Sometimes compute_dtype is not a string!!
- "bnb_4bit_compute_dtype": compute_dtype,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_use_double_quant": True,
- "llm_int8_enable_fp32_cpu_offload": False,
- "llm_int8_has_fp16_weight": False,
- "llm_int8_skip_modules": None,
- "llm_int8_threshold": 6.0,
- "load_in_4bit": True,
- "load_in_8bit": False,
- "quant_method": "bitsandbytes",
- }
- model.config.update({"quantization_config": quantization_config})
- else:
- if hasattr(quantization_config, "to_dict"):
- model.config.update(
- {"quantization_config": quantization_config.to_dict()}
- )
- elif isinstance(quantization_config, dict):
- model.config.update({"quantization_config": quantization_config})
-
- if load_in_fp8 != False:
- _tag_model_with_fp8_torchao_config(model, fp8_mode)
+ # Fix up bitsandbytes config
+ quantization_config = \
+ {
+ # Sometimes torch_dtype is not a string!!
+ "bnb_4bit_compute_dtype" : model.config.to_dict()["torch_dtype"],
+ "bnb_4bit_quant_type" : "nf4",
+ "bnb_4bit_use_double_quant" : True,
+ "llm_int8_enable_fp32_cpu_offload" : False,
+ "llm_int8_has_fp16_weight" : False,
+ "llm_int8_skip_modules" : None,
+ "llm_int8_threshold" : 6.0,
+ "load_in_4bit" : True,
+ "load_in_8bit" : False,
+ "quant_method" : "bitsandbytes",
+ }
+ model.config.update({"quantization_config" : quantization_config})
+ pass
if is_peft:
# From https://github.com/huggingface/peft/issues/184
@@ -1508,30 +723,14 @@ def from_pretrained(
trust_remote_code = trust_remote_code,
)
# Patch it as well!
- model = FastBaseModel.post_patch_model(
- model, use_gradient_checkpointing, trust_remote_code = trust_remote_code
- )
-
- # Apply QAT if specified
- if qat_scheme is not None:
- print("Unsloth: Applying QAT to mitigate quantization degradation")
- model = FastModel._prepare_for_qat(model, qat_scheme)
-
- # Patch Tiled MLP
- # to turn on set UNSLOTH_TILED_MLP to "arctic", "target", or "target:{GB}""
- patch_tiled_mlp_choice = os.environ.get(
- "UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0"
- )
- if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp:
- patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)
-
- model = _fix_rope_inv_freq(model)
+ model = FastBaseModel.post_patch_model(model, use_gradient_checkpointing)
+ pass
return model, tokenizer
-
+ pass
+pass
class FastVisionModel(FastModel):
pass
-
class FastTextModel(FastModel):
pass
diff --git a/unsloth/models/loader_utils.py b/unsloth/models/loader_utils.py
index cf5af983a6..e3eadd8c0f 100644
--- a/unsloth/models/loader_utils.py
+++ b/unsloth/models/loader_utils.py
@@ -12,141 +12,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ..device_type import DEVICE_TYPE_TORCH
-import importlib
-import os
-import torch
-import re
-import tempfile
-from typing import Union
-from .mapper import (
- INT_TO_FLOAT_MAPPER,
- FLOAT_TO_INT_MAPPER,
- MAP_TO_UNSLOTH_16bit,
- FLOAT_TO_FP8_BLOCK_MAPPER,
- FLOAT_TO_FP8_ROW_MAPPER,
-)
-
+from .mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit
# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
+from packaging.version import Version
from transformers import __version__ as transformers_version
-from unsloth.models._utils import TorchAOConfig
-from unsloth_zoo.utils import Version
-import gc
-
transformers_version = Version(transformers_version)
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
-LOCAL_RANK_KEYS = ("LOCAL_RANK", "RANK")
-WORLD_SIZE_KEYS = ("WORLD_SIZE",)
-
-BAD_MAPPINGS = {
- "unsloth/Qwen3-32B-unsloth-bnb-4bit".lower(): "unsloth/Qwen3-32B-bnb-4bit".lower(), # 32B dynamic quant is way too big
- "unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit".lower(): "unsloth/Qwen3-30B-A3B".lower(), # HF loads MoEs too slowly
- "unsloth/Qwen3-30B-A3B-bnb-4bit".lower(): "unsloth/Qwen3-30B-A3B".lower(), # We rather do it on the fly
- "unsloth/Qwen3-30B-A3B-Base-unsloth-bnb-4bit".lower(): "unsloth/Qwen3-30B-A3B-Base".lower(), # HF loads MoEs too slowly
- "unsloth/Qwen3-30B-A3B-Base-bnb-4bit".lower(): "unsloth/Qwen3-30B-A3B-Base".lower(), # We rather do it on the fly
-}
-
-
-def _get_torchao_fp8_config(fp8_mode):
- # Import lazily so an optional, broken vLLM install does not break plain `import unsloth`.
- from unsloth_zoo.vllm_utils import _get_torchao_fp8_config as _impl
-
- return _impl(fp8_mode)
-
-
-def _get_env_int(keys):
- for key in keys:
- value = os.environ.get(key)
- if value is None:
- continue
- try:
- return int(value)
- except ValueError:
- continue
- return None
-
-
-def _infer_distributed_ranks():
- if torch.distributed.is_available() and torch.distributed.is_initialized():
- try:
- return torch.distributed.get_rank(), torch.distributed.get_world_size()
- except Exception:
- pass
- return _get_env_int(LOCAL_RANK_KEYS), _get_env_int(WORLD_SIZE_KEYS)
-
-
-def is_distributed():
- rank, world_size = _infer_distributed_ranks()
- return (world_size or 1) > 1 or (rank is not None and rank > 0)
-
-
-def prepare_device_map():
- rank, world_size = _infer_distributed_ranks()
- distributed = (world_size or 1) > 1 or (rank is not None and rank > 0)
- if not distributed:
- return None, False
-
- local_rank = 0 if rank is None else rank
- device_map = {"": f"{DEVICE_TYPE_TORCH}:{local_rank}"}
- try:
- if DEVICE_TYPE_TORCH == "cuda":
- torch.cuda.set_device(local_rank)
- elif DEVICE_TYPE_TORCH == "xpu" and hasattr(torch, "xpu"):
- torch.xpu.set_device(local_rank)
- except Exception:
- pass
- return device_map, True
-
def __get_model_name(
model_name,
load_in_4bit = True,
- INT_TO_FLOAT_MAPPER = None,
- FLOAT_TO_INT_MAPPER = None,
+ INT_TO_FLOAT_MAPPER = None,
+ FLOAT_TO_INT_MAPPER = None,
MAP_TO_UNSLOTH_16bit = None,
- load_in_fp8 = False,
- FLOAT_TO_FP8_BLOCK_MAPPER = None,
- FLOAT_TO_FP8_ROW_MAPPER = None,
):
model_name = str(model_name)
lower_model_name = model_name.lower()
- assert load_in_fp8 in (True, False, "block")
- if load_in_fp8 != False:
- if load_in_fp8 == True and (os.environ.get("UNSLOTH_HAS_FBGEMM", "0") == "1"):
- if lower_model_name in FLOAT_TO_FP8_ROW_MAPPER:
- # Faster row scaling only works if FBGEMM works!
- return FLOAT_TO_FP8_ROW_MAPPER[lower_model_name]
- elif lower_model_name in FLOAT_TO_FP8_BLOCK_MAPPER:
- # Otherwise we use the slower blockwise type
- return FLOAT_TO_FP8_BLOCK_MAPPER[lower_model_name]
- else:
- if lower_model_name in FLOAT_TO_FP8_BLOCK_MAPPER:
- return FLOAT_TO_FP8_BLOCK_MAPPER[lower_model_name]
- # Mapper didn't find a pre-quantized model.
- # For vllm >= 0.12.0, we can quantize the model to FP8 on the fly,
- # so just return the original model name. Older vllm versions will
- # fall through to offline quantization via _offline_quantize_to_fp8.
- if importlib.util.find_spec("vllm") is not None:
- import vllm
-
- if Version(vllm.__version__) >= Version("0.12.0"):
- return model_name
- return None
+ if not SUPPORTS_FOURBIT and lower_model_name in INT_TO_FLOAT_MAPPER:
- elif not SUPPORTS_FOURBIT and lower_model_name in INT_TO_FLOAT_MAPPER:
model_name = INT_TO_FLOAT_MAPPER[lower_model_name]
print(
- f"Unsloth: Your transformers version of {transformers_version} does not support native "
- f"4bit loading.\nThe minimum required version is 4.37.\n"
- f'Try `pip install --upgrade "transformers>=4.37"`\n'
- f"to obtain the latest transformers build, then restart this session.\n"
+ f"Unsloth: Your transformers version of {transformers_version} does not support native "\
+ f"4bit loading.\nThe minimum required version is 4.37.\n"\
+ f'Try `pip install --upgrade "transformers>=4.37"`\n'\
+ f"to obtain the latest transformers build, then restart this session.\n"\
f"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading)."
)
return model_name
-
+
elif not load_in_4bit and lower_model_name in INT_TO_FLOAT_MAPPER:
+
new_model_name = INT_TO_FLOAT_MAPPER[lower_model_name]
# logger.warning_once(
# f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\
@@ -155,279 +52,74 @@ def __get_model_name(
return new_model_name
elif not load_in_4bit and lower_model_name in MAP_TO_UNSLOTH_16bit:
+
new_model_name = MAP_TO_UNSLOTH_16bit[lower_model_name]
return new_model_name
elif load_in_4bit and SUPPORTS_FOURBIT and lower_model_name in FLOAT_TO_INT_MAPPER:
+
# Support returning original full -bnb-4bit name if specified specifically
# since we'll map it to the dynamic version instead
if lower_model_name.endswith("-bnb-4bit"):
return lower_model_name
-
+
new_model_name = FLOAT_TO_INT_MAPPER[lower_model_name]
# logger.warning_once(
# f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\
# f"We shall load `{new_model_name}` for 4x faster loading."
# )
return new_model_name
+ pass
return None
+pass
def _get_new_mapper():
try:
import requests
-
new_mapper = "https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/models/mapper.py"
- with requests.get(new_mapper, timeout = 3) as new_mapper:
- new_mapper = new_mapper.text
- new_mapper = new_mapper[new_mapper.find("__INT_TO_FLOAT_MAPPER") :]
- new_mapper = (
- new_mapper.replace("INT_TO_FLOAT_MAPPER", "NEW_INT_TO_FLOAT_MAPPER")
- .replace("FLOAT_TO_INT_MAPPER", "NEW_FLOAT_TO_INT_MAPPER")
+ with requests.get(new_mapper, timeout = 3) as new_mapper: new_mapper = new_mapper.text
+ new_mapper = new_mapper[new_mapper.find("__INT_TO_FLOAT_MAPPER"):]
+ new_mapper = new_mapper\
+ .replace("INT_TO_FLOAT_MAPPER", "NEW_INT_TO_FLOAT_MAPPER")\
+ .replace("FLOAT_TO_INT_MAPPER", "NEW_FLOAT_TO_INT_MAPPER")\
.replace("MAP_TO_UNSLOTH_16bit", "NEW_MAP_TO_UNSLOTH_16bit")
- )
exec(new_mapper, globals())
- return (
- NEW_INT_TO_FLOAT_MAPPER,
- NEW_FLOAT_TO_INT_MAPPER,
- NEW_MAP_TO_UNSLOTH_16bit,
- )
+ return NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit
except:
return {}, {}, {}
+ pass
+pass
-def _resolve_with_mappers(
- model_name,
- load_in_4bit,
- load_in_fp8,
- int_to_float,
- float_to_int,
- map_to_unsloth_16bit,
-):
- return __get_model_name(
+def get_model_name(model_name, load_in_4bit = True):
+ new_model_name = __get_model_name(
model_name = model_name,
load_in_4bit = load_in_4bit,
- INT_TO_FLOAT_MAPPER = int_to_float,
- FLOAT_TO_INT_MAPPER = float_to_int,
- MAP_TO_UNSLOTH_16bit = map_to_unsloth_16bit,
- load_in_fp8 = load_in_fp8,
- FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,
- FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,
+ INT_TO_FLOAT_MAPPER = INT_TO_FLOAT_MAPPER,
+ FLOAT_TO_INT_MAPPER = FLOAT_TO_INT_MAPPER,
+ MAP_TO_UNSLOTH_16bit = MAP_TO_UNSLOTH_16bit,
)
-
-
-def get_model_name(
- model_name,
- load_in_4bit = True,
- load_in_fp8 = False,
- token = None,
- trust_remote_code = False,
-):
- assert load_in_fp8 in (True, False, "block")
- new_model_name = _resolve_with_mappers(
- model_name = model_name,
- load_in_4bit = load_in_4bit,
- load_in_fp8 = load_in_fp8,
- int_to_float = INT_TO_FLOAT_MAPPER,
- float_to_int = FLOAT_TO_INT_MAPPER,
- map_to_unsloth_16bit = MAP_TO_UNSLOTH_16bit,
- )
- # In the rare case, we convert bad model names to other names
- # For eg too large dynamic quants or MoEs
- if (
- new_model_name is not None
- and type(new_model_name) is str
- and new_model_name.lower() in BAD_MAPPINGS
- ):
- new_model_name = BAD_MAPPINGS[new_model_name.lower()]
-
- if (
- new_model_name is None
- and model_name.count("/") == 1
- and model_name[0].isalnum()
- ):
+ if new_model_name is None and model_name.count("/") == 1 and model_name[0].isalnum():
# Try checking if a new Unsloth version allows it!
- NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = (
- _get_new_mapper()
- )
- upgraded_model_name = _resolve_with_mappers(
+ NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = _get_new_mapper()
+ upgraded_model_name = __get_model_name(
model_name = model_name,
load_in_4bit = load_in_4bit,
- load_in_fp8 = load_in_fp8,
- int_to_float = NEW_INT_TO_FLOAT_MAPPER,
- float_to_int = NEW_FLOAT_TO_INT_MAPPER,
- map_to_unsloth_16bit = NEW_MAP_TO_UNSLOTH_16bit,
+ INT_TO_FLOAT_MAPPER = NEW_INT_TO_FLOAT_MAPPER,
+ FLOAT_TO_INT_MAPPER = NEW_FLOAT_TO_INT_MAPPER,
+ MAP_TO_UNSLOTH_16bit = NEW_MAP_TO_UNSLOTH_16bit,
)
if upgraded_model_name is not None:
raise NotImplementedError(
- f"Unsloth: {model_name} is not supported in your current Unsloth version! Please update Unsloth via:\n\n"
- "pip uninstall unsloth unsloth_zoo -y\n"
- 'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
- 'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'
+ f"Unsloth: {model_name} is not supported in your current Unsloth version! Please update Unsloth via:\n\n"\
+ 'pip uninstall unsloth unsloth_zoo -y\n'\
+ 'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'\
+ 'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'\
)
-
- if new_model_name is None:
- new_model_name = model_name
-
- return new_model_name
-
-
-def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str:
- """
- Quantizes the model to fp8 using torchao and saving the quantized model to a
- temporary location. Return the path to the quantized model.
-
- Note: For vllm >= 0.12.0, we should dynamically quantize the model in vllm instead:
-
- llm = LLM(
- ...
- hf_overrides={"quantization_config_file": "torchao_config.json"},
- )
- """
- temp_dir = tempfile.gettempdir()
- new_model_name = model_name.split("/")[-1] + "-fp8-" + fp8_mode
- new_model_name = os.path.join(temp_dir, new_model_name)
- print(
- f"Unsloth: Quantizing '{model_name}' to fp8, using model_name='{new_model_name}' instead"
- )
-
- if not os.path.isdir(new_model_name):
- from transformers import (
- AutoModelForCausalLM,
- AutoModelForImageTextToText,
- AutoTokenizer,
- AutoProcessor,
- TorchAoConfig,
- AutoConfig,
- )
-
- qconfig = _get_torchao_fp8_config(fp8_mode)
- qconfig = TorchAoConfig(qconfig)
- config = AutoConfig.from_pretrained(model_name)
- is_vlm = any(
- x.endswith(("ForConditionalGeneration", "ForVisionText2Text"))
- for x in config.architectures
- )
- is_vlm = is_vlm or hasattr(config, "vision_config")
- auto_model = AutoModelForImageTextToText if is_vlm else AutoModelForCausalLM
- auto_processor = AutoProcessor if is_vlm else AutoTokenizer
- model = auto_model.from_pretrained(
- model_name,
- torch_dtype = "auto",
- device_map = "auto",
- quantization_config = qconfig,
- )
- tokenizer = auto_processor.from_pretrained(model_name)
- model.save_pretrained(new_model_name, safe_serialization = False)
- del model
- for _ in range(2):
- torch.cuda.empty_cache()
- gc.collect()
- tokenizer.save_pretrained(new_model_name)
- return new_model_name
-
-
-def _tag_model_with_fp8_torchao_config(model: torch.nn.Module, fp8_mode: str):
- """
- Tag a model with a `TorchAOConfig` so downstream callers will know what to do with it.
- """
- try:
- base_config = _get_torchao_fp8_config(fp8_mode)
- model.torchao_config = TorchAOConfig(
- qat_scheme = None,
- base_config_and_filter_fns = [(base_config, None)],
- )
- except:
pass
-
-
-def _get_fp8_mode_and_check_settings(
- load_in_fp8: Union[bool, str],
- fast_inference: bool,
- full_finetuning: bool = False,
- load_in_4bit: bool = False,
- load_in_8bit: bool = False,
- load_in_16bit: bool = False,
-) -> str:
- """
- Assuming `load_in_fp8` is enabled, raise appropriate errors on incompatible settings
- and environment. Currently this feature requires:
-
- 1. H100 GPUs or after
- 2. torchao 0.15.0+ (or nightly)
- 3. torch 2.9.0+
- 4. If fbgemm_gpu_genai is installed, require 1.4.1+
-
- Returns the fp8 mode, one of "row" or "block".
- """
- assert load_in_fp8 is not False
- if load_in_fp8 is True:
- fp8_mode = "row" # default
- else:
- fp8_mode = load_in_fp8
-
- # Check user settings
- if fp8_mode not in ["row", "block"]:
- raise ValueError(
- f"Unsloth: `load_in_fp8` can only be 'row' or 'block', got '{fp8_mode}'"
- )
- if full_finetuning:
- raise ValueError(
- "Unsloth: `load_in_fp8` is not compatible with full finetuning"
- )
- if load_in_4bit or load_in_8bit or load_in_16bit:
- raise ValueError(
- "Unsloth: `load_in_fp8` is not compatible with `load_in_4bit`, `load_in_8bit` or `load_in_16bit`",
- )
-
- # Check if this is Hopper or above
- if not (
- torch.cuda.is_available()
- and torch.version.cuda
- and torch.cuda.get_device_capability() >= (9, 0)
- ):
- raise ValueError(
- "Unsloth: On the fly `load_in_fp8` requires H100 GPUs or after. Try `unsloth/Qwen3-8B` instead."
- )
-
- # Check if torch >= 2.9.0
- if Version(torch.__version__) < Version("2.9.0"):
- raise ValueError(
- "Unsloth: On the fly `load_in_fp8` requires torch 2.9.0+. Try `unsloth/Qwen3-8B` instead."
- )
-
- # Check if torchao has this PR: https://github.com/pytorch/ao/pull/3158,
- # which will be released in 0.15.0.
- if importlib.util.find_spec("torchao") is None:
- raise ValueError(
- "Unsloth: Please install torchao for on the fly float8 to work! Try `unsloth/Qwen3-8B` instead."
- )
- import torchao
-
- error_message = (
- "Unsloth: `load_in_fp8` requires torchao 0.15.0+ (or nightly).\n"
- f"You have torchao version={torchao.__version__}\n"
- "Use `pip install --upgrade --force-reinstall torchao`"
- )
- if Version(torchao.__version__) < Version("0.15.0"):
- raise ValueError(error_message)
-
- # If fbgemm_gpu_genai is installed and old, disable FBGEMM and use Triton instead
- if (
- importlib.util.find_spec("fbgemm_gpu") is not None
- and importlib.util.find_spec("fbgemm_gpu.experimental") is not None
- ):
- import fbgemm_gpu.experimental.gen_ai
-
- if Version(fbgemm_gpu.__version__) < Version("1.4.1"):
- # Old FBGEMM version - disable and use Triton kernels instead
- os.environ["UNSLOTH_HAS_FBGEMM"] = "0"
- from unsloth_zoo.log import logger
-
- logger.info(
- f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu.__version__} is old for FP8 loading. "
- f"Using Triton kernels instead."
- )
- return fp8_mode
+ pass
+ return new_model_name if new_model_name is not None else model_name
+pass
diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py
index f0f430eb7e..9af5317986 100644
--- a/unsloth/models/mapper.py
+++ b/unsloth/models/mapper.py
@@ -15,9 +15,6 @@
__all__ = [
"INT_TO_FLOAT_MAPPER",
"FLOAT_TO_INT_MAPPER",
- "MAP_TO_UNSLOTH_16bit",
- "FLOAT_TO_FP8_BLOCK_MAPPER",
- "FLOAT_TO_FP8_ROW_MAPPER",
]
__INT_TO_FLOAT_MAPPER = \
@@ -117,7 +114,7 @@
"unsloth/gemma-1.1-7b-it",
"google/gemma-1.1-7b-it",
),
- "unsloth/Starling-LM-7B-beta" : (
+ "unsloth/Starling-LM-7B-beta-bnb-4bit" : (
"unsloth/Starling-LM-7B-beta",
"Nexusflow/Starling-LM-7B-beta",
),
@@ -236,35 +233,21 @@
"meta-llama/Meta-Llama-3.1-8B",
"unsloth/Meta-Llama-3.1-8B-bnb-4bit",
),
- "unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit" : {
- "8" : (
- "RedHatAI/Llama-3.1-8B-Instruct-FP8",
- "unsloth/Llama-3.1-8B-Instruct-FP8-Block",
- "unsloth/Llama-3.1-8B-Instruct-FP8-Dynamic",
- ),
- "16" : (
- "unsloth/Meta-Llama-3.1-8B-Instruct",
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
- "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
- ),
- },
+ "unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit" : (
+ "unsloth/Meta-Llama-3.1-8B-Instruct",
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
+ "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
+ ),
"unsloth/Llama-3.1-8B-unsloth-bnb-4bit" : (
"unsloth/Llama-3.1-8B",
"meta-llama/Llama-3.1-8B",
"unsloth/Llama-3.1-8B-bnb-4bit",
),
- "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit" : {
- "8" : (
- "RedHatAI/Llama-3.1-8B-Instruct-FP8",
- "unsloth/Llama-3.1-8B-Instruct-FP8-Block",
- "unsloth/Llama-3.1-8B-Instruct-FP8-Dynamic",
- ),
- "16" : (
- "unsloth/Llama-3.1-8B-Instruct",
- "meta-llama/Llama-3.1-8B-Instruct",
- "unsloth/Llama-3.1-8B-Instruct-bnb-4bit",
- ),
- },
+ "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit" : (
+ "unsloth/Llama-3.1-8B-Instruct",
+ "meta-llama/Llama-3.1-8B-Instruct",
+ "unsloth/Llama-3.1-8B-Instruct-bnb-4bit",
+ ),
"unsloth/Meta-Llama-3.1-70B-bnb-4bit" : (
"unsloth/Meta-Llama-3.1-70B",
"meta-llama/Meta-Llama-3.1-70B",
@@ -491,30 +474,16 @@
"meta-llama/Llama-3.2-3B",
"unsloth/Llama-3.2-3B-bnb-4bit",
),
- "unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit" : {
- "8": (
- "RedHatAI/Llama-3.2-1B-Instruct-FP8",
- "unsloth/Llama-3.2-1B-Instruct-FP8-Block",
- "unsloth/Llama-3.2-1B-Instruct-FP8-Dynamic",
- ),
- "16" : (
- "unsloth/Llama-3.2-1B-Instruct",
- "meta-llama/Llama-3.2-1B-Instruct",
- "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
- ),
- },
- "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit" : {
- "8": (
- "RedHatAI/Llama-3.2-3B-Instruct-FP8",
- "unsloth/Llama-3.2-3B-Instruct-FP8-Block",
- "unsloth/Llama-3.2-3B-Instruct-FP8-Dynamic",
- ),
- "16" : (
- "unsloth/Llama-3.2-3B-Instruct",
- "meta-llama/Llama-3.2-3B-Instruct",
- "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
- ),
- },
+ "unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit" : (
+ "unsloth/Llama-3.2-1B-Instruct",
+ "meta-llama/Llama-3.2-1B-Instruct",
+ "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
+ ),
+ "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit" : (
+ "unsloth/Llama-3.2-3B-Instruct",
+ "meta-llama/Llama-3.2-3B-Instruct",
+ "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
+ ),
"unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit" : (
"unsloth/Llama-3.1-Nemotron-70B-Instruct",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
@@ -592,18 +561,10 @@
"unsloth/QwQ-32B-Preview",
"Qwen/QwQ-32B-Preview",
),
- "unsloth/Llama-3.3-70B-Instruct-unsloth-bnb-4bit" : {
- "8" : (
- "RedHatAI/Llama-3.3-70B-Instruct-FP8",
- "unsloth/Llama-3.3-70B-Instruct-FP8-Block",
- "unsloth/Llama-3.3-70B-Instruct-FP8-Dynamic",
- ),
- "16" : (
- "unsloth/Llama-3.3-70B-Instruct",
- "meta-llama/Llama-3.3-70B-Instruct",
- "unsloth/Llama-3.3-70B-Instruct-bnb-4bit",
- ),
- },
+ "unsloth/Llama-3.3-70B-Instruct-bnb-4bit" : (
+ "unsloth/Llama-3.3-70B-Instruct",
+ "meta-llama/Llama-3.3-70B-Instruct",
+ ),
"unsloth/phi-4-unsloth-bnb-4bit" : (
"unsloth/phi-4",
"microsoft/phi-4",
@@ -657,11 +618,6 @@
"Qwen/Qwen2.5-VL-7B-Instruct",
"unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit",
),
- "unsloth/Qwen2.5-VL-32B-Instruct-unsloth-bnb-4bit" : (
- "unsloth/Qwen2.5-VL-32B-Instruct",
- "Qwen/Qwen2.5-VL-32B-Instruct",
- "unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit",
- ),
"unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit" : (
"unsloth/Qwen2.5-VL-72B-Instruct",
"Qwen/Qwen2.5-VL-72B-Instruct",
@@ -762,653 +718,34 @@
"allenai/OLMo-2-0325-32B-Instruct",
"unsloth/OLMo-2-0325-32B-Instruct-bnb-4bit",
),
- "unsloth/Mistral-Small-3.1-24B-Instruct-2503-unsloth-bnb-4bit" : (
- "unsloth/Mistral-Small-3.1-24B-Instruct-2503",
- "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
- "unsloth/Mistral-Small-3.1-24B-Instruct-2503-bnb-4bit",
- ),
- "unsloth/Mistral-Small-3.1-24B-Base-2503-unsloth-bnb-4bit" : (
- "unsloth/Mistral-Small-3.1-24B-Base-2503",
- "mistralai/Mistral-Small-3.1-24B-Base-2503",
- "unsloth/Mistral-Small-3.1-24B-Base-2503-bnb-4bit",
- ),
- "unsloth/Qwen3-0.6B-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-0.6B-FP8",
- "unsloth/Qwen3-0.6B-FP8",
- "unsloth/Qwen3-0.6B-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-0.6B",
- "Qwen/Qwen3-0.6B",
- "unsloth/Qwen3-0.6B-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-1.7B-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-1.7B-FP8",
- "unsloth/Qwen3-1.7B-FP8",
- "unsloth/Qwen3-1.7B-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-1.7B",
- "Qwen/Qwen3-1.7B",
- "unsloth/Qwen3-1.7B-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-4B-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-4B-FP8",
- "unsloth/Qwen3-4B-FP8",
- "unsloth/Qwen3-4B-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-4B",
- "Qwen/Qwen3-4B",
- "unsloth/Qwen3-4B-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-8B-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-8B-FP8",
- "unsloth/Qwen3-8B-FP8",
- "unsloth/Qwen3-8B-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-8B",
- "Qwen/Qwen3-8B",
- "unsloth/Qwen3-8B-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-14B-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-14B-FP8",
- "unsloth/Qwen3-14B-FP8",
- "unsloth/Qwen3-14B-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-14B",
- "Qwen/Qwen3-14B",
- "unsloth/Qwen3-14B-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-32B-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-32B-FP8",
- "unsloth/Qwen3-32B-FP8",
- "unsloth/Qwen3-32B-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-32B",
- "Qwen/Qwen3-32B",
- "unsloth/Qwen3-32B-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit" : (
- "unsloth/Qwen3-30B-A3B",
- "Qwen/Qwen3-30B-A3B",
- "unsloth/Qwen3-30B-A3B-bnb-4bit",
- ),
- "unsloth/Qwen3-0.6B-Base-unsloth-bnb-4bit" : (
- "unsloth/Qwen3-0.6B-Base",
- "Qwen/Qwen3-0.6B-Base",
- "unsloth/Qwen3-0.6B-Base-bnb-4bit",
- ),
- "unsloth/Qwen3-1.7B-Base-unsloth-bnb-4bit" : (
- "unsloth/Qwen3-1.7B-Base",
- "Qwen/Qwen3-1.7B-Base",
- "unsloth/Qwen3-1.7B-Base-bnb-4bit",
- ),
- "unsloth/Qwen3-4B-Base-unsloth-bnb-4bit" : (
- "unsloth/Qwen3-4B-Base",
- "Qwen/Qwen3-4B-Base",
- "unsloth/Qwen3-4B-Base-bnb-4bit",
- ),
- "unsloth/Qwen3-8B-Base-unsloth-bnb-4bit" : (
- "unsloth/Qwen3-8B-Base",
- "Qwen/Qwen3-8B-Base",
- "unsloth/Qwen3-8B-Base-bnb-4bit",
- ),
- "unsloth/Qwen3-14B-Base-unsloth-bnb-4bit" : (
- "unsloth/Qwen3-14B-Base",
- "Qwen/Qwen3-14B-Base",
- "unsloth/Qwen3-14B-Base-bnb-4bit",
- ),
- "unsloth/Qwen3-30B-A3B-Base-bnb-4bit" : (
- "unsloth/Qwen3-30B-A3B-Base",
- "Qwen/Qwen3-30B-A3B-Base",
- ),
- "unsloth/phi-4-reasoning-unsloth-bnb-4bit" : (
- "unsloth/phi-4-reasoning",
- "microsoft/Phi-4-reasoning",
- "unsloth/phi-4-reasoning-bnb-4bit",
- ),
- "unsloth/phi-4-reasoning-plus-unsloth-bnb-4bit" : (
- "unsloth/phi-4-reasoning-plus",
- "microsoft/Phi-4-reasoning-plus",
- "unsloth/phi-4-reasoning-plus-bnb-4bit",
- ),
- "unsloth/phi-4-mini-reasoning-unsloth-bnb-4bit" : (
- "unsloth/phi-4-mini-reasoning",
- "microsoft/Phi-4-mini-reasoning",
- "unsloth/phi-4-mini-reasoning-bnb-4bit",
- ),
- "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit" : (
- "unsloth/Phi-4-mini-instruct",
- "microsoft/Phi-4-mini-instruct",
- "unsloth/Phi-4-mini-instruct-bnb-4bit",
- ),
- "unsloth/orpheus-3b-0.1-pretrained-unsloth-bnb-4bit" : (
- "unsloth/orpheus-3b-0.1-pretrained",
- "canopylabs/orpheus-3b-0.1-pretrained",
- "unsloth/orpheus-3b-0.1-pretrained-bnb-4bit",
- ),
- "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit" : (
- "unsloth/orpheus-3b-0.1-ft",
- "canopylabs/orpheus-3b-0.1-ft",
- "unsloth/orpheus-3b-0.1-ft-bnb-4bit",
- ),
- "unsloth/csm-1b" : (
- "unsloth/csm-1b",
- "sesame/csm-1b",
- ),
- "unsloth/whisper-large-v3" : (
- "unsloth/whisper-large-v3",
- "openai/whisper-large-v3",
- ),
- "unsloth/whisper-large-v3-turbo" : (
- "unsloth/whisper-large-v3-turbo",
- "openai/whisper-large-v3-turbo",
- ),
- "unsloth/whisper-small" : (
- "unsloth/whisper-small",
- "openai/whisper-small",
- ),
- "unsloth/CrisperWhisper" : (
- "unsloth/CrisperWhisper",
- "nyrahealth/CrisperWhisper",
- ),
- "unsloth/Llasa-1B" : (
- "unsloth/Llasa-1B",
- "HKUSTAudio/Llasa-1B",
- ),
- "unsloth/Spark-TTS-0.5B" : (
- "unsloth/Spark-TTS-0.5B",
- "SparkAudio/Spark-TTS-0.5B",
- ),
- "unsloth/Llama-OuteTTS-1.0-1B" : (
- "unsloth/Llama-OuteTTS-1.0-1B",
- "OuteAI/Llama-OuteTTS-1.0-1B",
- ),
- "unsloth/medgemma-4b-it-unsloth-bnb-4bit" : (
- "unsloth/medgemma-4b-it",
- "google/medgemma-4b-it",
- "unsloth/medgemma-4b-it-bnb-4bit",
- ),
- "unsloth/medgemma-27b-text-it-unsloth-bnb-4bit" : (
- "unsloth/medgemma-27b-text-it",
- "google/medgemma-27b-text-it",
- "unsloth/medgemma-27b-text-it-bnb-4bit",
- ),
- "unsloth/Devstral-Small-2505-unsloth-bnb-4bit" : (
- "unsloth/Devstral-Small-2505",
- "mistralai/Devstral-Small-2505",
- "unsloth/Devstral-Small-2505-bnb-4bit",
- ),
- "unsloth/DeepSeek-R1-0528-Qwen3-8B-unsloth-bnb-4bit" : (
- "unsloth/DeepSeek-R1-0528-Qwen3-8B",
- "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
- "unsloth/DeepSeek-R1-0528-Qwen3-8B-bnb-4bit",
- ),
- "unsloth/Magistral-Small-2506-unsloth-bnb-4bit" : (
- "unsloth/Magistral-Small-2506",
- "mistralai/Magistral-Small-2506",
- "unsloth/Magistral-Small-2506-bnb-4bit",
- ),
- "unsloth/Mistral-Small-3.2-24B-Instruct-2506-unsloth-bnb-4bit" : {
- "8" : (
- "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
- "unsloth/Mistral-Small-3.2-24B-Instruct-2506-FP8",
- "unsloth/Mistral-Small-3.2-24B-Instruct-2506-FP8",
- ),
- "16" : (
- "unsloth/Mistral-Small-3.2-24B-Instruct-2506",
- "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
- "unsloth/Mistral-Small-3.2-24B-Instruct-2506-bnb-4bit",
- ),
- },
- "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit" : (
- "unsloth/gemma-3n-E4B-it",
- "google/gemma-3n-E4B-it",
- "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
- ),
- "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit" : (
- "unsloth/gemma-3n-E2B-it",
- "google/gemma-3n-E2B-it",
- "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
- ),
- "unsloth/gemma-3n-E4B-unsloth-bnb-4bit" : (
- "unsloth/gemma-3n-E4B",
- "google/gemma-3n-E4B",
- "unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
- ),
- "unsloth/gemma-3n-E2B-unsloth-bnb-4bit" : (
- "unsloth/gemma-3n-E2B",
- "google/gemma-3n-E2B",
- "unsloth/gemma-3n-E2B-unsloth-bnb-4bit",
- ),
- "unsloth/Devstral-Small-2507-unsloth-bnb-4bit" : (
- "unsloth/Devstral-Small-2507",
- "mistralai/Devstral-Small-2507",
- "unsloth/Devstral-Small-2507-bnb-4bit",
- ),
- "unsloth/Qwen3-30B-A3B-Thinking-2507" : (
- "unsloth/Qwen3-30B-A3B-Thinking-2507",
- "Qwen/Qwen3-30B-A3B-Thinking-2507",
- ),
- "unsloth/Qwen3-30B-A3B-Instruct-2507" : (
- "unsloth/Qwen3-30B-A3B-Instruct-2507",
- "Qwen/Qwen3-30B-A3B-Instruct-2507",
- ),
- "unsloth/Qwen3-Coder-30B-A3B-Instruct" : (
- "unsloth/Qwen3-Coder-30B-A3B-Instruct",
- "Qwen/Qwen3-Coder-30B-A3B-Instruct",
- ),
- "unsloth/gpt-oss-20b-unsloth-bnb-4bit" : (
- "unsloth/gpt-oss-20b",
- "openai/gpt-oss-20b",
- "unsloth/gpt-oss-20b-unsloth-bnb-4bit",
- ),
- "unsloth/gpt-oss-120b-unsloth-bnb-4bit" : (
- "unsloth/gpt-oss-120b",
- "openai/gpt-oss-120b",
- "unsloth/gpt-oss-120b-unsloth-bnb-4bit",
- ),
- "unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-4B-Instruct-2507-FP8",
- "unsloth/Qwen3-4B-Instruct-2507-FP8",
- "unsloth/Qwen3-4B-Instruct-2507-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-4B-Instruct-2507",
- "Qwen/Qwen3-4B-Instruct-2507",
- "unsloth/Qwen3-4B-Instruct-2507-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-4B-Thinking-2507-FP8",
- "unsloth/Qwen3-4B-Thinking-2507-FP8",
- "unsloth/Qwen3-4B-Thinking-2507-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-4B-Thinking-2507",
- "Qwen/Qwen3-4B-Thinking-2507",
- "unsloth/Qwen3-4B-Thinking-2507-bnb-4bit",
- ),
- },
- "unsloth/gemma-3-270m-it-unsloth-bnb-4bit" : (
- "unsloth/gemma-3-270m-it",
- "google/gemma-3-270m-it",
- "unsloth/gemma-3-270m-it-bnb-4bit",
- ),
- "unsloth/gemma-3-270m-unsloth-bnb-4bit" : (
- "unsloth/gemma-3-270m",
- "google/gemma-3-270m",
- "unsloth/gemma-3-270m-bnb-4bit",
- ),
- "unsloth/Magistral-Small-2507-unsloth-bnb-4bit" : (
- "unsloth/Magistral-Small-2507",
- "mistralai/Magistral-Small-2507",
- "unsloth/Magistral-Small-2507-bnb-4bit",
- ),
- "unsloth/Magistral-Small-2509-unsloth-bnb-4bit" : {
- "8" : (
- "mistralai/Magistral-Small-2509",
- "unsloth/Magistral-Small-2509-FP8-Dynamic",
- "unsloth/Magistral-Small-2509-FP8-Dynamic",
- ),
- "16" : (
- "unsloth/Magistral-Small-2509",
- "mistralai/Magistral-Small-2509",
- "unsloth/Magistral-Small-2509-bnb-4bit",
- ),
- },
- "unsloth/Apertus-70B-Instruct-2509-unsloth-bnb-4bit" : (
- "unsloth/Apertus-70B-Instruct-2509",
- "swiss-ai/Apertus-70B-2509",
- "unsloth/Apertus-70B-Instruct-2509-unsloth-bnb-4bit",
- ),
- "unsloth/Apertus-8B-Instruct-2509-unsloth-bnb-4bit" : (
- "unsloth/Apertus-8B-Instruct-2509",
- "swiss-ai/Apertus-8B-2509",
- "unsloth/Apertus-8B-Instruct-2509-unsloth-bnb-4bit",
- ),
- "unsloth/granite-4.0-micro-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-micro",
- "ibm-granite/granite-4.0-micro",
- "unsloth/granite-4.0-micro-bnb-4bit",
- ),
- "unsloth/granite-4.0-h-micro-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-h-micro",
- "ibm-granite/granite-4.0-h-micro",
- "unsloth/granite-4.0-h-micro-bnb-4bit",
- ),
- "unsloth/granite-4.0-micro-base-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-micro-base",
- "ibm-granite/granite-4.0-micro-base",
- "unsloth/granite-4.0-micro-base-bnb-4bit",
- ),
- "unsloth/granite-4.0-h-micro-base-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-h-micro-base",
- "ibm-granite/granite-4.0-h-micro-base",
- "unsloth/granite-4.0-h-micro-base-bnb-4bit",
- ),
- "unsloth/granite-4.0-h-tiny" : (
- "unsloth/granite-4.0-h-tiny",
- "ibm-granite/granite-4.0-h-tiny",
- ),
- "unsloth/granite-4.0-h-small" : (
- "unsloth/granite-4.0-h-small",
- "ibm-granite/granite-4.0-h-small",
- ),
- "unsloth/granite-4.0-h-tiny-base" : (
- "unsloth/granite-4.0-h-tiny-base",
- "ibm-granite/granite-4.0-h-tiny-base",
- ),
- "unsloth/granite-4.0-h-small-base" : (
- "unsloth/granite-4.0-h-small-base",
- "ibm-granite/granite-4.0-h-small-base",
- ),
- "unsloth/Qwen3-VL-4B-Thinking-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-VL-4B-Thinking-FP8",
- "unsloth/Qwen3-VL-4B-Thinking-FP8",
- "unsloth/Qwen3-VL-4B-Thinking-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-VL-4B-Thinking",
- "Qwen/Qwen3-VL-4B-Thinking",
- "unsloth/Qwen3-VL-4B-Thinking-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-VL-8B-Thinking-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-VL-8B-Thinking-FP8",
- "unsloth/Qwen3-VL-8B-Thinking-FP8",
- "unsloth/Qwen3-VL-8B-Thinking-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-VL-8B-Thinking",
- "Qwen/Qwen3-VL-8B-Thinking",
- "unsloth/Qwen3-VL-8B-Thinking-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-VL-4B-Instruct-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-VL-4B-Instruct-FP8",
- "unsloth/Qwen3-VL-4B-Instruct-FP8",
- "unsloth/Qwen3-VL-4B-Instruct-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-VL-4B-Instruct",
- "Qwen/Qwen3-VL-4B-Instruct",
- "unsloth/Qwen3-VL-4B-Instruct-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-VL-8B-Instruct-FP8",
- "unsloth/Qwen3-VL-8B-Instruct-FP8",
- "unsloth/Qwen3-VL-8B-Instruct-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-VL-8B-Instruct",
- "Qwen/Qwen3-VL-8B-Instruct",
- "unsloth/Qwen3-VL-8B-Instruct-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-VL-2B-Thinking-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-VL-2B-Thinking-FP8",
- "unsloth/Qwen3-VL-2B-Thinking-FP8",
- "unsloth/Qwen3-VL-2B-Thinking-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-VL-2B-Thinking",
- "Qwen/Qwen3-VL-2B-Thinking",
- "unsloth/Qwen3-VL-2B-Thinking-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-VL-32B-Thinking-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-VL-32B-Thinking-FP8",
- "unsloth/Qwen3-VL-32B-Thinking-FP8",
- "unsloth/Qwen3-VL-32B-Thinking-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-VL-32B-Thinking",
- "Qwen/Qwen3-VL-32B-Thinking",
- "unsloth/Qwen3-VL-32B-Thinking-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-VL-2B-Instruct-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-VL-2B-Instruct-FP8",
- "unsloth/Qwen3-VL-2B-Instruct-FP8",
- "unsloth/Qwen3-VL-2B-Instruct-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-VL-2B-Instruct",
- "Qwen/Qwen3-VL-2B-Instruct",
- "unsloth/Qwen3-VL-2B-Instruct-bnb-4bit",
- ),
- },
- "unsloth/Qwen3-VL-32B-Instruct-unsloth-bnb-4bit" : {
- "8" : (
- "Qwen/Qwen3-VL-32B-Instruct-FP8",
- "unsloth/Qwen3-VL-32B-Instruct-FP8",
- "unsloth/Qwen3-VL-32B-Instruct-FP8",
- ),
- "16" : (
- "unsloth/Qwen3-VL-32B-Instruct",
- "Qwen/Qwen3-VL-32B-Instruct",
- "unsloth/Qwen3-VL-32B-Instruct-bnb-4bit",
- ),
- },
- "unsloth/granite-4.0-350m-base-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-350m-base",
- "ibm-granite/granite-4.0-350m-base",
- "unsloth/granite-4.0-350m-base-bnb-4bit",
- ),
- "unsloth/granite-4.0-350m-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-350m",
- "ibm-granite/granite-4.0-350m",
- "unsloth/granite-4.0-350m-bnb-4bit",
- ),
- "unsloth/granite-4.0-h-350m-base-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-h-350m-base",
- "ibm-granite/granite-4.0-h-350m-base",
- "unsloth/granite-4.0-h-350m-base-bnb-4bit",
- ),
- "unsloth/granite-4.0-h-350m-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-h-350m",
- "ibm-granite/granite-4.0-h-350m",
- "unsloth/granite-4.0-h-350m-bnb-4bit",
- ),
- "unsloth/granite-4.0-1b-base-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-1b-base",
- "ibm-granite/granite-4.0-1b-base",
- "unsloth/granite-4.0-1b-base-bnb-4bit",
- ),
- "unsloth/granite-4.0-1b-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-1b",
- "ibm-granite/granite-4.0-1b",
- "unsloth/granite-4.0-1b-bnb-4bit",
- ),
- "unsloth/granite-4.0-h-1b-base-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-h-1b-base",
- "ibm-granite/granite-4.0-h-1b-base",
- "unsloth/granite-4.0-h-1b-base-bnb-4bit",
- ),
- "unsloth/granite-4.0-h-1b-unsloth-bnb-4bit" : (
- "unsloth/granite-4.0-h-1b",
- "ibm-granite/granite-4.0-h-1b",
- "unsloth/granite-4.0-h-1b-bnb-4bit",
- ),
- "unsloth/gpt-oss-safeguard-20b" : (
- "unsloth/gpt-oss-safeguard-20b",
- "openai/gpt-oss-safeguard-20b",
- ),
- "unsloth/gpt-oss-safeguard-120b" : (
- "unsloth/gpt-oss-safeguard-120b",
- "openai/gpt-oss-safeguard-120b",
- ),
- "unsloth/functiongemma-270m-it-unsloth-bnb-4bit" : (
- "unsloth/functiongemma-270m-it",
- "google/functiongemma-270m-it",
- "unsloth/functiongemma-270m-it-unsloth-bnb-4bit",
- ),
- # Ministral 3 models
- "unsloth/Ministral-3-3B-Instruct-2512-unsloth-bnb-4bit" : {
- "8" : (
- "mistralai/Ministral-3-3B-Instruct-2512",
- "unsloth/Ministral-3-3B-Instruct-2512-FP8",
- "unsloth/Ministral-3-3B-Instruct-2512-FP8",
- ),
- "16" : (
- "unsloth/Ministral-3-3B-Instruct-2512",
- "mistralai/Ministral-3-3B-Instruct-2512",
- "unsloth/Ministral-3-3B-Instruct-2512-bnb-4bit",
- ),
- },
- "unsloth/Ministral-3-3B-Base-2512-unsloth-bnb-4bit" : (
- "unsloth/Ministral-3-3B-Base-2512",
- "mistralai/Ministral-3-3B-Base-2512",
- "unsloth/Ministral-3-3B-Base-2512-bnb-4bit",
- ),
- "unsloth/Ministral-3-3B-Reasoning-2512-unsloth-bnb-4bit" : (
- "unsloth/Ministral-3-3B-Reasoning-2512",
- "mistralai/Ministral-3-3B-Reasoning-2512",
- "unsloth/Ministral-3-3B-Reasoning-2512-bnb-4bit",
- ),
- "unsloth/Ministral-3-8B-Instruct-2512-unsloth-bnb-4bit" : {
- "8" : (
- "mistralai/Ministral-3-8B-Instruct-2512",
- "unsloth/Ministral-3-8B-Instruct-2512-FP8",
- "unsloth/Ministral-3-8B-Instruct-2512-FP8",
- ),
- "16" : (
- "unsloth/Ministral-3-8B-Instruct-2512",
- "mistralai/Ministral-3-8B-Instruct-2512",
- "unsloth/Ministral-3-8B-Instruct-2512-bnb-4bit",
- ),
- },
- "unsloth/Ministral-3-8B-Base-2512-unsloth-bnb-4bit" : (
- "unsloth/Ministral-3-8B-Base-2512",
- "mistralai/Ministral-3-8B-Base-2512",
- "unsloth/Ministral-3-8B-Base-2512-bnb-4bit",
- ),
- "unsloth/Ministral-3-8B-Reasoning-2512-unsloth-bnb-4bit" : (
- "unsloth/Ministral-3-8B-Reasoning-2512",
- "mistralai/Ministral-3-8B-Reasoning-2512",
- "unsloth/Ministral-3-8B-Reasoning-2512-bnb-4bit",
- ),
- "unsloth/Ministral-3-14B-Instruct-2512-unsloth-bnb-4bit" : {
- "8" : (
- "mistralai/Ministral-3-14B-Instruct-2512",
- "unsloth/Ministral-3-14B-Instruct-2512-FP8",
- "unsloth/Ministral-3-14B-Instruct-2512-FP8",
- ),
- "16" : (
- "unsloth/Ministral-3-14B-Instruct-2512",
- "mistralai/Ministral-3-14B-Instruct-2512",
- "unsloth/Ministral-3-14B-Instruct-2512-bnb-4bit",
- ),
- },
- "unsloth/Ministral-3-14B-Base-2512-unsloth-bnb-4bit" : (
- "unsloth/Ministral-3-14B-Base-2512",
- "mistralai/Ministral-3-14B-Base-2512",
- "unsloth/Ministral-3-14B-Base-2512-bnb-4bit",
- ),
- "unsloth/Ministral-3-14B-Reasoning-2512-unsloth-bnb-4bit" : (
- "unsloth/Ministral-3-14B-Reasoning-2512",
- "mistralai/Ministral-3-14B-Reasoning-2512",
- "unsloth/Ministral-3-14B-Reasoning-2512-bnb-4bit",
- ),
- "unsloth/Kimi-K2-Instruct-BF16" : (
- "unsloth/Kimi-K2-Instruct",
- ),
}
INT_TO_FLOAT_MAPPER = {}
FLOAT_TO_INT_MAPPER = {}
MAP_TO_UNSLOTH_16bit = {}
-FLOAT_TO_FP8_BLOCK_MAPPER = {}
-FLOAT_TO_FP8_ROW_MAPPER = {}
-
-
-def _add_with_lower(mapper, key, value):
- if key is None:
- return
- mapper[key] = value
- mapper[key.lower()] = value
-
-
-def _add_lower_only(mapper, key, value):
- if key is None:
- return
- mapper[key.lower()] = value
for key, values in __INT_TO_FLOAT_MAPPER.items():
- block, row = None, None
- if type(values) is dict:
- assert "16" in values
- float16_values = values["16"]
- # Float8 and other quantized types
- if "8" in values:
- float8_values = values["8"]
- assert len(float8_values) == 3
- official, block, row = float8_values
- _add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, key, block)
- _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, key, row)
- _add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, official + "-dynamic", block)
- _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, official, row)
- _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, official + "-dynamic", row)
- for k in float8_values + float16_values:
- _add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, k, block)
- _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, k, row)
-
- if float8_values[1] is not None and float8_values[1].startswith("unsloth"):
- for value in float8_values:
- if value is not None:
- _add_with_lower(MAP_TO_UNSLOTH_16bit, value, float8_values[1])
-
- for value in float8_values:
- if value is not None:
- FLOAT_TO_INT_MAPPER[value] = key
- FLOAT_TO_INT_MAPPER[value.lower()] = key.lower()
- values = float16_values
INT_TO_FLOAT_MAPPER[key] = values[0]
for value in values:
FLOAT_TO_INT_MAPPER[value] = key
+ pass
# Map to Unsloth version for 16bit versions
if len(values) == 2:
if values[0].startswith("unsloth"):
- _add_with_lower(MAP_TO_UNSLOTH_16bit, values[1], values[0])
- _add_with_lower(MAP_TO_UNSLOTH_16bit, block, values[0])
- _add_with_lower(MAP_TO_UNSLOTH_16bit, row, values[0])
+ MAP_TO_UNSLOTH_16bit[values[1]] = values[0]
+ MAP_TO_UNSLOTH_16bit[values[1].lower()] = values[0]
+ pass
elif len(values) == 3:
# Dynamic Unsloth quantization
if values[0].startswith("unsloth"):
- _add_with_lower(MAP_TO_UNSLOTH_16bit, values[1], values[0])
- _add_with_lower(MAP_TO_UNSLOTH_16bit, values[2], values[0])
- _add_with_lower(MAP_TO_UNSLOTH_16bit, block, values[0])
- _add_with_lower(MAP_TO_UNSLOTH_16bit, row, values[0])
+ MAP_TO_UNSLOTH_16bit[values[1]] = values[0]
+ MAP_TO_UNSLOTH_16bit[values[1].lower()] = values[0]
+ MAP_TO_UNSLOTH_16bit[values[2]] = values[0]
+ MAP_TO_UNSLOTH_16bit[values[2].lower()] = values[0]
pass
+ pass
# Get lowercased
lowered_key = key.lower()
@@ -1416,3 +753,5 @@ def _add_lower_only(mapper, key, value):
for value in values:
FLOAT_TO_INT_MAPPER[value.lower()] = lowered_key
+ pass
+pass
diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py
index 83e9ab9486..303c3d9589 100644
--- a/unsloth/models/mistral.py
+++ b/unsloth/models/mistral.py
@@ -15,19 +15,6 @@
from .llama import *
import os
from ._utils import __version__
-from unsloth_zoo.utils import _get_dtype
-from unsloth_zoo.hf_utils import dtype_from_config
-from ..utils.packing import (
- get_packed_info_from_kwargs,
- mask_packed_sequence_boundaries,
-)
-from ..utils.attention_dispatch import (
- AttentionConfig,
- AttentionContext,
- run_attention,
- SDPA,
- select_attention_backend,
-)
from .llama import (
LlamaRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
@@ -38,7 +25,6 @@
MistralModel,
MistralForCausalLM,
)
-
# For Pytorch 2.1.1
try:
from transformers.models.mistral.modeling_mistral import (
@@ -46,25 +32,26 @@
MistralFlashAttention2,
)
except:
- MistralSdpaAttention = MistralAttention
+ MistralSdpaAttention = MistralAttention
MistralFlashAttention2 = MistralAttention
+pass
from unsloth_zoo.utils import Version, _get_dtype
def MistralAttention_fast_forward(
self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- padding_mask: Optional[torch.LongTensor] = None,
+ hidden_states: torch.Tensor,
+ causal_mask: Optional[BlockDiagonalCausalMask] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- *args,
- **kwargs,
+ *args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
# Clear inference
if hasattr(self, "paged_attention"):
del self.paged_attention_K
@@ -74,20 +61,20 @@ def MistralAttention_fast_forward(
del self.temp_KV
del self.RH_Q
del self.attention
+ pass
bsz, q_len, _ = hidden_states.size()
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
+ n_heads = self.config.num_attention_heads
+ n_groups = self.num_key_value_groups
n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
- assert n_kv_heads * n_groups == n_heads
+ head_dim = self.head_dim
+ assert(n_kv_heads * n_groups == n_heads)
Q, K, V = self.apply_qkv(self, hidden_states)
- Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
+ Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
- seq_info = get_packed_info_from_kwargs(kwargs, Q.device)
kv_seq_len = K.shape[-2]
if past_key_value is not None:
@@ -95,58 +82,92 @@ def MistralAttention_fast_forward(
# Extend RoPE dynamically to fit in VRAM
self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
- cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index)
- rope_position_ids = (
- position_ids if position_ids is not None else kwargs.get("position_ids")
- )
- # Useful for LongRoPE
- Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
+ if position_ids is None:
+ cos = self.rotary_emb.cos_cached
+ sin = self.rotary_emb.sin_cached
+ Q, K = fast_rope_embedding(Q, K, cos, sin)
+ else:
+ cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
+ Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
+ pass
if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
V = torch.cat([past_key_value[1], V], dim = 2)
+ pass
past_key_value = (K, V) if use_cache else None
# Attention module
- sw_cfg = getattr(self.config, "sliding_window", None)
- sw = kv_seq_len if (sw_cfg is None or sw_cfg == "null") else sw_cfg
- window_size = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
-
- use_varlen = (
- seq_info is not None and past_key_value is None and window_size == (-1, -1)
- )
- backend = (
- SDPA if attention_mask is not None else select_attention_backend(use_varlen)
- )
- attention_config = AttentionConfig(
- backend = backend,
- n_kv_heads = n_kv_heads,
- n_groups = n_groups,
- flash_dense_kwargs = {"causal": True, "window_size": window_size},
- flash_varlen_kwargs = {
- "dropout_p": 0.0,
- "causal": True,
- "softmax_scale": getattr(self, "softmax_scale", None),
- },
- )
- context = AttentionContext(
- bsz = bsz,
- q_len = q_len,
- kv_seq_len = kv_seq_len,
- n_heads = n_heads,
- head_dim = head_dim,
- requires_grad = hidden_states.requires_grad,
- seq_info = seq_info,
- attention_mask = attention_mask,
- causal_mask = causal_mask,
- )
+ if (not HAS_FLASH_ATTENTION and attention_mask is None):
+ # Xformers memory efficient attention
+ Q = Q.transpose(1, 2)
+ K = K.transpose(1, 2)
+ V = V.transpose(1, 2)
+ K_M = V_M = bsz * kv_seq_len
+ Q_M = bsz * q_len
+
+ has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask)
+
+ # Group query attention
+ K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+ V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
+ K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
+ V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
+ if hidden_states.requires_grad:
+ K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
+ V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
+
+ if has_swa:
+ Q = Q.view(1, Q_M, n_heads, head_dim)
+ K = K.view(1, K_M, n_heads, head_dim)
+ V = V.view(1, V_M, n_heads, head_dim)
+ pass
+ else:
+ # Xformers does support the forward pass though
+ Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
+
+ if has_swa:
+ Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim)
+ K = K.view(1, K_M, n_kv_heads, n_groups, head_dim)
+ V = V.view(1, V_M, n_kv_heads, n_groups, head_dim)
+ pass
+ pass
- A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
- attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
+ A = xformers_attention(Q, K, V, attn_bias = causal_mask)
+ A = A.view(bsz, q_len, n_heads, head_dim)
+
+ elif HAS_FLASH_ATTENTION and attention_mask is None:
+ Q = Q.transpose(1, 2)
+ K = K.transpose(1, 2)
+ V = V.transpose(1, 2)
+ sw = getattr(self.config, "sliding_window", None)
+ sw = kv_seq_len if (sw is None or sw == "null") else sw
+ window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
+ A = flash_attn_func(Q, K, V, causal = True, window_size = window)
+ else:
+ # Grouped query attention
+ # if n_groups != 1:
+ K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+ V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
+ K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
+ V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
+ # pass
+ # Must be contiguous or else results are False!
+ # https://github.com/pytorch/pytorch/issues/112577
+ Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
+ # Needs (batch_size, n_heads, seq_len, head_dim)
+ # is_casual and attention_mask must not be both set!
+ A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
+ # Go back to (batch_size, seq_len, n_heads, head_dim)
+ A = A.transpose(1, 2).contiguous()
+ pass
+
+ attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
attn_output = self.apply_o(self, attn_output)
attn_weights = None
return attn_output, attn_weights, past_key_value
+pass
def MistralForCausalLM_fast_forward(
@@ -164,91 +185,27 @@ def MistralForCausalLM_fast_forward(
return_dict: Optional[bool] = None,
num_logits_to_keep: Optional[int] = 0,
logits_to_keep: Optional[int] = 0,
- *args,
- **kwargs,
+ *args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
+
if causal_mask is None and past_key_values is None:
bsz, q_len = input_ids.shape
sliding_window = getattr(self.config, "sliding_window", None)
-
- if HAS_XFORMERS:
- # Always create causal mask for xformers
- if (
- sliding_window is None
- or sliding_window == "null"
- or sliding_window <= 0
- ):
- causal_mask = xformers.attn_bias.LowerTriangularMask()
- elif q_len <= sliding_window:
- causal_mask = xformers.attn_bias.LowerTriangularMask()
- else:
- causal_mask = xformers.attn_bias.BlockDiagonalCausalMask.from_seqlens(
- [q_len] * bsz
- ).make_local_attention(window_size = sliding_window)
-
- # If attention_mask exists, it will be handled in the attention forward
-
+ if sliding_window is None or sliding_window == "null" or sliding_window <= 0:
+ causal_mask = xformers.attn_bias.LowerTriangularMask()
+ elif q_len <= sliding_window:
+ causal_mask = xformers.attn_bias.LowerTriangularMask()
else:
- # Not using xformers - need to create attention masks
- if (
- sliding_window is None
- or sliding_window == "null"
- or sliding_window <= 0
- or q_len <= sliding_window
- ):
- # Fully causal mask
- causal_mask_values = torch.triu(
- torch.full((q_len, q_len), -torch.inf, device = input_ids.device),
- diagonal = 1,
- )
- else:
- # Sliding window attention
- q_indices = torch.arange(q_len, device = input_ids.device).view(-1, 1)
- k_indices = torch.arange(q_len, device = input_ids.device).view(1, -1)
-
- causal_bool_mask = k_indices <= q_indices
- window_bool_mask = (q_indices - k_indices) < sliding_window
-
- causal_mask_values = torch.where(
- causal_bool_mask & window_bool_mask, 0.0, -torch.inf
- )
-
- # Combine with existing attention_mask if present
- if attention_mask is None:
- attention_mask = causal_mask_values[None, None, :, :].expand(
- bsz, 1, q_len, q_len
- )
- else:
- if attention_mask.dim() == 2:
- # Convert 0/1 padding mask to additive format: 1->0 (keep), 0->-inf (mask)
- padding_mask = torch.where(
- attention_mask[:, None, None, :].bool(),
- 0.0,
- -torch.inf,
- )
- attention_mask = causal_mask_values[None, None, :, :] + padding_mask
- else:
- attention_mask = (
- attention_mask + causal_mask_values[None, None, :, :]
- )
-
- attention_mask = attention_mask.to(
- dtype = _get_dtype(dtype_from_config(self.config))
- )
+ causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\
+ .from_seqlens([q_len]*bsz)\
+ .make_local_attention(window_size = sliding_window)
+ pass
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
self.model._has_no_labels = labels is None
@@ -273,19 +230,18 @@ def MistralForCausalLM_fast_forward(
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
return_dict = return_dict,
- **kwargs,
)
+ pass
hidden_states = outputs[0]
bsz, q_len, hd = hidden_states.shape
lm_head = self.lm_head.weight
lm_head_device = lm_head.device
-
+
# Move items to same device as lm_head
hidden_states = hidden_states.to(lm_head_device)
- if labels is not None:
- labels = labels.to(lm_head_device)
+ if labels is not None: labels = labels.to(lm_head_device)
# If we are in GRPO mode, return raw hidden states
if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1":
@@ -299,51 +255,33 @@ def MistralForCausalLM_fast_forward(
hidden_states = outputs.hidden_states,
attentions = outputs.attentions,
)
+ pass
if bsz == 1 and q_len == 1:
logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype))
logits = logits.unsqueeze(0).unsqueeze(0)
elif num_logits_to_keep != 0:
- logits = self.lm_head(
- hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype)
- )
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype))
else:
RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
# < 1024 Normal Unsloth uses less VRAM!
- if bsz * q_len <= 1024 and not RETURN_LOGITS:
- # Use unsloth_fused_ce_loss which actually calculates the best chunk size to reduce VRAM usage
- RETURN_LOGITS = False
-
- if not RETURN_LOGITS and labels is not None:
- n_items = kwargs.get("num_items_in_batch", None)
- if n_items is None:
- n_items = kwargs.get("n_items", None)
- logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
+ if bsz * q_len <= 1024: RETURN_LOGITS = True
- # loss = fused_linear_cross_entropy(
- # hidden_states = hidden_states,
- # lm_weight = lm_head,
- # labels = labels,
- # num_items_in_batch = n_items,
- # logit_softcapping = logit_softcapping,
- # )
- loss = unsloth_fused_ce_loss(
- trainer = None,
+ if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None:
+ n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)
+ logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
+ loss = fused_linear_cross_entropy(
hidden_states = hidden_states,
- lm_head_weight = lm_head,
- lm_head_bias = None,
+ lm_weight = lm_head,
labels = labels,
- mask = None,
- n_items = n_items,
- scaling = getattr(self, "accelerator_scaler", None),
- target_gb = None,
- torch_compile = True,
+ num_items_in_batch = n_items,
logit_softcapping = logit_softcapping,
)
+
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
-
+
output = CausalLMOutputWithPast(
loss = loss,
logits = EMPTY_LOGITS,
@@ -354,7 +292,8 @@ def MistralForCausalLM_fast_forward(
return output
pass
logits = self.lm_head(hidden_states.to(lm_head.dtype))
- logits = logits.to(_get_dtype(dtype_from_config(self.config)))
+ pass
+ logits = logits.to(_get_dtype(self.config.torch_dtype))
loss = None
if labels is not None:
@@ -367,18 +306,12 @@ def MistralForCausalLM_fast_forward(
shift_labels = torch.empty_like(labels)
shift_labels[..., :-1] = labels[..., 1:]
shift_labels[..., -1] = -100
- mask_packed_sequence_boundaries(
- shift_labels,
- kwargs.get("packed_seq_lengths"),
- )
- n_items = kwargs.get("num_items_in_batch", None)
- if n_items is None:
- n_items = kwargs.get("n_items", None)
loss = fast_cross_entropy_loss(
- logits = shift_logits,
- labels = shift_labels,
- n_items = n_items,
+ logits = shift_logits,
+ labels = shift_labels,
+ n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None),
)
+ pass
if not return_dict:
output = (logits,) + outputs[1:]
@@ -391,6 +324,7 @@ def MistralForCausalLM_fast_forward(
hidden_states = outputs.hidden_states,
attentions = outputs.attentions,
)
+pass
# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
@@ -408,70 +342,74 @@ def patch_mistral_nemo_attention(function):
"self.o_proj = nn.Linear(self.config.num_attention_heads * self.head_dim, self.config.hidden_size, bias=False)",
)
return function
+pass
class FastMistralModel(FastLlamaModel):
+
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
- model_name = "mistral",
- rope_module = LlamaRotaryEmbedding,
+ model_name = "mistral",
+ rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
- attention_module = MistralAttention,
+ attention_module = MistralAttention,
)
# Just for Mistral Nemo models!
if function is not None and init_name is not None:
function = patch_mistral_nemo_attention(function)
# if True:#init_name is not None:
exec(function, globals())
- MistralAttention.__init__ = eval(init_name)
- MistralAttention.forward = MistralAttention_fast_forward
- MistralSdpaAttention.forward = MistralAttention_fast_forward
+ MistralAttention.__init__ = eval(init_name)
+ pass
+ MistralAttention .forward = MistralAttention_fast_forward
+ MistralSdpaAttention .forward = MistralAttention_fast_forward
MistralFlashAttention2.forward = MistralAttention_fast_forward
- MistralDecoderLayer.forward = LlamaDecoderLayer_fast_forward
- MistralModel.forward = LlamaModel_fast_forward
- MistralForCausalLM.forward = MistralForCausalLM_fast_forward
- PeftModelForCausalLM.forward = PeftModel_fast_forward
+ MistralDecoderLayer .forward = LlamaDecoderLayer_fast_forward
+ MistralModel .forward = LlamaModel_fast_forward
+ MistralForCausalLM .forward = MistralForCausalLM_fast_forward
+ PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(MistralForCausalLM)
-
+
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
- # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.
+ # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.mistral.modeling_mistral
-
- transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding = (
- LlamaRotaryEmbedding
- )
+ transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding = LlamaRotaryEmbedding
return
+ pass
+
@staticmethod
def from_pretrained(
- model_name = "unsloth/mistral-7b-bnb-4bit",
- max_seq_length = None,
- dtype = None,
- load_in_4bit = True,
- token = None,
- device_map = "sequential",
- rope_scaling = None, # Mistral does not support RoPE scaling
- fix_tokenizer = True,
- model_patcher = None,
- tokenizer_name = None,
+ model_name = "unsloth/mistral-7b-bnb-4bit",
+ max_seq_length = None,
+ dtype = None,
+ load_in_4bit = True,
+ token = None,
+ device_map = "sequential",
+ rope_scaling = None, # Mistral does not support RoPE scaling
+ fix_tokenizer = True,
+ model_patcher = None,
+ tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
- model_name = model_name,
- max_seq_length = max_seq_length,
- dtype = dtype,
- load_in_4bit = load_in_4bit,
- token = token,
- device_map = device_map,
- rope_scaling = rope_scaling,
- fix_tokenizer = fix_tokenizer,
- model_patcher = FastMistralModel,
- tokenizer_name = tokenizer_name,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ load_in_4bit = load_in_4bit,
+ token = token,
+ device_map = device_map,
+ rope_scaling = rope_scaling,
+ fix_tokenizer = fix_tokenizer,
+ model_patcher = FastMistralModel,
+ tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)
+ pass
+pass
diff --git a/unsloth/models/qwen2.py b/unsloth/models/qwen2.py
index 3f819d6dc1..82de1951b8 100644
--- a/unsloth/models/qwen2.py
+++ b/unsloth/models/qwen2.py
@@ -23,7 +23,6 @@
Qwen2Model,
Qwen2ForCausalLM,
)
-
# For Pytorch 2.1.1
try:
from transformers.models.qwen2.modeling_qwen2 import (
@@ -31,71 +30,73 @@
Qwen2FlashAttention2,
)
except:
- Qwen2SdpaAttention = Qwen2Attention
+ Qwen2SdpaAttention = Qwen2Attention
Qwen2FlashAttention2 = Qwen2Attention
+pass
class FastQwen2Model(FastLlamaModel):
+
@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
- model_name = "qwen2",
- rope_module = LlamaRotaryEmbedding,
+ model_name = "qwen2",
+ rope_module = LlamaRotaryEmbedding,
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
- attention_module = Qwen2Attention,
+ attention_module = Qwen2Attention,
)
if init_name is not None:
exec(function, globals())
- Qwen2Attention.__init__ = eval(init_name)
- Qwen2Attention.forward = LlamaAttention_fast_forward
- Qwen2SdpaAttention.forward = LlamaAttention_fast_forward
+ Qwen2Attention.__init__ = eval(init_name)
+ pass
+ Qwen2Attention .forward = LlamaAttention_fast_forward
+ Qwen2SdpaAttention .forward = LlamaAttention_fast_forward
Qwen2FlashAttention2.forward = LlamaAttention_fast_forward
- Qwen2DecoderLayer.forward = LlamaDecoderLayer_fast_forward
- Qwen2Model.forward = LlamaModel_fast_forward
- Qwen2ForCausalLM.forward = CausalLM_fast_forward(
- LlamaModel_fast_forward_inference
- )
- PeftModelForCausalLM.forward = PeftModel_fast_forward
+ Qwen2DecoderLayer .forward = LlamaDecoderLayer_fast_forward
+ Qwen2Model .forward = LlamaModel_fast_forward
+ Qwen2ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
+ PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
fix_prepare_inputs_for_generation(Qwen2ForCausalLM)
# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
- # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.
+ # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.qwen2.modeling_qwen2
-
- transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = (
- LlamaRotaryEmbedding
- )
+ transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = LlamaRotaryEmbedding
return
+ pass
+
@staticmethod
def from_pretrained(
- model_name = "Qwen/Qwen2-7B",
- max_seq_length = 4096,
- dtype = None,
- load_in_4bit = True,
- token = None,
- device_map = "sequential",
- rope_scaling = None, # Qwen2 does not support RoPE scaling
- fix_tokenizer = True,
- model_patcher = None,
- tokenizer_name = None,
+ model_name = "Qwen/Qwen2-7B",
+ max_seq_length = 4096,
+ dtype = None,
+ load_in_4bit = True,
+ token = None,
+ device_map = "sequential",
+ rope_scaling = None, # Qwen2 does not support RoPE scaling
+ fix_tokenizer = True,
+ model_patcher = None,
+ tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastLlamaModel.from_pretrained(
- model_name = model_name,
- max_seq_length = max_seq_length,
- dtype = dtype,
- load_in_4bit = load_in_4bit,
- token = token,
- device_map = device_map,
- rope_scaling = rope_scaling,
- fix_tokenizer = fix_tokenizer,
- model_patcher = FastQwen2Model,
- tokenizer_name = tokenizer_name,
+ model_name = model_name,
+ max_seq_length = max_seq_length,
+ dtype = dtype,
+ load_in_4bit = load_in_4bit,
+ token = token,
+ device_map = device_map,
+ rope_scaling = rope_scaling,
+ fix_tokenizer = fix_tokenizer,
+ model_patcher = FastQwen2Model,
+ tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)
+ pass
+pass
diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py
deleted file mode 100644
index b93dddb186..0000000000
--- a/unsloth/models/qwen3.py
+++ /dev/null
@@ -1,475 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .llama import *
-import os
-from ._utils import __version__
-from unsloth_zoo.utils import Version, _get_dtype
-from ..utils.packing import get_packed_info_from_kwargs
-from ..utils.attention_dispatch import (
- AttentionConfig,
- AttentionContext,
- run_attention,
- SDPA,
- select_attention_backend,
-)
-from .llama import (
- LlamaRotaryEmbedding,
- LlamaLinearScalingRotaryEmbedding,
- _LlamaModel_fast_forward_inference,
-)
-
-try:
- from transformers.models.qwen3.modeling_qwen3 import (
- Qwen3Attention,
- Qwen3DecoderLayer,
- Qwen3Model,
- Qwen3ForCausalLM,
- )
-except:
- transformers_version = Version(transformers_version)
- if not transformers_version >= Version(
- "4.50.3"
- ): # TODO: Update when transformers is updated
- raise ImportError(
- f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3 and Qwen3Moe.\n"
- f"The minimum required version is 4.50.3.\n"
- f'Try `pip install --upgrade "transformers>=4.50.3"`\n'
- f"to obtain the latest transformers build, then restart this session."
- )
-from transformers.modeling_attn_mask_utils import (
- _prepare_4d_causal_attention_mask_for_sdpa,
-)
-
-# For Pytorch 2.1.1
-try:
- from transformers.models.qwen3.modeling_qwen3 import (
- Qwen3SdpaAttention,
- Qwen3FlashAttention2,
- )
-except:
- Qwen3SdpaAttention = Qwen3Attention
- Qwen3FlashAttention2 = Qwen3Attention
-
-
-def Qwen3Attention_fast_forward(
- self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- padding_mask: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- *args,
- **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # Clear inference
- if hasattr(self, "paged_attention"):
- del self.paged_attention_K
- del self.paged_attention_V
- del self.paged_attention
- del self.temp_QA
- del self.temp_KV
- del self.RH_Q
- del self.attention
-
- bsz, q_len, _ = hidden_states.size()
-
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
- n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
- assert n_kv_heads * n_groups == n_heads
-
- Q, K, V = self.apply_qkv(self, hidden_states)
- Q = Q.view(
- bsz, q_len, n_heads, head_dim
- ) # .transpose(1, 2) # we will transpose after normalisation
- K = K.view(
- bsz, q_len, n_kv_heads, head_dim
- ) # .transpose(1, 2) # we will transpose after normalisation
- V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
- seq_info = get_packed_info_from_kwargs(kwargs, hidden_states.device)
-
- # Qwen3 has QKNorm. This seems to be the only difference from Qwen2.
- # Note that using fast_layernorm_compiled causes issues as the dimensions don't match up.
- # I tried to add a compiled version of the new norm but the numbers don't match up with Transformers
- # TODO: Check on the differences here.
- Q = fast_rms_layernorm(self.q_norm, Q)
- K = fast_rms_layernorm(self.k_norm, K)
-
- Q = Q.transpose(1, 2)
- K = K.transpose(1, 2)
-
- kv_seq_len = K.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- # Extend RoPE dynamically to fit in VRAM
- if position_embeddings and kv_seq_len <= position_embeddings[0].shape[0]:
- cos, sin = position_embeddings
- else:
- rotary_emb = self.rotary_emb
- rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)
- cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index)
-
- rope_position_ids = (
- position_ids if position_ids is not None else kwargs.get("position_ids")
- )
- # Useful for LongRoPE
- Q, K = fast_rope_embedding(Q, K, cos, sin, rope_position_ids)
-
- if past_key_value is not None:
- K = torch.cat([past_key_value[0], K], dim = 2)
- V = torch.cat([past_key_value[1], V], dim = 2)
- past_key_value = (K, V) if use_cache else None
-
- # Attention module
- use_varlen = seq_info is not None and past_key_value is None
- backend = (
- SDPA if attention_mask is not None else select_attention_backend(use_varlen)
- )
- attention_config = AttentionConfig(
- backend = backend,
- n_kv_heads = n_kv_heads,
- n_groups = n_groups,
- flash_dense_kwargs = {"causal": True},
- flash_varlen_kwargs = {
- "dropout_p": 0.0,
- "causal": True,
- "softmax_scale": getattr(self, "softmax_scale", None),
- },
- )
- context = AttentionContext(
- bsz = bsz,
- q_len = q_len,
- kv_seq_len = kv_seq_len,
- n_heads = n_heads,
- head_dim = head_dim,
- requires_grad = hidden_states.requires_grad,
- seq_info = seq_info,
- attention_mask = attention_mask,
- causal_mask = causal_mask,
- )
-
- A = run_attention(config = attention_config, context = context, Q = Q, K = K, V = V)
-
- attn_output = A.reshape(bsz, q_len, n_heads * head_dim)
- attn_output = self.apply_o(self, attn_output)
- attn_weights = None
- return attn_output, attn_weights, past_key_value
-
-
-torch_matmul = torch.matmul
-
-
-def Qwen3Attention_fast_forward_inference(
- self,
- hidden_states: torch.Tensor,
- past_key_value: Optional[Tuple[torch.Tensor]],
- position_ids,
- do_prefill = False,
- attention_mask = None,
- **kwargs,
-):
- """
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406
- Fast inference using KV cache.
- QK^T can be computed in 4 chunks
-
- [Q, q] @ [K, k].T where q, k are the new tokens.
- [QK^T, Qk^T]
- [qK^T, qk^T]
-
- Since the attention mask wipes Qk^T, we just get
- [QK^T, 0]
- [qK^T, qk^T]
-
- Since softmax is row-wise, we get
- softmax([QK^T, 0])
- softmax([qK^T, qk^T])
-
- We then multiply by [V]
- [v]
- softmax([QK^T, 0]) [softmax(QK^T)V] *
- softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]]
-
- But notice * [softmax(QK^T)V] is just the last attention.
- We just need to compute the last final row.
-
- This means we can pass in a row of Q, but we need to
- remember K and V, which are called the KV cache.
- """
- Xn = hidden_states
- bsz, _, hd = hidden_states.size()
- K1, V1 = past_key_value
- dtype = Xn.dtype
-
- n_heads = self.config.num_attention_heads
- n_groups = self.num_key_value_groups
- n_kv_heads = self.config.num_key_value_heads
- head_dim = self.head_dim
- # assert(n_kv_heads * n_groups == n_heads)
-
- hidden_size = self.config.hidden_size
- attention_size = n_heads * head_dim
- seq_len = K1.shape[-2]
- kv_seq_len = seq_len + 1
-
- # Prefill phase
- # if not hasattr(self, "paged_attention"):
- device = hidden_states.device
- if do_prefill:
- self.paged_attention = torch.empty(
- (KV_CACHE_INCREMENT + seq_len + 1, 2, bsz, n_kv_heads, head_dim),
- dtype = dtype,
- device = device,
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
- self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
- self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
- self.temp_QA = torch.empty(
- (2, bsz, 1, attention_size), dtype = dtype, device = device
- )
- self.temp_KV = torch.empty(
- (2, bsz, 1, n_kv_heads * head_dim), dtype = dtype, device = device
- )
- self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device)
-
- # Mistral Nemo 12b has weird dimensions
- if attention_size != hidden_size:
- self.temp_O = torch.empty((bsz, 1, hidden_size), dtype = dtype, device = device)
- else:
- self.temp_O = self.temp_QA[1][:, :, :hidden_size]
-
- self.attention = torch.empty(
- (bsz, n_heads, 1, KV_CACHE_INCREMENT + seq_len), dtype = dtype, device = device
- )
- self.scalar = 1.0 / math_sqrt(self.head_dim)
- self.half_head_dim = head_dim // 2
- elif kv_seq_len >= self.paged_attention.shape[0]:
- self.paged_attention.resize_(
- (
- self.paged_attention.shape[0] + KV_CACHE_INCREMENT,
- 2,
- bsz,
- n_kv_heads,
- head_dim,
- )
- )
- self.paged_attention_K = self.paged_attention[:, 0]
- self.paged_attention_V = self.paged_attention[:, 1]
- self.attention.resize_(
- (bsz, n_heads, 1, self.attention.shape[-1] + KV_CACHE_INCREMENT)
- )
-
- Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
- Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
- Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
- Qn = Qn.view(
- bsz, 1, n_heads, head_dim
- ) # .transpose(1, 2) # we will transpose after normalisation
- Kn = Kn.view(
- bsz, 1, n_kv_heads, head_dim
- ) # .transpose(1, 2) # we will transpose after normalisation
- Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
-
- Qn = fast_rms_layernorm_inference(self.q_norm, Qn)
- Kn = fast_rms_layernorm_inference(self.k_norm, Kn)
-
- Qn = Qn.transpose(1, 2)
- Kn = Kn.transpose(1, 2)
-
- # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
- # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
-
- # Need to do it prior 2 steps before hitting full on short KV cache
- # or else error
- self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2)
- cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index)
- cos = cos[position_ids].unsqueeze(1)
- sin = sin[position_ids].unsqueeze(1)
- h = self.half_head_dim
-
- RH_Q = self.RH_Q
- RH_Q[:, :, :, :h] = Qn[:, :, :, h:]
- RH_Q[:, :, :, h:] = Qn[:, :, :, :h]
- RH_Q[:, :, :, :h].neg_() # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
- Qn *= cos
- Qn.addcmul_(RH_Q, sin)
-
- RH_K = RH_Q[
- :, :n_kv_heads, :, :
- ] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
- RH_K[:, :, :, :h] = Kn[:, :, :, h:]
- RH_K[:, :, :, h:] = Kn[:, :, :, :h]
- RH_K[:, :, :, :h].neg_() # torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
- Kn *= cos
- Kn.addcmul_(RH_K, sin)
-
- # New KV cache
- # Kn = torch.cat([K1, Kn], dim = 2)
- # Vn = torch.cat([V1, Vn], dim = 2)
- self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
- self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
- Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
- Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
-
- # Handle sliding windows
- sliding_window = getattr(self.config, "sliding_window", None)
- if sliding_window is not None and kv_seq_len > sliding_window:
- start = kv_seq_len - sliding_window
- Knn = Kn[:, :, start:, :] # .contiguous()
- Vnn = Vn[:, :, start:, :] # .contiguous()
- if attention_mask is not None:
- attention_mask = attention_mask[..., start:]
- else:
- Knn, Vnn = Kn, Vn
-
- # when qlen==vlen and attn_mask is None, we should use causal attention
- Q_len = Qn.shape[-2]
- K_len = Knn.shape[-2]
- if attention_mask is not None and attention_mask.dim() == 2:
- attention_mask = attention_mask[:, None, None, :].to(torch.bool)
- elif (
- attention_mask is not None
- and attention_mask.dim() == 4
- and attention_mask.dtype != torch.bool
- ):
- attention_mask = attention_mask.eq(0)
- if attention_mask is None and Q_len == K_len:
- is_causal = True
- else:
- is_causal = False
- use_sdpa_gqa = SDPA_HAS_GQA
- if (
- use_sdpa_gqa
- and isinstance(attention_mask, torch.Tensor)
- and attention_mask.dim() >= 3
- and attention_mask.shape[0] > 1
- ):
- # Avoid SDPA GQA drift for batched masked decode.
- use_sdpa_gqa = False
-
- # Grouped query attention
- _, _, cached_len, _ = Knn.shape
- if bsz == 1 or ((not use_sdpa_gqa) and n_groups != 1):
- Knn = Knn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
- Vnn = Vnn[:, :, None, :, :].expand(
- bsz, n_kv_heads, n_groups, cached_len, head_dim
- )
- Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
- Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
-
- # Attention
- if bsz == 1:
- Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
- # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
- A = torch_matmul(
- Qn, Knn.transpose(2, 3), out = self.attention[:, :, :, :cached_len]
- )
- A[:] = torch_nn_functional_softmax(
- A, dim = -1, dtype = torch.float32
- ) # .to(A.dtype)
- A = torch_matmul(A, Vnn, out = Qn)
- else:
- if use_sdpa_gqa:
- A = scaled_dot_product_attention(
- Qn,
- Knn,
- Vnn,
- attn_mask = attention_mask,
- is_causal = is_causal,
- enable_gqa = True,
- )
- else:
- A = scaled_dot_product_attention(
- Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = is_causal
- )
- A = A.transpose(1, 2)
- A = A.reshape(bsz, 1, attention_size)
- A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
- return A, (Kn, Vn)
-
-
-class FastQwen3Model(FastLlamaModel):
- @staticmethod
- def pre_patch():
- init_name, function = patch_linear_scaling(
- model_name = "Qwen3",
- rope_module = LlamaRotaryEmbedding,
- scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
- attention_module = Qwen3Attention,
- )
- if init_name is not None:
- exec(function, globals())
- Qwen3Attention.__init__ = eval(init_name)
- Qwen3Attention.forward = Qwen3Attention_fast_forward
- Qwen3SdpaAttention.forward = Qwen3Attention_fast_forward
- Qwen3FlashAttention2.forward = Qwen3Attention_fast_forward
- Qwen3DecoderLayer.forward = LlamaDecoderLayer_fast_forward
- Qwen3Model.forward = LlamaModel_fast_forward
- Qwen3ForCausalLM.forward = CausalLM_fast_forward(
- _LlamaModel_fast_forward_inference(Qwen3Attention_fast_forward_inference)
- )
- PeftModelForCausalLM.forward = PeftModel_fast_forward
- fix_prepare_inputs_for_generation(Qwen3ForCausalLM)
-
- # Solves https://github.com/unslothai/unsloth/issues/168
- # Static KV Cache was introduced in 4.38.0, causing training to be much slower.
- # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.
- # https://github.com/huggingface/transformers/pull/27931
- # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
- import transformers.models.qwen3.modeling_qwen3
-
- transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding = (
- LlamaRotaryEmbedding
- )
- return
-
- @staticmethod
- def from_pretrained( # TODO: Change after release
- model_name = "Qwen/Qwen3-7B",
- max_seq_length = 4096,
- dtype = None,
- load_in_4bit = True,
- token = None,
- device_map = "sequential",
- rope_scaling = None,
- fix_tokenizer = True,
- model_patcher = None,
- tokenizer_name = None,
- trust_remote_code = False,
- **kwargs,
- ):
- return FastLlamaModel.from_pretrained(
- model_name = model_name,
- max_seq_length = max_seq_length,
- dtype = dtype,
- load_in_4bit = load_in_4bit,
- token = token,
- device_map = device_map,
- rope_scaling = rope_scaling,
- fix_tokenizer = fix_tokenizer,
- model_patcher = FastQwen3Model,
- tokenizer_name = tokenizer_name,
- trust_remote_code = trust_remote_code,
- **kwargs,
- )
diff --git a/unsloth/models/qwen3_moe.py b/unsloth/models/qwen3_moe.py
deleted file mode 100644
index e1f8c71b6b..0000000000
--- a/unsloth/models/qwen3_moe.py
+++ /dev/null
@@ -1,243 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from .llama import *
-import os
-from ._utils import __version__
-from .llama import (
- LlamaRotaryEmbedding,
- LlamaLinearScalingRotaryEmbedding,
-)
-from .qwen3 import (
- Qwen3Attention_fast_forward,
- FastQwen3Model,
-)
-from transformers.models.qwen3_moe.modeling_qwen3_moe import (
- Qwen3MoeAttention,
- Qwen3MoeSparseMoeBlock,
- Qwen3MoeMLP,
- Qwen3MoeDecoderLayer,
- Qwen3MoeModel,
- Qwen3MoeForCausalLM,
-)
-
-# For Pytorch 2.1.1
-# TODO: Transformers moved to `attention_interface`. So we might not need these anymore
-# try:
-# from transformers.models.qwen3_moe.modeling_qwen3_moe import (
-# Qwen3SdpaAttention,
-# Qwen3FlashAttention2,
-# )
-# except:
-# Qwen3SdpaAttention = Qwen3Attention
-# Qwen3FlashAttention2 = Qwen3Attention
-# pass
-from unsloth_zoo.utils import Version, _get_dtype
-
-
-torch_nn_functional_softmax = torch.nn.functional.softmax
-
-
-def Qwen3MoeSparseMoeBlock_fast_forward(self, X, temp_gate = None, temp_up = None):
- # adapted from https://github.com/huggingface/transformers/pull/36878/files#diff-0855b77fc27ad9449158a1c74953f909b011c00de7125f7c8e68d0ff209c092aR356-R370
-
- bsz, seq_len, hd = X.shape
- X = X.view(-1, hd)
-
- router_logits = fast_linear_forward(
- self.gate_proj, X, out = temp_gate
- ) # pretty much the only change from transformers implementation.
-
- routing_weights = torch_nn_functional_softmax(
- router_logits, dim = -1, dtype = torch.float32
- )
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1)
- routing_weights /= routing_weights.sum(dim = -1, keepdim = True)
- # we cast back to the input dtype
- routing_weights = routing_weights.to(X.dtype)
- final_X = torch.zeros((bsz * seq_len, hd), dtype = torch.float32, device = X.device)
-
- # One hot encode the selected experts to create an expert mask
- # this will be used to easily index which expert is going to be sollicitated
- expert_mask = torch.nn.functional.one_hot(
- selected_experts, num_classes = self.num_experts
- ).permute(2, 1, 0)
-
- # Loop over all available experts in the model and perform the computation on each expert
- for expert_idx in range(self.num_experts):
- expert_layer = self.experts[expert_idx]
- idx, top_x = torch.where(expert_mask[expert_idx])
-
- # Index the correct hidden states and compute the expert hidden state for
- # the current expert. We need to make sure to multiply the output hidden
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
- current_state = X[None, top_x].reshape(-1, hd)
- current_X = (
- expert_layer(current_state) * routing_weights[top_x, idx, None]
- ) # Qwen3MoeMLP.forward = fast_swiglu_inference takes care of making this faster. Analogous to Dense models' MLP
-
- # However `index_add_` only support torch tensors for indexing so we'll use
- # the `top_x` tensor here.
- final_X.index_add_(0, top_x, current_X.to(X.dtype))
- final_X = final_X.reshape(bsz, seq_len, hd)
- return final_X, router_logits
-
-
-def Qwen3MoeDecoderLayer_fast_forward(
- self,
- hidden_states: torch.Tensor,
- causal_mask: Optional[BlockDiagonalCausalMask] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- output_router_logits: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- padding_mask: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- *args,
- **kwargs,
-):
- residual = hidden_states
-
- if use_cache and hasattr(
- self, "_flag_for_generation"
- ): # past_key_value is not None:
- residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- self.input_layernorm, hidden_states
- )
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
- position_embeddings = position_embeddings,
- _flag_for_generation = self._flag_for_generation,
- )
- hidden_states += residual
-
- # MoE Router MLP
- residual = hidden_states
- hidden_states = fast_rms_layernorm_inference(
- self.post_attention_layernorm, hidden_states
- )
- hidden_states, router_logits = Qwen3MoeSparseMoeBlock_fast_forward(
- self.mlp, hidden_states
- )
- hidden_states += residual
- else:
- residual = hidden_states
- hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states = hidden_states,
- causal_mask = causal_mask,
- attention_mask = attention_mask,
- position_ids = position_ids,
- past_key_value = past_key_value,
- output_attentions = output_attentions,
- use_cache = use_cache,
- padding_mask = padding_mask,
- position_embeddings = position_embeddings,
- )
- hidden_states = residual + hidden_states
-
- # MoE Router MLP
- residual = hidden_states
- hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
- hidden_states, router_logits = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if output_router_logits:
- outputs += (router_logits,)
- if use_cache:
- outputs += (present_key_value,)
- return outputs
-
-
-class FastQwen3MoeModel(FastQwen3Model):
- @staticmethod
- def pre_patch():
- init_name, function = patch_linear_scaling(
- model_name = "Qwen3Moe",
- rope_module = LlamaRotaryEmbedding,
- scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
- attention_module = Qwen3MoeAttention,
- )
- if init_name is not None:
- exec(function, globals())
- Qwen3MoeAttention.__init__ = eval(init_name)
- Qwen3MoeAttention.forward = Qwen3Attention_fast_forward
- # Qwen3SdpaAttention .forward = Qwen3Attention_fast_forward
- # Qwen3FlashAttention2 .forward = Qwen3Attention_fast_forward
- Qwen3MoeSparseMoeBlock.forward = Qwen3MoeSparseMoeBlock_fast_forward
- Qwen3MoeMLP.forward = (
- fast_swiglu_inference # This is analogous to Dense models' MLP
- )
- Qwen3MoeDecoderLayer.forward = Qwen3MoeDecoderLayer_fast_forward
- Qwen3MoeModel.forward = LlamaModel_fast_forward
- Qwen3MoeForCausalLM.forward = CausalLM_fast_forward(
- LlamaModel_fast_forward_inference
- )
- PeftModelForCausalLM.forward = PeftModel_fast_forward
- fix_prepare_inputs_for_generation(Qwen3MoeForCausalLM)
-
- # Solves https://github.com/unslothai/unsloth/issues/168
- # Static KV Cache was introduced in 4.38.0, causing training to be much slower.
- # Inference can now be CUDAGraphed, but we shall retain the old rotary embeddings.
- # https://github.com/huggingface/transformers/pull/27931
- # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py\
- import transformers.models.qwen3_moe.modeling_qwen3_moe
-
- transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRotaryEmbedding = (
- LlamaRotaryEmbedding
- )
- return
-
- @staticmethod
- def from_pretrained( # TODO: Change after release
- model_name = "Qwen/Qwen3-7B",
- max_seq_length = 4096,
- dtype = None,
- load_in_4bit = True,
- token = None,
- device_map = "sequential",
- rope_scaling = None,
- fix_tokenizer = True,
- model_patcher = None,
- tokenizer_name = None,
- trust_remote_code = False,
- **kwargs,
- ):
- return FastLlamaModel.from_pretrained(
- model_name = model_name,
- max_seq_length = max_seq_length,
- dtype = dtype,
- load_in_4bit = load_in_4bit,
- token = token,
- device_map = device_map,
- rope_scaling = rope_scaling,
- fix_tokenizer = fix_tokenizer,
- model_patcher = FastQwen3MoeModel,
- tokenizer_name = tokenizer_name,
- trust_remote_code = trust_remote_code,
- **kwargs,
- )
diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py
old mode 100755
new mode 100644
index 30546a048d..c450ef6df5
--- a/unsloth/models/rl.py
+++ b/unsloth/models/rl.py
@@ -22,112 +22,38 @@
import inspect
import os
import re
+import torch
from unsloth_zoo.compiler import create_new_function
-from unsloth_zoo.log import logger
from unsloth_zoo.logging_utils import PatchRLStatistics
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS
-from ..device_type import DEVICE_TYPE
from .rl_replacements import (
RL_EXTRA_ARGS,
RL_FUNCTIONS,
RL_PRE_ITEMS,
RL_CONFIG_CHANGES,
RL_METRICS_CHANGES,
- RL_ADDITIONAL_FUNCTIONS,
)
+selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"]
torch_compile_options = {
- "epilogue_fusion": True,
- "max_autotune": False, # Disable Triton mm kernels
- "shape_padding": True,
- "trace.enabled": False,
- "triton.cudagraphs": False,
+ "epilogue_fusion" : True,
+ "max_autotune" : False, # Disable Triton mm kernels
+ "shape_padding" : True,
+ "trace.enabled" : False,
+ "triton.cudagraphs" : False,
}
-# vLLM compatibility shim (TRL expects GuidedDecodingParams even if vLLM doesn't provide it)
-try:
- import vllm.sampling_params as _unsloth_vllm_sp
-
- if not hasattr(_unsloth_vllm_sp, "GuidedDecodingParams"):
-
- class GuidedDecodingParams:
- def __init__(self, **kwargs):
- self.kwargs = kwargs
-
- _unsloth_vllm_sp.GuidedDecodingParams = GuidedDecodingParams
-except Exception:
- pass
-
-from trl import __version__ as trl_version_raw
-from importlib.metadata import version as importlib_version
-from unsloth_zoo.utils import Version
-
-try:
- trl_version = Version(trl_version_raw)
-except Exception:
- try:
- trl_version = Version(importlib_version("trl"))
- except Exception:
- trl_version = Version("0.0.0")
-
-# Get PyTorch version for feature detection
-try:
- torch_version = Version(torch.__version__.split("+")[0].split("a")[0].split("b")[0])
-except Exception:
- torch_version = Version("0.0.0")
-
-# Get transformers version for feature detection
-try:
- from transformers import __version__ as _transformers_version_raw
-
- transformers_version = Version(_transformers_version_raw)
-except Exception:
- transformers_version = Version("0.0.0")
-
def vLLMSamplingParams(**kwargs):
from vllm import SamplingParams
-
sampling_params = SamplingParams(**kwargs)
sampling_params._set_kwargs = kwargs
return sampling_params
-
+pass
def PatchRL(FastLanguageModel):
- try:
- from trl.models.utils import unwrap_model_for_generation
- except ImportError:
- try:
- from trl.models import unwrap_model_for_generation
- except ImportError:
- # Local fallback -- TRL removed or moved this symbol
- from contextlib import contextmanager as _cm
-
- @_cm
- def unwrap_model_for_generation(
- model, accelerator, gather_deepspeed3_params = True
- ):
- unwrapped_model = accelerator.unwrap_model(model)
- is_gc = getattr(unwrapped_model, "is_gradient_checkpointing", False)
- if is_gc:
- unwrapped_model.gradient_checkpointing_disable()
- if (
- getattr(accelerator, "state", None) is not None
- and getattr(accelerator.state, "deepspeed_plugin", None) is not None
- and accelerator.state.deepspeed_plugin.zero_stage == 3
- ):
- if not gather_deepspeed3_params:
- yield accelerator.unwrap_model(model)
- else:
- import deepspeed
-
- with deepspeed.zero.GatheredParameters(model.parameters()):
- yield accelerator.unwrap_model(model)
- else:
- yield unwrapped_model
- if is_gc:
- unwrapped_model.gradient_checkpointing_enable()
+ from trl.models.utils import unwrap_model_for_generation
from contextlib import contextmanager
@contextmanager
@@ -139,13 +65,12 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs):
# We must use .clone for Unsloth since we force inference_mode
# Rather we should have used no_grad
original_generate = unwrapped_model.generate
-
def generate_with_clone(*args, **kwargs):
out = original_generate(*args, **kwargs)
if isinstance(out, torch.Tensor):
return out.clone()
return out
-
+ pass
unwrapped_model.generate = generate_with_clone
try:
@@ -154,144 +79,26 @@ def generate_with_clone(*args, **kwargs):
# Restore generate and return
unwrapped_model.generate = original_generate
FastLanguageModel.for_training(model)
-
- from transformers import Trainer
- from transformers.trainer_pt_utils import nested_detach
-
- @torch.no_grad()
- def unsloth_prediction_step(
- self,
- model,
- inputs,
- prediction_loss_only,
- ignore_keys,
- ):
- """
- Perform an evaluation step on `model` using `inputs`.
- Subclass and override to inject custom behavior.
- Args:
- model (`nn.Module`):
- The model to evaluate.
- inputs (`Dict[str, Union[torch.Tensor, Any]]`):
- The inputs and targets of the model.
- The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
- argument `labels`. Check your model's documentation for all accepted arguments.
- prediction_loss_only (`bool`):
- Whether or not to return the loss only.
- ignore_keys (`List[str]`, *optional*):
- A list of keys in the output of your model (if it is a dictionary) that should be ignored when
- gathering predictions.
- Return:
- Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
- logits and labels (each being optional).
- """
- has_labels = (
- False
- if len(self.label_names) == 0
- else all(inputs.get(k) is not None for k in self.label_names)
- )
- # For CLIP-like models capable of returning loss values.
- # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
- # is `True` in `model.forward`.
- return_loss = inputs.get("return_loss", None)
- if return_loss is None:
- return_loss = self.can_return_loss
- loss_without_labels = (
- True if len(self.label_names) == 0 and return_loss else False
- )
-
- inputs = self._prepare_inputs(inputs)
- if ignore_keys is None:
- if hasattr(self.model, "config"):
- ignore_keys = getattr(
- self.model.config, "keys_to_ignore_at_inference", []
- )
- else:
- ignore_keys = []
-
- # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
- if has_labels or loss_without_labels:
- labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
- if len(labels) == 1:
- labels = labels[0]
- else:
- labels = None
-
- os.environ["UNSLOTH_RETURN_LOGITS"] = "1"
- with torch.no_grad():
- if has_labels or loss_without_labels:
- with self.compute_loss_context_manager():
- loss, outputs = self.compute_loss(
- model, inputs, return_outputs = True
- )
- loss = loss.mean().detach()
-
- if isinstance(outputs, dict):
- logits = tuple(
- v for k, v in outputs.items() if k not in ignore_keys + ["loss"]
- )
- else:
- logits = outputs[1:]
- else:
- loss = None
- with self.compute_loss_context_manager():
- tokenized_output = self.processing_class(
- inputs["prompt"],
- padding = True,
- truncation = True,
- return_tensors = "pt",
- ).to(model.device)
- outputs = model(**tokenized_output)
- if isinstance(outputs, dict):
- logits = tuple(
- v for k, v in outputs.items() if k not in ignore_keys
- )
- else:
- logits = outputs
- # TODO: this needs to be fixed and made cleaner later.
- if self.args.past_index >= 0:
- self._past = outputs[self.args.past_index - 1]
- os.environ["UNSLOTH_RETURN_LOGITS"] = "0"
- if prediction_loss_only:
- return (loss, None, None)
-
- logits = nested_detach(logits)
- if len(logits) == 1:
- logits = logits[0]
-
- return (loss, logits, labels)
+ pass
+ pass
+ pass
import trl.trainer
-
trainers = dir(trl.trainer)
trainers = [x for x in trainers if x.endswith("_trainer")]
unwrap = "unwrap_model_for_generation"
for trainer in trainers:
- try:
- current_trainer = getattr(trl.trainer, trainer)
- except:
- continue
+ try: current_trainer = eval(f"trl.trainer.{trainer}")
+ except: continue
if hasattr(current_trainer, unwrap):
- try:
- setattr(current_trainer, unwrap, unsloth_unwrap_model_for_generation)
- except:
- continue
- Trainer.prediction_step = unsloth_prediction_step
-
+ try: exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}")
+ except: continue
+ pass
+pass
-grpo_selective_log_softmax = RL_REPLACEMENTS["grpo_selective_log_softmax"]
-selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"]
-calculate_pad_tokens_in_prompt = RL_REPLACEMENTS["calculate_pad_tokens_in_prompt"]
-create_completion_attention_mask = RL_REPLACEMENTS["create_completion_attention_mask"]
-left_pack_padding = RL_REPLACEMENTS["left_pack_padding"]
-align_logprobs_with_mask = RL_REPLACEMENTS["align_logprobs_with_mask"]
-autotune_batch_and_chunks = RL_REPLACEMENTS["grpo_autotune_batch_and_chunks"]
-sanitize_logprob = RL_REPLACEMENTS["sanitize_logprob"]
RLTrainer_replacement = '''
import os
-import math
-import logging
from typing import *
from dataclasses import dataclass, field
from packaging.version import Version
@@ -299,51 +106,7 @@ def unsloth_prediction_step(
import numpy as np
from contextlib import nullcontext
from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
+from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
torch_compile_options = {{
"epilogue_fusion" : True,
@@ -353,15 +116,7 @@ def wrapper(self, *args, **kwargs):
"triton.cudagraphs" : False,
}}
-{grpo_selective_log_softmax_code}
{selective_log_softmax_code}
-{calculate_pad_tokens_in_prompt_code}
-{create_completion_attention_mask_code}
-{left_pack_padding_code}
-{align_logprobs_with_mask_code}
-{autotune_batch_and_chunks_code}
-{sanitize_logprob_code}
-
{RL_pre}
@dataclass
@@ -377,38 +132,15 @@ class Unsloth{RLConfig_name}({RLConfig_name}):
default = -1,
metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}},
)
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {{'help': 'Multiplier for chunked logit computations.'}},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {{'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}},
- )
- {max_seq_length_pre}
def __init__({RLConfig_arguments},
vllm_sampling_params = None,
unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- {max_seq_length_call}
**kwargs,
):
{RLConfig_extra_args}
super().__init__({RLConfig_call_args}{RLConfig_kwargs})
self.vllm_sampling_params = vllm_sampling_params
self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- {max_seq_length_post}
-{RLConfig_post}
pass
{RLTrainer_extras}
@@ -422,238 +154,43 @@ def __init__({RLTrainer_arguments},
):
if args is None: args = Unsloth{RLConfig_name}()
{RLTrainer_extra_args}
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
super().__init__({RLTrainer_call_args}{RLTrainer_kwargs})
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
{RLTrainer_post}
pass
'''
-
-def _wrap_grpo_generate_and_score(trainer_cls):
- if not hasattr(trainer_cls, "_generate_and_score_completions"):
- return
- original = trainer_cls._generate_and_score_completions
- if getattr(original, "_unsloth_restore_training_wrapped", False):
- return
-
- def wrapped(self, *args, **kwargs):
- was_training = getattr(getattr(self, "model", None), "training", None)
- try:
- return original(self, *args, **kwargs)
- finally:
- if (
- was_training is False
- and hasattr(self, "model")
- and hasattr(self.model, "for_inference")
- ):
- try:
- self.model.for_inference()
- except Exception:
- pass
-
- wrapped._unsloth_restore_training_wrapped = True
- trainer_cls._generate_and_score_completions = wrapped
-
-
def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
# Patch for vLLM and Unsloth PEFT
import trl
import trl.trainer
-
try:
trainer = eval(f"trl.trainer.{trainer_file}")
except Exception as error:
- logger.info(f"Unsloth: Could not import trl.trainer.{trainer_file}: {error}")
return
-
+
# Get SFTTrainer and SFTConfig names
- name = [
- x
- for x in dir(trainer)
- if x.endswith("Trainer")
- and x != "Trainer"
- and not x.startswith("_")
- and trainer_file.split("_")[0] in x.lower()
- ]
- config = [
- x
- for x in dir(trainer)
- if x.endswith("Config")
- and x != "Config"
- and not x.startswith("_")
- and trainer_file.split("_")[0] in x.lower()
- ]
- if len(name) != 1:
- logger.info(
- f"Unsloth: Could not find Trainer class in trl.trainer.{trainer_file}. Found: {name}"
- )
- return
- if len(config) != 1:
- # TRL 0.26+: Config may be in a separate *_config.py module
- config_module_name = trainer_file.replace("_trainer", "_config")
- try:
- config_mod = eval(f"trl.trainer.{config_module_name}")
- config = [
- x
- for x in dir(config_mod)
- if x.endswith("Config")
- and x != "Config"
- and not x.startswith("_")
- and trainer_file.split("_")[0] in x.lower()
- ]
- except Exception:
- pass
- if len(config) != 1 and len(name) == 1:
- # Thin wrapper fallback: walk the Trainer's MRO to find Config
- # in the real implementation module (e.g., trl.experimental.bco)
- try:
- _temp_cls = eval(f"trl.trainer.{trainer_file}.{name[0]}")
- for _parent in _temp_cls.__mro__[1:]:
- if _parent is object:
- continue
- _parent_mod = inspect.getmodule(_parent)
- if (
- _parent_mod is None
- or _parent_mod.__name__ == f"trl.trainer.{trainer_file}"
- ):
- continue
- config = [
- x
- for x in dir(_parent_mod)
- if x.endswith("Config")
- and x != "Config"
- and not x.startswith("_")
- and trainer_file.split("_")[0] in x.lower()
- ]
- if len(config) == 1:
- break
- except Exception:
- pass
- if len(config) != 1:
- logger.info(
- f"Unsloth: Could not find Config class in trl.trainer.{trainer_file}. Found: {config}"
- )
- return
+ name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()]
+ config = [x for x in dir(trainer) if x.endswith("Config") and x != "Config" and trainer_file.split("_")[0] in x.lower()]
+ if len(name) != 1: return
+ if len(config) != 1: return
# Get SFTTrainer, SFTConfig
RLTrainer_name = name[0]
- RLConfig_name = config[0]
- try:
- RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}")
- except Exception as e:
- logger.info(
- f"Unsloth: Could not load {RLTrainer_name} from trl.trainer.{trainer_file}: {e}"
- )
- return
- _config_resolved_module = None
- try:
- RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}")
- except Exception:
- # TRL 0.26+: Config may be in a separate *_config.py module
- try:
- config_module_name = trainer_file.replace("_trainer", "_config")
- RLConfig = eval(f"trl.trainer.{config_module_name}.{RLConfig_name}")
- except Exception:
- # Thin wrapper fallback: load Config from parent trainer's module
- _config_loaded = False
- try:
- _temp_cls = eval(f"trl.trainer.{trainer_file}.{name[0]}")
- for _parent in _temp_cls.__mro__[1:]:
- if _parent is object:
- continue
- _parent_mod = inspect.getmodule(_parent)
- if (
- _parent_mod is None
- or _parent_mod.__name__ == f"trl.trainer.{trainer_file}"
- ):
- continue
- if hasattr(_parent_mod, RLConfig_name):
- RLConfig = getattr(_parent_mod, RLConfig_name)
- _config_resolved_module = _parent_mod
- _config_loaded = True
- break
- except Exception:
- pass
- if not _config_loaded:
- logger.info(f"Unsloth: Could not load {RLConfig_name}")
- return
+ RLConfig_name = config[0]
+ try: RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}")
+ except: return
+ try: RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}" )
+ except: return
# Check name
- if RLTrainer.__name__.startswith("Unsloth"):
- print(f"Unsloth: {RLTrainer.__name__} is already patched.")
- return
- if RLConfig.__name__.startswith("Unsloth"):
- print(f"Unsloth: {RLConfig.__name__} is already patched.")
- return
-
- # TRL 0.26+: Resolve thin wrappers to their experimental parent class.
- # Thin wrappers are deprecation shims in trl.trainer that just forward
- # *args/**kwargs to the real implementation in trl.experimental.
- # Only resolve if a parent class actually lives in a trl.experimental module.
- _trainer_resolved_module = None
- try:
- _trainer_src = inspect.getsource(RLTrainer)
- _trainer_module = inspect.getmodule(RLTrainer)
- _trainer_module_src = (
- inspect.getsource(_trainer_module) if _trainer_module else ""
- )
- if (
- "trl.experimental" in _trainer_src
- or "trl.experimental" in _trainer_module_src
- ):
- for _parent in RLTrainer.__mro__[1:]:
- if _parent is object:
- continue
- _parent_mod = inspect.getmodule(_parent)
- if _parent_mod is None:
- continue
- # Only resolve to a parent that lives in trl.experimental
- if "trl.experimental" in _parent_mod.__name__:
- RLTrainer = _parent
- _trainer_resolved_module = _parent_mod
- break
- except Exception:
- pass
-
- try:
- _config_src = inspect.getsource(RLConfig)
- _config_module = inspect.getmodule(RLConfig)
- _config_module_src = inspect.getsource(_config_module) if _config_module else ""
- if (
- "trl.experimental" in _config_src
- or "trl.experimental" in _config_module_src
- ):
- for _parent in RLConfig.__mro__[1:]:
- if _parent is object:
- continue
- _parent_mod = inspect.getmodule(_parent)
- if _parent_mod is None:
- continue
- # Only resolve to a parent that lives in trl.experimental
- if "trl.experimental" in _parent_mod.__name__:
- RLConfig = _parent
- break
- except Exception:
- pass
+ if RLTrainer.__name__.startswith("Unsloth"): return
+ if RLConfig .__name__.startswith("Unsloth"): return
# Get old source
old_RLTrainer_source = inspect.getsource(RLTrainer)
- old_RLConfig_source = inspect.getsource(RLConfig)
+ old_RLConfig_source = inspect.getsource(RLConfig)
- if _trainer_resolved_module is not None:
- all_imports = dir(_trainer_resolved_module)
- elif _config_resolved_module is not None:
- all_imports = dir(_config_resolved_module)
- else:
- all_imports = dir(trainer)
+ all_imports = dir(trainer)
# Fix _deprecate_arguments not getting imported so stop __ but not _
imports = [x for x in all_imports if not x.startswith("__")]
@@ -662,38 +199,23 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
processed = []
for RLobject in [RLTrainer, RLConfig]:
parameters = inspect.signature(RLobject.__init__).parameters
- types = (
- bool,
- type(None),
- int,
- float,
- str,
- )
+ types = (bool, type(None), int, float, str,)
arguments = ["self"]
call_args = []
for k, v in parameters.items():
- if k == "self":
- continue
+ if k == "self": continue
v = v.default
- if v == "\n":
- v = re.escape("\n")
- if v is EMPTY:
- arguments.append(k)
- elif type(v) is str:
- arguments.append(f"{k} = '{v}'")
- elif type(v) in types:
- arguments.append(f"{k} = {v}")
- else:
- continue
+ if v == "\n": v = re.escape("\n")
+ if v is EMPTY: arguments.append(k)
+ elif type(v) is str: arguments.append(f"{k} = '{v}'")
+ elif type(v) in types: arguments.append(f"{k} = {v}")
+ else: continue
call_args.append(f"{k} = {k}")
+ pass
arguments = f"\n{' '*8}" + f",\n{' '*8}".join(arguments)
call_args = f"\n{' '*12}" + f",\n{' '*12}".join(call_args)
- processed.append(
- (
- arguments,
- call_args,
- )
- )
+ processed.append((arguments, call_args,))
+ pass
# Process RLTrainer first
arguments, call_args = processed[0]
@@ -706,304 +228,197 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
"processing_class = processing_class",
"processing_class = tokenizer if tokenizer is not None else processing_class",
)
+ pass
- # Edit bf16, fp16 by checking model's dtype/torch_dtype directly
+ # Edit bf16, fp16 by checking model's torch_dtype directly
extra_args = ""
if "args" in call_args and "model" in call_args:
- mixed_precision = (
- "use_bf16 = getattr(args, 'bf16', False)\n"
- "if type(use_bf16) is not bool: use_bf16 = False\n"
- "use_fp16 = getattr(args, 'fp16', False)\n"
- "if type(use_fp16) is not bool: use_fp16 = False\n"
- "force_float32 = False\n"
- "full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'\n"
- "if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):\n"
- " print('Unsloth: Switching to float32 training since model cannot work with float16')\n"
- " force_float32 = True\n"
- "mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')\n"
- "dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)\n"
- "if dtype is None: dtype = model.get_input_embeddings().weight.dtype\n"
- "from unsloth_zoo.utils import _get_dtype\n"
- "dtype = _get_dtype(dtype)\n"
- "float16 = dtype == torch.float16\n"
- "if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"
- "if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"
- "if force_float32:\n"
- " # Forced float32 training\n"
- " args.fp16 = False\n"
- " args.bf16 = False\n"
- " os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n"
- " if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'\n"
- " # args.mixed_precision is a new argument which needs to be set now\n"
- "elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\n"
- " # Mixed precision training\n"
- " args.fp16 = float16\n"
- " args.bf16 = not float16\n"
- " os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'\n"
- " if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'\n"
- " # args.mixed_precision is a new argument which needs to be set now\n"
- "elif mixed_precision_dtype == 'bfloat16':\n"
- " # Both False since bfloat16 full finetuning doesn't do any autocasting.\n"
- " args.fp16 = False\n"
- " args.bf16 = False\n"
- " os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n"
- " if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'\n"
- " # args.mixed_precision is a new argument which needs to be set now\n"
- "\n"
- )
+ mixed_precision = \
+ "use_bf16 = getattr(args, 'bf16', False)\n"\
+ "use_fp16 = getattr(args, 'fp16', False)\n"\
+ "force_float32 = False\n"\
+ "if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':\n"\
+ " print('Unsloth: Switching to float32 training since model cannot work with float16')\n"\
+ " force_float32 = True\n"\
+ "mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')\n"\
+ "dtype = getattr(model.config, 'torch_dtype', None)\n"\
+ "if dtype is None: dtype = model.get_input_embeddings().dtype\n"\
+ "from unsloth_zoo.utils import _get_dtype\n"\
+ "dtype = _get_dtype(dtype)\n"\
+ "float16 = dtype == torch.float16\n"\
+ "if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\
+ "if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\
+ "if force_float32:\n"\
+ " args.fp16 = False\n"\
+ " args.bf16 = False\n"\
+ " os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n"\
+ "elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\n"\
+ " args.fp16 = float16\n"\
+ " args.bf16 = not float16\n"\
+ " os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'\n"
+ "elif mixed_precision_dtype == 'bfloat16':\n"\
+ " args.fp16 = False\n"\
+ " args.bf16 = False\n"\
+ " os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n"
extra_args += mixed_precision
+ pass
# Check if per_device_eval_batch_size (default 8) bigger than bsz
# Also use FP16 / BF16 evaluation
if "args" in call_args:
# Check eval_dataset first
if "eval_dataset" in call_args:
- check_eval_dataset = (
- "if getattr(args, 'eval_dataset', None) is not None and "
- "getattr(args, 'eval_strategy', 'no') == 'no':\n"
- " args.eval_strategy = 'steps'\n"
- " if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1\n"
- )
+ check_eval_dataset = \
+ "if getattr(args, 'eval_dataset', None) is not None and "\
+ "getattr(args, 'eval_strategy', 'no') == 'no':\n"\
+ " args.eval_strategy = 'steps'\n"\
+ " if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1\n"
extra_args += check_eval_dataset
+ pass
# Check if gradient accumulation bug fix is applied
- check_ga = (
- "ga_steps = getattr(args, 'gradient_accumulation_steps', None)\n"
- "if ga_steps is not None and ga_steps > 1:\n"
- " from transformers import __version__ as transformers_version\n"
- " if Version(transformers_version) <= Version('4.45.2'):\n"
- " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\n"
- " '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')\n"
- )
+ check_ga = \
+ "ga_steps = getattr(args, 'gradient_accumulation_steps', None)\n"\
+ "if ga_steps is not None and ga_steps > 1:\n"\
+ " from transformers import __version__ as transformers_version\n"\
+ " if Version(transformers_version) <= Version('4.45.2'):\n"\
+ " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\n"\
+ " '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')\n"
extra_args += check_ga
- eval_changes = (
- "if getattr(args, 'eval_strategy', 'no') != 'no':\n"
- " eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)\n"
- " if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size\n"
- " if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps\n"
- "fp16_full_eval = getattr(args, 'fp16_full_eval', False)\n"
- "if type(fp16_full_eval) is not bool: fp16_full_eval = False\n"
- "bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n"
- "if type(bf16_full_eval) is not bool: bf16_full_eval = False\n"
- "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"
- "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"
- "if force_float32:\n"
- " args.bf16_full_eval = False\n"
- " args.fp16_full_eval = False\n"
- "elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\n"
- " args.bf16_full_eval = True\n"
- " args.fp16_full_eval = False\n"
- "elif not bf16_full_eval and not fp16_full_eval:\n"
- " args.bf16_full_eval = args.bf16\n"
- " args.fp16_full_eval = args.fp16\n"
- )
+ eval_changes = \
+ "if getattr(args, 'eval_strategy', 'no') != 'no':\n"\
+ " eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)\n"\
+ " if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size\n"\
+ " if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps\n"\
+ "fp16_full_eval = getattr(args, 'fp16_full_eval', False)\n"\
+ "bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n"\
+ "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"\
+ "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"\
+ "if force_float32:\n"\
+ " args.bf16_full_eval = False\n"\
+ " args.fp16_full_eval = False\n"\
+ "elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\n"\
+ " args.bf16_full_eval = True\n"\
+ " args.fp16_full_eval = False\n"\
+ "elif not bf16_full_eval and not fp16_full_eval:\n"\
+ " args.bf16_full_eval = args.bf16\n"\
+ " args.fp16_full_eval = args.fp16\n"
extra_args += eval_changes
+ pass
# Force logits to be produced if preprocess_logits_for_metrics or compute_metrics is used
if "model" in call_args:
- logits_check = (
- "_output_logits = False\n"
- "if locals().get('compute_metrics', None) is not None: _output_logits = True\n"
- "if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True\n"
- "if _output_logits:\n"
- " os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"
- )
+ logits_check = \
+ "_output_logits = False\n"\
+ "if locals().get('compute_metrics', None) is not None: _output_logits = True\n"\
+ "if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True\n"\
+ "if _output_logits:\n"\
+ " os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"
extra_args += logits_check
- warnings_issued_check = (
- "if model is not None:\n"
- " _warnings_issued = getattr(model, 'warnings_issued', None)\n"
- " if _warnings_issued is None:\n"
- " model.warnings_issued = {}\n"
- " elif not isinstance(_warnings_issued, dict):\n"
- " try:\n"
- " model.warnings_issued = dict(_warnings_issued)\n"
- " except Exception:\n"
- " model.warnings_issued = {}\n"
- )
- extra_args += warnings_issued_check
+ pass
# Check max_seq_length
if "model" in call_args:
- length_check = (
- "if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):\n"
- " pass\n"
- "else:\n"
- " model_max_seq_length = getattr(model, 'max_seq_length', None)\n"
- " args_max_seq_length = getattr(args, 'max_seq_length', None)\n"
- " if args_max_seq_length is None and model_max_seq_length is not None:\n"
- " max_seq_length = model.max_seq_length\n"
- " if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length\n"
- " elif args_max_seq_length is not None and model_max_seq_length is not None:\n"
- " if args_max_seq_length > model_max_seq_length:\n"
- " print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '\n"
- " 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')\n"
- " args.max_seq_length = model_max_seq_length\n"
- )
+ length_check = \
+ "if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):\n"\
+ " pass\n"\
+ "else:\n"\
+ " model_max_seq_length = getattr(model, 'max_seq_length', None)\n"\
+ " args_max_seq_length = getattr(args, 'max_seq_length', None)\n"\
+ " if args_max_seq_length is None and model_max_seq_length is not None:\n"\
+ " max_seq_length = model.max_seq_length\n"\
+ " if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length\n"
+ " elif args_max_seq_length is not None and model_max_seq_length is not None:\n"\
+ " if args_max_seq_length > model_max_seq_length:\n"\
+ " print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but \n"\
+ " the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')\n"\
+ " args.max_seq_length = model_max_seq_length\n"
extra_args += length_check
-
- # At this point max_seq_length might be set, but trl is moving to max_length
- if trainer_file == "sft_trainer":
- max_length_check = (
- "if 'max_length' not in locals() and not hasattr(args, 'max_length'):\n"
- " pass\n"
- "else:\n"
- " if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0:\n"
- " if hasattr(args, 'max_length'):\n"
- " args.max_length = args.max_seq_length\n"
- " max_length = args.max_length\n"
- " else:\n"
- " model_max_length = getattr(model, 'max_seq_length', None)\n"
- " if model_max_length is None: model_max_length = getattr(model, 'max_length', None)\n"
- " if model_max_length is not None:\n"
- " args.max_length = model_max_length\n"
- " max_length = args.max_length\n"
- " elif hasattr(args, 'max_length') and args.max_length is not None:\n"
- " max_length = args.max_length\n"
- " # if we are here, then we are in a weird case where max_length is set but max_seq_length is not set\n"
- " setattr(model, 'max_seq_length', max_length)\n"
- " else:\n"
- " print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.')\n"
- " args.max_length = 1024\n"
- )
- extra_args += max_length_check
+ pass
# Enable for training and move padding side of tokenizer to right
if "model" in call_args:
- training_check = (
- "if model is not None and hasattr(model, 'for_training'):\n"
- " model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))\n"
- "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"
- "if 'processing_class' in locals():\n"
- " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"
- " if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): "
- "processing_class.tokenizer.padding_side = 'right'\n"
- )
+ training_check = \
+ "if model is not None and hasattr(model, 'for_training'):\n"\
+ " model.for_training()\n"\
+ "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\
+ "if 'processing_class' in locals():\n"\
+ " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\
+ " if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): "\
+ "processing_class.tokenizer.padding_side = 'right'\n"
extra_args += training_check
+ pass
# Check data collator if it's correct!
if "data_collator" in call_args and "train_dataset" in call_args:
- data_collator_check = (
- "__tokenizer = processing_class if 'processing_class' in locals() else tokenizer\n"
- "from unsloth_zoo.vision_utils import UnslothVisionDataCollator\n"
- "if not isinstance(data_collator, UnslothVisionDataCollator):\n"
- " if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"
- " data_collator = TransformersDataCollatorForLanguageModeling(\n"
- " __tokenizer,\n"
- " mlm = False,\n"
- " mlm_probability = 0.0,\n"
- " pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"
- " )\n"
- " elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"
- " data_collator = DataCollatorForSeq2Seq(\n"
- " __tokenizer,\n"
- " pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"
- " )\n"
- "else:\n"
- " if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False\n"
- " if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''\n"
- " if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}\n"
- )
+ data_collator_check = \
+ "__tokenizer = processing_class if 'processing_class' in locals() else tokenizer\n"\
+ "from unsloth_zoo.vision_utils import UnslothVisionDataCollator\n"\
+ "if not isinstance(data_collator, UnslothVisionDataCollator):\n"\
+ " if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\
+ " data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)\n"\
+ " elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\
+ " data_collator = DataCollatorForSeq2Seq(__tokenizer)\n"\
+ "else:\n"\
+ " if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False\n"\
+ " if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''\n"\
+ " if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}\n"
extra_args += data_collator_check
# Also check if .pad exists -> if not, and is VLM, then change it!
- pad_check = (
- "if not isinstance(data_collator, UnslothVisionDataCollator):\n"
- " if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):\n"
- " if isinstance(data_collator, DataCollatorForSeq2Seq):\n"
- " data_collator = DataCollatorForSeq2Seq(\n"
- " __tokenizer.tokenizer,\n"
- " pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"
- " )\n"
- " else:\n"
- " data_collator = TransformersDataCollatorForLanguageModeling(\n"
- " __tokenizer.tokenizer,\n"
- " mlm = False,\n"
- " mlm_probability = 0.0,\n"
- " pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"
- " )\n"
- )
+ pad_check = \
+ "if not isinstance(data_collator, UnslothVisionDataCollator):\n"\
+ " if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):\n"\
+ " if isinstance(data_collator, DataCollatorForSeq2Seq):\n"\
+ " data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)\n"\
+ " else:\n"\
+ " data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)\n"
extra_args += pad_check
+ pass
# Check NEFTune
if "model" in call_args:
- neftune_check = (
- "if hasattr(self, 'neftune_hook_handle'):\n"
- " self.neftune_hook_handle.remove()\n"
- " if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"
- "if getattr(args, 'neftune_noise_alpha', None) is not None:\n"
- " model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"
- "pass\n"
- )
+ neftune_check = \
+ "if hasattr(self, 'neftune_hook_handle'):\n"\
+ " self.neftune_hook_handle.remove()\n"\
+ " if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"\
+ "if getattr(args, 'neftune_noise_alpha', None) is not None:\n"\
+ " model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"\
+ "pass\n"
RLTrainer_post += neftune_check
-
- # Add accelerator scaler to model
- if "model" in call_args:
- accelerator_check = (
- "if hasattr(self, 'accelerator'):\n"
- " scaler = self.accelerator.scaler\n"
- " current_model = model\n"
- " while hasattr(current_model, 'model'):\n"
- " current_model.accelerator_scaler = scaler\n"
- " current_model = current_model.model\n"
- " current_model.accelerator_scaler = scaler\n"
- "pass\n"
- )
- RLTrainer_post += accelerator_check
-
- # Add enabling and disabling training modes
- if "model" in call_args:
- training_check = (
- "if hasattr(self, 'train'):\n"
- " self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)\n"
- "pass\n"
- )
- RLTrainer_post += training_check
-
- # Sync chat_template from processing_class to vLLM's tokenizer
- # This fixes base models that have custom chat templates applied after loading
- if "model" in call_args:
- vllm_chat_template_sync = (
- "if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):\n"
- " _vllm_tok = self.llm.get_tokenizer()\n"
- " _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)\n"
- " if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:\n"
- " _vllm_tok.chat_template = _pc.chat_template\n"
- "pass\n"
- )
- RLTrainer_post += vllm_chat_template_sync
+ pass
# Edit optional metrics
other_metrics_processor = ""
if trainer_file in RL_METRICS_CHANGES:
process_extra_args = RL_METRICS_CHANGES[trainer_file]
for process_extra_arg in process_extra_args:
- other_metrics_processor += process_extra_arg(
- old_RLTrainer_source, old_RLConfig_source
- )
+ other_metrics_processor += process_extra_arg(call_args, extra_args)
+ pass
# Add statistics as well!
- extra_args += (
- "other_metrics = []\n"
- f"{other_metrics_processor}\n"
- "from unsloth_zoo.logging_utils import PatchRLStatistics\n"
+ extra_args += \
+ "other_metrics = []\n"\
+ f"{other_metrics_processor}\n"\
+ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\
f"PatchRLStatistics('{trainer_file}', other_metrics)\n"
- )
# Patch optional args
if trainer_file in RL_EXTRA_ARGS:
process_extra_args = RL_EXTRA_ARGS[trainer_file]
for process_extra_arg in process_extra_args:
extra_args += process_extra_arg(call_args, extra_args)
+ pass
# Create RLTrainer args
extra_args = extra_args.split("\n")
- extra_args = "\n".join(" " * 8 + x for x in extra_args)
+ extra_args = "\n".join(" "*8 + x for x in extra_args)
RLTrainer_post = RLTrainer_post.split("\n")
- RLTrainer_post = "\n".join(" " * 8 + x for x in RLTrainer_post)
- RLTrainer_arguments = arguments
+ RLTrainer_post = "\n".join(" "*8 + x for x in RLTrainer_post)
+ RLTrainer_arguments = arguments
RLTrainer_extra_args = extra_args
- RLTrainer_call_args = call_args
+ RLTrainer_call_args = call_args
# Fix RLConfig next
arguments, call_args = processed[1]
@@ -1011,596 +426,161 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
# Edit GA / bsz and weight_decay
replacements = {
- "output_dir": None,
- "logging_nan_inf_filter": False,
- "per_device_train_batch_size": 4,
- "gradient_accumulation_steps": 2,
- "weight_decay": 0.01,
- "seed": 3407,
- "optim": "adamw_8bit",
- "learning_rate": 5e-05,
- "per_device_eval_batch_size": 4,
- "eval_accumulation_steps": 2,
- "torch_empty_cache_steps": 250,
- "logging_steps": 1,
- "max_seq_length": None,
- "num_generations": 8,
- # "steps_per_generation" : 1, # Otherwise defaults to ga_steps which is wrong
- # "generation_batch_size" : None, # Useless. If steps_per_generation set, generation_batch_size clashes
- "top_k": None,
- "vllm_mode": "colocate",
- "generation_kwargs": {},
- "bf16": False,
- "fp16": False,
- "report_to": "none",
- "include_tokens_per_second": False,
- "include_num_input_tokens_seen": False,
- "auto_find_batch_size": False, # Auto /2 batch size - too many people complained so removing
- "dataloader_pin_memory": True,
- "padding_free": None, # None = user didn't set it, allows auto-enable detection
- # Might fail so disable for now
- # "dataloader_persistent_workers" : True, # Keeps dataloader in RAM
- # "dataloader_prefetch_factor" : 2,
- # "dataloader_num_workers" : 2, # Default is 0 means 1
+ "output_dir" : None,
+ "logging_nan_inf_filter" : False,
+ "per_device_train_batch_size" : 4,
+ "gradient_accumulation_steps" : 2,
+ "weight_decay" : 0.01,
+ "warmup_ratio" : 0.1,
+ "seed" : 3407,
+ "optim" : "adamw_8bit",
+ "learning_rate" : 5e-05,
+ "per_device_eval_batch_size" : 4,
+ "eval_accumulation_steps" : 2,
+ "torch_empty_cache_steps" : 250,
+ "logging_steps" : 1,
}
- # warmup_ratio deprecated in transformers >= 5.0; warmup_steps accepts float
- if transformers_version >= Version("5.0.0"):
- replacements["warmup_steps"] = 0.1
- else:
- replacements["warmup_ratio"] = 0.1
-
for k, v in replacements.items():
x = f"{k}( = [^,\n]{{1,}})?,\n"
y = f"'{v}'" if type(v) is str else f"{v}"
y = f"{k} = {y},\n"
arguments = re.sub(x, y, arguments)
-
- # Fix GRPO beta default as 0.001 TRL used to be 0.04, now 0.00!
- # https://github.com/huggingface/trl/pull/3516
- # https://verl.readthedocs.io/en/latest/examples/config.html
- if trainer_file == "grpo_trainer":
- replacements = {
- "loss_type": "bnpo", # Default GRPO paper
- "beta": 0.001, # Recommended as seen in verl
- "auto_find_batch_size": False, # Cannot work on GRPO
- # [TODO] See https://fengyao.notion.site/off-policy-rl
- # https://github.com/huggingface/trl/pull/3867 (August 7th)
- "vllm_importance_sampling_correction": False,
- }
- for k, v in replacements.items():
- x = f"{k}( = [^,\n]{{1,}})?,\n"
- y = f"'{v}'" if type(v) is str else f"{v}"
- y = f"{k} = {y},\n"
- arguments = re.sub(x, y, arguments)
+ pass
# Warn on too large or too small learning rate
- if "learning_rate" in call_args:
- learning_rate_check = (
- "if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! "
- "Consider increasing it, otherwise gradient updates will be close to 0!')\n"
- "if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! "
- "Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n"
- )
+ if " learning_rate" in call_args:
+ learning_rate_check = \
+ "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! "\
+ "Consider increasing it, otherwise gradient updates will be close to 0!')\n"\
+ "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! "\
+ "Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n"
extra_args += learning_rate_check
-
- # Fix num_train_epochs = None causing TypeError in Trainer.__init__
- # Trainer does `args.num_train_epochs > 0` which fails when None
- if "num_train_epochs" in call_args:
- num_train_epochs_check = (
- "if num_train_epochs is None:\n"
- " num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override\n"
- )
- extra_args += num_train_epochs_check
-
- # Check if max_seq_length is NOT defined (max_length is now default)
- if "max_seq_length" not in call_args and "max_length" in call_args:
- max_seq_length_pre = """max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )"""
- max_seq_length_call = "max_seq_length = None,"
- max_seq_length_post = "self.max_seq_length = max_seq_length"
- else:
- max_seq_length_pre = ""
- max_seq_length_call = ""
- max_seq_length_post = ""
+ pass
# Add output_dir saving
if "output_dir" in call_args:
# Default checks
- saving_check = (
- "if output_dir is None and save_strategy == 'steps' and save_steps == 500:\n"
- " output_dir = 'unsloth_training_checkpoints'\n"
- " save_strategy = 'no'\n"
- )
+ saving_check = \
+ "if output_dir is None and save_strategy == 'steps' and save_steps == 500:\n"\
+ " output_dir = 'unsloth_training_checkpoints'\n"\
+ " save_strategy = 'no'\n"
extra_args += saving_check
+ pass
# Edit dataset_num_proc
if "dataset_num_proc" in call_args:
- num_proc_check = (
- "import multiprocessing as _mp\n"
- "if _mp.get_start_method() != 'fork':\n"
- " dataset_num_proc = None\n"
- "elif dataset_num_proc is None:\n"
- " import psutil\n"
- " dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)\n"
- " memory_gb_left = psutil.virtual_memory().available / (1024**3)\n"
- " if memory_gb_left <= 2: dataset_num_proc = 1\n"
- " else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))\n"
- )
+ num_proc_check = \
+ "if dataset_num_proc is None:\n"\
+ " from multiprocessing import cpu_count\n"\
+ " dataset_num_proc = cpu_count()\n"
extra_args += num_proc_check
-
- # Add padding if flex attention is added
- if "pad_to_multiple_of" in call_args:
- pad_to_multiple_of = (
- "if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':\n"
- " from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION\n"
- " if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:\n"
- " from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE\n"
- " pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE\n"
- "\n"
- )
- extra_args += pad_to_multiple_of
-
- # Check for loss_type = dr_grpo and scale_rewards for GRPO
- if "loss_type" in call_args and "scale_rewards" in call_args:
- # See https://github.com/huggingface/trl/issues/3130#issuecomment-2746947835
- # DAPO uses per token loss so BNPO loss used
- check_dr_grpo = (
- "if loss_type.lower() == 'dr_grpo':\n"
- " loss_type = 'dr_grpo'\n"
- "elif loss_type.lower() == 'dapo':\n"
- " loss_type = 'dapo'\n"
- "if loss_type.lower() == 'dr_grpo':\n"
- " if scale_rewards == None:\n"
- " scale_rewards = True\n"
- " elif scale_rewards == True:\n"
- " print('Unsloth: The Dr GRPO paper recommends setting `scale_rewards` to False! Will override. Set it to `None` to force False.')\n"
- " scale_rewards = False\n"
- "elif loss_type.lower() == 'dapo':\n"
- " if mask_truncated_completions != True:\n"
- " print('Unsloth: The DAPO paper recommends `mask_truncated_completions = True` - we will set it.')\n"
- " if epsilon_high != 0.28:\n"
- " print('Unsloth: The DAPO paper recommends `epsilon_high = 0.28` - we will set it.')\n"
- " if beta != 0.0:\n"
- " print(f'[WARNING] Unsloth: The DAPO paper recommends setting `beta = 0.0` to remove the KL term - You have set it to {beta}.')\n"
- " mask_truncated_completions = True\n"
- " epsilon_high = 0.28\n"
- "\n"
- )
- extra_args += check_dr_grpo
-
- # Check GRPO num_generations mismatch
- if (
- "per_device_train_batch_size" in call_args
- and "num_generations" in call_args
- and "steps_per_generation" in call_args
- and "generation_batch_size" in call_args
- ):
- # if world size is not set by accelerate or torchrun at this point it will be 1
- check_num_generations = (
- "if steps_per_generation is None and generation_batch_size is None:\n"
- " ga = gradient_accumulation_steps\n"
- " world_size = int(os.environ.get('WORLD_SIZE', '1'))\n"
- " if (ga * world_size * per_device_train_batch_size) % num_generations != 0:\n"
- " print('Unsloth: We now expect `per_device_train_batch_size` * `gradient_accumulation_steps` * `world_size` to be a multiple of `num_generations`.\\n"
- "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"
- " per_device_train_batch_size = num_generations\n"
- "\n"
- )
- extra_args += check_num_generations
- elif "per_device_train_batch_size" in call_args and "num_generations" in call_args:
- if "steps_per_generation" not in call_args:
- print(f"Unsloth: Could not find `steps_per_generation` in {trainer_file}")
- if "generation_batch_size" not in call_args:
- print(f"Unsloth: Could not find `generation_batch_size` in {trainer_file}")
-
- check_num_generations = (
- "if (per_device_train_batch_size // num_generations) * num_generations != per_device_train_batch_size:\n"
- " print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"
- "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"
- " per_device_train_batch_size = num_generations\n"
- "\n"
- )
- extra_args += check_num_generations
-
- # Check temperature must not be <= 0. Also stop if >= 10
- if "temperature" in call_args:
- check_temperature = (
- "if temperature <= 0:\n"
- " raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')\n"
- "elif temperature >= 10:\n"
- " raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')\n"
- "\n"
- )
- extra_args += check_temperature
+ pass
# Edit config with anything extra
if trainer_file in RL_CONFIG_CHANGES:
process_extra_args = RL_CONFIG_CHANGES[trainer_file]
for process_extra_arg in process_extra_args:
extra_args += process_extra_arg(old_RLTrainer_source, old_RLConfig_source)
+ pass
+
+ # Edit report_to and default it to nothing if max_steps is like 60
# Create RLConfig args
extra_args = extra_args.split("\n")
- extra_args = "\n".join(" " * 8 + x for x in extra_args)
- RLConfig_arguments = arguments
+ extra_args = "\n".join(" "*8 + x for x in extra_args)
+ RLConfig_arguments = arguments
RLConfig_extra_args = extra_args
- RLConfig_call_args = call_args
-
- # TRL 0.27.0+ forces use_reentrant=False in gradient_checkpointing_kwargs.
- # Unsloth gradient checkpointing requires use_reentrant=True, so we remove
- # the setting after super().__init__() when it gets auto-applied.
- RLConfig_post = ""
- if trl_version >= Version("0.27.0") and RLConfig_name == "GRPOConfig":
- RLConfig_post = (
- " # Unsloth: Remove use_reentrant=False forced by TRL 0.27.0+\n"
- " if getattr(self, 'gradient_checkpointing_kwargs', None) is not None:\n"
- " if 'use_reentrant' in self.gradient_checkpointing_kwargs:\n"
- " del self.gradient_checkpointing_kwargs['use_reentrant']\n"
- )
+ RLConfig_call_args = call_args
# Patch vLLM and other functions
- RLTrainer_extras = patch_functions(
- RLTrainer, trainer_file, RLTrainer_name, all_imports, imports
- )
+ RLTrainer_extras = patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports)
if RLTrainer_extras is None:
RLTrainer_extras = f"_Unsloth{RLTrainer_name} = {RLTrainer_name}"
# Create full module
exec(f"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)")
__RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__
- if __RLTrainer_doc__ is None:
- __RLTrainer_doc__ = ""
- __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}").__doc__
- if __RLConfig_doc__ is None:
- __RLConfig_doc__ = ""
+ if __RLTrainer_doc__ is None: __RLTrainer_doc__ = ""
+ __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__
+ if __RLConfig_doc__ is None: __RLConfig_doc__ = ""
# Get all pre-modules
if trainer_file in RL_PRE_ITEMS:
RL_pre = "\n".join(RL_PRE_ITEMS[trainer_file])
else:
RL_pre = ""
+ pass
# Check if SamplingParams is in there
if "SamplingParams" in old_RLTrainer_source:
RL_pre = RL_pre + "\n" + inspect.getsource(vLLMSamplingParams)
-
- # Selective log softmax and other functions
+ pass
+
+ # Selective log softmax
selective_log_softmax_code = inspect.getsource(selective_log_softmax)
- grpo_selective_log_softmax_code = inspect.getsource(grpo_selective_log_softmax)
- calculate_pad_tokens_in_prompt_code = inspect.getsource(
- calculate_pad_tokens_in_prompt
- )
- create_completion_attention_mask_code = inspect.getsource(
- create_completion_attention_mask
- )
- left_pack_padding_code = inspect.getsource(left_pack_padding)
- align_logprobs_with_mask_code = inspect.getsource(align_logprobs_with_mask)
- autotune_batch_and_chunks_code = inspect.getsource(autotune_batch_and_chunks)
- sanitize_logprob_code = inspect.getsource(sanitize_logprob)
+
# Get final source code
RLTrainer_source = RLTrainer_replacement.format(
- RLTrainer_name = RLTrainer_name,
- __RLTrainer_doc__ = __RLTrainer_doc__,
- RLTrainer_arguments = RLTrainer_arguments,
+ RLTrainer_name = RLTrainer_name,
+ __RLTrainer_doc__ = __RLTrainer_doc__,
+ RLTrainer_arguments = RLTrainer_arguments,
RLTrainer_extra_args = RLTrainer_extra_args,
- RLTrainer_call_args = RLTrainer_call_args,
- RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args.endswith(",") else 0 :],
- RLConfig_name = RLConfig_name,
- __RLConfig_doc__ = __RLConfig_doc__,
- RLConfig_arguments = RLConfig_arguments,
- RLConfig_extra_args = RLConfig_extra_args,
- RLConfig_call_args = RLConfig_call_args,
- RLConfig_kwargs = ",**kwargs"[1 if RLConfig_call_args.endswith(",") else 0 :],
- RLConfig_post = RLConfig_post,
- RLTrainer_extras = RLTrainer_extras,
- RLTrainer_post = RLTrainer_post,
- RL_pre = RL_pre,
- max_seq_length_pre = max_seq_length_pre,
- max_seq_length_call = max_seq_length_call,
- max_seq_length_post = max_seq_length_post,
- selective_log_softmax_code = selective_log_softmax_code,
- grpo_selective_log_softmax_code = grpo_selective_log_softmax_code,
- calculate_pad_tokens_in_prompt_code = calculate_pad_tokens_in_prompt_code,
- create_completion_attention_mask_code = create_completion_attention_mask_code,
- autotune_batch_and_chunks_code = autotune_batch_and_chunks_code,
- left_pack_padding_code = left_pack_padding_code,
- align_logprobs_with_mask_code = align_logprobs_with_mask_code,
- sanitize_logprob_code = sanitize_logprob_code,
- )
-
- if RLTrainer_name == "GRPOTrainer":
- # Base torch_compile_options shared by all device types
- base_options = """torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,"""
-
- # Generate torch_compile_options based on device type
- if DEVICE_TYPE == "cuda":
- # CUDA-specific options (added to base options)
- cuda_options = """
- "triton.enable_persistent_tma_matmul": torch.cuda.get_device_capability()[0] >= 9,"""
- # cutlass options were added in PyTorch 2.8.0
- if torch_version >= Version("2.8.0"):
- cuda_options += """
- "cuda.cutlass_epilogue_fusion_enabled": torch.cuda.get_device_capability()[0] >= 9,
- "cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9,"""
- cuda_options += """
- "cuda.compile_opt_level" : "-O2",
- "cuda.enable_cuda_lto" : True,
- }"""
- new_options = base_options + cuda_options
- else:
- # XPU, HIP, and other device types use base options only
- new_options = (
- base_options
- + """
- }"""
- )
-
- pattern = r"torch_compile_options\s*=\s*\{[^}]*\}"
-
- RLTrainer_source = re.sub(
- pattern, new_options, RLTrainer_source, flags = re.DOTALL
- )
+ RLTrainer_call_args = RLTrainer_call_args,
+ RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args.endswith(",") else 0:],
- if trl_version >= Version("0.27.0"):
- peft_pattern = (
- r"\s*if is_peft_available\(\) and is_peft_model\(model\) and args\.beta != 0\.0:"
- r".*?"
- r"param\.data = param\.data\.to\(torch\.bfloat16\)"
- )
-
- replacement_comment = "\n # PEFT initialization logic removed via script for trl >= 0.27.0\n"
+ RLConfig_name = RLConfig_name,
+ __RLConfig_doc__ = __RLConfig_doc__,
+ RLConfig_arguments = RLConfig_arguments,
+ RLConfig_extra_args = RLConfig_extra_args,
+ RLConfig_call_args = RLConfig_call_args,
+ RLConfig_kwargs = ",**kwargs"[1 if RLConfig_call_args .endswith(",") else 0:],
- RLTrainer_source = re.sub(
- peft_pattern, replacement_comment, RLTrainer_source, flags = re.DOTALL
- )
+ RLTrainer_extras = RLTrainer_extras,
+ RLTrainer_post = RLTrainer_post,
+ RL_pre = RL_pre,
- elif trl_version >= Version("0.26.0"):
- peft_block_pattern = (
- r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:"
- r".*?"
- r"param\.data = param\.data\.to\(torch\.bfloat16\)"
- )
-
- RLTrainer_source = re.sub(
- peft_block_pattern,
- "\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n",
- RLTrainer_source,
- flags = re.DOTALL,
- )
-
- # Remove TRL's unconditional bfloat16 cast of trainable params (added in
- # TRL 0.26.0). TRL hardcodes bfloat16 for QLoRA per the original paper's
- # recommendation, but this is wrong: it ignores the user's requested dtype
- # and breaks GradScaler when training with fp16=True. Unsloth already
- # handles adapter dtype correctly via patch_model_and_tokenizer, so the
- # entire block is unnecessary. For GRPOTrainer the enclosing peft init
- # block is already removed above, making this a no-op for GRPO.
- RLTrainer_source = RLTrainer_source.replace(
- 'if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):',
- "if False:",
+ selective_log_softmax_code = selective_log_softmax_code,
)
- if RLTrainer_name == "SFTTrainer":
- original_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]'
- new_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]'
- RLTrainer_source = RLTrainer_source.replace(original_text, new_text)
-
- # Do NOT override _is_vlm -- let TRL detect VLM models naturally.
- # In TRL 0.27.1+, forcing _is_vlm=False causes a ValueError when
- # vision datasets are used with VLM models.
- #
- # However, some notebooks pass a bare tokenizer (processor.tokenizer) as
- # processing_class. TRL then sets _is_vlm=False even for VLM models.
- # Add a model-architecture-based override before the validation check.
- _vlm_check_original = (
- ' self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample\n'
- " if self._is_vision_dataset and not self._is_vlm:"
- )
- _vlm_check_patched = (
- ' self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample\n'
- " # Unsloth: override _is_vlm for VLM models that pass a bare tokenizer\n"
- " if not self._is_vlm and self._is_vision_dataset:\n"
- " _m = model\n"
- ' if hasattr(_m, "model"): _m = _m.model\n'
- ' if hasattr(getattr(_m, "config", None), "vision_config") or \\\n'
- ' _m.__class__.__name__.endswith("ForConditionalGeneration"):\n'
- " self._is_vlm = True\n"
- " if self._is_vision_dataset and not self._is_vlm:"
- )
- if _vlm_check_original in RLTrainer_source:
- RLTrainer_source = RLTrainer_source.replace(
- _vlm_check_original, _vlm_check_patched
- )
-
- # Fix TRL 0.22.x: VLM models with text-only datasets.
- # TRL 0.22.x checks _is_vlm (model type) not _is_vision_dataset (dataset
- # content, added in 0.25.1+). When _is_vlm=True, signature columns are
- # vision-only ["messages","prompt","completion","images"], which have zero
- # overlap with tokenized text columns. Fix: merge both column sets into the
- # VLM branch. Extra columns not in the dataset are harmlessly ignored by
- # _remove_unused_columns (it only raises when zero columns match).
- _sig_vlm_old = (
- 'self._signature_columns = ["messages", "prompt", "completion", "images"]'
- )
- _sig_vlm_new = (
- 'self._signature_columns = ["messages", "prompt", "completion", "images",'
- ' "input_ids", "labels", "attention_mask", "seq_lengths", "completion_mask", "assistant_masks"]'
- )
- RLTrainer_source = RLTrainer_source.replace(_sig_vlm_old, _sig_vlm_new)
-
- # Inject model reference before _prepare_dataset for dynamic
- # token_type_ids detection in sft_prepare_dataset
- _prep_pattern = r"([ \t]*)train_dataset = self\._prepare_dataset\("
- _prep_replacement = r"\1self._unsloth_model_ref = model\n\1train_dataset = self._prepare_dataset("
- RLTrainer_source = re.sub(
- _prep_pattern, _prep_replacement, RLTrainer_source, count = 1
- )
-
- # Silence TRL's noisy batch_size=1 + padding-free warning (handles both
- # the original "anihilate" typo and the corrected "annihilate" spelling)
- for _typo in ("anihilate", "annihilate"):
- _idx = RLTrainer_source.find(_typo)
- if _idx == -1:
- continue
- # Walk backwards to find "if args.per_device_train_batch_size"
- _block_start = RLTrainer_source.rfind(
- "if args.per_device_train_batch_size == 1", 0, _idx
- )
- if _block_start == -1:
- continue
- # Walk backwards to the newline before the if
- _line_start = RLTrainer_source.rfind("\n", 0, _block_start)
- # Walk forwards past the closing paren to the end of the block
- _close = RLTrainer_source.find(")", _idx)
- if _close == -1:
- continue
- _block_end = RLTrainer_source.find("\n", _close)
- if _block_end == -1:
- continue
- RLTrainer_source = (
- RLTrainer_source[:_line_start] + RLTrainer_source[_block_end:]
- )
- break
-
# Remove multiple doc strings
if __RLConfig_doc__ != "" and RLTrainer_source.count(__RLTrainer_doc__) == 2:
RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, "", 1)
+ pass
# Remove multiple newlines
RLTrainer_source = re.sub(r"[\n]{3,}", "\n", RLTrainer_source)
# Create new function
- _resolved_module = _trainer_resolved_module or _config_resolved_module
- _model_location = (
- _resolved_module.__name__
- if _resolved_module is not None
- else f"trl.trainer.{trainer_file}"
- )
created_module = create_new_function(
f"Unsloth{RLTrainer_name}",
RLTrainer_source,
- _model_location,
+ f"trl.trainer.{trainer_file}",
imports,
overwrite = False,
)
-
+
# Patch Trainer
- exec(
- f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}",
- locals(),
- globals(),
- )
- exec(
- f"trl.trainer.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}",
- locals(),
- globals(),
- )
- exec(
- f"trl.trainer.{trainer_file}.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}",
- locals(),
- globals(),
- )
-
+ exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals())
+ exec(f"trl.trainer.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals())
+ exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals())
+
# Patch Config
- exec(
- f"trl.{RLConfig_name} = created_module.Unsloth{RLConfig_name}",
- locals(),
- globals(),
- )
- exec(
- f"trl.trainer.{RLConfig_name} = created_module.Unsloth{RLConfig_name}",
- locals(),
- globals(),
- )
- exec(
- f"trl.trainer.{trainer_file}.{RLConfig_name} = created_module.Unsloth{RLConfig_name}",
- locals(),
- globals(),
- )
-
- if trainer_file == "grpo_trainer":
- try:
- _wrap_grpo_generate_and_score(
- getattr(created_module, f"Unsloth{RLTrainer_name}")
- )
- except Exception as e:
- logger.info(
- f"Unsloth: Could not wrap _generate_and_score_completions for {RLTrainer_name}: {e}"
- )
+ exec(f"trl.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals())
+ exec(f"trl.trainer.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals())
+ exec(f"trl.trainer.{trainer_file}.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals())
+pass
def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports):
init = inspect.getsource(RLTrainer.__init__)
old_init = init
- # Remove brackets in comments since it interferes ie (...)
- comments = re.findall(r"\#[^\n]{1,}\n", init)
- bracketed_comments = [x for x in comments if "(" in x or ")" in x]
- # Replace with [...] instead
- for bracketed_comment in bracketed_comments:
- init = init.replace(
- bracketed_comment,
- bracketed_comment.replace("(", "[").replace(")", "]"),
- )
-
# Remove peft_config
init = init.replace("elif peft_config is None:", "elif False:")
init = init.replace("elif peft_config is not None:", "elif False:")
init = init.replace("if peft_config is None:", "if False:")
init = init.replace("if peft_config is not None:", "if False:")
init = init.replace("get_peft_model(model, peft_config)", "model")
- # New TRL 0.20.0
- init = init.replace(
- "if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):",
- "if False:",
- )
- # New TRL 0.20.0
- init = init.replace(
- "model = self._prepare_peft_model(model, peft_config, args)\n", "pass\n"
- )
- # TRL 0.22.0+ uses prepare_peft_model as a standalone function
- init = init.replace("model = prepare_peft_model(model, peft_config, args)", "pass")
-
- # Skip add_adapter("ref") for reference model computation
- # Unsloth: We comment out the "ref" adapter creation because:
- # 1. We want to use the original BASE MODEL as the reference model, not the SFT/LoRA model
- # 2. PEFT doesn't allow multiple adapters when target_parameters is used (MoE models)
- # When "ref" is not in peft_config, GRPO/RLOO fallback uses disable_adapter()
- # which gives the base model logits - exactly what we want
- add_adapter_block_pattern = (
- r"([ \t]*)" # Capture leading indentation
- r"if\s+is_peft_available\(\)\s+and\s+is_peft_model\(model\)\s+and\s+args\.beta\s*!=\s*0\.0\s*:"
- r"(.*?)" # Match the entire block until ref_param.data.copy_
- r"ref_param\.data\.copy_\(param\.data\)"
- )
-
- def comment_out_block(match):
- """Comment out each line in the matched block, preserving indentation."""
- full_match = match.group(0)
- indent = match.group(1)
- lines = full_match.split("\n")
- commented_lines = []
- # Add explanation comment first
- commented_lines.append(
- f"{indent}# Unsloth: Commented out - use base model as reference, not SFT/LoRA model"
- )
- # Comment out each line - insert # after leading whitespace to preserve indentation
- for line in lines:
- if line.strip():
- stripped = line.lstrip()
- leading_ws = line[: len(line) - len(stripped)]
- commented_lines.append(f"{leading_ws}# {stripped}")
- else:
- commented_lines.append(line)
- return "\n".join(commented_lines)
-
- init = re.sub(add_adapter_block_pattern, comment_out_block, init, flags = re.DOTALL)
# Set use_vllm if not set
if "args.use_vllm" in init and "model" in init and "args" in init:
@@ -1612,158 +592,84 @@ def comment_out_block(match):
)
if len(replacer) != 0:
replacer = replacer[0]
- vllm_setter = (
- "\n"
- + " " * 8
- + "if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'):\n"
- + " " * 12
- + "if (getattr(args, 'use_vllm', False) == False):\n"
- + " " * 16
- + "args.use_vllm = True\n"
- )
- # " " * 16 + "args.vllm_importance_sampling_correction = True\n" + \
- # " " * 16 + "args.vllm_importance_sampling_cap = 2.0\n"
-
- if "grpo" in trainer_file and trl_version >= Version("0.18.0"):
- # If model has vllm_engine, then use vllm in colocate mode. Donot wait for server
- vllm_setter += " " * 12 + "args.vllm_mode='colocate'\n"
- if trl_version >= Version("0.23.0"):
- # We need to set this flag for sleep mode auto working with trl update
- vllm_setter += (
- " " * 12
- + "if os.environ.get('UNSLOTH_VLLM_STANDBY', '0') == '1':\n"
- + " " * 16
- + "args.vllm_enable_sleep_mode=True\n"
- )
-
+ vllm_setter = "\n" + " "*8 + \
+ "if hasattr(model, 'vllm_engine') and "\
+ "hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): "\
+ "args.use_vllm = True\n"
init = init.replace(replacer, replacer + vllm_setter)
-
- # breakpoint()
+ pass
+ pass
vllm_part = re.findall(
- r"(\n[\s]{8}" r"if (self|args)\.use_vllm\:.*?" r"\n[\s]{8}" "else:\n)",
+ r"(\n[\s]{8}"\
+ r"if (self|args)\.use_vllm\:.*?"\
+ r"\n[\s]{8}"\
+ "else:\n)",
init,
flags = re.MULTILINE | re.DOTALL,
)
-
if len(vllm_part) == 1:
vllm_part, args = vllm_part[0][0], vllm_part[0][1]
# Strip all comments
- new_vllm_part = re.sub(
- r"^\s*\#[^\n]*\n?", "", vllm_part, flags = re.MULTILINE
- ) # to also remove whole comment line instead of just starting at #
- new_vllm_part = re.sub(
- r"\s*\#.*$", "", new_vllm_part, flags = re.MULTILINE
- ) # remove comments that occur after code
+ new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part)
# Get SamplingParams
sampling_params = re.findall(
- r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}" r"SamplingParams\(.+?\))",
+ r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}"\
+ r"SamplingParams\(.+?\))",
new_vllm_part,
flags = re.MULTILINE | re.DOTALL,
)
-
if len(sampling_params) == 1:
sampling_params = sampling_params[0]
+
# Fix guided_decoding
sampling_params = sampling_params.replace(
"guided_decoding=guided_decoding,",
- "guided_decoding="
- 'GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) '
+ 'guided_decoding='\
+ 'GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) '\
'if getattr(args, "vllm_guided_decoding_regex", None) is not None else None,',
)
# Replace with our vLLM engine
- sampling_params = (
- " " * 12
- + "self.llm = model.vllm_engine; self._last_loaded_step = 0; "
- + sampling_params
- ) # Add spaces
-
- # count the indentation of last line of sampling_params.
- splitted_sampling_params = sampling_params.split("\n")
- if len(splitted_sampling_params) >= 2:
- last_line = splitted_sampling_params[-1]
- last_prev_line = splitted_sampling_params[-2]
- last_prev_indentation = len(last_prev_line) - len(
- last_prev_line.lstrip()
- )
- last_indentation = len(last_line) - len(last_line.lstrip())
-
- # Add extra arguments to SamplingParams
- extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {})"
- # Backwards replace
- to_replace = (
- ",\n"
- + " " * last_prev_indentation
- + extra
- + ",\n"
- + " " * last_indentation
- + ")"
- )
- sampling_params = to_replace.join(sampling_params.rsplit(")", 1))
- # Strip multiple commas
- sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params)
-
- new_vllm_part = (
- f"\n{' '*8}if {args}.use_vllm:\n{sampling_params}"
- f"\n{' '*8}else:\n"
- )
-
- if trl_version >= Version("0.18.0"):
- # Replace LLM init with already existing vLLM engine for colocate mode
- vllm_llm_init_pattern = r"self\.llm\s*=\s*LLM\(.*?\)*\)\s*?\n(?!,)"
- vllm_llm_replacement = "self.llm = model.vllm_engine\n"
- new_vllm_part = re.sub(
- vllm_llm_init_pattern,
- vllm_llm_replacement,
- new_vllm_part,
- flags = re.DOTALL, # Ensure . matches newlines [[5]]
- )
-
- init = init.replace(vllm_part, new_vllm_part)
+ sampling_params = \
+ " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \
+ sampling_params # Add spaces
+
+ # Add extra arguments to SamplingParams
+ extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {})"
+ # Backwards replace
+ to_replace = "," + extra + "," + ")"
+ sampling_params = to_replace.join(sampling_params.rsplit(")", 1))
+ # Strip multiple commas
+ sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params)
+
+ new_vllm_part = \
+ f"\n{' '*8}if {args}.use_vllm:\n{sampling_params}"\
+ f"\n{' '*8}else:\n"
+
+ init = init.replace(vllm_part, new_vllm_part)
+ pass
+ pass
# Search for vLLM calling in all child functions
functions = dir(RLTrainer)
RLTrainer_source = inspect.getsource(RLTrainer)
functions = [x for x in functions if f"def {x}" in RLTrainer_source]
- changed = {
- "__init__": (
- old_init,
- init,
- )
- }
+ changed = {"__init__" : (old_init, init,)}
edit_functions = RL_FUNCTIONS.get(trainer_file, [])
for function in functions:
- if not hasattr(RLTrainer, function):
- continue
- if function in changed:
- original_source, source = changed[function]
- else:
- fx = getattr(RLTrainer, function)
- try:
- source = inspect.getsource(fx)
- except:
- continue
- original_source = source
+ if not hasattr(RLTrainer, function): continue
+ fx = getattr(RLTrainer, function)
+ try: source = inspect.getsource(fx)
+ except: continue
+ original_source = source
# Check for function
for edit_function in edit_functions:
source = edit_function(function, source)
-
- """
- import torch
- X = torch.ones((2, 2048, 201088), dtype = torch.bfloat16, device = "cuda")
- X[torch.randperm(2, dtype = torch.int64, device = X.device)]
-
- will error out in torch 2.8 AcceleratorError: CUDA error: invalid configuration argument
- """
- source = re.sub(
- r"(\n[\s]{4,})generation_batch = shuffle_sequence_dict\(generation_batch\)\n",
- r"\n\1try: generation_batch = shuffle_sequence_dict(generation_batch)\n\1except: pass\n",
- source,
- )
+ pass
# llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
source = re.sub(
@@ -1785,57 +691,23 @@ def comment_out_block(match):
r"",
source,
)
-
+
# Replace self.llm.generate and self.llm.chat
- if "CUDA_VISIBLE_DEVICES" in os.environ:
- lora_name = (
- trainer_file
- + "_lora_model_' + "
- + "(os.environ.get('CUDA_VISIBLE_DEVICES', '0').replace(',',''))"
- )
- else:
- lora_name = trainer_file + "_lora_model'"
+ lora_name = trainer_file + "_lora_model"
source = re.sub(
r"(self\.llm\.(?:generate|chat)\([^\)]{1,})\)",
- r"\1, lora_request = self.model.load_lora('"
- + lora_name
- + r", load_tensors = True))",
- source,
- )
- # All these are to fix multiple commas before lora_request (in case the original code ends with something like ",)")
- # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1388 for eg has such an ending
- source = re.sub(r"\,[\s]{1,}\,[\s]{0,}lora_request", ", lora_request", source)
- source = re.sub(r"[\s]{1,}\,[\s]{0,}lora_request", ", lora_request", source)
- source = re.sub(r"[\,]{1,}[\s]{0,}lora_request", ", lora_request", source)
- # Prefer using unsloth's sampling params and fallback to trl's if not found
- # We'll enable this later separately when combining both this and GRPOConfig params
- # source = re.sub(
- # r"sampling_params\s*=\s*sampling_params",
- # r"sampling_params = getattr(self.args, 'vllm_sampling_params', sampling_params)",
- # source
- # )
- # Fix later versions of SamplingParams via grpo_update_SamplingParams
- source = source.replace(
- "sampling_params = SamplingParams(**generation_kwargs)",
- "sampling_params = SamplingParams("
- "**grpo_update_SamplingParams("
- "SamplingParams, generation_kwargs, "
- "getattr(self.args, 'vllm_sampling_params', None)"
- ")"
- ")",
+ r"\1, lora_request = self.model.load_lora('" + lora_name + r"', load_tensors = True))",
+ source
)
# Skip if no changes done
- if source == original_source:
- continue
+ if source == original_source: continue
# Find all imports
imports += [x for x in all_imports if not x.startswith("_") and x in source]
- changed[function] = (
- original_source,
- source,
- )
+ changed[function] = (original_source, source,)
+ pass
# Import all functions
imports = list(set(imports))
@@ -1844,67 +716,29 @@ def comment_out_block(match):
for function in changed:
old, new = changed[function]
RLTrainer_source = RLTrainer_source.replace(old, new)
+ pass
RLTrainer_source = RLTrainer_source.replace(
f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1
)
return RLTrainer_source
+pass
def patch_trl_rl_trainers():
# Patch all TRL modules if they have vLLM or PEFT
import trl.trainer
-
all_trainers = dir(trl.trainer)
- all_trainers = [
- x
- for x in all_trainers
- if x.islower() and x.endswith("_trainer") and x != "base_trainer"
- ]
+ all_trainers = [x for x in all_trainers if x.islower() and x.endswith("_trainer")]
for trainer in all_trainers:
- try:
- _patch_trl_rl_trainers(trainer)
- except Exception as e:
- logger.warning_once(f"Unsloth: Could not patch trl.trainer.{trainer}: {e}")
- return
-
-
-def patch_trl_openenv():
- for function in RL_ADDITIONAL_FUNCTIONS["openenv"]:
- logger.info(f"Unsloth: Patching trl openenv with function: {function.__name__}")
- function() # Call the function to apply the patch
- return
-
-
-def patch_trl_vllm_generation():
- # trl moved vllm stuff to trl/generation/vllm_generation.py
- # We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference
- # Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause
- for function in RL_ADDITIONAL_FUNCTIONS["vllm_generation"]:
- logger.info(
- f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}"
- )
- function()
- return
-
-
-def patch_trl_vllm_generation():
- # trl moved vllm stuff to trl/generation/vllm_generation.py
- # We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference
- # Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause
- for function in RL_ADDITIONAL_FUNCTIONS["vllm_generation"]:
- logger.info(
- f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}"
- )
- function()
+ _patch_trl_rl_trainers(trainer)
return
+pass
def PatchFastRL(algorithm = None, FastLanguageModel = None):
- if FastLanguageModel is not None:
- PatchRL(FastLanguageModel)
+ if FastLanguageModel is not None: PatchRL(FastLanguageModel)
patch_trl_rl_trainers()
- patch_trl_openenv()
- patch_trl_vllm_generation()
if type(algorithm) is str and algorithm.islower():
PatchRLStatistics(algorithm)
+pass
diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py
old mode 100755
new mode 100644
index 314feb5d2a..4071ef835a
--- a/unsloth/models/rl_replacements.py
+++ b/unsloth/models/rl_replacements.py
@@ -20,112 +20,66 @@
"RL_METRICS_CHANGES",
]
-import os
import re
import torch
import inspect
-import linecache
from collections import defaultdict
-from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding
-from unsloth_zoo.utils import Version
-from trl import __version__ as trl_version_raw
-from importlib.metadata import version as importlib_version
-from unsloth_zoo.log import logger
-from unsloth_zoo.device_type import device_synchronize
-import importlib.util
-from ..device_type import (
- is_hip,
- get_device_type,
- DEVICE_TYPE,
- DEVICE_TYPE_TORCH,
- DEVICE_COUNT,
- ALLOW_PREQUANTIZED_MODELS,
-)
-import textwrap
-from ._utils import _get_inference_mode_context_manager
-
-RL_EXTRA_ARGS = defaultdict(list)
-RL_FUNCTIONS = defaultdict(list)
-RL_PRE_ITEMS = defaultdict(list)
-RL_CONFIG_CHANGES = defaultdict(list)
+from unsloth_zoo.rl_replacements import RL_REPLACEMENTS
+RL_EXTRA_ARGS = defaultdict(list)
+RL_FUNCTIONS = defaultdict(list)
+RL_PRE_ITEMS = defaultdict(list)
+RL_CONFIG_CHANGES = defaultdict(list)
RL_METRICS_CHANGES = defaultdict(list)
-RL_ADDITIONAL_FUNCTIONS = defaultdict(list)
torch_compile_options = {
- "epilogue_fusion": True,
- "max_autotune": False, # I saw speedups, but not sure if this has issues in collab
- "shape_padding": True,
- "trace.enabled": False,
- "triton.cudagraphs": False,
+ "epilogue_fusion" : True,
+ "max_autotune" : True,
+ "shape_padding" : True,
+ "trace.enabled" : False,
+ "triton.cudagraphs" : False,
}
-try:
- trl_version = Version(trl_version_raw)
-except Exception:
- try:
- trl_version = Version(importlib_version("trl"))
- except Exception:
- trl_version = Version("0.0.0")
-
-
# Check untrained tokens
def sft_trainer_fix_untrained_tokens(call_args, extra_args):
if "model" in call_args and "train_dataset" in call_args:
- fix_tokenizer = (
- "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\\n')\n"
- "from unsloth_zoo.tokenizer_utils import fix_untrained_tokens\n"
- "from unsloth_zoo.training_utils import fix_zero_training_loss\n"
- "if 'tokenizer' not in locals(): tokenizer = processing_class\n"
- "fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n"
- "fix_zero_training_loss(model, tokenizer, train_dataset)\n"
- )
+ fix_tokenizer = \
+ "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\\n')\n"\
+ "from unsloth_zoo.tokenizer_utils import fix_untrained_tokens\n"\
+ "from unsloth_zoo.training_utils import fix_zero_training_loss\n"\
+ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\
+ "fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n"\
+ "fix_zero_training_loss(model, tokenizer, train_dataset)\n"
return fix_tokenizer
return ""
-
-
+pass
RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untrained_tokens)
-# Fix top_k for GRPO vLLM.
-# https://github.com/huggingface/trl/pull/4695 with this change trl added top_k in GRPOConfig and defaults to 0
-# We don't want that since vllm's all include top_k is -1 and 0 returns an error on SamplingParams creation.
-def grpo_config_fix_vllm_top_k(old_RLTrainer_source, old_RLConfig_source):
- return "if use_vllm and (top_k is None or top_k == 0): top_k = -1\n"
-
-
-RL_CONFIG_CHANGES["grpo_trainer"].append(grpo_config_fix_vllm_top_k)
-
-
# Remove DPO columns which might randomnly be tokenized
def dpo_trainer_fix_columns(call_args, extra_args):
if "model" in call_args and "train_dataset" in call_args:
- fix_dpo = (
- "if hasattr(train_dataset, 'column_names'):\n"
- " column_names = set(train_dataset.column_names)\n"
- " check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"
- " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"
- " 'prompt_input_ids', 'prompt_attention_mask']\n"
- " if all(x in column_names for x in check):\n"
- " train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"
- " del check, column_names\n"
- )
+ fix_dpo = \
+ "if hasattr(train_dataset, 'column_names'):\n"\
+ " column_names = set(train_dataset.column_names)\n"\
+ " check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\
+ " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\
+ " 'prompt_input_ids', 'prompt_attention_mask']\n"\
+ " if all(x in column_names for x in check):\n"\
+ " train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\
+ " del check, column_names\n"
return fix_dpo
return ""
-
-
+pass
RL_EXTRA_ARGS["dpo_trainer"].append(dpo_trainer_fix_columns)
# Fix tokenizer double BOS
def sft_trainer_prepare_dataset(function_name, function):
- if (
- function_name != "_prepare_non_packed_dataloader"
- and function_name != "_prepare_dataset"
- ):
- return function
+ if function_name != "_prepare_non_packed_dataloader" and \
+ function_name != "_prepare_dataset": return function
fast_sft_prepare_dataset = RL_REPLACEMENTS.get("sft_prepare_dataset", None)
- if fast_sft_prepare_dataset is not None:
+ if fast_sft_prepare_dataset is not None and "pack_examples" in function:
params = inspect.signature(fast_sft_prepare_dataset).parameters.keys()
params = ".*?".join(params)
matched = re.match(
@@ -137,36 +91,33 @@ def sft_trainer_prepare_dataset(function_name, function):
# Use fast version!
function = inspect.getsource(fast_sft_prepare_dataset)
function = function.split("\n")
- function = "\n".join(" " * 4 + x for x in function)
- function = function.replace(
- "def sft_prepare_dataset", "def _prepare_dataset"
- )
+ function = "\n".join(" "*4 + x for x in function)
+ function = function.replace("def sft_prepare_dataset", "def _prepare_dataset")
return function
-
- check_text = (
- "if 'skip_prepare_dataset' in locals() and skip_prepare_dataset:\n"
- " return dataset\n"
- "if 'tokenizer' not in locals(): tokenizer = processing_class\n"
- "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"
- "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"
- "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"
- "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"
- "chat_template = getattr(tokenizer, 'chat_template', None)\n"
- "chat_template = '' if chat_template is None else chat_template\n"
- "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "
- "if getattr(tokenizer, 'bos_token', None) is not None else False\n"
- "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"
- " from functools import partial\n"
- " tokenizer_call = tokenizer.__call__\n"
- " tokenizer.__call__ = partial(tokenizer_call, add_special_tokens = False)\n"
- " processing_class = tokenizer\n"
- "else:\n"
- " tokenizer_call = None\n"
- " add_special_tokens = False if has_bos_token_already else locals().get('add_special_tokens', False)\n"
- )
+ pass
+ pass
+
+ check_text = \
+ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\
+ "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\
+ "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"\
+ "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"\
+ "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"\
+ "chat_template = getattr(tokenizer, 'chat_template', None)\n"\
+ "chat_template = '' if chat_template is None else chat_template\n"\
+ "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\
+ "if getattr(tokenizer, 'bos_token', None) is not None else False\n"\
+ "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"\
+ " from functools import partial\n"\
+ " tokenizer_call = tokenizer.__call__\n"\
+ " tokenizer.__call__ = partial(tokenizer_call, add_special_tokens = False)\n"\
+ " processing_class = tokenizer\n"\
+ "else:\n"\
+ " tokenizer_call = None\n"\
+ " add_special_tokens = False if has_bos_token_already else locals().get('add_special_tokens', False)\n"
check_text = check_text.split("\n")
- check_text = "\n".join(" " * 8 + x for x in check_text)
+ check_text = "\n".join(" "*8 + x for x in check_text)
check_text = check_text.rstrip() + "\n"
# .*? matches first match. .+? matches final match.
@@ -178,31 +129,26 @@ def sft_trainer_prepare_dataset(function_name, function):
if len(replacer) != 0:
replacer = replacer[0]
function = function.replace(replacer, replacer + check_text)
+ pass
# Return tokenizer's original state
- return_state = (
- "if tokenizer_call is not None: tokenizer.__call__ = tokenizer_call\n"
- )
+ return_state = "if tokenizer_call is not None: tokenizer.__call__ = tokenizer_call\n"
function = re.sub(
r"\n([ ]{4,})(return .*?[\s]{0,})$",
rf"\1{return_state}\1\2",
function,
)
return function
-
-
+pass
RL_FUNCTIONS["sft_trainer"].append(sft_trainer_prepare_dataset)
# Ignore mean_token_accuracy since it needs logits
# We override it directly with our version
def sft_trainer_compute_loss(function_name, function):
- if function_name != "compute_loss":
- return function
+ if function_name != "compute_loss": return function
- def compute_loss(
- self, model, inputs, return_outputs = False, num_items_in_batch = None
- ):
+ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
outputs = super().compute_loss(
model,
inputs,
@@ -210,919 +156,137 @@ def compute_loss(
num_items_in_batch = num_items_in_batch,
)
return outputs
+ pass
function = inspect.getsource(compute_loss)
return function
-
-
+pass
RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss)
-# Fix bare pop("push_to_hub_token") in compiled SFT/IterativeSFT trainer __init__
-# On transformers 5.0+, to_dict() no longer includes push_to_hub_token, so bare pop KeyErrors
-def sft_trainer_push_to_hub_token(function_name, function):
- if function_name != "__init__":
- return function
- return function.replace(
- 'dict_args.pop("push_to_hub_token")', 'dict_args.pop("push_to_hub_token", None)'
- )
-
-
-RL_FUNCTIONS["sft_trainer"].append(sft_trainer_push_to_hub_token)
-
-
# Autocast precision for GRPO
def grpo_trainer__prepare_inputs(function_name, function):
- if function_name != "_prepare_inputs":
- return function
+ if function_name != "_prepare_inputs": return function
+
+ if "with torch.inference_mode()" not in function: return function
# Add mixed precision training
function = function.replace(
"with torch.inference_mode():",
- "with torch.inference_mode(), "
- "torch.amp.autocast(device_type = 'cuda', "
- "dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "
- "if not torch.is_autocast_enabled('cuda') else nullcontext())"
- "if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):",
+
+ "with torch.inference_mode(), "\
+ "torch.amp.autocast(device_type = 'cuda', "\
+ "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\
+ "if not torch.is_autocast_enabled('cuda') else nullcontext():",
)
+
+ # Disable attaching a float32 conversion hook which upcasts logits to FP32
function = function.replace(
"self.accelerator.unwrap_model(self.model)",
"self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False)",
)
return function
-
-
+pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs)
-# Remove collective RPC of reload weights from generate
-# trl added reload weights (potentially for quantized models), we don't need it for our use case (LoRA primarily)
-# https://github.com/huggingface/trl/commit/7856d3b1f6518601732f489883b341bb6dd36434#diff-964e6fd373aa93037604064cb2b822d7f8e2735e33f791065acf2c4c3552d393R1168-R1169
-def grpo_trainer__generate_single_turn(function_name, function):
- if function_name != "_generate_single_turn":
- return function
-
- # Remove the reload_weights collective RPC call from the generate function's source
- # function = function.replace('self.llm.collective_rpc("reload_weights")', "")
- # The regex below does the same thing but is more flexible and can handle single or double quotes
- # This is for older versions.
- function = re.sub(
- r"self\.llm\.collective_rpc\(\s*(['\"])reload_weights\1\s*\)",
- "",
- function,
- )
-
- # Current TRL versions call vllm_generation.sync_weights() every step.
- # When Unsloth fast inference LoRA is active, weights are already shared.
- sync_weights_block = re.compile(
- r"(?P[ \t]*)with profiling_context\(self,\s*(['\"])sync_weights\2\s*\):\n"
- r"(?P=indent)[ \t]+self\.vllm_generation\.sync_weights\(\)\n",
- re.MULTILINE,
- )
-
- def remove_sync_weights_block(match):
- indent = match.group("indent")
- return (
- f"{indent}# Unsloth fast inference LoRA shares weights with vLLM already.\n"
- f"{indent}# Skipping per-step vLLM sync_weights().\n"
- )
-
- function = sync_weights_block.sub(remove_sync_weights_block, function)
-
- # TRL 0.24.0-0.25.1 truncation regression fix
- #
- # TRL 0.22.2-0.23.1 used smart truncation via truncate_with_protected_tokens():
- # - Tokenizes first without truncation
- # - Then truncates keeping the RIGHTMOST tokens (preserves assistant turn)
- # - Protects special tokens (image_token, vision_start/end) from removal
- #
- # TRL 0.24.0-0.25.1 removed this and passed kwargs directly to the tokenizer:
- # max_length=self.max_prompt_length, truncation=True, add_special_tokens=False
- # This causes issues because tokenizer truncation doesn't protect special tokens
- # and may not preserve the end of the prompt properly.
- #
- # TRL 0.26.2+ removed these kwargs entirely (no tokenizer-level truncation).
- #
- # Fix: Remove these kwargs so TRL 0.24.0-0.25.1 behaves like 0.26.2+ (no truncation).
- # This is a no-op for versions that don't have these kwargs (0.22.2-0.23.1, 0.26.2+).
- for pattern in [
- r'["\']?max_length["\']?\s*[:=]\s*self\.max_prompt_length\s*,\s*\n?',
- r'["\']?truncation["\']?\s*[:=]\s*True\s*,\s*\n?',
- r'["\']?add_special_tokens["\']?\s*[:=]\s*False\s*,\s*\n?',
- ]:
- function = re.sub(pattern, "", function)
-
- return function
-
-
-RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_single_turn)
-
-
-# Fix incorrect special tokens handling and truncation in older TRL versions
-def grpo_trainer__generate_and_score_completions(function_name, function):
- if function_name != "_generate_and_score_completions":
- return function
-
- # TRL 0.19.0 did skip_special_tokens = True which should be False
- function = function.replace(
- "prompt_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False",
- "prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False",
- )
-
- # Left pad prompt before calculation old and ref hidden states
- line_to_replace = 'batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size'
-
- # The new multi-line string that will replace the line above
- replacement_lines = """
- max_left_pad = None
- batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
- try:
- # TRL 0.23.1 and below path
- if not has_images:
- # Left pad prompt before calculation old and ref hidden states
- left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id)
- max_left_pad = torch.max(left_pad_tokens_per_prompt).item()
- except:
- # TRL 0.24.0 and below path
- if images is None:
- # Left pad prompt before calculation old and ref hidden states
- left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id)
- max_left_pad = torch.max(left_pad_tokens_per_prompt).item()
- self.model.for_training()"""
-
- function = function.replace(line_to_replace, replacement_lines)
-
- pattern_to_find = re.compile(
- r"^\s*if self\.args\.gradient_accumulation_steps % generate_every != 0 or \(\s*"
- r"self\.use_vllm and self\.vllm_importance_sampling_correction\s*"
- r"\):",
- re.MULTILINE,
- )
-
- replacement_text = """
- if self.args.gradient_accumulation_steps % generate_every != 0 or (
- self.use_vllm
- ):"""
- # Use re.sub() to perform the replacement
- function, num_replacements = pattern_to_find.subn(replacement_text, function)
-
- pattern_to_find = re.compile(
- r"(^\s*)all_logprobs = \[" # Capture indentation (group 1)
- r".*?" # Match everything inside non-greedily
- r"for output in outputs\.outputs\s*"
- r"\]",
- re.DOTALL | re.MULTILINE,
- )
-
- # sanitize_logprob is injected as a module-level function via RLTrainer_replacement
- # template in rl.py (from RL_REPLACEMENTS), so just reference it directly here.
- replacement_text = (
- r"\1all_logprobs = [\n"
- r"\1 [sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]\n"
- r"\1 for outputs in all_outputs\n"
- r"\1 for output in outputs.outputs\n"
- r"\1]"
- )
-
- function, num_replacements = pattern_to_find.subn(replacement_text, function)
-
- # Always between max_prompt_length and use_vllm
- found = re.findall(
- r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?"
- r"\2if self\.use_vllm:)",
- function,
- flags = re.DOTALL | re.MULTILINE,
- )
- if len(found) != 0:
- replace_part, spacing = found[0]
- removed_comments = re.sub(r"\#[^\n]{1,}", "", replace_part)
- splits = removed_comments.split("\n")
- if (
- sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits) == 2
- and len(spacing) >= 8
- ):
- new_replacement = f"""\n{spacing}if self.max_prompt_length is not None:
- # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
- # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
- # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
- protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
- protected = [token for token in protected if token is not None]
- prompt_ids, prompt_mask = truncate_with_protected_tokens(
- prompt_ids, prompt_mask, self.max_prompt_length, protected
- )
-
- prompts_text = [re.sub(rf"^({{re.escape(self.pad_token)}})+", "", text) for text in prompts_text]
-
- # The chat template inserts a single image token into the prompt text. However, when this text is later
- # tokenized, the single image token string is expanded into multiple image token IDs, depending on the
- # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
- # collapse them back into a single token string to match the original template.
- if self.image_token is not None:
- prompts_text = [
- re.sub(rf"({{re.escape(self.image_token)}})+", self.image_token, text) for text in prompts_text
- ]
- # Generate completions using either vLLM or regular generation
- if self.use_vllm:"""
- function = function.replace(replace_part, new_replacement)
-
- # Important note: we disable TRL's importance sampling logic
- # It is disabled because the LLM path moves left padding to the right.
- # We must adjust the vLLM sampling_logprob tensor in Unsloth to account for this.
- string_to_find = "if self.use_vllm and self.vllm_importance_sampling_correction:"
-
- replacement_string = (
- "if False and self.use_vllm and self.vllm_importance_sampling_correction:"
- )
-
- function = function.replace(string_to_find, replacement_string)
-
- string_to_find = """ if "image_sizes" in prompt_inputs:
- output["image_sizes"] = prompt_inputs["image_sizes"]"""
-
- replacement_string = """ if "image_sizes" in prompt_inputs:
- output["image_sizes"] = prompt_inputs["image_sizes"]
- if max_left_pad is not None:
- output["max_left_pad"] = torch.tensor(prompt_ids.shape[0] * [max_left_pad]).unsqueeze(-1)
- try:
- if self.use_vllm and getattr(self, "vllm_importance_sampling_correction", False):
- output["sampling_per_token_logps"] = sampling_per_token_logps
- except NameError:
- output["sampling_per_token_logps"] = None"""
-
- function = function.replace(string_to_find, replacement_string)
-
- # TRL 0.24.0+ extracts prompts = [x["prompt"] for x in inputs], losing metadata
- # like reasoning_effort. Inject code to store per-sample chat_template_kwargs on self.
- _metadata_extraction = (
- "\n"
- " # Unsloth: Extract per-sample chat_template_kwargs before metadata is lost\n"
- " _ct_ = getattr(self.processing_class, 'chat_template', None) or ''\n"
- " _sk_ = {'prompt', 'chosen', 'rejected', 'completion', 'messages', 'label',\n"
- " 'images', 'image', 'videos', 'video', 'audios', 'audio'}\n"
- " self._unsloth_batch_chat_kwargs = []\n"
- " for _inp_ in inputs:\n"
- " _kw_ = {}\n"
- " if isinstance(_inp_, dict):\n"
- " for _k_ in _inp_.keys() - _sk_:\n"
- " if _k_ in _ct_ and isinstance(_inp_[_k_], str):\n"
- " _kw_[_k_] = _inp_[_k_]\n"
- " self._unsloth_batch_chat_kwargs.append(_kw_)\n"
- )
- # Insert after: prompts = [x["prompt"] for x in inputs]
- _target_line = 'prompts = [x["prompt"] for x in inputs]'
- if _target_line in function:
- function = function.replace(
- _target_line,
- _target_line + _metadata_extraction,
- )
-
- # This path is for TRL 0.24.0 images is a variable exclusive to this version
- string_to_find = """ if images is not None:
- output["num_images"] = num_images"""
-
- replacement_string = """ if images is not None:
- output["num_images"] = num_images
- if max_left_pad is not None:
- output["max_left_pad"] = torch.tensor(prompt_ids.shape[0] * [max_left_pad]).unsqueeze(-1)
- try:
- if self.use_vllm and getattr(self, "vllm_importance_sampling_correction", False):
- output["sampling_per_token_logps"] = sampling_per_token_logps
- except NameError:
- output["sampling_per_token_logps"] = None"""
-
- function = function.replace(string_to_find, replacement_string)
-
- if trl_version >= Version("0.24.0"):
- # We replace the call using 'completions' with one using 'completions_text'
- string_to_find = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)"
- replacement_string = (
- " if images is not None:\n"
- " rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)\n"
- " else:\n"
- " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)"
- )
- function = function.replace(string_to_find, replacement_string)
-
- if "wake_up()" not in function:
- # Sleep functionality has been added to trl in v0.23.0. We do not want to redo this.
- # https://github.com/huggingface/trl/commit/edbe8234bc7e528f72ac76607de9d3e4753e2709
-
- pattern = re.compile(r".*self\.llm\.generate\(.*\).*", re.MULTILINE)
- matches = list(pattern.finditer(function))
- patched = function
-
- # Generally there's only one match. But this is just to make sure we don't miss any.
- for match in reversed(matches):
- line = match.group(0)
- indent_match = re.match(r"(\s*)", line)
- indent = indent_match.group(1) if indent_match else ""
-
- wrapped = (
- f"{indent}if hasattr(self, 'llm'):\n"
- f"{indent} if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n"
- f"{indent} self.llm.wake_up()\n"
- f"{line}\n\n"
- f"{indent}if hasattr(self, 'llm'):\n"
- f"{indent} if getattr(self.llm.llm_engine.vllm_config.model_config, 'enable_sleep_mode', False):\n"
- f"{indent} self.llm.sleep(os.environ.get('VLLM_SLEEP_MODE', 1))\n"
- )
-
- patched = patched[: match.start()] + wrapped + patched[match.end() :]
-
- function = patched
-
- return function
-
-
-RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_and_score_completions)
-
-
-# Fix {"reasoning_effort" : "high"} not applied
-def grpo_trainer_fix_maybe_apply_chat_template(function_name, function):
- spaces = function.find("def ")
- if spaces % 4 != 0:
- return function
- spaces += 4
- replacement = """
- _chat_template_ = getattr(self.processing_class, "chat_template", None)
- if _chat_template_ is None: _chat_template_ = ""
- _supported_keys_ = set(("prompt", "chosen", "rejected", "completion", "messages", "label"))
- _batch_chat_kwargs_ = getattr(self, "_unsloth_batch_chat_kwargs", None)
-
- prompts_text = []
- for _idx_, _example_ in enumerate(__INPUTS__REPLACEMENT__):
- _tokenizer_kwargs_ = {}
- if type(_example_) is not dict:
- _example_ = {"prompt": _example_}
- _left_keys_ = _example_.keys() - _supported_keys_
- for k in _left_keys_:
- if k in _chat_template_:
- v = _example_[k]
- if type(v) is str:
- _tokenizer_kwargs_[k] = v
- if _batch_chat_kwargs_ is not None and _idx_ < len(_batch_chat_kwargs_):
- for _bk_, _bv_ in _batch_chat_kwargs_[_idx_].items():
- if _bk_ not in _tokenizer_kwargs_:
- _tokenizer_kwargs_[_bk_] = _bv_
- _x_ = maybe_apply_chat_template(_example_, self.processing_class, **_tokenizer_kwargs_)["prompt"]
- prompts_text.append(_x_)
- """
- replacement = textwrap.dedent(replacement).strip()
- replacement = textwrap.indent(replacement, spaces * " ")
- replacement = f"\n{replacement}\n"
- what = 'prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]'
- function = function.replace(
- what, replacement.replace("__INPUTS__REPLACEMENT__", "inputs")
- )
-
- """prompts_text = [
- maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
- ]"""
- function = re.sub(
- r"prompts_text = \["
- r"[\s]{0,}"
- r"maybe_apply_chat_template\(\{[\"\']prompt[\"\'][\s]{0,}\:[\s]{0,}prompt[\s]{0,}\}[\s]{0,}\,[\s]{0,}self\.processing_class\)"
- r"\[[\"\']prompt[\"\']\] for prompt in prompts"
- r"[\s]{0,}"
- r"\]",
- replacement.replace("__INPUTS__REPLACEMENT__", "prompts"),
- function,
- )
- return function
-
-
-RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_fix_maybe_apply_chat_template)
-
-
# Remove _move_model_to_vllm
def grpo_trainer__move_model_to_vllm(function_name, function):
- if function_name != "_move_model_to_vllm":
- return function
+ if function_name != "_move_model_to_vllm": return function
- def _move_model_to_vllm(self, *args, **kwargs):
- return None
+ def _move_model_to_vllm(self, *args, **kwargs): return None
function = inspect.getsource(_move_model_to_vllm)
return function
-
-
+pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm)
# Edit _get_per_token_logps to handle mixed precision
def grpo_trainer__get_per_token_logps(function_name, function):
- if function_name != "_get_per_token_logps":
- return function
+ if function_name != "_get_per_token_logps": return function
- def _get_per_token_logps(
- self, model, input_ids, attention_mask, logits_to_keep, compute_efficient = False
- ):
- if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
- return None # Unsloth efficient GRPO
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
+ if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
+ return None # Unsloth efficient GRPO
# Otherwise, calculate normally:
- if not hasattr(self, "_autocast_dtype"):
- self._autocast_dtype = (
- torch.float16
- if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
- else torch.bfloat16
- )
- if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
- self._autocast_dtype = torch.float16
-
- os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
- with torch.amp.autocast(device_type = DEVICE_TYPE, dtype = self._autocast_dtype):
+ if not hasattr(self, '_autocast_dtype'):
+ self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float32
+ with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
- logits = model(
- input_ids = input_ids,
- attention_mask = attention_mask,
- logits_to_keep = logits_to_keep + 1,
- ).logits
- # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
- return logits
- # input_ids = input_ids[:, -logits_to_keep:]
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
+
+ input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
- # logits = logits[:, -logits_to_keep:]
- # return logits
- # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
- # logits = logits / self.temperature
- # logps = selective_log_softmax(logits, input_ids)
-
- # row_indices, col_indices = torch.where(logps < -20)
-
- # # Method 1: Check if tensors have elements
- # if len(row_indices) > 0 and len(col_indices) > 0:
- # breakpoint() # Breakpoint triggered here
- # print("Found high values!")
- # return logps # compute logprobs for the input tokens
+ logits = logits[:, -logits_to_keep:]
+ return logits
+ # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
+ pass
+ pass
function = inspect.getsource(_get_per_token_logps)
return function
-
-
+pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps)
-
-def grpo_trainer__get_per_token_logps_and_entropies(function_name, function):
- if function_name != "_get_per_token_logps_and_entropies":
- return function
-
- # Just copy over from _get_per_token_logps replacement function above. For now this returns None anyway
- def _get_per_token_logps_and_entropies(
- self,
- model,
- input_ids,
- attention_mask,
- logits_to_keep,
- batch_size = None,
- compute_entropy = False,
- compute_efficient = False,
- *args,
- **kwargs,
- ):
- # All Unsloth code here in this function is licensed under AGPL3
- # if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
- # return None, None # logps, entropies Unsloth efficient GRPO
- if compute_efficient:
- return None, None
- else:
- if not hasattr(self, "_autocast_dtype"):
- self._autocast_dtype = (
- torch.float16
- if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
- else torch.bfloat16
- )
- if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
- self._autocast_dtype = torch.float16
-
- pixel_values, image_grid_thw = (
- kwargs.get("pixel_values", None),
- kwargs.get("image_grid_thw", None),
- )
- pixel_attention_mask, image_sizes = (
- kwargs.get("pixel_attention_mask", None),
- kwargs.get("image_sizes", None),
- )
-
- unwrapped_model = self.accelerator.unwrap_model(
- model, keep_fp32_wrapper = False
- )
-
- lm_head = self.model.get_output_embeddings().weight
-
- dtype_bytes = (
- 16 if self._autocast_dtype in [torch.float16, torch.bfloat16] else 32
- )
- total_rows = input_ids.shape[0]
- seq_len = input_ids.shape[1]
- hidden_dim = lm_head.shape[1]
- vocab_dim = lm_head.shape[0]
-
- if self.args.unsloth_grpo_mini_batch is None:
- B, multiplier = autotune_batch_and_chunks(
- total_rows,
- seq_len,
- hidden_dim,
- vocab_dim,
- dtype_bytes,
- self.args.unsloth_logit_chunk_multiplier,
- )
- B = total_rows // B
- else:
- B = self.args.unsloth_grpo_mini_batch
-
- if self.args.unsloth_logit_chunk_multiplier is None:
- multiplier = max(4, seq_len // 4096)
- else:
- multiplier = self.args.unsloth_logit_chunk_multiplier
-
- all_logprobs_list = []
- if pixel_values is None:
- left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(
- input_ids, logits_to_keep, self.processing_class.pad_token_id
- )
- max_left_pad = torch.max(left_pad_tokens_per_prompt).item()
- input_ids = left_pack_padding(
- input_ids, self.processing_class.pad_token_id
- )
- attention_mask = input_ids != self.processing_class.pad_token_id
- attention_mask = attention_mask.to(attention_mask.dtype)
- else:
- max_left_pad = 0
-
- # input_ids_chunks = torch.chunk(input_ids, chunks = B, dim = 0)
- attention_mask_chunks = torch.chunk(attention_mask, chunks = B, dim = 0)
-
- def chunk_optional(tensor, chunks):
- if tensor is None:
- return [None] * chunks
- return torch.chunk(tensor, chunks = chunks, dim = 0)
-
- import math
-
- total_samples = input_ids.shape[0]
- batch_size = math.ceil(total_samples / B)
-
- input_ids_chunks = []
- attention_mask_chunks = []
- pixel_values_chunks = []
- image_grid_thw_chunks = []
- pixel_attention_mask_chunks = []
-
- current_pixel_idx = 0
- # TRL 0.23.0 batching logic
- for start in range(0, total_samples, batch_size):
- end = start + batch_size
-
- input_ids_chunks.append(input_ids[start:end])
- attention_mask_chunks.append(attention_mask[start:end])
-
- if image_grid_thw is not None and pixel_values is not None:
- grid_slice = image_grid_thw[start:end]
- image_grid_thw_chunks.append(grid_slice)
-
- batch_pixel_count = grid_slice.prod(dim = -1).sum().item()
-
- start_pixel_idx = current_pixel_idx
- end_pixel_idx = current_pixel_idx + batch_pixel_count
-
- pixel_values_chunks.append(
- pixel_values[start_pixel_idx:end_pixel_idx]
- )
-
- if pixel_attention_mask is not None:
- pixel_attention_mask_chunks.append(
- pixel_attention_mask[start_pixel_idx:end_pixel_idx]
- )
- else:
- pixel_attention_mask_chunks.append(None)
-
- current_pixel_idx = end_pixel_idx
-
- else:
- pixel_values_chunks.append(None)
- image_grid_thw_chunks.append(None)
- pixel_attention_mask_chunks.append(None)
-
- if image_sizes is not None and not isinstance(image_sizes, torch.Tensor):
- image_sizes_chunks = [[size] for size in image_sizes]
- else:
- image_sizes_chunks = chunk_optional(image_sizes, B)
-
- temperature = self.temperature
- logit_softcapping = getattr(model.config, "final_logit_softcapping", 0)
- if logit_softcapping is None:
- logit_softcapping = 0
- logit_scale_multiply = getattr(model.config, "logit_scale", 0)
- if logit_scale_multiply is None:
- logit_scale_multiply = 0
- logit_scale_divide = getattr(model.config, "logits_scaling", 0)
- if logit_scale_divide is None:
- logit_scale_divide = 0
-
- zipped_inputs = zip(
- input_ids_chunks,
- attention_mask_chunks,
- pixel_values_chunks,
- image_grid_thw_chunks,
- pixel_attention_mask_chunks,
- image_sizes_chunks,
- )
- os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
-
- with _get_inference_mode_context_manager(model):
- for (
- input_ids_chunk,
- attention_mask_chunk,
- pixel_values_chunk,
- image_grid_thw_chunk,
- pixel_attention_mask_chunk,
- image_sizes_chunk,
- ) in zipped_inputs:
- with torch.amp.autocast(
- device_type = "cuda", dtype = self._autocast_dtype
- ):
- if pixel_values is None:
- logits_chunk = unwrapped_model(
- input_ids = input_ids_chunk,
- attention_mask = attention_mask_chunk,
- pixel_values = pixel_values_chunk,
- image_grid_thw = image_grid_thw_chunk,
- pixel_attention_mask = pixel_attention_mask_chunk,
- image_sizes = image_sizes_chunk,
- ).logits
-
- completion_input_ids_chunk = input_ids_chunk[
- :, -(logits_to_keep + max_left_pad) :
- ]
- logits_chunk = logits_chunk[
- :, -(logits_to_keep + max_left_pad + 1) :, :
- ]
- logits_chunk = logits_chunk[:, :-1, :]
- else:
- # Essentially, for VLMs we do not go via the optimized path in models/,
- # so we don't encounter the Flash Attn left-padding issue.
- logits_chunk = unwrapped_model(
- input_ids = input_ids_chunk,
- attention_mask = attention_mask_chunk,
- pixel_values = pixel_values_chunk,
- image_grid_thw = image_grid_thw_chunk,
- pixel_attention_mask = pixel_attention_mask_chunk,
- image_sizes = image_sizes_chunk,
- logits_to_keep = logits_to_keep + 1,
- ).logits
-
- logits_chunk = logits_chunk[:, :-1, :]
- completion_input_ids_chunk = input_ids_chunk[
- :, -logits_to_keep:
- ]
-
- logprobs_chunk = chunked_hidden_states_selective_log_softmax(
- logits_chunk,
- lm_head,
- completion_input_ids_chunk,
- chunks = input_ids_chunk.shape[0] * multiplier,
- logit_scale_multiply = logit_scale_multiply,
- logit_scale_divide = logit_scale_divide,
- logit_softcapping = logit_softcapping,
- temperature = temperature,
- )
- # This is needed to avoid race conditions with GPT OSS offload_embbed=True
- # However, it seems that this line does not slow down or disrupt models.
- device_synchronize()
- all_logprobs_list.append(logprobs_chunk)
- logprobs = torch.cat(all_logprobs_list, dim = 0)
- entropies = None
-
- os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
-
- return logprobs.detach(), entropies # logps, entropies
- # input_ids = input_ids[:, -logits_to_keep:]
- # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
- # See https://github.com/huggingface/trl/issues/2770
- # logits = logits[:, -logits_to_keep:]
- # return logits
- # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
- # logits = logits / self.temperature
- # logps = selective_log_softmax(logits, input_ids)
-
- # row_indices, col_indices = torch.where(logps < -20)
-
- # # Method 1: Check if tensors have elements
- # if len(row_indices) > 0 and len(col_indices) > 0:
- # breakpoint() # Breakpoint triggered here
- # print("Found high values!")
- # return logps # compute logprobs for the input tokens
-
- function = inspect.getsource(_get_per_token_logps_and_entropies)
- return function
-
-
-RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps_and_entropies)
-
-grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"]
+grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"]
grpo_compute_loss_slow = RL_REPLACEMENTS["grpo_compute_loss_slow"]
-UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"]
-grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"]
-grpo_update_SamplingParams = RL_REPLACEMENTS["grpo_update_SamplingParams"]
+UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"]
+grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"]
RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss))
RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO))
RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss))
RL_PRE_ITEMS["grpo_trainer"].append(grpo_compute_loss_slow)
-RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_update_SamplingParams))
-RL_PRE_ITEMS["grpo_trainer"].append(
- inspect.getsource(_get_inference_mode_context_manager)
-)
-
# Edit _get_per_token_logps to handle mixed precision
def grpo_trainer_compute_loss(function_name, function):
- if function_name != "compute_loss":
- return function
+ if function_name != "compute_loss": return function
- def compute_loss(
- self, model, inputs, return_outputs = False, num_items_in_batch = None
- ):
+ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
- completion_ids, completion_mask = (
- inputs["completion_ids"],
- inputs["completion_mask"],
- )
- pixel_values, image_grid_thw = (
- inputs.get("pixel_values", None),
- inputs.get("image_grid_thw", None),
- )
- pixel_attention_mask, image_sizes = (
- inputs.get("pixel_attention_mask", None),
- inputs.get("image_sizes", None),
- )
- num_items_in_batch = inputs.get("num_items_in_batch", None)
- sampling_per_token_logps = inputs.get("sampling_per_token_logps", None)
- current_gradient_accumulation_steps = self.current_gradient_accumulation_steps
- num_processes = self.accelerator.num_processes
-
- input_ids = torch.cat([prompt_ids, completion_ids], dim = 1)
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
bsz, qlen = input_ids.shape
- attention_mask = torch.cat([prompt_mask, completion_mask], dim = 1)
- # attention_mask = None
- logits_to_keep = completion_ids.size(
- 1
- ) # we only need to compute the logits for the completion tokens
+ # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
+ attention_mask = None
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
_input_ids = input_ids
_logits_to_keep = logits_to_keep
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
- get_logps_func = (
- lambda model,
- input_ids,
- attention_mask,
- logits_to_keep,
- batch_size = None,
- compute_entropy = False,
- compute_efficient = False: self._get_per_token_logps(
- model, input_ids, attention_mask, logits_to_keep, compute_efficient
- )
- if hasattr(self, "_get_per_token_logps")
- else self._get_per_token_logps_and_entropies(
- model,
- input_ids,
- attention_mask,
- logits_to_keep,
- batch_size,
- compute_entropy,
- compute_efficient,
- )[0]
- ) # logps
-
- per_token_logps = get_logps_func(
- model, input_ids, attention_mask, logits_to_keep, compute_efficient = True
- )
# Compute the KL divergence between the model and the reference model
- # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
- # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
- # if self.beta != 0.0:
- # with torch.inference_mode(), model.disable_adapter():
- # ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)
- # else:
- # ref_per_token_logps = None
- ref_logps = inputs.get("ref_per_token_logps", None)
+ ref_per_token_logps = inputs["ref_per_token_logps"]
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
+
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
- old_logps = inputs.get("old_per_token_logps", None)
-
input_ids = input_ids[:, -logits_to_keep:]
-
- # Get logit softcapping and logit scale
- logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) # Gemma
- if logit_softcapping is None:
- logit_softcapping = 0
- logit_scale_multiply = getattr(model.config, "logit_scale", 0) # Cohere
- if logit_scale_multiply is None:
- logit_scale_multiply = 0
- logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite
- if logit_scale_divide is None:
- logit_scale_divide = 0
-
- max_left_pad = inputs.get("max_left_pad", 0)
if per_token_logps is not None:
- (
- loss,
- completion_length,
- mean_kl,
- delta,
- flat_is_ratio,
- coef_1,
- completion_mask,
- ) = grpo_compute_loss_slow(
- ref_logps,
- per_token_logps,
- old_logps,
- input_ids,
- completion_mask,
- self.beta,
- advantages,
- pixel_values = pixel_values,
- image_grid_thw = image_grid_thw,
- loss_type = self.args.loss_type,
- importance_sampling_level = self.importance_sampling_level,
- epsilon_low = self.epsilon_low,
- epsilon_high = self.epsilon_high,
- max_completion_length = self.args.max_completion_length,
- delta = self.args.delta,
- temperature = self.args.temperature,
- max_left_pad = max_left_pad,
- logit_softcapping = logit_softcapping,
- logit_scale_multiply = logit_scale_multiply,
- logit_scale_divide = logit_scale_divide,
- num_items_in_batch = num_items_in_batch,
- current_gradient_accumulation_steps = current_gradient_accumulation_steps,
- num_processes = num_processes,
- sampling_per_token_logps = sampling_per_token_logps,
+ loss, completion_length, mean_kl = grpo_compute_loss_slow(
+ ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages,
)
else:
- if hasattr(self.args, "loss_type"):
- (
- loss,
- completion_length,
- mean_kl,
- delta,
- flat_is_ratio,
- coef_1,
- completion_mask,
- ) = grpo_accumulated_loss(
- trainer = self,
- input_ids = _input_ids,
- pixel_values = pixel_values,
- image_grid_thw = image_grid_thw,
- logits_to_keep = logits_to_keep,
- completion_mask = completion_mask,
- advantages = advantages,
- old_logps = old_logps,
- ref_logps = ref_logps,
- n_chunks = self.args.unsloth_num_chunks,
- loss_type = self.args.loss_type,
- importance_sampling_level = self.importance_sampling_level,
- epsilon_low = self.epsilon_low,
- epsilon_high = self.epsilon_high,
- max_completion_length = self.args.max_completion_length,
- delta = self.args.delta,
- temperature = self.args.temperature,
- max_left_pad = max_left_pad,
- logit_softcapping = logit_softcapping,
- logit_scale_multiply = logit_scale_multiply,
- logit_scale_divide = logit_scale_divide,
- attention_mask = attention_mask,
- num_items_in_batch = num_items_in_batch,
- current_gradient_accumulation_steps = current_gradient_accumulation_steps,
- num_processes = num_processes,
- sampling_per_token_logps = sampling_per_token_logps,
- )
- else:
- # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
- loss, completion_length, mean_kl, coef_1, completion_mask = (
- grpo_accumulated_loss(
- trainer = self,
- input_ids = _input_ids,
- logits_to_keep = logits_to_keep,
- completion_mask = completion_mask,
- advantages = advantages,
- old_logps = old_logps,
- ref_logps = ref_logps,
- n_chunks = self.args.unsloth_num_chunks,
- temperature = self.args.temperature,
- logit_softcapping = logit_softcapping,
- logit_scale_multiply = logit_scale_multiply,
- logit_scale_divide = logit_scale_divide,
- attention_mask = attention_mask,
- )
- )
+ loss, completion_length, mean_kl = grpo_accumulated_loss(
+ self, _input_ids, logits_to_keep, completion_mask, advantages,
+ n_chunks = self.args.unsloth_num_chunks,
+ )
+
+ # Log the metrics
+ # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
+
+ # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
+ # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
+
if "train" in self._metrics:
mode = "eval" if self.control.should_evaluate else "train"
self._metrics[mode]["completion_length"].append(completion_length.item())
@@ -1130,408 +294,43 @@ def compute_loss(
else:
self._metrics["completion_length"].append(completion_length.item())
self._metrics["kl"].append(mean_kl.item())
-
- if (
- self.use_vllm
- and delta is not None
- and getattr(self, "vllm_importance_sampling_correction", False)
- ):
- mean_delta = (
- torch.mean(delta)
- if delta.numel() > 0
- else torch.tensor(0.0, device = self.model.device)
- )
- max_delta = (
- torch.max(delta)
- if delta.numel() > 0
- else torch.tensor(0.0, device = self.model.device)
- )
- self._metrics[mode]["sampling/sampling_logp_difference/mean"].append(
- self.accelerator.gather(mean_delta).mean().item()
- )
- self._metrics[mode]["sampling/sampling_logp_difference/max"].append(
- self.accelerator.gather(max_delta).max().item()
- )
-
- min_importance_sampling_ratio = (
- torch.min(flat_is_ratio)
- if flat_is_ratio.numel() > 0
- else torch.tensor(0.0, device = self.model.device)
- )
- mean_importance_sampling_ratio = (
- torch.mean(flat_is_ratio)
- if flat_is_ratio.numel() > 0
- else torch.tensor(0.0, device = self.model.device)
- )
- max_importance_sampling_ratio = (
- torch.max(flat_is_ratio)
- if flat_is_ratio.numel() > 0
- else torch.tensor(0.0, device = self.model.device)
- )
- self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
- self.accelerator.gather(min_importance_sampling_ratio)
- .nan_to_num(nan = float("inf"))
- .min()
- .item()
- )
- self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append(
- self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item()
- )
- self._metrics[mode]["sampling/importance_sampling_ratio/max"].append(
- self.accelerator.gather(max_importance_sampling_ratio)
- .nan_to_num(nan = float("-inf"))
- .max()
- .item()
- )
-
- completion_token_count = completion_mask.sum().clamp(min = 1.0)
-
- def masked_batch_mean(x):
- if x.shape[1] == 1: # when importance_sampling_level == "sequence"
- return x.mean()
- else:
- return (x * completion_mask).sum() / completion_token_count
-
- if advantages.dim() == 1:
- advantages = advantages.unsqueeze(1)
-
- if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
- # Compute the clipped probability ratios
- is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)
- is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)
- is_region_clipped = is_low_clipped | is_high_clipped
-
- low_clip = masked_batch_mean(is_low_clipped.float())
- high_clip = masked_batch_mean(is_high_clipped.float())
- clip_ratio = masked_batch_mean(is_region_clipped.float())
-
- gathered_low_clip = self.accelerator.gather(low_clip)
- self._metrics[mode]["clip_ratio/low_mean"].append(
- gathered_low_clip.nanmean().item()
- )
- self._metrics[mode]["clip_ratio/low_min"].append(
- nanmin(gathered_low_clip).item()
- )
- gathered_high_clip = self.accelerator.gather(high_clip)
- self._metrics[mode]["clip_ratio/high_mean"].append(
- gathered_high_clip.nanmean().item()
- )
- self._metrics[mode]["clip_ratio/high_max"].append(
- nanmax(gathered_high_clip).item()
- )
- gathered_clip_ratio = self.accelerator.gather(clip_ratio)
- self._metrics[mode]["clip_ratio/region_mean"].append(
- gathered_clip_ratio.nanmean().item()
- )
- elif self.loss_type == "cispo":
- is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0)
- cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
- gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
- self._metrics[mode]["cispo_clip_ratio"].append(
- gathered_cispo_clip_ratio.nanmean().item()
- )
-
return loss
+ pass
function = inspect.getsource(compute_loss)
return function
-
-
+pass
RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss)
-
-# Fix KTO shape mismatch when Unsloth model forward truncates input_ids
-# but labels aren't truncated. TRL 0.27.2+ _process_tokens only truncates
-# completions, not prompts -- so prompts exceeding max_seq_length cause the
-# model to produce shorter logits than the labels expect.
-def kto_trainer_get_batch_logps(function_name, function):
- if function_name != "get_batch_logps":
- return function
- # The raise is inside an if block inside the method, so we need
- # to preserve the exact indentation of the raise statement.
- old = 'raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")'
- new = (
- "# Unsloth: auto-truncate to shorter sequence length (model may have truncated input_ids)\n"
- " _min_len = min(logits.shape[1], labels.shape[1])\n"
- " logits = logits[:, :_min_len, :]\n"
- " labels = labels[:, :_min_len]"
- )
- function = function.replace(old, new)
- return function
-
-
-RL_FUNCTIONS["kto_trainer"].append(kto_trainer_get_batch_logps)
-
-
# https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356
# TRL warns if batch size is not a multiple of num_generations -> fix this.
def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source):
- if "divisible by the number of generations" not in RLTrainer_source:
- # in later trl versions this doesn't exist anymore
- return ""
- if "num_generations" not in RLConfig_source:
- return ""
-
- check_batch_size = (
- "div = per_device_train_batch_size // num_generations\n"
- "if div * num_generations != per_device_train_batch_size:\n"
- " print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"
- "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"
- " per_device_train_batch_size = num_generations\n"
- )
+ if "divisible by the number of generations" not in RLTrainer_source: return ""
+ if "num_generations" not in RLConfig_source: return ""
+
+ check_batch_size = \
+ "div = per_device_train_batch_size // num_generations\n"\
+ "if div * num_generations != per_device_train_batch_size:\n"\
+ " print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\
+ "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"\
+ " per_device_train_batch_size = num_generations\n"
return check_batch_size
-
-
+pass
RL_CONFIG_CHANGES["grpo_trainer"].append(grpo_trainer_fix_batch_size)
# Add other reward function names
def grpo_trainer_metrics(RLTrainer_source, RLConfig_source):
- if "reward_funcs" not in RLTrainer_source:
- return ""
-
- # For new TRL we have /mean and /std
- use_mean = "rewards/{reward_func_name}/mean" in RLTrainer_source
- use_std = "rewards/{reward_func_name}/std" in RLTrainer_source
- if not use_mean:
- use_normal = "rewards/{reward_func_name}" in RLTrainer_source
- else:
- use_normal = False
-
- log_metrics = (
- "if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\n"
- "else: _reward_funcs = reward_funcs\n"
- "for reward_func in _reward_funcs:\n"
- " try:\n"
- " reward_func_name = reward_func.__name__\n"
- f" if {use_mean}:\n"
- " other_metrics.append(f'rewards/{reward_func_name}/mean')\n"
- f" if {use_std}:\n"
- " other_metrics.append(f'rewards/{reward_func_name}/std')\n"
- f" if {use_normal}:\n"
- " other_metrics.append(f'rewards/{reward_func_name}')\n"
- " except: pass\n"
- )
+ if "reward_funcs" not in RLTrainer_source: return ""
+
+ log_metrics = \
+ "if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\n"\
+ "else: _reward_funcs = reward_funcs\n"\
+ "for reward_func in _reward_funcs:\n"\
+ " try:\n"\
+ " reward_func_name = reward_func.__name__\n"\
+ " other_metrics.append(f'rewards/{reward_func_name}')\n"\
+ " except: pass\n"
return log_metrics
-
-
+pass
RL_METRICS_CHANGES["grpo_trainer"].append(grpo_trainer_metrics)
-
-
-def openenv_vllm_reload_weights():
- # This function patches the trl openenv generate_rollout_completions function to:
- # 1. Remove the reload_weights call (unsloth handles weight reloading)
- # 2. Fix wake_up call to be compatible with unsloth (remove tags to wake everything)
- #
- # The issue: TRL's wake_up(tags=["kv_cache"]) only wakes kv_cache, leaving is_sleeping=True
- # at the executor level. This causes unsloth's patched generate to try waking up again,
- # resulting in double create_and_map on already-mapped handles.
- #
- # The fix: Use wake_up() with no tags, which wakes everything. Unsloth's patched
- # CuMemAllocator.wake_up skips weights anyway, so this is safe.
- if importlib.util.find_spec("trl") is None:
- return
- if Version(importlib_version("trl")) < Version("0.26.0"):
- return
-
- try:
- import trl.experimental.openenv.utils as openenv_utils
- import trl.experimental.openenv as openenv
- except (ImportError, NameError, Exception) as e:
- logger.info(f"Unsloth: Failed to import trl openenv: {e}")
- logger.info(
- "Unsloth: trl.experimental.openenv not available — skipping RL openenv patches."
- )
- return
-
- # trl 0.28 changed the function name yet again! Thanks trl :)
- patch_target_name = "_generate_rollout_completions_colocate"
- if hasattr(openenv_utils, patch_target_name):
- patch_target = getattr(openenv_utils, patch_target_name)
- else:
- # Older TRL versions may keep sleep/wake logic in the public dispatcher.
- patch_target_name = "generate_rollout_completions"
- patch_target = getattr(openenv_utils, patch_target_name)
-
- src = inspect.getsource(patch_target)
- src = textwrap.dedent(src)
- original_src = src
-
- # Remove the reload_weights call - unsloth handles this differently
- src = re.sub(r'.*\.collective_rpc\(\s*([\'"])reload_weights\1\s*\).*\n?', "", src)
-
- # Change wake_up(tags=["kv_cache"]) to wake_up() - wake everything to set is_sleeping=False
- # This prevents double wake_up issues. Unsloth's allocator skips weights anyway.
- src = re.sub(r"\.wake_up\(tags=\[.*?\]\)", ".wake_up()", src)
-
- if original_src == src:
- logger.warning("Unsloth: Warning - regex did not match, patch may have failed")
- return
-
- # Execute and explicitly assign to module
- local_ns = {}
- exec(compile(src, "", "exec"), openenv_utils.__dict__, local_ns)
- patched_func = local_ns[patch_target_name]
-
- # Patch the target function in utils; if dispatcher was patched also update parent module alias.
- setattr(openenv_utils, patch_target_name, patched_func)
- if patch_target_name == "generate_rollout_completions":
- openenv.generate_rollout_completions = patched_func
- logger.info(f"Unsloth: Patched trl openenv {patch_target_name}")
-
-
-RL_ADDITIONAL_FUNCTIONS["openenv"].append(openenv_vllm_reload_weights)
-
-
-def vllm_generation_init_patch():
- # trl moved vllm stuff to trl/generation/vllm_generation.py
- # We need to patch it to not instantiate another vLLM instance if we already have one with fast_inference
- # Edit the TRL source directly and install the patched function in the TRL module.
- # https://github.com/huggingface/trl/commit/0eb66d8f2fc63b3d00d8dbc18f99c3f48750bd16
- # This exists in trl versions 0.28.0 and above
-
- if importlib.util.find_spec("trl") is None:
- return
- if Version(importlib_version("trl")) < Version("0.28.0"):
- return
-
- try:
- import trl.generation.vllm_generation as vllm_generation
- except (ImportError, NameError, Exception) as e:
- logger.info(f"Unsloth: Failed to import trl.generation.vllm_generation: {e}")
- return
-
- def patch_vllm_generation_method(method_name, transform, marker, filename_suffix):
- method = getattr(vllm_generation.VLLMGeneration, method_name, None)
- if method is None:
- logger.info(f"Unsloth: Could not find VLLMGeneration.{method_name}")
- return False
-
- try:
- src = inspect.getsource(method)
- except Exception as e:
- logger.info(
- f"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}"
- )
- return False
-
- src = textwrap.dedent(src)
- if marker in src:
- return True
-
- src = transform(src)
- filename = f""
- source_lines = [line + "\n" for line in src.splitlines()]
- linecache.cache[filename] = (
- len(src),
- None,
- source_lines,
- filename,
- )
-
- local_ns = {}
- exec(compile(src, filename, "exec"), vllm_generation.__dict__, local_ns)
- setattr(vllm_generation.VLLMGeneration, method_name, local_ns[method_name])
- return True
-
- # Patch init to remove vLLM.LLM instantiation
- def patch_init_vllm(src):
- pattern = re.compile(
- r"(?P^(?P[ \t]*)self\.llm\s*=\s*LLM\s*\(\n(?:.*\n)*?^(?P=indent)\))",
- re.MULTILINE,
- )
-
- def replace_llm_block(match):
- indent = match.group("indent")
- llm_block = textwrap.dedent(match.group("llm_block"))
- return (
- f"{indent}if hasattr(model, 'vllm_engine'):\n"
- f"{indent} # Unsloth already inits vLLM in fast inference mode. Do not redo :)\n"
- f"{indent} self.llm = model.vllm_engine\n"
- f"{indent} self.unsloth_fast_inference_lora = True\n"
- f"{indent}else:\n" + textwrap.indent(llm_block, indent + " ")
- )
-
- patched_src, num_replacements = pattern.subn(replace_llm_block, src, count = 1)
- if num_replacements == 0:
- raise RuntimeError(
- "Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed"
- )
- return patched_src
-
- # has some sync_weights or reload rpc calls.
- # we patched the grpo_trainer to strip them for prev versions
- # Ref: grpo_trainer__generate_single_turn above around L270-280
- def patch_sync_weights(src):
- pattern = re.compile(
- r"^(?Pdef sync_weights\(self\):\n)(?P(?:.*\n)*)",
- re.MULTILINE,
- )
-
- def replace_sync_weights(match):
- body = match.group("body")
- guard = (
- " if getattr(self, 'unsloth_fast_inference_lora', False):\n"
- " # Unsloth fast inference LoRA shares weights with vLLM already.\n"
- " return\n\n"
- )
- return match.group("def_line") + guard + body
-
- patched_src, num_replacements = pattern.subn(replace_sync_weights, src, count = 1)
- if num_replacements == 0:
- raise RuntimeError(
- "Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed"
- )
- return patched_src
-
- def patch_generate(src):
- pattern = re.compile(
- r"^(?P[ \t]*)self\.llm\.collective_rpc\(\s*(['\"])reload_weights\2\s*\)\s*$",
- re.MULTILINE,
- )
-
- def replace_reload_weights(match):
- indent = match.group("indent")
- return f'{indent}pass # self.llm.collective_rpc("reload_weights")'
-
- patched_src, num_replacements = pattern.subn(
- replace_reload_weights, src, count = 1
- )
- if num_replacements == 0:
- raise RuntimeError(
- "Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed"
- )
- return patched_src
-
- try:
- init_patched = patch_vllm_generation_method(
- "_init_vllm",
- patch_init_vllm,
- "self.unsloth_fast_inference_lora = True",
- "init_vllm",
- )
- sync_patched = patch_vllm_generation_method(
- "sync_weights",
- patch_sync_weights,
- "if getattr(self, 'unsloth_fast_inference_lora', False):",
- "sync_weights",
- )
- generate_patched = patch_vllm_generation_method(
- "generate",
- patch_generate,
- 'pass # self.llm.collective_rpc("reload_weights")',
- "generate",
- )
- except RuntimeError as e:
- logger.warning(str(e))
- return
-
- if init_patched:
- logger.info("Unsloth: Patched trl VLLMGeneration._init_vllm")
- if sync_patched:
- logger.info("Unsloth: Patched trl VLLMGeneration.sync_weights")
- if generate_patched:
- logger.info("Unsloth: Patched trl VLLMGeneration.generate")
-
-
-RL_ADDITIONAL_FUNCTIONS["vllm_generation"].append(vllm_generation_init_patch)
diff --git a/unsloth/models/sentence_transformer.py b/unsloth/models/sentence_transformer.py
deleted file mode 100644
index ad59165a50..0000000000
--- a/unsloth/models/sentence_transformer.py
+++ /dev/null
@@ -1,2111 +0,0 @@
-# Copyright 2025 electroglyph. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-
-from .loader import FastModel, DISABLE_SDPA_MODEL_NAMES
-from ._utils import SUPPORTS_BFLOAT16
-import inspect
-import json
-import os
-import types
-from huggingface_hub import hf_hub_download
-from typing import Optional
-import torch
-from transformers.modeling_outputs import BaseModelOutput
-from collections import OrderedDict
-from transformers.models.distilbert import modeling_distilbert
-from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
-import transformers
-from packaging.version import Version
-import re
-from transformers import AutoModel, AutoConfig
-from transformers.models.auto.auto_factory import _get_model_class
-import tempfile
-from huggingface_hub import HfApi, get_token
-from ..save import unsloth_save_pretrained_torchao, unsloth_save_pretrained_gguf
-import contextlib
-import shutil
-
-
-def _save_pretrained_torchao(
- self,
- save_directory,
- tokenizer = None,
- torchao_config = None,
- push_to_hub = False,
- token = None,
-):
- self.save_pretrained(save_directory)
-
- # grab inner model
- inner_model = self[0].auto_model
- if hasattr(inner_model, "_orig_mod"):
- inner_model = inner_model._orig_mod
-
- # merge LoRA first
- if hasattr(inner_model, "merge_and_unload"):
- inner_model = inner_model.merge_and_unload()
-
- # confirm Transformer path
- transformer_path = "0_Transformer"
- modules_path = os.path.join(save_directory, "modules.json")
- if os.path.exists(modules_path):
- try:
- with open(modules_path, "r") as f:
- modules = json.load(f)
- for m in modules:
- if m.get("type", "").endswith("Transformer"):
- transformer_path = m.get("path", "")
- break
- except:
- pass
-
- transformer_dir = os.path.join(save_directory, transformer_path)
- transformer_dir = os.path.abspath(transformer_dir)
-
- if tokenizer is None:
- tokenizer = self.tokenizer
-
- @contextlib.contextmanager
- def patch_unsloth_save():
- original_causal = transformers.AutoModelForCausalLM
- original_rmtree = shutil.rmtree
- # unsloth_save_pretrained_torchao expects AutoModelForCausalLM
- transformers.AutoModelForCausalLM = transformers.AutoModel
- # prevent unsloth from deleting the unquantized model directory
- shutil.rmtree = lambda *args, **kwargs: None
- try:
- yield
- finally:
- # unpatch
- transformers.AutoModelForCausalLM = original_causal
- shutil.rmtree = original_rmtree
-
- with patch_unsloth_save():
- unsloth_save_pretrained_torchao(
- inner_model,
- transformer_dir,
- tokenizer = tokenizer,
- torchao_config = torchao_config,
- push_to_hub = push_to_hub,
- token = token,
- )
-
- # avoid `0_Transformer-torchao`, it was either this or fix modules.json
- torchao_dir = transformer_dir + "-torchao"
- if os.path.exists(torchao_dir):
- if not os.path.exists(transformer_dir):
- os.makedirs(transformer_dir, exist_ok = True)
-
- # move contents
- for item in os.listdir(torchao_dir):
- s = os.path.join(torchao_dir, item)
- d = os.path.join(transformer_dir, item)
- if os.path.isdir(s):
- shutil.copytree(s, d, dirs_exist_ok = True)
- else:
- shutil.copy2(s, d)
-
- # remove torchao dir
- shutil.rmtree(torchao_dir)
-
- # remove conflicting safetensors if we brought in bin
- if os.path.exists(os.path.join(transformer_dir, "pytorch_model.bin")):
- safetensors_path = os.path.join(transformer_dir, "model.safetensors")
- if os.path.exists(safetensors_path):
- try:
- os.remove(safetensors_path)
- except:
- pass
-
- try:
- FastSentenceTransformer._add_unsloth_branding(save_directory)
- except:
- pass
-
-
-# Thanks Etherl:
-def _save_pretrained_gguf(
- self,
- save_directory,
- tokenizer = None,
- quantization_method = "fast_quantized",
- first_conversion = None,
- push_to_hub = False,
- token = None,
- max_shard_size = "5GB",
- temporary_location = "_unsloth_temporary_saved_buffers",
- maximum_memory_usage = 0.85,
- **kwargs,
-):
- """
- Saves the SentenceTransformer model to GGUF format by saving the inner transformer model,
- converting it, and placing the resulting GGUF files in the save directory.
- """
- # 1. Save standard SentenceTransformer structure (configs, modules.json, etc.)
- self.save_pretrained(save_directory)
-
- # 2. Extract inner transformer model
- inner_model = self[0].auto_model
- if hasattr(inner_model, "_orig_mod"):
- inner_model = inner_model._orig_mod
-
- # If it's a PEFT model, unsloth_save_pretrained_gguf handles merging,
- # but we pass the inner model wrapper.
-
- # 3. Identify where the transformer weights are stored
- transformer_path = "0_Transformer"
- modules_path = os.path.join(save_directory, "modules.json")
- if os.path.exists(modules_path):
- try:
- with open(modules_path, "r") as f:
- modules = json.load(f)
- for m in modules:
- if m.get("type", "").endswith("Transformer"):
- transformer_path = m.get("path", "")
- break
- except:
- pass
-
- # This is where Unsloth will perform the save + conversion operations
- transformer_dir = os.path.join(save_directory, transformer_path)
- # Ensure this path is absolute for consistent comparison later
- transformer_dir = os.path.abspath(transformer_dir)
-
- if tokenizer is None:
- tokenizer = self.tokenizer
-
- # 4. Patch environment to ensure Unsloth treats this embedding model correctly
- @contextlib.contextmanager
- def patch_unsloth_gguf_save():
- # Prevent deletion of the directory we just created via self.save_pretrained
- original_rmtree = shutil.rmtree
- try:
- yield
- finally:
- shutil.rmtree = original_rmtree
-
- # 5. Call Unsloth's GGUF saver on the inner model targeting the transformer subdirectory
- with patch_unsloth_gguf_save():
- result = unsloth_save_pretrained_gguf(
- inner_model,
- save_directory = transformer_dir,
- tokenizer = tokenizer,
- quantization_method = quantization_method,
- first_conversion = first_conversion,
- push_to_hub = False, # Force local first to move files
- token = token,
- max_shard_size = max_shard_size,
- temporary_location = temporary_location,
- maximum_memory_usage = maximum_memory_usage,
- )
-
- # 6. Move GGUF files from the subdirectory (0_Transformer) to the root save_directory
- gguf_files = result.get("gguf_files", [])
-
- new_gguf_locations = []
-
- for gguf_file in gguf_files:
- if os.path.exists(gguf_file):
- filename = os.path.basename(gguf_file)
- dest_path = os.path.join(save_directory, filename)
-
- # Convert to absolute path to avoid mixing relative/absolute in commonpath
- abs_gguf_file = os.path.abspath(gguf_file)
-
- # Check if file is inside transformer_dir (subpath)
- try:
- is_subpath = (
- os.path.commonpath([abs_gguf_file, transformer_dir])
- == transformer_dir
- )
- except ValueError:
- # Can happen on Windows with different drives, or mix of absolute/relative (handled by abspath above)
- is_subpath = False
-
- if is_subpath:
- # If the GGUF file is inside the transformer_dir, move it out to root
- shutil.move(gguf_file, dest_path)
- new_gguf_locations.append(dest_path)
- else:
- # If it's elsewhere, move it to root if not already there
- if os.path.abspath(dest_path) != abs_gguf_file:
- shutil.move(gguf_file, dest_path)
- new_gguf_locations.append(dest_path)
-
- # Update result with new locations
- result["gguf_files"] = new_gguf_locations
-
- # 7. Add branding
- try:
- FastSentenceTransformer._add_unsloth_branding(save_directory)
-
- # Add GGUF details to README
- readme_path = os.path.join(save_directory, "README.md")
- if os.path.exists(readme_path):
- with open(readme_path, "a", encoding = "utf-8") as f:
- f.write("\n## GGUF Quantization\n")
- f.write(
- f"This model contains GGUF quantized versions in: {', '.join([os.path.basename(f) for f in new_gguf_locations])}\n"
- )
- except:
- pass
-
- # 8. Handle Push to Hub if requested
- if push_to_hub:
- if token is None:
- token = get_token()
-
- api = HfApi(token = token)
- repo_id = save_directory # Assuming save_directory is the repo name if pushing
-
- print(f"Unsloth: Uploading to {repo_id}...")
- try:
- api.create_repo(
- repo_id = repo_id, exist_ok = True, private = kwargs.get("private", False)
- )
- api.upload_folder(
- folder_path = save_directory,
- repo_id = repo_id,
- commit_message = "Upload GGUF and SentenceTransformer model",
- )
- print(f"Unsloth: Uploaded to https://huggingface.co/{repo_id}")
- except Exception as e:
- print(f"Unsloth: Upload failed: {e}")
-
- return result
-
-
-def _push_to_hub_gguf(
- self,
- repo_id,
- tokenizer = None,
- quantization_method = "fast_quantized",
- first_conversion = None,
- token = None,
- private = None,
- commit_message = "Upload GGUF SentenceTransformer model trained with Unsloth",
- commit_description = "Upload GGUF model trained with Unsloth 2x faster",
- max_shard_size = "5GB",
- temporary_location = "_unsloth_temporary_saved_buffers",
- maximum_memory_usage = 0.85,
- create_pr = False,
- revision = None,
- tags = None,
- **kwargs,
-):
- """
- Converts the SentenceTransformer model to GGUF format and pushes to the Hugging Face Hub.
-
- This method:
- 1. Saves the model locally to a temporary directory in GGUF format.
- 2. Uploads the GGUF files, config, Ollama Modelfile, and README to the Hub.
- 3. Cleans up the temporary directory.
-
- Args:
- repo_id (str): The Hugging Face Hub repo ID (e.g., "username/model-name").
- tokenizer: The tokenizer to save. Defaults to `self.tokenizer`.
- quantization_method (str or list): GGUF quantization method(s). Can be a string or list of strings.
- Choose from the following options:
- * "not_quantized" : Recommended. Fast conversion. Slow inference, big files.
- * "fast_quantized" : Recommended. Fast conversion. OK inference, OK file size.
- * "quantized" : Recommended. Slow conversion. Fast inference, small files.
- * "f32" : Not recommended. Retains 100% accuracy, but super slow and memory hungry.
- * "f16" : Fastest conversion + retains 100% accuracy. Slow and memory hungry.
- * "q8_0" : Fast conversion. High resource use, but generally acceptable.
- * "q4_k_m" : Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K
- * "q5_k_m" : Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K
- * "q2_k" : Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.
- * "q3_k_l" : Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K
- * "q3_k_m" : Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K
- * "q3_k_s" : Uses Q3_K for all tensors
- * "q4_0" : Original quant method, 4-bit.
- * "q4_1" : Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.
- * "q4_k_s" : Uses Q4_K for all tensors
- * "q5_0" : Higher accuracy, higher resource usage and slower inference.
- * "q5_1" : Even higher accuracy, resource usage and slower inference.
- * "q5_k_s" : Uses Q5_K for all tensors
- * "q6_k" : Uses Q8_K for all tensors
- first_conversion (str, optional): The initial conversion format before quantization.
- token (str, optional): Hugging Face token. Uses cached token if not provided.
- private (bool, optional): Whether the repo should be private.
- commit_message (str): Commit message for the upload.
- commit_description (str): Commit description for the upload.
- max_shard_size (str): Maximum shard size for saving.
- temporary_location (str): Temp directory for intermediate files.
- maximum_memory_usage (float): Max fraction of memory to use.
- create_pr (bool): Whether to create a pull request instead of pushing directly.
- revision (str, optional): Branch/revision to push to.
- tags (list, optional): Additional tags for the repo.
-
- Returns:
- str: The full repo ID on Hugging Face Hub.
- """
- if token is None:
- token = get_token()
- if token is None:
- raise ValueError(
- "No HF token provided. Please provide a token or login with `huggingface-cli login`"
- )
-
- api = HfApi(token = token)
-
- # Determine full repo_id
- if "/" not in repo_id:
- username = api.whoami()["name"]
- full_repo_id = f"{username}/{repo_id}"
- else:
- full_repo_id = repo_id
-
- model_name = full_repo_id.split("/")[-1]
-
- # Create repo
- try:
- api.create_repo(
- repo_id = full_repo_id,
- private = private,
- exist_ok = True,
- repo_type = "model",
- )
- except Exception as e:
- print(f"Unsloth Warning: Could not create repo: {e}")
-
- # Save to temporary directory first
- with tempfile.TemporaryDirectory(prefix = "unsloth_st_gguf_") as temp_dir:
- print(f"Unsloth: Converting SentenceTransformer to GGUF format...")
-
- # Call save_pretrained_gguf to do the local conversion
- result = _save_pretrained_gguf(
- self,
- save_directory = temp_dir,
- tokenizer = tokenizer,
- quantization_method = quantization_method,
- first_conversion = first_conversion,
- push_to_hub = False, # We handle upload ourselves
- token = token,
- max_shard_size = max_shard_size,
- temporary_location = temporary_location,
- maximum_memory_usage = maximum_memory_usage,
- )
-
- gguf_files = result.get("gguf_files", [])
- modelfile_location = result.get("modelfile_location", None)
- is_vlm = result.get("is_vlm", False)
- fix_bos_token = result.get("fix_bos_token", False)
-
- print(f"Unsloth: Uploading GGUF to https://huggingface.co/{full_repo_id}...")
-
- # Upload GGUF files
- for file_location in gguf_files:
- if os.path.exists(file_location):
- filename = os.path.basename(file_location)
- print(f" Uploading {filename}...")
- api.upload_file(
- path_or_fileobj = file_location,
- path_in_repo = filename,
- repo_id = full_repo_id,
- repo_type = "model",
- commit_message = commit_message,
- commit_description = commit_description,
- create_pr = create_pr,
- revision = revision,
- )
-
- # Upload Modelfile if exists
- if modelfile_location and os.path.exists(modelfile_location):
- print(" Uploading Ollama Modelfile...")
- api.upload_file(
- path_or_fileobj = modelfile_location,
- path_in_repo = "Modelfile",
- repo_id = full_repo_id,
- repo_type = "model",
- commit_message = f"{commit_message} - Ollama Modelfile",
- create_pr = create_pr,
- revision = revision,
- )
-
- # Upload config.json if exists
- config_path = os.path.join(temp_dir, "config.json")
- if os.path.exists(config_path):
- print(" Uploading config.json...")
- api.upload_file(
- path_or_fileobj = config_path,
- path_in_repo = "config.json",
- repo_id = full_repo_id,
- repo_type = "model",
- commit_message = f"{commit_message} - config",
- create_pr = create_pr,
- revision = revision,
- )
-
- # Create and upload README
- gguf_basenames = [os.path.basename(f) for f in gguf_files if os.path.exists(f)]
- readme_content = f"""---
-tags:
-- gguf
-- llama.cpp
-- unsloth
-- sentence-transformers
-{"- vision-language-model" if is_vlm else ""}
----
-
-# {model_name} - GGUF
-
-This sentence-transformers model was finetuned and converted to GGUF format using [Unsloth](https://github.com/unslothai/unsloth).
-
-## Available Model files:
-"""
- for fname in gguf_basenames:
- readme_content += f"- `{fname}`\n"
-
- if modelfile_location and os.path.exists(modelfile_location):
- readme_content += "\n## Ollama\n"
- readme_content += "An Ollama Modelfile is included for easy deployment.\n"
-
- if fix_bos_token:
- readme_content += "\n## Note\n"
- readme_content += (
- "The model's BOS token behavior was adjusted for GGUF compatibility.\n"
- )
-
- readme_content += (
- "\nThis was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth)\n"
- '[
](https://github.com/unslothai/unsloth)\n'
- )
-
- readme_path = os.path.join(temp_dir, "README.md")
- with open(readme_path, "w", encoding = "utf-8") as f:
- f.write(readme_content)
-
- api.upload_file(
- path_or_fileobj = readme_path,
- path_in_repo = "README.md",
- repo_id = full_repo_id,
- repo_type = "model",
- commit_message = "Add README",
- create_pr = create_pr,
- revision = revision,
- )
-
- # Add tags
- all_tags = ["gguf", "llama-cpp", "unsloth", "sentence-transformers"]
- if is_vlm:
- all_tags.append("vision-language-model")
- if tags is not None:
- if isinstance(tags, (list, tuple)):
- all_tags.extend(tags)
- else:
- all_tags.append(tags)
- try:
- api.add_tags(repo_id = full_repo_id, tags = all_tags, repo_type = "model")
- except:
- pass
-
- print(
- f"Unsloth: Successfully uploaded GGUF to https://huggingface.co/{full_repo_id}"
- )
- return full_repo_id
-
-
-class FastSentenceTransformer(FastModel):
- @staticmethod
- def _read_pooling_mode(model_name, token):
- """
- Read the pooling mode from the modules.json file if it exists, otherwise return "mean".
- """
- try:
- if os.path.exists(model_name) and os.path.exists(
- os.path.join(model_name, "modules.json")
- ):
- modules_json_path = os.path.join(model_name, "modules.json")
- else:
- modules_json_path = hf_hub_download(
- model_name, "modules.json", token = token
- )
-
- with open(modules_json_path, "r") as f:
- modules_config = json.load(f)
-
- pooling_config_path = None
- for module in modules_config:
- if module.get("type", "") == "sentence_transformers.models.Pooling":
- pooling_path = module.get("path", "")
- if pooling_path:
- # try to find config.json for pooling module
- if os.path.exists(model_name) and os.path.exists(
- os.path.join(model_name, pooling_path, "config.json")
- ):
- pooling_config_path = os.path.join(
- model_name, pooling_path, "config.json"
- )
- else:
- pooling_config_path = hf_hub_download(
- model_name,
- os.path.join(pooling_path, "config.json"),
- token = token,
- )
- break
-
- if pooling_config_path:
- with open(pooling_config_path, "r") as f:
- pooling_config = json.load(f)
- # from here:
- # https://github.com/huggingface/sentence-transformers/blob/main/sentence_transformers/models/Pooling.py#L43
- pooling_map = {
- "pooling_mode_cls_token": "cls",
- "pooling_mode_mean_tokens": "mean",
- "pooling_mode_max_tokens": "max",
- "pooling_mode_mean_sqrt_len_tokens": "mean_sqrt_len",
- "pooling_mode_weightedmean_tokens": "weightedmean",
- "pooling_mode_lasttoken": "lasttoken",
- }
- for config_key, mode in pooling_map.items():
- if pooling_config.get(config_key):
- if mode != "mean":
- print(f"Pooling mode detected as {mode}, updating...")
- return mode
-
- except Exception as e:
- print(
- f"Failed to detect pooling mode, not a sentence-transformers model. Using default pooling mode 'mean', this may or may not work."
- )
- return "mean"
-
- # should prolly be done upstream instead of this hackfest here
- @staticmethod
- def _patch_mpnet_v4():
- """
- Patch the MPNetModel to support gradient checkpointing.
- Supports transformers 4.
- """
- from transformers.models.mpnet import modeling_mpnet
-
- # add supports_gradient_checkpointing flag
- modeling_mpnet.MPNetModel.supports_gradient_checkpointing = True
-
- # add _set_gradient_checkpointing method
- def _set_gradient_checkpointing(self, module = None, value = True):
- if module is None:
- module = self.encoder
- if isinstance(module, modeling_mpnet.MPNetEncoder):
- module.gradient_checkpointing = value
-
- modeling_mpnet.MPNetModel._set_gradient_checkpointing = (
- _set_gradient_checkpointing
- )
-
- # patch MPNetEncoder.forward to support checkpointing
- # based on:
- # https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/mpnet/modeling_mpnet.py#L321
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = False,
- **kwargs,
- ):
- position_bias = self.compute_position_bias(hidden_states)
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
-
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- # do gradient checkpointing if enabled and training
- if getattr(self, "gradient_checkpointing", False) and self.training:
-
- def create_custom_forward(module):
- # bog standard checkpoint
- def custom_forward(*inputs):
- return module(*inputs, output_attentions = output_attentions)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(layer_module),
- hidden_states,
- attention_mask,
- head_mask[i] if head_mask is not None else None,
- position_bias,
- use_reentrant = True, # fix for torch 2.9
- )
- else:
- # original code from here on
- layer_outputs = layer_module(
- hidden_states,
- attention_mask,
- head_mask[i] if head_mask is not None else None,
- position_bias,
- output_attentions = output_attentions,
- **kwargs,
- )
-
- hidden_states = layer_outputs[0]
-
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
-
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, all_hidden_states, all_attentions]
- if v is not None
- )
- return BaseModelOutput(
- last_hidden_state = hidden_states,
- hidden_states = all_hidden_states,
- attentions = all_attentions,
- )
-
- # assign the patched forward
- modeling_mpnet.MPNetEncoder.forward = forward
-
- @staticmethod
- def _patch_mpnet_v5():
- """
- Patch the MPNetModel to support gradient checkpointing.
- Supports transformers 5.
- """
- from transformers.models.mpnet import modeling_mpnet
-
- # add supports_gradient_checkpointing flag
- modeling_mpnet.MPNetModel.supports_gradient_checkpointing = True
-
- # add _set_gradient_checkpointing method
- def _set_gradient_checkpointing(self, module = None, value = True):
- if module is None:
- module = self.encoder
- if isinstance(module, modeling_mpnet.MPNetEncoder):
- module.gradient_checkpointing = value
-
- modeling_mpnet.MPNetModel._set_gradient_checkpointing = (
- _set_gradient_checkpointing
- )
-
- # patch MPNetEncoder.forward to support checkpointing
- # based on:
- # https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/mpnet/modeling_mpnet.py#L284
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = False,
- **kwargs,
- ):
- position_bias = self.compute_position_bias(hidden_states)
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
-
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- # do gradient checkpointing if enabled and training
- if getattr(self, "gradient_checkpointing", False) and self.training:
-
- def create_custom_forward(module):
- # checkpoint
- def custom_forward(*inputs):
- return module(*inputs, output_attentions = output_attentions)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(layer_module),
- hidden_states,
- attention_mask,
- position_bias,
- use_reentrant = True, # required for torch >= 2.9
- )
- else:
- # original code from here on
- layer_outputs = layer_module(
- hidden_states,
- attention_mask,
- position_bias,
- output_attentions,
- **kwargs,
- )
-
- hidden_states = layer_outputs[0]
-
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
-
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, all_hidden_states, all_attentions]
- if v is not None
- )
- return BaseModelOutput(
- last_hidden_state = hidden_states,
- hidden_states = all_hidden_states,
- attentions = all_attentions,
- )
-
- modeling_mpnet.MPNetEncoder.forward = forward
-
- @staticmethod
- def _patch_distilbert_v4():
- # change kwargs to positional args to be compatible with peft_utils
- """
- Patch the forward method of the DistilBertModel to use positional arguments instead of keyword arguments.
- Transformers 4 version.
- """
-
- # based on:
- # https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/distilbert/modeling_distilbert.py#L666
- # original code from here on:
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ):
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
-
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both input_ids and inputs_embeds at the same time"
- )
- elif input_ids is not None:
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
- input_shape = input_ids.size()
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError(
- "You have to specify either input_ids or inputs_embeds"
- )
-
- device = input_ids.device if input_ids is not None else inputs_embeds.device
-
- head_mask_is_none = head_mask is None
- # Prepare head mask if needed
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
-
- embeddings = self.embeddings(
- input_ids, inputs_embeds
- ) # (bs, seq_length, dim)
-
- if self.config._attn_implementation == "flash_attention_2":
- attention_mask = (
- attention_mask
- if (attention_mask is not None and 0 in attention_mask)
- else None
- )
- else:
- if attention_mask is None:
- attention_mask = torch.ones(
- input_shape, device = device
- ) # (bs, seq_length)
-
- if (
- self.config._attn_implementation == "sdpa"
- and head_mask_is_none
- and not output_attentions
- ):
- attention_mask = _prepare_4d_attention_mask_for_sdpa(
- attention_mask, embeddings.dtype, tgt_len = input_shape[1]
- )
- # patch here, change kwargs to positional args:
- return self.transformer(
- embeddings,
- attention_mask,
- head_mask,
- output_attentions,
- output_hidden_states,
- return_dict,
- )
-
- modeling_distilbert.DistilBertModel.forward = forward
-
- @staticmethod
- def _has_add_pooling_layer(config, auto_model_class = None):
- """
- Checks if the model class supports the `add_pooling_layer` argument
- """
- try:
- if auto_model_class is None:
- auto_model_class = AutoModel
- # try to resolve the class
- model_class = _get_model_class(config, auto_model_class._model_mapping)
-
- if model_class:
- sig = inspect.signature(model_class.__init__)
- return "add_pooling_layer" in sig.parameters
- except:
- pass
-
- return False
-
- @staticmethod
- def _patch_distilbert_v5():
- """
- Patch the forward method of the DistilBertModel to use positional arguments instead of keyword arguments.
- Transformers 5 version.
- """
- # based on:
- # https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/distilbert/modeling_distilbert.py#L386
- # original code from here on:
- from transformers.masking_utils import create_bidirectional_mask
-
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- **kwargs,
- ):
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError(
- "You must specify exactly one of input_ids or inputs_embeds"
- )
-
- embeddings = self.embeddings(input_ids, inputs_embeds, position_ids)
-
- attention_mask = create_bidirectional_mask(
- config = self.config,
- input_embeds = embeddings,
- attention_mask = attention_mask,
- )
-
- # patch here: unsloth gradient checkpointing hook needs positional arguments
- return self.transformer(
- embeddings,
- attention_mask,
- **kwargs,
- )
-
- modeling_distilbert.DistilBertModel.forward = forward
-
- @staticmethod
- def _add_unsloth_tags(repo_id, token, tags = None):
- """
- Add Unsloth and sentence-transformers tags to the Hugging Face Hub repository.
- """
- from huggingface_hub import HfApi
-
- api = HfApi(token = token)
- if tags is None:
- tags = []
- tags.extend(["unsloth", "sentence-transformers"])
- try:
- api.add_tags(
- repo_id = repo_id,
- tags = tags,
- repo_type = "model",
- )
- except:
- pass
-
- @staticmethod
- def _add_unsloth_branding(save_directory):
- """
- Add Unsloth branding to the README.md file generated by sentence-transformers.
- """
- readme_path = os.path.join(save_directory, "README.md")
- if not os.path.exists(readme_path):
- return
-
- with open(readme_path, "r", encoding = "utf-8") as f:
- content = f.read()
-
- # add unsloth tag to frontmatter
- if "---\ntags:\n" in content:
- content = content.replace("---\ntags:\n", "---\ntags:\n- unsloth\n")
- else:
- # if tags exist but not right at start, use regex to append
- pattern = r"(^tags:\s*\n)"
- if re.search(pattern, content, re.MULTILINE):
- content = re.sub(
- pattern, r"\1- unsloth\n", content, count = 1, flags = re.MULTILINE
- )
-
- # add branding badge and text
- branding = (
- "\n\nThis model was finetuned with [Unsloth](https://github.com/unslothai/unsloth).\n\n"
- '[
](https://github.com/unslothai/unsloth)\n'
- )
-
- # add to description
- if "# SentenceTransformer" in content:
- parts = content.split("# SentenceTransformer", 1)
- content = parts[0] + "# SentenceTransformer" + branding + parts[1]
- else:
- content += branding
-
- with open(readme_path, "w", encoding = "utf-8") as f:
- f.write(content)
-
- @staticmethod
- def _module_path(model_name, token = None):
- """
- Returns the path to the modules.json file or None
- """
- try:
- if os.path.exists(model_name) and os.path.isdir(model_name):
- path = os.path.join(model_name, "modules.json")
- return path if os.path.exists(path) else None
- else:
- try:
- return hf_hub_download(model_name, "modules.json", token = token)
- except:
- return None
- except:
- return None
-
- @staticmethod
- def _create_transformer_module(
- model_name,
- model,
- tokenizer,
- max_seq_length,
- trust_remote_code,
- ):
- """Helper to create and configure a Transformer module."""
- from sentence_transformers.models import Transformer
-
- # prevents sentence-transformers from loading the model a second time, thanks Etherl
- original_from_pretrained = AutoModel.from_pretrained
-
- def return_existing_model(*args, **kwargs):
- return model
-
- try:
- # Temporarily redirect AutoModel loading to return our pre-loaded model
- AutoModel.from_pretrained = return_existing_model
-
- # Initialize Transformer
- transformer_module = Transformer(
- model_name,
- max_seq_length = max_seq_length,
- model_args = {"trust_remote_code": trust_remote_code},
- config_args = {"trust_remote_code": trust_remote_code},
- )
- finally:
- # Restore original functionality immediately
- AutoModel.from_pretrained = original_from_pretrained
-
- transformer_module.tokenizer = tokenizer
- transformer_module.do_lower_case = getattr(tokenizer, "do_lower_case", False)
-
- # sentence-transformers only passes along known keys to model.forward
- model_forward_params = list(inspect.signature(model.forward).parameters)
- transformer_module.model_forward_params = set(model_forward_params) | {
- "input_ids",
- "attention_mask",
- "token_type_ids",
- "inputs_embeds",
- }
-
- # determine max_seq_length if not provided
- if max_seq_length is None:
- if hasattr(model, "config") and hasattr(
- model.config, "max_position_embeddings"
- ):
- max_seq_length = model.config.max_position_embeddings
- elif hasattr(tokenizer, "model_max_length"):
- max_seq_length = tokenizer.model_max_length
- else:
- max_seq_length = 512
-
- transformer_module.max_seq_length = max_seq_length
- transformer_module.config_keys = ["max_seq_length", "do_lower_case"]
- transformer_module.save_in_root = True
-
- if hasattr(model, "config"):
- model.config.tokenizer_class = tokenizer.__class__.__name__
-
- return transformer_module
-
- @staticmethod
- def _load_modules(
- model_name,
- token,
- model,
- tokenizer,
- max_seq_length,
- pooling_mode,
- trust_remote_code = False,
- ) -> tuple[OrderedDict, bool]:
- """
- Load modules from modules.json if available, otherwise fallback to hard-coded modules.
-
- Returns:
- tuple[OrderedDict, bool]: (modules, no_modules_json)
- """
- from sentence_transformers.util import import_from_string, load_dir_path
- from sentence_transformers.models import Pooling, Normalize
-
- modules = OrderedDict()
- modules_json_path = FastSentenceTransformer._module_path(model_name, token)
-
- if modules_json_path:
- with open(modules_json_path, encoding = "utf8") as f:
- modules_config = json.load(f)
-
- for module_config in modules_config:
- class_ref = module_config["type"]
- name = module_config.get(
- "name", str(module_config.get("idx", len(modules)))
- )
-
- if class_ref == "sentence_transformers.models.Transformer":
- transformer_module = (
- FastSentenceTransformer._create_transformer_module(
- model_name,
- model,
- tokenizer,
- max_seq_length,
- trust_remote_code,
- )
- )
- modules[name] = transformer_module
- else:
- # load other modules (Pooling, Normalize, etc.)
- module_path = module_config["path"]
- if os.path.isdir(model_name):
- load_path = os.path.join(model_name, module_path)
- else:
- try:
- load_path = load_dir_path(
- model_name, module_path, token = token
- )
- except Exception as e:
- print(
- f"Unsloth Warning: Could not download module {module_path}: {e}"
- )
- continue
-
- module_class = import_from_string(class_ref)
- try:
- module = module_class.load(load_path)
- modules[name] = module
- except Exception as e:
- print(
- f"Unsloth Warning: Failed to load module {name} ({class_ref}): {e}"
- )
-
- return modules, False
-
- # fallback if no modules.json (non sentence-transformers models)
- print(
- "Unsloth: No modules.json found, falling back to [Transformer, Pooling, Normalize]. This may or may not work."
- )
-
- transformer_module = FastSentenceTransformer._create_transformer_module(
- model_name, model, tokenizer, max_seq_length, trust_remote_code
- )
- modules["0"] = transformer_module
-
- hidden_size = getattr(model.config, "hidden_size", 768)
-
- if pooling_mode == "mean":
- pooling_mode = FastSentenceTransformer._read_pooling_mode(model_name, token)
-
- modules["1"] = Pooling(
- word_embedding_dimension = hidden_size, pooling_mode = pooling_mode
- )
- modules["2"] = Normalize()
-
- return modules, True
-
- # Encoder model types that benefit from native torch.compile instead of Unsloth patching
- ENCODER_MODEL_TYPES = {
- "mpnet",
- "bert",
- "distilbert",
- "modernbert",
- "roberta",
- "xlm-roberta",
- "albert",
- "electra",
- }
-
- @staticmethod
- def _estimate_compile_threshold(
- model,
- batch_size = None,
- grad_accum = None,
- max_seq_length = None,
- ):
- """
- Estimate the minimum training steps needed for torch.compile to be beneficial.
- Returns the threshold with a 1.2x safety margin built in.
-
- Based on empirical benchmarks:
- - Larger models have lower breakeven (more time saved per step)
- - Warmup time scales with model size but speedup also increases
-
- Optional inputs (batch_size, grad_accum, max_seq_length) allow
- a coarse pre-run adjustment. These are intentionally conservative
- and avoid any runtime measurements.
- """
- # Get parameter count from inner model
- if hasattr(model, "__getitem__"):
- try:
- inner = model[0].auto_model
- params = sum(p.numel() for p in inner.parameters())
- except:
- params = 100_000_000 # Default to 100M if can't determine
- else:
- params = sum(p.numel() for p in model.parameters())
-
- model_type = None
- try:
- if "inner" in locals():
- model_type = getattr(getattr(inner, "config", None), "model_type", None)
- except Exception:
- model_type = None
- if isinstance(model_type, str):
- model_type = model_type.lower()
-
- params_m = params / 1e6
-
- # Empirical formula based on benchmarks with batch_size=2, grad_accum=4
- # Small models: high fixed overhead, lower speedup
- # Large models: warmup scales but speedup is significant
- if params_m < 50:
- estimated_warmup = 35 + params_m * 0.3
- base_speedup = 1.35
- elif params_m < 200:
- estimated_warmup = 12 + params_m * 0.03
- base_speedup = 1.75
- else:
- estimated_warmup = 15 + params_m * 0.04
- base_speedup = 1.60
-
- # Estimate time per step (ms) and time saved
- naive_ms = 50 + params_m * 1.0
- compiled_ms = naive_ms / base_speedup
- time_saved_per_step_s = (naive_ms - compiled_ms) / 1000
-
- if time_saved_per_step_s > 0:
- breakeven = estimated_warmup / time_saved_per_step_s
- else:
- breakeven = float("inf")
-
- # Return threshold with 1.2x safety margin
- threshold = breakeven * 1.2
-
- # Optional adjustment based on expected work per step.
- # This uses only pre-run information (batch size, grad accum, seq length).
- generic_scale = 1.0
- fast_scale = 1.0
- if (
- batch_size is not None
- or grad_accum is not None
- or max_seq_length is not None
- ):
- try:
- bs = int(batch_size) if batch_size is not None else 2
- ga = int(grad_accum) if grad_accum is not None else 4
- seq = int(max_seq_length) if max_seq_length is not None else 512
- except Exception:
- bs, ga, seq = 2, 4, 512
-
- bs = max(1, bs)
- ga = max(1, ga)
- # Guard against unbounded tokenizer.model_max_length
- seq = max(64, min(seq, 8192))
-
- ref_bs, ref_ga, ref_seq = 2, 4, 512
-
- # Generic path: lighter scaling, less conservative than params-only.
- ga_scale = (ref_ga / ga) ** 1.0
- bs_seq_scale = ((ref_bs * ref_seq) / (bs * seq)) ** 0.15
- generic_scale = 0.35 * ga_scale * bs_seq_scale
- generic_scale = max(0.05, min(generic_scale, 5.0))
-
- # Fast encoder path: stronger scaling based on observed behavior.
- fast_ga_scale = (ref_ga / ga) ** 1.5
- fast_bs_seq_scale = ((ref_bs * ref_seq) / (bs * seq)) ** 0.25
- fast_scale = 0.2 * fast_ga_scale * fast_bs_seq_scale
- fast_scale = max(0.05, min(fast_scale, 5.0))
-
- # Conservative safety factors: generic is less conservative than fast.
- generic_threshold = threshold * generic_scale * 1.25
-
- is_fast_type = (
- isinstance(model_type, str)
- and model_type in FastSentenceTransformer.ENCODER_MODEL_TYPES
- )
- if is_fast_type:
- fast_threshold = threshold * fast_scale * 1.5
- # Prefer the smaller (less conservative) of the two estimates.
- final_threshold = min(generic_threshold, fast_threshold)
- else:
- final_threshold = generic_threshold
-
- # Reduce mpnet overestimation slightly.
- if model_type == "mpnet":
- final_threshold *= 0.7
-
- # Lower bound to avoid compiling on extremely short runs.
- return int(max(20, final_threshold))
-
- @staticmethod
- def _apply_torch_compile(model, mode = "default"):
- """
- Apply torch.compile to a SentenceTransformer model.
- Includes workaround for accelerate's unwrap_model bug.
- """
- if hasattr(model, "__getitem__"):
- inner_model = model[0].auto_model
- compiled = torch.compile(inner_model, mode = mode)
- model[0].auto_model = compiled
- # Fix for accelerate unwrap_model bug:
- # When SentenceTransformer contains a compiled inner model,
- # accelerate checks has_compiled_regions() which returns True,
- # then tries to access model.__dict__["_orig_mod"] which fails.
- # This workaround sets _orig_mod to satisfy accelerate.
- model.__dict__["_orig_mod"] = model
- else:
- model = torch.compile(model, mode = mode)
- return model
-
- @staticmethod
- def from_pretrained(
- model_name,
- max_seq_length = None,
- dtype = None,
- load_in_4bit = False, # Changed default: 4-bit is slow for encoders
- load_in_8bit = False,
- load_in_16bit = True, # Changed default: 16-bit is optimal for encoders
- full_finetuning = False,
- token = None,
- device_map = "sequential",
- rope_scaling = None,
- fix_tokenizer = True,
- trust_remote_code = False,
- use_gradient_checkpointing = False, # Changed default: conflicts with torch.compile
- resize_model_vocab = None,
- revision = None,
- use_exact_model_name = False,
- offload_embedding = False,
- random_state = 3407,
- max_lora_rank = 64,
- disable_log_stats = True,
- qat_scheme = None,
- unsloth_tiled_mlp = False,
- pooling_mode = "mean",
- for_inference = False,
- **kwargs,
- ):
- try:
- from sentence_transformers import SentenceTransformer
- from sentence_transformers.models import Transformer, Pooling, Normalize
- except ImportError:
- raise ImportError(
- "Unsloth: To use `FastSentenceTransformer`, you must install `sentence-transformers`.\n"
- "Run `pip install sentence-transformers` to install it."
- )
-
- # if for_inference == True, skip Unsloth optimizations to avoid torch compile issues
- if for_inference:
- st_device = device_map
- if isinstance(st_device, dict) or (
- isinstance(st_device, str) and st_device in ["auto", "sequential"]
- ):
- st_device = None
-
- # this was added because when loading for inference it was defaulting to float32
- # propagate dtype to model_kwargs, default to "auto"
- model_kwargs = kwargs.get("model_kwargs", {})
- model_kwargs["dtype"] = dtype if dtype is not None else "auto"
-
- # filter kwargs for SentenceTransformer
- st_kwargs = {
- "device": st_device,
- "trust_remote_code": trust_remote_code,
- "token": token,
- "revision": revision,
- "model_kwargs": model_kwargs,
- }
-
- # add other known kwargs if present
- known_keys = [
- "cache_folder",
- "truncate_dim",
- "tokenizer_kwargs",
- "config_kwargs",
- ]
- for k in known_keys:
- if k in kwargs:
- st_kwargs[k] = kwargs[k]
-
- st_model = SentenceTransformer(model_name, **st_kwargs)
- return st_model
-
- # sanity check, thanks Etherl:
- if full_finetuning and (load_in_4bit or load_in_8bit):
- print(
- "Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA."
- )
- load_in_4bit = False
- load_in_8bit = False
- load_in_fp8 = False
- load_in_16bit = False
-
- if int(load_in_4bit) + int(load_in_8bit) + int(load_in_16bit) >= 2:
- raise RuntimeError(
- "Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!\n"
- "Also, we by default set `load_in_16bit = True`.\n"
- "If you want 4bit LoRA finetuning, set `load_in_16bit = False` and `load_in_4bit = True`\n"
- "If you want 8bit finetuning, set both `load_in_16bit = False` and `load_in_8bit = True`"
- )
-
- if "auto_model" not in kwargs:
- kwargs["auto_model"] = AutoModel
-
- transformers4 = Version(transformers.__version__).major < 5
- model_type = ""
- config = None
- try:
- config = AutoConfig.from_pretrained(
- model_name, token = token, trust_remote_code = trust_remote_code
- )
- model_type = getattr(config, "model_type", "")
- except:
- pass
-
- # Fast encoder path: Use native torch.compile for encoder models (6x speedup)
- # This bypasses Unsloth's auto-compiler which adds @torch.compiler.disable decorators
- # that interfere with torch.compile and cause runtime errors for encoder models.
- # NOTE: The old Unsloth path is BROKEN for encoder models with torch 2.9+ due to
- # conflicting @torch.compile and @torch.compiler.disable decorators.
- # Set UNSLOTH_COMPILE_DISABLE=1 to disable torch.compile and use the old path.
- is_encoder_model = (
- model_type.lower() in FastSentenceTransformer.ENCODER_MODEL_TYPES
- )
- use_fast_encoder = os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") != "1"
- if use_fast_encoder and is_encoder_model:
- # torch.compile mode: "default" is safest for PEFT/LoRA training
- # Note: "reduce-overhead" uses CUDA Graphs which is incompatible with PEFT
- compile_mode = "default"
-
- # Determine dtype - handle float16 machines that don't support bfloat16
- if dtype is None:
- if load_in_16bit:
- dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
- else:
- dtype = torch.float32
- elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
- print(
- "Unsloth: Device does not support bfloat16. Using float16 instead."
- )
- dtype = torch.float16
-
- # Determine device
- st_device = device_map
- if isinstance(st_device, dict) or (
- isinstance(st_device, str) and st_device in ["auto", "sequential"]
- ):
- st_device = "cuda"
-
- # Check if model supports SDPA (Scaled Dot Product Attention) for extra speedup
- supports_sdpa = False
- if config is not None:
- try:
- model_class = _get_model_class(
- config, kwargs.get("auto_model", AutoModel)._model_mapping
- )
- supports_sdpa = getattr(model_class, "_supports_sdpa", False)
- except:
- pass
-
- # Build model_kwargs for SentenceTransformer
- model_kwargs = {"torch_dtype": dtype}
-
- # Enable SDPA if supported (1.2x extra speedup on top of torch.compile)
- # But disable for models with known SDPA + torch.compile backward issues
- _force_eager = False
- for _sdpa_model in DISABLE_SDPA_MODEL_NAMES:
- if _sdpa_model in model_type.lower():
- supports_sdpa = False
- _force_eager = True
- break
- if supports_sdpa:
- model_kwargs["attn_implementation"] = "sdpa"
- elif _force_eager:
- model_kwargs["attn_implementation"] = "eager"
-
- # Print optimization status
- sdpa_str = " + SDPA" if supports_sdpa else ""
- if load_in_4bit:
- print(
- f"Unsloth: Using fast encoder path for {model_type} with 4-bit quantization{sdpa_str}"
- )
- else:
- print(
- f"Unsloth: Using fast encoder path for {model_type} (torch.compile{sdpa_str})"
- )
-
- # Handle 4-bit quantization via BitsAndBytesConfig
- if load_in_4bit:
- from transformers import BitsAndBytesConfig
-
- bnb_config = BitsAndBytesConfig(
- load_in_4bit = True,
- bnb_4bit_compute_dtype = dtype,
- bnb_4bit_quant_type = "nf4",
- bnb_4bit_use_double_quant = True,
- )
- model_kwargs["quantization_config"] = bnb_config
- # When using quantization, device must be handled by accelerate
- st_device = None
-
- # Handle gradient checkpointing - warn user it conflicts with torch.compile
- _use_gc = use_gradient_checkpointing
- if _use_gc and _use_gc != False:
- print(
- "Unsloth Warning: Gradient checkpointing is incompatible with torch.compile."
- )
- print("Disabling torch.compile to enable gradient checkpointing.")
- compile_mode = None # Disable compilation
-
- is_mpnet = "mpnet" == model_type.lower()
-
- if is_mpnet and transformers4:
- FastSentenceTransformer._patch_mpnet_v4()
- elif is_mpnet:
- FastSentenceTransformer._patch_mpnet_v5()
-
- # Load via native SentenceTransformer (bypasses Unsloth patching)
- st_model = SentenceTransformer(
- model_name,
- device = st_device,
- trust_remote_code = trust_remote_code,
- token = token,
- revision = revision,
- model_kwargs = model_kwargs,
- )
-
- # Store metadata for get_peft_model
- st_model._unsloth_fast_encoder = True
- st_model._compile_mode = compile_mode
- st_model._dtype = dtype
- st_model._load_in_4bit = load_in_4bit
- st_model.no_modules = False
-
- # Add save methods
- def _save_pretrained_merged(self, save_directory, **save_kwargs):
- self.save_pretrained(save_directory)
- tokenizer = save_kwargs.pop("tokenizer", self.tokenizer)
- if hasattr(self[0], "auto_model"):
- inner = self[0].auto_model
- # Handle compiled model
- if hasattr(inner, "_orig_mod"):
- inner = inner._orig_mod
- if hasattr(inner, "merge_and_unload"):
- merged = inner.merge_and_unload()
- merged.save_pretrained(save_directory)
- elif hasattr(inner, "save_pretrained"):
- inner.save_pretrained(save_directory)
- if tokenizer is not None:
- tokenizer.save_pretrained(save_directory)
- FastSentenceTransformer._add_unsloth_branding(save_directory)
-
- st_model.save_pretrained_merged = types.MethodType(
- _save_pretrained_merged, st_model
- )
-
- st_model.save_pretrained_torchao = types.MethodType(
- _save_pretrained_torchao, st_model
- )
-
- st_model.save_pretrained_gguf = types.MethodType(
- _save_pretrained_gguf, st_model
- )
-
- st_model.push_to_hub_gguf = types.MethodType(_push_to_hub_gguf, st_model)
-
- def _push_to_hub_merged(self, repo_id, **push_kwargs):
- hub_token = push_kwargs.get("token", None) or get_token()
- if hub_token is None:
- raise ValueError("No HF token provided")
- api = HfApi(token = hub_token)
- try:
- api.create_repo(
- repo_id = repo_id,
- private = push_kwargs.get("private"),
- exist_ok = True,
- repo_type = "model",
- )
- except:
- pass
- FastSentenceTransformer._add_unsloth_tags(repo_id, hub_token)
- with tempfile.TemporaryDirectory() as temp_dir:
- self.save_pretrained_merged(temp_dir, **push_kwargs)
- api.upload_folder(
- folder_path = temp_dir,
- repo_id = repo_id,
- commit_message = push_kwargs.get(
- "commit_message", "Upload model"
- ),
- )
- print(f"Unsloth: Pushed to https://huggingface.co/{repo_id}")
-
- st_model.push_to_hub_merged = types.MethodType(
- _push_to_hub_merged, st_model
- )
-
- return st_model
-
- # Warn if using 4-bit with encoder (slow due to dequantization overhead)
- if is_encoder_model and load_in_4bit:
- print(
- "Unsloth Warning: 4-bit quantization adds ~2.3x overhead for encoder models."
- )
- print("Consider using load_in_16bit=True for better performance.")
-
- # check if the model supports add_pooling_layer
- if "add_pooling_layer" not in kwargs:
- supported = FastSentenceTransformer._has_add_pooling_layer(
- config, kwargs.get("auto_model", AutoModel)
- )
- if supported:
- kwargs["add_pooling_layer"] = False
-
- # forces fp8 to be False since it's not supported
- fp8 = kwargs.pop("load_in_fp8", None)
- if fp8:
- logging.info("Unsloth: Disabling fp8 for model")
- load_in_fp8 = False
-
- # this is a fix for Snowflake/snowflake-arctic-embed-l-v2.0
- # it has pooler weights which we don't care about for training,
- # however unsloth throws an exception if "UNSLOTH_WARN_UNINITIALIZED" == 1 and it sees unused weights
- old_environ = os.environ.get("UNSLOTH_WARN_UNINITIALIZED", "1")
- os.environ["UNSLOTH_WARN_UNINITIALIZED"] = "0"
-
- is_distilbert = "distilbert" == model_type.lower()
- is_mpnet = "mpnet" == model_type.lower()
-
- if is_distilbert and transformers4:
- FastSentenceTransformer._patch_distilbert_v4()
- elif is_distilbert:
- FastSentenceTransformer._patch_distilbert_v5()
- elif is_mpnet and transformers4:
- FastSentenceTransformer._patch_mpnet_v4()
- elif is_mpnet:
- FastSentenceTransformer._patch_mpnet_v5()
-
- # check if modules.json exists - if not, force 16-bit training
- # why? because i have to implement saving myself for these models, and i don't feel like adding dequantization
- # to the save_pretrained_merged for a model that really should be trained in 16-bit anyway
- has_modules_json = (
- FastSentenceTransformer._module_path(model_name, token) is not None
- )
-
- if not has_modules_json and load_in_4bit:
- print(
- "Unsloth: No modules.json found. This is not a sentence-transformers model.\n"
- "Forcing 16-bit loading to simplify merged model saving."
- )
- load_in_4bit = False
- load_in_16bit = True
-
- try:
- model, tokenizer = FastModel.from_pretrained(
- model_name = model_name,
- max_seq_length = max_seq_length,
- dtype = dtype,
- load_in_4bit = load_in_4bit,
- load_in_8bit = load_in_8bit,
- load_in_16bit = load_in_16bit,
- full_finetuning = full_finetuning,
- token = token,
- device_map = device_map,
- rope_scaling = rope_scaling,
- fix_tokenizer = fix_tokenizer,
- trust_remote_code = trust_remote_code,
- use_gradient_checkpointing = use_gradient_checkpointing,
- resize_model_vocab = resize_model_vocab,
- revision = revision,
- return_logits = False,
- use_exact_model_name = use_exact_model_name,
- offload_embedding = offload_embedding,
- random_state = random_state,
- max_lora_rank = max_lora_rank,
- disable_log_stats = disable_log_stats,
- qat_scheme = qat_scheme,
- load_in_fp8 = load_in_fp8,
- unsloth_tiled_mlp = unsloth_tiled_mlp,
- **kwargs,
- )
- finally:
- os.environ["UNSLOTH_WARN_UNINITIALIZED"] = old_environ
-
- # try to load modules, otherwise fallback to old hard-coded modules
- from sentence_transformers import SentenceTransformer
-
- modules, no_modules = FastSentenceTransformer._load_modules(
- model_name,
- token,
- model,
- tokenizer,
- max_seq_length,
- pooling_mode,
- trust_remote_code = trust_remote_code,
- )
-
- st_device = device_map
- if isinstance(st_device, dict) or (
- isinstance(st_device, str) and st_device in ["auto", "sequential"]
- ):
- st_device = None
-
- st_model = SentenceTransformer(modules = modules, device = st_device)
- st_model.no_modules = no_modules
-
- def _save_pretrained_merged(self, save_directory, **kwargs):
- # check which adapter files exist before save_pretrained
- adapter_files = ["adapter_model.safetensors", "adapter_config.json"]
- existing_before = {
- f
- for f in adapter_files
- if os.path.exists(os.path.join(save_directory, f))
- }
-
- # sentence-transformers config and modules only get saved if we call save_pretrained
- self.save_pretrained(save_directory)
-
- # remove LoRA adapters only if they were created by save_pretrained (not pre-existing)
- for file in adapter_files:
- if file not in existing_before:
- try:
- os.remove(os.path.join(save_directory, file))
- except:
- pass
-
- tokenizer = kwargs.pop("tokenizer", self.tokenizer)
- if self.no_modules:
- # fallback for non-sentence-transformers models
- print(
- "Unsloth: No modules detected. Using standard merge_and_unload for saving..."
- )
- safe_kwargs = kwargs.copy()
- # filter out Unsloth-specific args that are not in huggingface's save_pretrained
- unsloth_args = [
- "save_method",
- "temporary_location",
- "maximum_memory_usage",
- ]
- for k in unsloth_args:
- safe_kwargs.pop(k, None)
-
- merged_model = self[0].auto_model.merge_and_unload()
- merged_model.save_pretrained(save_directory, **safe_kwargs)
- if tokenizer is not None:
- tokenizer.save_pretrained(save_directory)
- else:
- self[0].auto_model.save_pretrained_merged(
- save_directory, tokenizer = tokenizer, **kwargs
- )
-
- # add Unsloth branding to the generated README
- try:
- FastSentenceTransformer._add_unsloth_branding(save_directory)
- except Exception as e:
- print(f"Unsloth Warning: Failed to add branding to README: {e}")
-
- st_model.save_pretrained_merged = types.MethodType(
- _save_pretrained_merged, st_model
- )
-
- st_model.save_pretrained_torchao = types.MethodType(
- _save_pretrained_torchao, st_model
- )
-
- st_model.save_pretrained_gguf = types.MethodType(
- _save_pretrained_gguf, st_model
- )
-
- st_model.push_to_hub_gguf = types.MethodType(_push_to_hub_gguf, st_model)
-
- def _push_to_hub_merged(self, repo_id, **kwargs):
- token = kwargs.get("token", None) or get_token()
- if token is None:
- raise ValueError(
- "No HF token provided. Please provide a token or login with `hf auth login`"
- )
- private = kwargs.get("private", None)
- commit_message = kwargs.get("commit_message", "Upload model")
-
- from huggingface_hub import HfApi
-
- api = HfApi(token = token)
- try:
- api.create_repo(
- repo_id = repo_id,
- private = private,
- exist_ok = True,
- repo_type = "model",
- )
- except:
- pass
-
- # order doesn't seem to matter for this after repo creation...
- FastSentenceTransformer._add_unsloth_tags(repo_id, token)
-
- with tempfile.TemporaryDirectory() as temp_dir:
- self.save_pretrained_merged(temp_dir, **kwargs)
- api.upload_folder(
- folder_path = temp_dir,
- repo_id = repo_id,
- commit_message = commit_message,
- )
- print(
- f"Unsloth: Successfully pushed merged model to https://huggingface.co/{repo_id}"
- )
-
- st_model.push_to_hub_merged = types.MethodType(_push_to_hub_merged, st_model)
- return st_model
-
- @staticmethod
- def get_peft_model(
- model,
- r = 16,
- target_modules = [
- "query",
- "key",
- "value",
- "dense",
- ],
- lora_alpha = 16,
- lora_dropout = 0.0,
- bias = "none",
- layers_to_transform = None,
- layers_pattern = None,
- use_gradient_checkpointing = False, # Changed default: conflicts with torch.compile
- random_state = 3407,
- max_seq_length = 2048,
- use_rslora = False,
- modules_to_save = None,
- init_lora_weights = True,
- loftq_config = {},
- **kwargs,
- ):
- from sentence_transformers import SentenceTransformer
- from peft import LoraConfig, get_peft_model as peft_get_peft_model
-
- if "task_type" not in kwargs:
- kwargs["task_type"] = "FEATURE_EXTRACTION"
- print("Setting task_type to FEATURE_EXTRACTION")
-
- if isinstance(model, SentenceTransformer):
- # Check if this is a fast encoder model (uses torch.compile instead of Unsloth patching)
- is_fast_encoder = getattr(model, "_unsloth_fast_encoder", False)
-
- if is_fast_encoder:
- # Fast encoder path: Use native PEFT + torch.compile (6x speedup)
- transformer_module = model[0]
- inner_model = transformer_module.auto_model
-
- # Check if model is quantized (4-bit/8-bit)
- is_quantized = (
- getattr(inner_model, "is_quantized", False)
- or getattr(inner_model.config, "quantization_config", None)
- is not None
- )
-
- # Track if gradient checkpointing was actually enabled
- gc_enabled = False
-
- # this is needed when from_pretrained was called without gradient
- # checkpointing but get_peft_model requests it
- if use_gradient_checkpointing and use_gradient_checkpointing != False:
- import transformers
- from packaging.version import Version
-
- transformers4 = Version(transformers.__version__).major < 5
- model_type = getattr(inner_model.config, "model_type", "").lower()
-
- if model_type == "mpnet" and transformers4:
- FastSentenceTransformer._patch_mpnet_v4()
- elif model_type == "mpnet":
- FastSentenceTransformer._patch_mpnet_v5()
-
- # Prepare for k-bit training if quantized
- if is_quantized:
- from ._utils import prepare_model_for_kbit_training
-
- _gc_for_kbit = (
- use_gradient_checkpointing
- if use_gradient_checkpointing
- else False
- )
- try:
- inner_model = prepare_model_for_kbit_training(
- inner_model,
- use_gradient_checkpointing = _gc_for_kbit,
- )
- print("Unsloth: Prepared quantized model for k-bit training")
- gc_enabled = bool(_gc_for_kbit)
- except ValueError as e:
- if "does not support gradient checkpointing" in str(e):
- # Model doesn't support gradient checkpointing, disable it
- print(
- f"Unsloth Warning: {inner_model.__class__.__name__} does not support gradient checkpointing. Skipping."
- )
- inner_model = prepare_model_for_kbit_training(
- inner_model,
- use_gradient_checkpointing = False,
- )
- print(
- "Unsloth: Prepared quantized model for k-bit training (without gradient checkpointing)"
- )
- else:
- raise
-
- # Enable gradient checkpointing if requested (only for non-quantized, since prepare_model handles it)
- elif use_gradient_checkpointing and use_gradient_checkpointing != False:
- if hasattr(inner_model, "gradient_checkpointing_enable"):
- try:
- inner_model.gradient_checkpointing_enable()
- print("Unsloth: Enabled gradient checkpointing")
- gc_enabled = True
- except ValueError as e:
- if "does not support gradient checkpointing" in str(e):
- print(
- f"Unsloth Warning: {inner_model.__class__.__name__} does not support gradient checkpointing. Skipping."
- )
-
- # Create LoRA config
- lora_config = LoraConfig(
- r = r,
- lora_alpha = lora_alpha,
- target_modules = target_modules,
- lora_dropout = lora_dropout,
- bias = bias,
- task_type = kwargs.get("task_type", "FEATURE_EXTRACTION"),
- )
-
- # Apply PEFT directly (not through FastModel)
- peft_model = peft_get_peft_model(inner_model, lora_config)
-
- # Apply QAT if specified
- qat_scheme = kwargs.get("qat_scheme", None)
- if qat_scheme is not None:
- from ._utils import _prepare_model_for_qat
-
- peft_model = _prepare_model_for_qat(peft_model, qat_scheme)
-
- # Determine compile mode (only if not using gradient checkpointing)
- compile_mode = getattr(model, "_compile_mode", "default")
- # Re-enable torch.compile if gradient checkpointing was requested but couldn't be enabled
- if compile_mode is None and not gc_enabled:
- compile_mode = "default"
- print(
- "Unsloth: Re-enabling torch.compile since gradient checkpointing is not supported"
- )
-
- # Re-assign the peft model back to the transformer module
- transformer_module.auto_model = peft_model
-
- # Store compile info for auto-compile at trainer time
- # torch.compile is deferred until training starts so we can check max_steps
- if compile_mode is not None:
- model._compile_mode = compile_mode
- model._compile_threshold = (
- FastSentenceTransformer._estimate_compile_threshold(model)
- )
- # Flag to indicate compile has not been applied yet
- model._compile_pending = True
- print(
- f"Unsloth: torch.compile will be applied automatically if max_steps > {model._compile_threshold}"
- )
- else:
- model._compile_mode = None
- model._compile_pending = False
- print(
- "Unsloth: torch.compile disabled (gradient checkpointing enabled)"
- )
-
- return model
-
- # Original path for non-fast-encoder models
- transformer_module = model[0]
- inner_model = transformer_module.auto_model
-
- peft_model = FastModel.get_peft_model(
- model = inner_model,
- r = r,
- target_modules = target_modules,
- lora_alpha = lora_alpha,
- lora_dropout = lora_dropout,
- bias = bias,
- layers_to_transform = layers_to_transform,
- layers_pattern = layers_pattern,
- use_gradient_checkpointing = use_gradient_checkpointing,
- random_state = random_state,
- max_seq_length = max_seq_length,
- use_rslora = use_rslora,
- modules_to_save = modules_to_save,
- init_lora_weights = init_lora_weights,
- loftq_config = loftq_config,
- **kwargs,
- )
-
- # re-assign the peft model back to the transformer module
- transformer_module.auto_model = peft_model
- return model
- else:
- return FastModel.get_peft_model(
- model = model,
- r = r,
- target_modules = target_modules,
- lora_alpha = lora_alpha,
- lora_dropout = lora_dropout,
- bias = bias,
- layers_to_transform = layers_to_transform,
- layers_pattern = layers_pattern,
- use_gradient_checkpointing = use_gradient_checkpointing,
- random_state = random_state,
- max_seq_length = max_seq_length,
- use_rslora = use_rslora,
- modules_to_save = modules_to_save,
- init_lora_weights = init_lora_weights,
- loftq_config = loftq_config,
- **kwargs,
- )
-
-
-def _patch_sentence_transformer_trainer():
- """
- Patch SentenceTransformerTrainer to automatically apply torch.compile
- when training steps exceed the breakeven threshold.
-
- This is called automatically when this module is imported.
- """
- try:
- from sentence_transformers import SentenceTransformerTrainer
- except ImportError:
- return # sentence_transformers not installed
-
- if getattr(SentenceTransformerTrainer, "_unsloth_auto_compile_patched", False):
- return # Already patched
-
- from functools import wraps
-
- _original_init = SentenceTransformerTrainer.__init__
-
- @wraps(_original_init)
- def _patched_init(self, *args, **kwargs):
- # Extract model and training_args
- model = kwargs.get("model") or (args[0] if args else None)
- training_args = kwargs.get("args") or (args[1] if len(args) > 1 else None)
-
- # Check if model has pending compile
- if (
- model is not None
- and training_args is not None
- and getattr(model, "_compile_pending", False)
- ):
- max_steps = getattr(training_args, "max_steps", -1)
- compile_mode = getattr(model, "_compile_mode", "default")
-
- # Re-estimate threshold now that training args are available
- batch_size = getattr(training_args, "per_device_train_batch_size", None)
- grad_accum = getattr(training_args, "gradient_accumulation_steps", None)
- max_seq_length = getattr(model, "max_seq_length", None)
- if max_seq_length is None and hasattr(model, "__getitem__"):
- try:
- max_seq_length = getattr(model[0], "max_seq_length", None)
- except Exception:
- max_seq_length = None
- if max_seq_length is None:
- tokenizer = getattr(model, "tokenizer", None)
- max_seq_length = (
- getattr(tokenizer, "model_max_length", None)
- if tokenizer is not None
- else None
- )
-
- threshold = FastSentenceTransformer._estimate_compile_threshold(
- model,
- batch_size = batch_size,
- grad_accum = grad_accum,
- max_seq_length = max_seq_length,
- )
- model._compile_threshold = threshold
-
- if max_steps > 0 and max_steps >= threshold:
- print(
- f"Unsloth: Auto-compiling model ({max_steps} steps >= {threshold} threshold)"
- )
- FastSentenceTransformer._apply_torch_compile(model, mode = compile_mode)
- model._compile_pending = False
- elif max_steps > 0:
- print(
- f"Unsloth: Skipping torch.compile ({max_steps} steps < {threshold} threshold)"
- )
- model._compile_pending = False
-
- # Call original __init__
- _original_init(self, *args, **kwargs)
-
- # Disable mixed precision when FORCE_FLOAT32 is active (matches rl.py behavior)
- if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
- if hasattr(self, "args") and self.args is not None:
- if self.args.fp16 or self.args.bf16:
- print(
- "Unsloth: Switching to float32 training since model cannot work with float16"
- )
- self.args.fp16 = False
- self.args.bf16 = False
- if hasattr(self.args, "bf16_full_eval"):
- self.args.bf16_full_eval = False
- if hasattr(self.args, "fp16_full_eval"):
- self.args.fp16_full_eval = False
-
- SentenceTransformerTrainer.__init__ = _patched_init
- SentenceTransformerTrainer._unsloth_auto_compile_patched = True
-
-
-# Auto-patch trainer on module import
-_patch_sentence_transformer_trainer()
diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py
index a8adba99e7..24015f82fe 100644
--- a/unsloth/models/vision.py
+++ b/unsloth/models/vision.py
@@ -19,27 +19,18 @@
AutoTokenizer,
AutoModelForCausalLM,
)
-
try:
from transformers import AutoModelForImageTextToText
-
AutoModelForVision2Seq = AutoModelForImageTextToText
except:
from transformers import AutoModelForVision2Seq
+pass
from ..kernels import (
post_patch_loss_function,
)
-from ._utils import __version__, importlib_version, _prepare_model_for_qat
+from ._utils import __version__
from ._utils import *
-from .loader_utils import _get_fp8_mode_and_check_settings
from ..save import patch_saving_functions
-from ..models.loader_utils import is_distributed
-from unsloth_zoo.gradient_checkpointing import (
- unpatch_unsloth_gradient_checkpointing,
- unpatch_unsloth_smart_gradient_checkpointing,
-)
-import torch.utils.checkpoint as torch_checkpoint
-import transformers.modeling_utils as hf_modeling_utils
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
from peft import PeftModelForCausalLM
from transformers import set_seed as transformers_set_seed
@@ -52,119 +43,48 @@
from transformers import __version__ as transformers_version
from triton import __version__ as triton_version
from unsloth_zoo.utils import _get_dtype
-from unsloth_zoo.hf_utils import (
- dtype_from_config,
- add_dtype_kwargs,
- fix_lora_auto_mapping,
- get_auto_processor,
-)
from unsloth_zoo.patching_utils import patch_model_and_tokenizer
from unsloth_zoo.training_utils import prepare_model_for_training
-
-from unsloth_zoo.utils import Version
-from transformers import __version__ as transformers_version
-
import types
import functools
import os
import gc
import math
+import functools
from typing import Optional, Tuple, List, Union
import re, inspect, sys
-import contextlib
-
+import types
try:
from huggingface_hub.utils import get_token
except:
# Old HF Hub versions <= 0.0.25
from huggingface_hub.utils._token import get_token
-from ..device_type import (
- is_hip,
- get_device_type,
- DEVICE_TYPE,
- DEVICE_TYPE_TORCH,
- DEVICE_COUNT,
- ALLOW_PREQUANTIZED_MODELS,
-)
+pass
__all__ = [
"FastBaseModel",
]
-global NUM_LOGITS_TO_KEEP
-NUM_LOGITS_TO_KEEP = dict()
-
-VLLM_SUPPORTED_VLM = [
- "qwen2_5_vl",
+global FORCE_FLOAT32
+FORCE_FLOAT32 = [
"gemma3",
- "mistral3",
- "qwen3_vl",
- "qwen3_vl_moe",
-]
-VLLM_NON_LORA_VLM = [
- "mllama",
-]
-PRE_COMPILE_INFERENCE = [
- "gpt_oss",
]
-from transformers import GenerationConfig, CompileConfig, AutoConfig
-
-try:
- from transformers import PreTrainedConfig
-
- PretrainedConfig = PreTrainedConfig
-except:
- from transformers import PretrainedConfig
-
-HAS_TORCH_DTYPE = "torch_dtype" in PretrainedConfig.__doc__
-
-_compile_config = CompileConfig(
- fullgraph = False,
- dynamic = None,
- mode = "reduce-overhead",
-)
-_compile_config.disable = True # Must set manually
-
-try:
- torch_compiler_set_stance = torch.compiler.set_stance
-except:
- torch_compiler_set_stance = None
+global FORCE_EAGER_ATTENTION
+FORCE_EAGER_ATTENTION = [
+ "pixtral", # Pixtral SDPA not implemented
+]
+global NUM_LOGITS_TO_KEEP
+NUM_LOGITS_TO_KEEP = dict()
def unsloth_base_fast_generate(
self,
*args,
**kwargs,
):
- if len(args) != 0:
- input_ids = args[0]
- elif "input_ids" in kwargs:
- input_ids = kwargs["input_ids"]
- elif "input" in kwargs:
- input_ids = kwargs["input"]
- elif "input_features" in kwargs:
- input_ids = kwargs["input_features"]
- elif "input_embeds" in kwargs:
- input_ids = kwargs["input_embeds"]
- elif "inputs" in kwargs:
- input_ids = kwargs["inputs"]
- else:
- key = next(iter(kwargs.keys()))
- if type(kwargs[key]) is not torch.Tensor:
- raise TypeError("Unsloth: You need to pass in input_ids to .generate!")
- input_ids = kwargs[key]
- assert type(input_ids) is torch.Tensor
- bsz = input_ids.shape[0]
-
FastBaseModel.for_inference(self)
- dtype = _get_dtype(dtype_from_config(self.config))
- # Handle full float32 cases as config.dtype == torch.float32!
- do_bfloat16_mixed_precision = (
- os.environ.get("UNSLOTH_BFLOAT16_MIXED_PRECISION", "0") == "1"
- )
- if do_bfloat16_mixed_precision:
- dtype = torch.bfloat16
+ dtype = _get_dtype(self.config.torch_dtype)
# Check if VLM
is_vlm = any(
@@ -174,35 +94,38 @@ def unsloth_base_fast_generate(
is_vlm = is_vlm or hasattr(self.config, "vision_config")
arch = self.config.architectures[0]
- # Remove token_type_ids - WRONG for Gemma 3 since bidirectional attention
- if hasattr(self, "generate") and hasattr(self, "forward"):
- # did not combine with below since self might not have model
- keys = inspect.signature(self.forward).parameters.keys()
- if "token_type_ids" not in keys:
- kwargs.pop("token_type_ids", None)
- # kwargs.pop("token_type_ids", None)
+ # Remove token_type_ids
+ kwargs.pop("token_type_ids", None)
# VLMs do not allow logits_to_keep
- global NUM_LOGITS_TO_KEEP
- if arch not in NUM_LOGITS_TO_KEEP:
- m = self
- # Find which is needed ie
- # num_logits_to_keep or logits_to_keep
- while hasattr(m, "model"):
- if hasattr(m, "forward"):
- keys = inspect.signature(m.forward).parameters.keys()
- if "num_logits_to_keep" in keys:
- NUM_LOGITS_TO_KEEP[arch] = "num_logits_to_keep"
- break
- elif "logits_to_keep" in keys:
- NUM_LOGITS_TO_KEEP[arch] = "logits_to_keep"
- break
- m = m.model
+ if not is_vlm:
+ global NUM_LOGITS_TO_KEEP
if arch not in NUM_LOGITS_TO_KEEP:
- NUM_LOGITS_TO_KEEP[arch] = None
- key = NUM_LOGITS_TO_KEEP[arch]
- if key is not None and key not in kwargs:
- kwargs[key] = 1
+ m = self
+ # Find which is needed ie
+ # num_logits_to_keep or logits_to_keep
+ while hasattr(m, "model"):
+ if hasattr(m, "forward"):
+ keys = inspect.signature(m.forward).parameters.keys()
+ if "num_logits_to_keep" in keys:
+ NUM_LOGITS_TO_KEEP[arch] = "num_logits_to_keep"
+ break
+ elif "logits_to_keep" in keys:
+ NUM_LOGITS_TO_KEEP[arch] = "logits_to_keep"
+ break
+ m = m.model
+ pass
+ if arch not in NUM_LOGITS_TO_KEEP:
+ NUM_LOGITS_TO_KEEP[arch] = None
+ pass
+ pass
+ key = NUM_LOGITS_TO_KEEP[arch]
+ if key is not None and key not in kwargs:
+ kwargs[key] = 1
+ else:
+ pass
+ # kwargs.pop("logits_to_keep", None)
+ # kwargs.pop("num_logits_to_keep", None)
# Check pad_token
model_eos_token_id = getattr(self.config, "eos_token_id", None)
@@ -212,905 +135,221 @@ def unsloth_base_fast_generate(
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
# Get pixel values for VLMs
- try:
- kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype)
- except:
- pass
+ try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype)
+ except: pass
# Mixed precision autocast
- if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
- autocaster = torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = torch.float16)
- dtype = torch.float16
- else:
- autocaster = torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype)
- # Prepare LoRA
- # state_dict = convert_lora_modules(self, dtype = dtype)
-
- # Set compile dynamic shapes
- torch._dynamo.mark_static(input_ids, 0)
- torch._dynamo.mark_dynamic(input_ids, 1)
- if "attention_mask" in kwargs:
- torch._dynamo.mark_static(kwargs["attention_mask"], 0)
- torch._dynamo.mark_dynamic(kwargs["attention_mask"], 1)
- if "token_type_ids" in kwargs:
- torch._dynamo.mark_static(kwargs["token_type_ids"], 0)
- torch._dynamo.mark_dynamic(kwargs["token_type_ids"], 1)
-
- # Fix generation_config
- # Use hybrid if sliding window seen, otherwise try static
- cache_implementation = getattr(self.config, "cache_implementation", None)
- if getattr(
- self, "_supports_static_cache", getattr(self, "_can_compile_fullgraph", True)
- ):
- if os.environ.get("UNSLOTH_DISABLE_STATIC_GENERATION", "0") == "0":
- cache_implementation = "static"
- elif Version(transformers_version) < Version("4.56.0.dev0"):
- cache_implementation = None
- else:
- # Should work in latest transformers!
- cache_implementation = "static"
- else:
- cache_implementation = None
- if cache_implementation is not None:
- swa = getattr(
- getattr(self.config, "text_config", self.config), "sliding_window", None
- )
- if (swa == 0 or type(swa) is not int) and (
- getattr(self, "_can_compile_fullgraph", True) is True
- ):
- cache_implementation = "static"
- else:
- if Version(transformers_version) < Version("4.56.0.dev0"):
- cache_implementation = "hybrid"
- else:
- cache_implementation = "static"
- # [TODO] Unsure why static fails
- if do_bfloat16_mixed_precision:
- cache_implementation = None
-
- if "generation_config" in kwargs:
- kwargs["generation_config"].cache_implementation = cache_implementation
- if cache_implementation is not None:
- kwargs["generation_config"].compile_config = _compile_config
- else:
- kwargs["cache_implementation"] = cache_implementation
- if cache_implementation is not None:
- kwargs["compile_config"] = _compile_config
-
- # Delete cached Flex Attention masks to reset inference
- for name, module in self.named_modules():
- if hasattr(module, "_flex_attention_cache"):
- try:
- del module._flex_attention_cache
- except:
- pass
- # Solves AttributeError: 'SlidingWindowLayer' object has no attribute 'max_batch_size'
- if hasattr(module, "_cache") and "cache_utils" in str(module._cache.__class__):
- try:
- del module._cache
- except:
- pass
-
- # DO INFERENCE
- with torch.inference_mode(), autocaster:
+ if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32
+ with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype):
output = self._old_generate(*args, **kwargs)
+ pass
- # Delete cached Flex Attention masks to reset inference
- for name, module in self.named_modules():
- if hasattr(module, "_flex_attention_cache"):
- try:
- del module._flex_attention_cache
- except:
- pass
- # Solves AttributeError: 'SlidingWindowLayer' object has no attribute 'max_batch_size'
- if hasattr(module, "_cache") and "cache_utils" in str(module._cache.__class__):
- try:
- del module._cache
- except:
- pass
-
- # FastBaseModel.for_training(self)
+ FastBaseModel.for_training(self)
return output
-
-
-def _construct_vlm_processor_fallback(
- tokenizer_name, model_type, token, trust_remote_code
-):
- """Construct a VLM processor manually when AutoProcessor.from_pretrained fails.
-
- Some VLMs (e.g., LFM2.5-VL) have tokenizer_class entries that AutoTokenizer
- cannot resolve. This function loads the image processor and tokenizer separately,
- sets required special token attributes, and constructs the processor.
- """
- try:
- from transformers import AutoImageProcessor, PreTrainedTokenizerFast, AutoConfig
- from transformers.models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
- import json
-
- # Load image processor
- image_processor = AutoImageProcessor.from_pretrained(
- tokenizer_name,
- token = token,
- trust_remote_code = trust_remote_code,
- )
- # Load tokenizer via PreTrainedTokenizerFast (bypasses tokenizer_class check)
- tok = PreTrainedTokenizerFast.from_pretrained(
- tokenizer_name,
- padding_side = "left",
- token = token,
- trust_remote_code = trust_remote_code,
- )
- # Read tokenizer_config.json for model-specific special tokens
- try:
- from huggingface_hub import hf_hub_download
-
- config_path = hf_hub_download(
- tokenizer_name, "tokenizer_config.json", token = token
- )
- with open(config_path, "r", encoding = "utf-8") as f:
- tok_config = json.load(f)
- # Set model-specific special tokens and their IDs
- for key in (
- "image_token",
- "image_start_token",
- "image_end_token",
- "image_thumbnail",
- "video_token",
- ):
- if key in tok_config and not hasattr(tok, key):
- setattr(tok, key, tok_config[key])
- id_key = key + "_id" if not key.endswith("_id") else key
- token_id = tok.convert_tokens_to_ids(tok_config[key])
- if not hasattr(tok, id_key):
- setattr(tok, id_key, token_id)
- except Exception:
- pass
-
- # Find the processor class - try model_type first, then top-level config model_type
- proc_class_name = PROCESSOR_MAPPING_NAMES.get(model_type)
- if proc_class_name is None:
- # model_type might be a sub-model type (e.g. "lfm2" instead of "lfm2_vl").
- # Try the top-level config.model_type which often has the processor mapping.
- try:
- config = AutoConfig.from_pretrained(
- tokenizer_name,
- token = token,
- trust_remote_code = trust_remote_code,
- )
- proc_class_name = PROCESSOR_MAPPING_NAMES.get(config.model_type)
- except Exception:
- pass
-
- if proc_class_name is not None:
- import transformers
-
- proc_class = getattr(transformers, proc_class_name, None)
- if proc_class is not None:
- processor = proc_class(image_processor = image_processor, tokenizer = tok)
- # Copy chat_template from tokenizer to processor if needed
- if not getattr(processor, "chat_template", None) and getattr(
- tok, "chat_template", None
- ):
- processor.chat_template = tok.chat_template
- return processor
- except Exception:
- pass
- return None
+pass
class FastBaseModel:
+
@staticmethod
def from_pretrained(
- model_name = "unsloth/Llama-3.2-1B-Instruct",
- max_seq_length = 2048,
- dtype = None,
- load_in_4bit = True,
- load_in_8bit = False,
- load_in_16bit = False,
- full_finetuning = False,
- token = None,
- device_map = "sequential",
+ model_name = "unsloth/Llama-3.2-1B-Instruct",
+ max_seq_length = 2048,
+ dtype = None,
+ load_in_4bit = True,
+ load_in_8bit = False,
+ full_finetuning = False,
+ token = None,
+ device_map = "sequential",
trust_remote_code = False,
- model_types = None,
- tokenizer_name = None,
- auto_model = AutoModelForVision2Seq,
+ model_types = None,
+ tokenizer_name = None,
+ auto_model = AutoModelForVision2Seq,
use_gradient_checkpointing = "unsloth",
- supports_sdpa = True,
- whisper_language = None,
- whisper_task = None,
- auto_config = None,
- offload_embedding = False,
- float32_mixed_precision = None, # Forces float32 mixed precision
- # vLLM parameters
- fast_inference = False,
- gpu_memory_utilization = 0.5,
- float8_kv_cache = False,
- random_state = 3407,
- max_lora_rank = 64,
- disable_log_stats = False,
- unsloth_vllm_standby = False,
- load_in_fp8 = False, # fp8 LoRA (True, False, 'block')
**kwargs,
):
- if unsloth_vllm_standby and os.environ.get("UNSLOTH_VLLM_STANDBY", "0") != "1":
- raise RuntimeError(
- "Unsloth: UNSLOTH_VLLM_STANDBY is True, but UNSLOTH_VLLM_STANDBY is not set to 1!"
- )
-
- if model_types is None:
- raise RuntimeError(
- "Unsloth: Please use FastModel or FastVisionModel and not use FastBaseModel directly!"
- )
- if os.environ.get("UNSLOTH_MODEL_NAME", "") == "":
- os.environ["UNSLOTH_MODEL_NAME"] = model_name.lower()
-
- is_vlm = auto_model in [AutoModelForVision2Seq, AutoModelForImageTextToText]
- is_whisper = whisper_language is not None and whisper_task is not None
- auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizer
-
- model_type_arch = model_types[0]
- if model_type_arch == "siglip":
- for model_type_arch in model_types:
- if model_type_arch != "siglip":
- break
-
- vllm_enable_lora = True
-
- if is_vlm and fast_inference:
- if not any(arch in VLLM_SUPPORTED_VLM for arch in model_types):
- raise RuntimeError(
- f"Unsloth: Fast inference is only supported for Language models and Qwen2.5-VL, Gemma3 among vision models. "
- f"Found architectures: {', '.join(model_types)}!"
- )
-
- if any(arch in VLLM_NON_LORA_VLM for arch in model_types):
- # mllama is still only in vllm v0 https://arc.net/l/quote/llwkfgmu
- # https://docs.vllm.ai/en/stable/models/supported_models.html#text-generation_1
- # vLLM V0 does not support LoRA on multi modal models.
- # TODO: Update this once vLLM V1 supports Llama 3.2 aka mllama
- vllm_enable_lora = False
-
os.environ["UNSLOTH_USE_NEW_MODEL"] = "1"
if trust_remote_code:
print(
- "Unsloth: WARNING `trust_remote_code` is True.\n"
+ "Unsloth: WARNING `trust_remote_code` is True.\n"\
"Are you certain you want to do remote code execution?"
)
- token = hf_login(token)
+ pass
+ if token is None: token = get_token()
SUPPORTS_BFLOAT16 = is_bfloat16_supported()
-
- if DEVICE_TYPE == "cuda":
- gpu_stats = torch.cuda.get_device_properties(0)
- gpu_stats_name = (
- gpu_stats.name + ". " if gpu_stats.name != "" else "NVIDIA GPU Device. "
- )
- gpu_version = torch.version.cuda
- gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}."
- try:
- vllm_version = f" vLLM: {importlib_version('vllm')}."
- except:
- vllm_version = ""
- elif DEVICE_TYPE == "hip":
- gpu_stats = torch.cuda.get_device_properties(0)
- gpu_stats_name = resolve_hip_gpu_stats_name(gpu_stats)
- gpu_version = torch.version.hip
- gpu_stats_snippet = f"ROCm Toolkit: {gpu_version}."
- try:
- vllm_version = f" vLLM: {importlib_version('vllm')}."
- except:
- vllm_version = ""
- elif DEVICE_TYPE == "xpu":
- gpu_stats = torch.xpu.get_device_properties(0)
- gpu_stats_name = (
- gpu_stats.name + ". " if gpu_stats.name != "" else "Intel XPU Device. "
- )
- gpu_version = torch.version.xpu
- gpu_stats_snippet = f"Intel Toolkit: {gpu_version}."
- # [TODO] After adding vLLM support for XPU, change this
- vllm_version = ""
- else:
- raise ValueError(f"Unsloth: Unsupported device type: {DEVICE_TYPE}")
-
+ gpu_stats = torch.cuda.get_device_properties(0)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
- arch_name = model_type_arch.title()
- arch_name = arch_name.replace("_Vl_", "_VL_").replace("_Moe", "_MoE")
- statistics = (
- f"==((====))== Unsloth {__version__}: Fast {arch_name} patching. Transformers: {transformers_version}.{vllm_version}\n"
- f" {chr(92)}{chr(92)} /| {gpu_stats_name}Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"
- f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"
- f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"
- f' "-____-" Free license: http://github.com/unslothai/unsloth'
- )
+ from importlib.metadata import version as importlib_version
+ try: vllm_version = f" vLLM: {importlib_version('vllm')}."
+ except: vllm_version = ""
+ model_type_arch = model_types[0]
+ if model_type_arch == "siglip" and len(model_types) != 1:
+ model_type_arch = model_types[1]
+
+ statistics = \
+ f"==((====))== Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\
+ f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\
+ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\
+ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
+ f' "-____-" Free license: http://github.com/unslothai/unsloth'
print(statistics)
# Warn about fast transfers
- if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ:
- old_hf_transfer = os.environ["HF_HUB_ENABLE_HF_TRANSFER"]
- if old_hf_transfer in ("False", "false"):
- old_hf_transfer = "0"
- if old_hf_transfer in ("True", "true"):
- old_hf_transfer = "1"
- else:
- old_hf_transfer = "0"
- if old_hf_transfer == "1":
- print(
- "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!"
- )
- if old_hf_transfer != "0":
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+ old_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0")
+ if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1":
+ print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!")
+ pass
+ # Return old flag
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
- # For debugging - we use a download counter to see if environments are not breaking or if HF is down
- get_statistics(kwargs.get("local_files_only", False))
+ get_statistics() # For debugging - we use a download counter to see if environments are not breaking
if dtype is None:
dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
- elif os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
- if dtype == torch.float16:
- dtype = torch.bfloat16
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
- logger.warning_once(
- "Device does not support bfloat16. Will change to float16."
- )
+ logger.warning_once("Device does not support bfloat16. Will change to float16.")
dtype = torch.float16
- assert dtype in (torch.float16, torch.bfloat16, torch.float32)
+ assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
+
+ global FORCE_FLOAT32
+ os.environ["UNSLOTH_FORCE_FLOAT32"] = "0"
bnb_compute_dtype = dtype
- do_forced_float32 = False
- if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
- print(
- f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32."
- )
- bnb_compute_dtype = torch.float16
- do_forced_float32 = True
-
- # Check for custom data-types
- custom_datatype = None
- correct_dtype = None
- if os.environ.get("UNSLOTH_FORCE_CUSTOM_DTYPE", "") != "":
- custom_datatype = os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"]
- assert custom_datatype.count(";") >= 4
- checker, _dtype, _bnb_compute_dtype, _custom_datatype, execute_code = (
- custom_datatype.split(";", 4)
- )
- # Allow custom dtypes on all runs
- allow_all_runs = checker == "all"
- # Allow only on float16 datatypes
- allow_float16_runs = (
- checker == "float16" or checker == "torch.float16"
- ) and (
- dtype == torch.float16
- or os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1"
- )
- if allow_all_runs or allow_float16_runs:
- if eval(_dtype) is not None:
- dtype = eval(_dtype)
- if eval(_bnb_compute_dtype) is not None:
- bnb_compute_dtype = eval(_bnb_compute_dtype)
- correct_dtype = bnb_compute_dtype
- custom_datatype = _custom_datatype
- # Execute code as well
- if len(execute_code.strip()) != 0:
- exec(execute_code)
- else:
- custom_datatype = None
- correct_dtype = None
-
- # Stop SDPA for some archs like Pixtral / Mistral3
- flex_attn_impl = None
- if auto_config is None:
- auto_config = AutoConfig.from_pretrained(
- model_name,
- token = token,
- trust_remote_code = trust_remote_code,
- )
- try:
- model_class = auto_model._model_mapping[auto_config.__class__]
- except Exception:
- model_class = None
- flex_attn_impl = prefer_flex_attn_if_supported(model_class, auto_config)
-
- # Handle FP8 models: get_model_name has already redirected this to BF16 sibling if the model ships with
- # FP8 weights. We just need to update it here for sanity.
- auto_config.model_name = model_name
- # Re-resolve model_class after potential config change
- try:
- model_class = auto_model._model_mapping[auto_config.__class__]
- except Exception:
- model_class = None
-
- model_type = str(getattr(auto_config, "model_type", "")).lower()
- if model_type.startswith("gemma3n"):
- # Gemma3N variants initialize timm-based vision towers which do
- # not support flex_attention, so default to eager unless overridden.
- default_attn_impl = "eager"
- else:
- default_attn_impl = "flex_attention" if flex_attn_impl else "sdpa"
- if not ("attn_implementation" in kwargs):
- kwargs["attn_implementation"] = default_attn_impl
- if not supports_sdpa and kwargs.get("attn_implementation") == "sdpa":
- if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "0") == "0":
- print(
- f"Unsloth: {model_type_arch.title()} does not support SDPA - switching to fast eager."
- )
- del kwargs["attn_implementation"]
+ for disable_name in FORCE_FLOAT32:
+ if (disable_name.lower() == model_type_arch.lower() or \
+ disable_name.lower() in model_name.lower()) and \
+ dtype == torch.float16:
+
+ print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.")
+ os.environ["UNSLOTH_FORCE_FLOAT32"] = "1"
+ bnb_compute_dtype = torch.float32
+ break
+ pass
+
+ global FORCE_EAGER_ATTENTION
+ attn_implementation = "sdpa"
+ for disable_name in FORCE_EAGER_ATTENTION:
+ if (disable_name.lower() == model_type_arch.lower() or \
+ disable_name.lower() in model_name.lower()):
+
+ print(f"Unsloth: {model_type_arch} does not support SDPA - switching to eager!")
+ attn_implementation = "eager"
+ break
+ pass
bnb_config = None
- user_quantization_config = kwargs.get("quantization_config", None)
if full_finetuning and (load_in_4bit or load_in_8bit):
- print(
- "Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA."
- )
+ print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.")
load_in_4bit = False
load_in_8bit = False
- load_in_16bit = False
-
- if int(load_in_4bit) + int(load_in_8bit) + int(load_in_16bit) >= 2:
- raise RuntimeError(
- "Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!"
- )
- _skip_modules = SKIP_QUANTIZATION_MODULES.copy()
- # Nemotron-H uses 'mixer' (not 'mamba') for Mamba layers.
- # Mamba fused kernels pass out_proj.weight directly to F.linear,
- # which fails with quantized Params4bit. Skip out_proj from quantization.
- if any(mt == "nemotron_h" for mt in (model_types or [])):
- _skip_modules.append("out_proj")
+ pass
+ if load_in_4bit and load_in_8bit:
+ raise RuntimeError("Unsloth: Can only load in 4bit or 8bit, not both!")
if load_in_4bit:
bnb_config = BitsAndBytesConfig(
- load_in_4bit = True,
+ load_in_4bit = True,
bnb_4bit_use_double_quant = True,
- bnb_4bit_quant_type = "nf4",
- bnb_4bit_compute_dtype = bnb_compute_dtype,
- llm_int8_skip_modules = _skip_modules,
+ bnb_4bit_quant_type = "nf4",
+ bnb_4bit_compute_dtype = bnb_compute_dtype,
+ llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
)
elif load_in_8bit:
bnb_config = BitsAndBytesConfig(
- load_in_8bit = True,
- llm_int8_skip_modules = _skip_modules,
+ load_in_8bit = True,
+ llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
)
- elif load_in_16bit:
- bnb_config = None
elif not load_in_4bit and not load_in_8bit and not full_finetuning:
- print(
- "Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA."
+ print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.")
+ load_in_4bit = True
+ bnb_config = BitsAndBytesConfig(
+ load_in_4bit = True,
+ bnb_4bit_use_double_quant = True,
+ bnb_4bit_quant_type = "nf4",
+ bnb_4bit_compute_dtype = bnb_compute_dtype,
+ llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(),
)
+ pass
if full_finetuning:
os.environ["UNSLOTH_ENABLE_FULL_FINETUNING"] = "1"
if dtype == torch.bfloat16:
- if float32_mixed_precision != True:
- print(
- f"Unsloth: Using bfloat16 full finetuning which cuts memory usage by 50%.\n"
- f"To enable float32 training, use `float32_mixed_precision = True` during FastLanguageModel.from_pretrained"
- )
- else:
- print(
- f"Unsloth: Using full float32 full finetuning. "
- f"To enable bfloat16 training to reduce VRAM usage by 50% albeit with a slightly higher loss, do:\n"
- "use `float32_mixed_precision = False` during FastLanguageModel.from_pretrained"
- )
- os.environ["UNSLOTH_BFLOAT16_MIXED_PRECISION"] = "1"
+ print("Unsloth: Using bfloat16 full finetuning which cuts memory usage by 50%.")
else:
- print(
- "Unsloth: Float16 full finetuning uses more memory since we upcast weights to float32."
- )
+ print("Unsloth: Float16 full finetuning uses more memory since we upcast weights to float32.")
else:
os.environ["UNSLOTH_ENABLE_FULL_FINETUNING"] = "0"
+ pass
- # Fix AttributeError: 'BitsAndBytesConfig' object has no attribute 'get_loading_attributes'
- if bnb_config is not None and not hasattr(bnb_config, "get_loading_attributes"):
- bnb_config.get_loading_attributes = lambda *args, **kwargs: {}
+ kwargs.pop("attn_implementation", None); # No need since we auto call it
# Cannot be None, since HF now checks for the config
- if load_in_4bit or load_in_8bit:
- # Ignore load_in_4bit / load_in_8bit for MXFP4 - best to get config file
- if (
- "gpt-oss-20b" in model_name.lower()
- or "gpt-oss-120b" in model_name.lower()
- ):
- pass
- else:
- if user_quantization_config is None:
- kwargs["quantization_config"] = bnb_config
- else:
- if auto_config is None:
- auto_config = AutoConfig.from_pretrained(
- model_name,
- token = token,
- trust_remote_code = trust_remote_code,
- )
- if hasattr(auto_config, "quantization_config"):
- from transformers.quantizers.auto import (
- AUTO_QUANTIZATION_CONFIG_MAPPING,
- )
-
- quantization_config = auto_config.quantization_config
- quant_method = quantization_config["quant_method"]
- # Sometimes bitsandbytes_4bit + bitsandbytes_8bit is provided
- if (
- quant_method == "bitsandbytes"
- and "bitsandbytes" not in AUTO_QUANTIZATION_CONFIG_MAPPING
- ):
- if "bitsandbytes_4bit" not in AUTO_QUANTIZATION_CONFIG_MAPPING:
- raise KeyError(
- "Unsloth: AUTO_QUANTIZATION_CONFIG_MAPPING does not have `bitsandbytes_4bit`"
- )
- quantizer = AUTO_QUANTIZATION_CONFIG_MAPPING["bitsandbytes_4bit"]
- else:
- quantizer = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
- quantizer_kwargs = {}
- if quant_method == "compressed-tensors":
- # Ignore these
- pass
- else:
- # We cannot dequantize since gpt-oss-20b MXFP4 will now be gpt-oss-20b-BF16
- if (
- load_in_16bit
- and "dequantize" in inspect.signature(quantizer).parameters
- ):
- quantizer_kwargs["dequantize"] = True
- try:
- # Sometimes this fails so we wrap it in a try except
- quantization_config = quantizer.from_dict(
- quantization_config, **quantizer_kwargs
- )
- except:
- pass
- if user_quantization_config is None:
- kwargs["quantization_config"] = quantization_config
-
- # Check if using forced float32 - we load it in bfloat16, then cast to float16!
- torch_dtype = dtype
- if do_forced_float32:
- torch_dtype = torch.bfloat16
-
- kwargs = add_dtype_kwargs(torch_dtype, kwargs)
-
- config_attn_impl = kwargs.get("attn_implementation", None)
- if config_attn_impl is None:
- config_attn_impl = "sdpa" if supports_sdpa else "eager"
- if auto_config is None:
- auto_config = AutoConfig.from_pretrained(
- model_name,
- token = token,
- trust_remote_code = trust_remote_code,
- )
- setattr(auto_config, "_attn_implementation", config_attn_impl)
- if hasattr(auto_config, "attn_implementation"):
- setattr(auto_config, "attn_implementation", config_attn_impl)
- model_config = auto_config
-
- verify_fp8_support_if_applicable(model_config)
-
- raise_handler = RaiseUninitialized()
- if not fast_inference:
- # Prevent load_in_fp8 from being forwarded into HF internal model loading
- load_in_fp8 = kwargs.pop("load_in_fp8", None)
- model = auto_model.from_pretrained(
- model_name,
- config = model_config,
- device_map = device_map,
- # torch_dtype = torch_dtype, # Transformers removed torch_dtype
- # quantization_config = bnb_config,
- token = token,
- trust_remote_code = trust_remote_code,
- # attn_implementation = attn_implementation,
- **kwargs,
- )
- if hasattr(model, "generate"):
- model.fast_generate = make_fast_generate_wrapper(model.generate)
- model.fast_generate_batches = error_out_no_vllm
- if offload_embedding:
- if bool(
- os.environ.get("WSL_DISTRO_NAME") or os.environ.get("WSL_INTEROP")
- ):
- # WSL doesn't work with offloaded embeddings
- pass
- elif os.name == "nt":
- # Windows doesn't work with offloaded embeddings
- pass
- else:
- embed_tokens = model.get_input_embeddings()
- nbytes = embed_tokens.weight.numel() * embed_tokens.weight.itemsize
- ngb = round(nbytes / 1024 / 1024 / 1024, 2)
- print(f"Unsloth: Offloading embeddings to RAM to save {ngb} GB.")
- embed_tokens.to("cpu")
-
- # Add hooks to move inputs to CPU and back to CUDA
- # [TODO] Doesn't seem to work!
- # def pre_hook(module, args):
- # args[0]._old_device = args[0].device
- # return (args[0].to("cpu", non_blocking = True))
- # def post_hook(module, args, output):
- # old_device = getattr(args[0], "_old_device", "cuda")
- # return output.to(old_device, non_blocking = True)
- # embed_tokens.register_forward_pre_hook(pre_hook, prepend = True)
- # embed_tokens.register_forward_hook (post_hook, prepend = True)
- # Must free GPU memory otherwise will not free!
- torch.cuda.empty_cache()
- gc.collect()
- else:
- from unsloth_zoo.vllm_utils import (
- load_vllm,
- get_vllm_state_dict,
- convert_vllm_to_huggingface,
- generate_batches,
- get_lora_supported_ranks,
- )
-
- if full_finetuning:
- max_lora_rank = max(get_lora_supported_ranks())
- raise NotImplementedError(
- "Unsloth: `fast_inference=True` cannot be used together with `full_finetuning=True`.\n"
- "Reason: fast_inference is optimized for inference-only workflows and "
- "does not currently support full fine-tuning.\n"
- "Workaround: disable fast_inference, or use parameter-efficient fine-tuning "
- f"(e.g. LoRA with rank r={max_lora_rank})."
- )
-
- model_config.model_name = model_name
-
- if fast_inference:
- fast_inference, model_name = fast_inference_setup(
- model_name, model_config
- )
-
- fp8_mode = None
- if load_in_fp8 != False:
- fp8_mode = _get_fp8_mode_and_check_settings(
- load_in_fp8,
- fast_inference,
- full_finetuning,
- load_in_4bit,
- load_in_8bit,
- load_in_16bit,
- )
-
- allowed_args = inspect.getfullargspec(load_vllm).args
- load_vllm_kwargs = dict(
- model_name = model_name,
- config = model_config,
- gpu_memory_utilization = gpu_memory_utilization,
- max_seq_length = max_seq_length,
- dtype = dtype,
- float8_kv_cache = float8_kv_cache,
- enable_lora = vllm_enable_lora,
- max_lora_rank = max_lora_rank,
- disable_log_stats = disable_log_stats,
- use_bitsandbytes = load_in_4bit,
- unsloth_vllm_standby = unsloth_vllm_standby,
- is_vision_model = is_vlm,
- fp8_mode = fp8_mode,
- )
- for allowed_arg in allowed_args:
- if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs:
- load_vllm_kwargs[allowed_arg] = kwargs[allowed_arg]
-
- # Load vLLM first
- llm = load_vllm(**load_vllm_kwargs)
-
- # Convert to HF format
- _, quant_state_dict = get_vllm_state_dict(
- llm,
- config = model_config,
- is_vision_model = is_vlm,
- load_in_fp8 = load_in_fp8,
- )
- model = convert_vllm_to_huggingface(
- quant_state_dict,
- model_config,
- dtype,
- bnb_config,
- is_vision_model = is_vlm,
- )
- model.vllm_engine = llm
- model.fast_generate = model.vllm_engine.generate
- model.fast_generate_batches = functools.partial(
- generate_batches, model.vllm_engine
- )
-
- raise_handler.remove()
-
+ if load_in_4bit: kwargs["quantization_config"] = bnb_config
+
+ model = auto_model.from_pretrained(
+ model_name,
+ device_map = device_map,
+ torch_dtype = dtype,
+ # quantization_config = bnb_config,
+ token = token,
+ trust_remote_code = trust_remote_code,
+ attn_implementation = attn_implementation,
+ **kwargs,
+ )
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
- # Check float32 norm weights
- if os.environ.get("UNSLOTH_HIGH_PRECISION_LAYERNORM", "0") == "1":
- for jj, (name, module) in enumerate(model.named_modules()):
- if (
- name.endswith(("norm", "norm1", "norm2", "norm3", "norm4"))
- or "layernorm" in name
- or "layer_norm" in name
- ) and hasattr(module, "weight"):
- module._pre_set_compute_dtype = torch.float32
- # Edit data-types
- if custom_datatype is not None:
- with torch.no_grad():
- for jj, (name, module) in enumerate(model.named_modules()):
- exec(custom_datatype)
- # Clear deleted GPU items
- for _ in range(3):
- gc.collect()
- if DEVICE_TYPE in ("cuda", "hip"):
- torch.cuda.empty_cache()
- elif DEVICE_TYPE == "xpu":
- torch.xpu.empty_cache()
-
# Counteract saved tokenizers
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
-
- # Fix _Unsloth_Patched_ prefix in local config files from old saves (issue #4085)
- if os.path.isdir(tokenizer_name):
- import json as _json
-
- for _cfg_name in (
- "processor_config.json",
- "preprocessor_config.json",
- "tokenizer_config.json",
- ):
- _cfg_path = os.path.join(tokenizer_name, _cfg_name)
- if os.path.exists(_cfg_path):
- try:
- with open(_cfg_path, "r", encoding = "utf-8") as _f:
- _cfg = _json.load(_f)
- if _cfg.get("processor_class", "").startswith(
- "_Unsloth_Patched_"
- ):
- _cfg["processor_class"] = _cfg["processor_class"][
- len("_Unsloth_Patched_") :
- ]
- with open(_cfg_path, "w", encoding = "utf-8") as _f:
- _json.dump(_cfg, _f, indent = 2, ensure_ascii = False)
- except Exception:
- pass
-
- if (whisper_language and whisper_task) or auto_model.__name__.endswith(
- "ForConditionalGeneration"
- ):
- try:
- tokenizer = auto_processor.from_pretrained(
- tokenizer_name,
- padding_side = "left",
- token = token,
- language = whisper_language,
- task = whisper_task,
- trust_remote_code = trust_remote_code,
- )
- except Exception:
- tokenizer = None
- else:
- try:
- tokenizer = auto_processor.from_pretrained(
- tokenizer_name,
- padding_side = "left",
- token = token,
- trust_remote_code = trust_remote_code,
- )
- except:
- tokenizer = get_auto_processor(
- tokenizer_name,
- padding_side = "left",
- token = token,
- trust_remote_code = trust_remote_code,
- )
-
- # If processor loading failed (e.g., tokenizer class not found),
- # or if AutoProcessor silently degraded to a text-only tokenizer
- # instead of returning a full VLM processor (issue #4085),
- # try constructing the processor manually from separate components.
- _processor_is_degraded = (
- is_vlm
- and tokenizer is not None
- and not hasattr(tokenizer, "image_processor")
+ auto_processor = AutoProcessor if auto_model is AutoModelForVision2Seq else AutoTokenizer
+ tokenizer = auto_processor.from_pretrained(
+ tokenizer_name,
+ padding_side = "right",
+ token = token,
)
- if (tokenizer is None or _processor_is_degraded) and is_vlm:
- _fallback = _construct_vlm_processor_fallback(
- tokenizer_name,
- model_type_arch,
- token,
- trust_remote_code,
- )
- if _fallback is not None:
- tokenizer = _fallback
- if tokenizer is None:
- import sys
-
- print(
- f"Unsloth: Warning - VLM processor fallback returned None for model_type={model_type_arch}",
- file = sys.stderr,
- )
if hasattr(tokenizer, "tokenizer"):
__tokenizer = tokenizer.tokenizer
# Add padding side as well
- __tokenizer.padding_side = "left"
+ __tokenizer.padding_side = "right"
# Check bos, eos, pad tokens
if hasattr(__tokenizer, "bos_token"):
- tokenizer.bos_token = __tokenizer.bos_token
+ tokenizer.bos_token = __tokenizer.bos_token
tokenizer.bos_token_id = __tokenizer.bos_token_id
if hasattr(__tokenizer, "eos_token"):
- tokenizer.eos_token = __tokenizer.eos_token
+ tokenizer.eos_token = __tokenizer.eos_token
tokenizer.eos_token_id = __tokenizer.eos_token_id
if hasattr(__tokenizer, "pad_token"):
- tokenizer.pad_token = __tokenizer.pad_token
+ tokenizer.pad_token = __tokenizer.pad_token
tokenizer.pad_token_id = __tokenizer.pad_token_id
+ pass
+ model, tokenizer = patch_tokenizer(model, tokenizer)
+ model = post_patch_loss_function(model)
# Fix other stuff like BnB compute data types
model, tokenizer = patch_model_and_tokenizer(
model,
tokenizer,
downcast_rope = False,
fix_embeddings = False,
- do_forced_float32 = do_forced_float32,
- correct_dtype = correct_dtype,
)
- try:
- model, tokenizer = patch_tokenizer(model, tokenizer)
- except Exception as _patch_err:
- # Some VLM processors (e.g., ERNIE VL) may fail during tokenizer patching.
- # Try loading tokenizer separately via AutoTokenizer as fallback.
- try:
- from transformers import AutoTokenizer as _AutoTokenizer
-
- _fallback_tok = _AutoTokenizer.from_pretrained(
- tokenizer_name,
- padding_side = "left",
- token = token,
- trust_remote_code = trust_remote_code,
- )
- model, _fallback_tok = patch_tokenizer(model, _fallback_tok)
- # Re-attach as processor wrapper if original was a processor
- if hasattr(tokenizer, "image_processor"):
- tokenizer.tokenizer = _fallback_tok
- else:
- tokenizer = _fallback_tok
- except Exception:
- # If fallback also fails, raise the original error
- raise _patch_err
- model = post_patch_loss_function(model)
-
# Log Unsloth version for future fastpaths for inference
if hasattr(model, "config"):
- model.config.update({"unsloth_version": __version__})
+ model.config.update({"unsloth_version" : __version__})
+ pass
patch_saving_functions(model, vision = True)
- if tokenizer is None:
- # Last resort: try loading tokenizer via AutoTokenizer, then PreTrainedTokenizerFast
- try:
- from transformers import AutoTokenizer as _AutoTokenizer
-
- tokenizer = _AutoTokenizer.from_pretrained(
- tokenizer_name,
- padding_side = "left",
- token = token,
- trust_remote_code = trust_remote_code,
- )
- except Exception:
- try:
- from transformers import PreTrainedTokenizerFast
-
- tokenizer = PreTrainedTokenizerFast.from_pretrained(
- tokenizer_name,
- padding_side = "left",
- token = token,
- trust_remote_code = trust_remote_code,
- )
- except Exception:
- del model
- raise RuntimeError(
- "Unsloth: The tokenizer is weirdly not loaded? Please check if there is one."
- )
patch_saving_functions(tokenizer, vision = True)
# Fix gradient accumulation
from transformers.trainer import Trainer
-
patch_gradient_accumulation_fix(Trainer)
# Save tokenizer for inference purposes
- tokenizer.padding_side = "left" # Force inference
+ tokenizer.padding_side = "left" # Force inference
if hasattr(tokenizer, "tokenizer"):
- tokenizer.tokenizer.padding_side = "left" # Force inference
+ tokenizer.tokenizer.padding_side = "left" # Force inference
m = model
while hasattr(m, "model"):
m.max_seq_length = max_seq_length
@@ -1118,74 +357,60 @@ def from_pretrained(
# Also set is_loaded_in_8bit to disable incorrect DDP
m.is_loaded_in_8bit = True if not full_finetuning else False
m = m.model
+ pass
m.max_seq_length = max_seq_length
- # Save to modules as well
- for module in model.modules():
- module.max_seq_length = max_seq_length
m._saved_temp_tokenizer = tokenizer
# Also set is_loaded_in_8bit to disable incorrect DDP
m.is_loaded_in_8bit = True if not full_finetuning else False
# Patch generate
- if os.environ.get("UNSLOTH_DISABLE_FAST_GENERATION", "0") == "0" and hasattr(
- model, "generate"
- ):
- if model.generate.__name__ != "unsloth_base_fast_generate":
- model._old_generate = model.generate
- unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__
- model.generate = types.MethodType(unsloth_base_fast_generate, model)
- model._unsloth_trust_remote_code = trust_remote_code
+ if model.generate.__name__ != "unsloth_base_fast_generate":
+ model._old_generate = model.generate
+ unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__
+ model.generate = types.MethodType(unsloth_base_fast_generate, model)
+
# Post patches
model = FastBaseModel.post_patch_model(
model,
use_gradient_checkpointing = use_gradient_checkpointing,
- trust_remote_code = trust_remote_code,
- model_type = model_type_arch,
- tokenizer = tokenizer,
- float32_mixed_precision = float32_mixed_precision,
)
# Clear deleted GPU items
for _ in range(3):
gc.collect()
- if DEVICE_TYPE in ("cuda", "hip"):
- torch.cuda.empty_cache()
- elif DEVICE_TYPE == "xpu":
- torch.xpu.empty_cache()
+ torch.cuda.empty_cache()
+ pass
return model, tokenizer
+ pass
+
@staticmethod
def get_peft_model(
model,
- r = 16,
- target_modules = None,
- lora_alpha = 16,
- lora_dropout = 0.0,
- bias = "none",
- finetune_vision_layers = True,
- finetune_language_layers = True,
+ r = 16,
+ target_modules = None,
+ lora_alpha = 16,
+ lora_dropout = 0,
+ bias = "none",
+ finetune_vision_layers = True,
+ finetune_language_layers = True,
finetune_attention_modules = True,
- finetune_mlp_modules = True,
- layers_to_transform = None,
- layers_pattern = None,
- use_gradient_checkpointing = "unsloth",
- random_state = 3407,
- max_seq_length = 2048, # not used anymore
- use_rslora = False,
- modules_to_save = None,
- init_lora_weights = True,
- loftq_config = {},
- task_type = TaskType.CAUSAL_LM,
- temporary_location = "_unsloth_temporary_saved_buffers",
- qat_scheme = None,
- target_parameters = None, # For MoE expert layers (nn.Parameter)
- ensure_weight_tying = False, # [TODO] Add `ensure_weight_tying` for `modules_to_save` for vision models
+ finetune_mlp_modules = True,
+ layers_to_transform = None,
+ layers_pattern = None,
+ use_gradient_checkpointing = True,
+ random_state = 3407,
+ max_seq_length = 2048, # not used anymore
+ use_rslora = False,
+ modules_to_save = None,
+ init_lora_weights = True,
+ loftq_config = {},
+ temporary_location = "_unsloth_temporary_saved_buffers",
**kwargs,
):
if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1":
- print(
- "Unsloth: Full finetuning is enabled, so .get_peft_model has no effect"
- )
+ print("Unsloth: Full finetuning is enabled, so .get_peft_model has no effect")
return model
+ pass
transformers_set_seed(random_state)
if type(r) is not int:
@@ -1194,353 +419,196 @@ def get_peft_model(
raise TypeError(f"Unsloth: Rank of {str(r)} must be larger than 0.")
if isinstance(model, PeftModelForCausalLM):
- raise RuntimeError(
- "Unsloth: You already added LoRA adapters to your model!"
- )
+ raise RuntimeError("Unsloth: You already added LoRA adapters to your model!")
if target_modules == "all-linear":
- finetune_vision_layers = True
- finetune_language_layers = True
+ finetune_vision_layers = True
+ finetune_language_layers = True
finetune_attention_modules = True
- finetune_mlp_modules = True
- if target_modules is None or target_modules == "all-linear":
+ finetune_mlp_modules = True
+ pass
+ if target_modules is None:
target_modules = get_peft_regex(
model,
- finetune_vision_layers = finetune_vision_layers,
- finetune_language_layers = finetune_language_layers,
+ finetune_vision_layers = finetune_vision_layers,
+ finetune_language_layers = finetune_language_layers,
finetune_attention_modules = finetune_attention_modules,
- finetune_mlp_modules = finetune_mlp_modules,
+ finetune_mlp_modules = finetune_mlp_modules,
)
else:
- assert type(target_modules) in (
- list,
- tuple,
- str,
- )
-
- if hasattr(model, "vllm_engine"):
- if (
- hasattr(model.vllm_engine, "llm_engine")
- and hasattr(model.vllm_engine.llm_engine, "vllm_config")
- and getattr(
- model.vllm_engine.llm_engine.vllm_config, "lora_config", None
- )
- is None
- ):
- # If vLLM is being used but lora is not enabled, throw an error
- # Ref https://github.com/vllm-project/vllm/blob/51ba839555a5d122eadd91e9c16463ac288f5fa1/vllm/v1/engine/processor.py#L148-L151
- raise RuntimeError("Unsloth: LoRA is not enabled for this model!")
- if finetune_vision_layers:
- # vLLM does not support LoRA on vision layers
- # https://github.com/vllm-project/vllm/blob/main/vllm/lora/models.py#L471-L477
- # TODO: Update this once vLLM V1 supports LoRA on vision layers (possibly not happening)
- raise RuntimeError(
- "Unsloth: Finetuning vision layers is not supported for fast_inference. Only text layers are supported!"
- )
- if model.config.model_type in VLLM_NON_LORA_VLM:
- # mllama is still only in vllm v0 https://arc.net/l/quote/llwkfgmu
- # https://docs.vllm.ai/en/stable/models/supported_models.html#text-generation_1
- # vLLM V0 does not support LoRA on multi modal models.
- # TODO: Update this once vLLM V1 supports Llama 3.2 aka mllama
- raise RuntimeError(
- "Unsloth: LoRA finetuning for Llama 3.2 aka mllama models is not supported with fast_inference!"
- )
+ assert(type(target_modules) in (list, tuple,))
+ pass
# Clear deleted GPU items
for _ in range(3):
gc.collect()
- if DEVICE_TYPE in ("cuda", "hip"):
- torch.cuda.empty_cache()
- elif DEVICE_TYPE == "xpu":
- torch.xpu.empty_cache()
+ torch.cuda.empty_cache()
+ pass
max_seq_length = model.max_seq_length
- # If we pass loftq_config = None we will get an error
- loftq_config = validate_loftq_config(
- loftq_config, lora_dropout, bias, init_lora_weights, model
- )
-
- # Auto-detect MoE models and populate target_parameters for expert layers
- if target_parameters is None:
- target_parameters = get_moe_target_parameters(model, target_modules)
-
- # Get only allowed parameters for LoraConfig
- local_variables = {
- **locals(),
- **kwargs,
- }
- del local_variables["kwargs"]
- allowed_parameters = inspect.signature(LoraConfig).parameters.keys()
lora_config = LoraConfig(
- **{k: v for k, v in local_variables.items() if k in allowed_parameters},
+ r = r,
+ lora_alpha = lora_alpha,
+ target_modules = target_modules,
+ lora_dropout = lora_dropout,
+ bias = bias,
+ task_type = TaskType.CAUSAL_LM,
)
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing = use_gradient_checkpointing,
)
model = _get_peft_model(model, lora_config)
- # Apply QAT + LoRA if specified
- if qat_scheme is not None:
- print("Unsloth: Applying QAT to mitigate quantization degradation")
- model = _prepare_model_for_qat(model, qat_scheme)
- # Fix LoraConfig.auto_mapping is None
- fix_lora_auto_mapping(model)
# Enable gradients on modules which are trainable
requires_grad_for_gradient_checkpointing(model)
- trust_remote_code = getattr(model, "_unsloth_trust_remote_code", False)
- model = FastBaseModel.post_patch_model(
- model,
- use_gradient_checkpointing = use_gradient_checkpointing,
- trust_remote_code = trust_remote_code,
- )
+
+ model = FastBaseModel.post_patch_model(model, use_gradient_checkpointing)
model.max_seq_length = max_seq_length
- # Save to modules as well
- for module in model.modules():
- module.max_seq_length = max_seq_length
+
# Clear deleted GPU items
for _ in range(3):
gc.collect()
- if DEVICE_TYPE in ("cuda", "hip"):
- torch.cuda.empty_cache()
- elif DEVICE_TYPE == "xpu":
- torch.xpu.empty_cache()
+ torch.cuda.empty_cache()
+ pass
patch_saving_functions(model, vision = True)
- patch_peft_fast_inference(model)
# Add for_inference and for_training
- model.for_training = functools.partial(FastBaseModel.for_training, model)
+ model.for_training = functools.partial(FastBaseModel.for_training, model)
model.for_inference = functools.partial(FastBaseModel.for_inference, model)
- m = model
- while hasattr(m, "model"):
- m.for_training = functools.partial(FastBaseModel.for_training, m)
- m.for_inference = functools.partial(FastBaseModel.for_inference, m)
- m = m.model
return model
+ pass
+
@staticmethod
def post_patch_model(
model,
use_gradient_checkpointing = True,
- trust_remote_code = False,
- model_type = None,
- tokenizer = None,
- float32_mixed_precision = None,
):
full_finetuning = os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1"
- if type(float32_mixed_precision) is bool:
- # Respect whatever it was set before
- pass
- else:
- float32_mixed_precision = True
- if (
- _get_dtype(dtype_from_config(model.config)) == torch.bfloat16
- and full_finetuning
- ):
- # Use bfloat16 precision for full finetuning
- float32_mixed_precision = False
-
- # VLMs can hit DDP "marked ready twice" with re-entrant checkpointing.
- # See: https://github.com/unslothai/unsloth/issues/3713.
- use_reentrant = not is_distributed()
- if not use_reentrant:
- # Under DDP, avoid the offloaded/re-entrant checkpoint patch.
- unpatch_unsloth_gradient_checkpointing()
- unpatch_unsloth_smart_gradient_checkpointing()
- # Force native checkpoint to default to non-reentrant for downstream calls.
- _orig_checkpoint = torch_checkpoint.checkpoint
-
- def _nonre_checkpoint(function, *args, **kwargs):
- kwargs["use_reentrant"] = False
- return _orig_checkpoint(function, *args, **kwargs)
-
- torch_checkpoint.checkpoint = _nonre_checkpoint
- hf_modeling_utils.checkpoint = _nonre_checkpoint
+ float32_mixed_precision = True
+ if _get_dtype(model.config.torch_dtype) == torch.bfloat16 and full_finetuning:
+ # Use bfloat16 precision for full finetuning
+ float32_mixed_precision = False
model = prepare_model_for_training(
model,
use_gradient_checkpointing = use_gradient_checkpointing,
- use_reentrant = use_reentrant,
- full_finetuning = full_finetuning,
- train_layernorms = full_finetuning,
- train_embedding = full_finetuning,
- train_lm_head = full_finetuning,
- float32_mixed_precision = float32_mixed_precision,
- patch_modules_to_save = True,
+ use_reentrant = True,
+ full_finetuning = full_finetuning,
+ train_layernorms = full_finetuning,
+ train_embedding = full_finetuning,
+ train_lm_head = full_finetuning,
+ float32_mixed_precision = float32_mixed_precision,
)
- from transformers.trainer import Trainer
-
- if (
- Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop"
- and trust_remote_code == False
- ):
- raise RuntimeError("Unsloth: Unsuccessfully patched inner_training_loop")
+ from transformers.trainer import Trainer
+ if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop":
+ raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop')
+ pass
patch_saving_functions(model, vision = True)
- # Patch tokenizer to pad to the left
+ # Patch tokenizer to pad to the right
m = model
while hasattr(m, "model"):
if hasattr(m, "_saved_temp_tokenizer"):
if hasattr(m._saved_temp_tokenizer, "tokenizer"):
- m._saved_temp_tokenizer.tokenizer.padding_side = "left"
+ m._saved_temp_tokenizer.tokenizer.padding_side = "right"
+ pass
# Also set is_loaded_in_8bit to disable incorrect DDP
m.is_loaded_in_8bit = True if not full_finetuning else False
m = m.model
+ pass
if hasattr(m, "_saved_temp_tokenizer"):
if hasattr(m._saved_temp_tokenizer, "tokenizer"):
- m._saved_temp_tokenizer.tokenizer.padding_side = "left"
+ m._saved_temp_tokenizer.tokenizer.padding_side = "right"
+ pass
# Also set is_loaded_in_8bit to disable incorrect DDP
m.is_loaded_in_8bit = True if not full_finetuning else False
# Clear deleted GPU items
for _ in range(3):
gc.collect()
- if DEVICE_TYPE in ("cuda", "hip"):
- torch.cuda.empty_cache()
- elif DEVICE_TYPE == "xpu":
- torch.xpu.empty_cache()
+ torch.cuda.empty_cache()
+ pass
# Add for_inference and for_training
- model.for_training = functools.partial(FastBaseModel.for_training, model)
+ model.for_training = functools.partial(FastBaseModel.for_training, model)
model.for_inference = functools.partial(FastBaseModel.for_inference, model)
- m = model
- while hasattr(m, "model"):
- m.for_training = functools.partial(FastBaseModel.for_training, m)
- m.for_inference = functools.partial(FastBaseModel.for_inference, m)
- m = m.model
- # Set weight[padding_idx] = 0 for embeddings that are NOT tied with the
- # lm_head. When weights are tied, zeroing the padding row also zeros
- # the corresponding lm_head row, forcing logit = 0 for the pad token.
- # Only do this if tokenizer is defined since eos_token == pad_token sometimes!
- pad_token_id = getattr(tokenizer, "pad_token_id", None)
- lm_head = getattr(model, "lm_head", None)
- lm_head_weight = (
- getattr(lm_head, "weight", None) if lm_head is not None else None
- )
- if (
- tokenizer is not None
- and getattr(tokenizer, "eos_token_id", None) != pad_token_id
- ):
- with torch.no_grad():
- for name, module in model.named_modules():
- if type(module) is torch.nn.Embedding:
- if (
- getattr(module, "weight", None) is not None
- and getattr(module, "padding_idx", None) is not None
- ):
- if (
- module.padding_idx == pad_token_id
- and module.padding_idx < module.weight.shape[0]
- ):
- # Skip if tied to lm_head
- if (
- lm_head_weight is not None
- and module.weight.data_ptr()
- == lm_head_weight.data_ptr()
- ):
- continue
- module.weight[module.padding_idx] = 0
+
+ # Patch generate
+ if model.generate.__name__ != "unsloth_base_fast_generate":
+ model._old_generate = model.generate
+ unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__
+ model.generate = types.MethodType(unsloth_base_fast_generate, model)
return model
+ pass
+
@staticmethod
def for_inference(model):
if not hasattr(model, "parameters"):
- raise TypeError(
- "Unsloth: I think you're passing a tokenizer, not the model to for_inference!"
- )
+ raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_inference!")
def _for_inference(m):
- if hasattr(m, "gradient_checkpointing"):
- m.gradient_checkpointing = False
- if hasattr(m, "training"):
- m.training = False
+ if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = False
+ if hasattr(m, "training"): m.training = False
# Pad tokenizer to the left
- if hasattr(m, "_saved_temp_tokenizer"):
- m._saved_temp_tokenizer.padding_side = "left"
+ if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "left"
# Set a flag for generation!
m._flag_for_generation = True
-
+ pass
m = model
while hasattr(m, "model"):
_for_inference(m)
m = m.model
_for_inference(m)
- model.eval() # to turn off training on modules deeper in
-
- # Since transformers 4.53, must turn off explicitly
- for module in model.modules():
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = False
# Also disable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
- if hasattr(embeddings, "training"):
- embeddings.training = False
+ if hasattr(embeddings, "training"): embeddings.training = False
+ pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
- if hasattr(embeddings, "training"):
- embeddings.training = False
- # Must disable returning hidden states in the case for GRPO
- os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
- # Must enable returning logits
- os.environ["UNSLOTH_RETURN_LOGITS"] = "1"
- # Turn off skip guards and set stance to default
- if torch_compiler_set_stance is not None:
- torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
+ if hasattr(embeddings, "training"): embeddings.training = False
+ pass
return model
+ pass
+
@staticmethod
def for_training(model, use_gradient_checkpointing = True):
if not hasattr(model, "parameters"):
- raise TypeError(
- "Unsloth: I think you're passing a tokenizer, not the model to for_training!"
- )
+ raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_training!")
# Delete all fast inference loras
for param in model.parameters():
if hasattr(param, "_fast_lora"):
del param._fast_lora
+ pass
def _for_training(m):
- if hasattr(m, "gradient_checkpointing"):
- m.gradient_checkpointing = use_gradient_checkpointing
- if hasattr(m, "training"):
- m.training = True
+ if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = use_gradient_checkpointing
+ if hasattr(m, "training"): m.training = True
# Pad tokenizer to the left
- if hasattr(m, "_saved_temp_tokenizer"):
- m._saved_temp_tokenizer.padding_side = "right"
+ if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "right"
# Set a flag for generation!
- if hasattr(m, "_flag_for_generation"):
- try:
- # Weirdly sometimes cannot succeed so do a try except
- del m._flag_for_generation
- except:
- pass
-
+ if hasattr(m, "_flag_for_generation"): del m._flag_for_generation
+ pass
m = model
while hasattr(m, "model"):
_for_training(m)
m = m.model
_for_training(m)
- model.train() # to turn on training on modules deeper in
-
- # Since transformers 4.53, must turn on explicitly
- for module in model.modules():
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = use_gradient_checkpointing
# Also re-enable training for embeddings for NEFTune
if hasattr(model, "get_input_embeddings"):
embeddings = model.get_input_embeddings()
- if hasattr(embeddings, "training"):
- embeddings.training = True
+ if hasattr(embeddings, "training"): embeddings.training = True
+ pass
if hasattr(model, "get_output_embeddings"):
embeddings = model.get_output_embeddings()
- if hasattr(embeddings, "training"):
- embeddings.training = True
- # Can re-enable not returning logits
- os.environ["UNSLOTH_RETURN_LOGITS"] = "0"
- # Turn off skip guards and set stance to default
- if torch_compiler_set_stance is not None:
- torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
+ if hasattr(embeddings, "training"): embeddings.training = True
+ pass
return model
+ pass
+pass
diff --git a/unsloth/ollama_template_mappers.py b/unsloth/ollama_template_mappers.py
deleted file mode 100644
index 1bf77461d9..0000000000
--- a/unsloth/ollama_template_mappers.py
+++ /dev/null
@@ -1,2193 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-__all__ = [
- "OLLAMA_TEMPLATES",
- "OLLAMA_TEMPLATE_TO_MODEL_MAPPER",
- "MODEL_TO_OLLAMA_TEMPLATE_MAPPER",
-]
-
-OLLAMA_TEMPLATES = {}
-
-# =========================================== Unsloth
-
-unsloth_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .System }}{{ .System }}
-{{ end }}{{ if .Prompt }}>>> User: {{ .Prompt }}
-{{ end }}>>> Assistant: {{ .Response }}{__EOS_TOKEN__}
-"""
-PARAMETER stop "{__EOS_TOKEN__}"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-SYSTEM """You are a helpful assistant to the user"""
-'''
-
-OLLAMA_TEMPLATES["unsloth"] = unsloth_ollama
-
-# =========================================== Zephyr
-
-zephyr_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .System }}<|system|>
-{{ .System }}{__EOS_TOKEN__}
-{{ end }}{{ if .Prompt }}<|user|>
-{{ .Prompt }}{__EOS_TOKEN__}
-{{ end }}<|assistant|>
-{{ .Response }}{__EOS_TOKEN__}
-"""
-PARAMETER stop "{__EOS_TOKEN__}"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-'''
-
-OLLAMA_TEMPLATES["zephyr"] = zephyr_ollama
-
-# =========================================== ChatML
-chatml_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .System }}<|im_start|>system
-{{ .System }}<|im_end|>
-{{ end }}{{ if .Prompt }}<|im_start|>user
-{{ .Prompt }}<|im_end|>
-{{ end }}<|im_start|>assistant
-{{ .Response }}<|im_end|>
-"""
-PARAMETER stop "<|im_start|>"
-PARAMETER stop "<|im_end|>"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-'''
-
-OLLAMA_TEMPLATES["chatml"] = chatml_ollama
-
-# =========================================== Mistral-1
-# Ollama from https://www.ollama.com/library/mistral
-# Mistral v0.1 https://ollama.com/library/mistral:v0.1/blobs/22e1b2e8dc2f
-# Mistral v0.2 https://ollama.com/library/mistral:v0.2/blobs/e6836092461f
-mistral_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]"""
-PARAMETER stop "[INST]"
-PARAMETER stop "[/INST]"
-'''
-
-# mistral:v0.3 https://ollama.com/library/mistral:v0.3/blobs/1ff5b64b61b9
-# mistral-large https://ollama.com/library/mistral-large:latest/blobs/96adabcf2c08
-mistral_v03_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- if .Messages }}
-{{- range $index, $_ := .Messages }}
-{{- if eq .Role "user" }}
-{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
-{{- end }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}
-
-{{ end }}{{ .Content }}[/INST]
-{{- else if eq .Role "assistant" }}
-{{- if .Content }}{{ .Content }}
-{{- else if .ToolCalls }}[TOOL_CALLS] [
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
-{{- end }}]
-{{- end }}
-{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}} [/TOOL_RESULTS]
-{{- end }}
-{{- end }}
-{{- else }}[INST] {{ if .System }}{{ .System }}
-
-{{ end }}{{ .Prompt }}[/INST]
-{{- end }}{{ .Response }}
-{{- if .Response }}
-{{- end }}"""
-PARAMETER stop "[INST]"
-PARAMETER stop "[/INST]"
-PARAMETER stop ""
-'''
-
-# Mistral-small https://ollama.com/library/mistral-small:latest/blobs/6db27cd4e277
-mistral_small_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- range $index, $_ := .Messages }}
-{{- if eq .Role "system" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]
-{{- else if eq .Role "user" }}
-{{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]
-{{- end }}[INST]{{ .Content }}[/INST]
-{{- else if eq .Role "assistant" }}
-{{- if .Content }}{{ .Content }}
-{{- if not (eq (len (slice $.Messages $index)) 1) }}
-{{- end }}
-{{- else if .ToolCalls }}[TOOL_CALLS][
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
-{{- end }}]
-{{- end }}
-{{- else if eq .Role "tool" }}[TOOL_RESULTS]{"content": {{ .Content }}}[/TOOL_RESULTS]
-{{- end }}
-{{- end }}"""
-PARAMETER temperature 0.15
-SYSTEM """You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris. Your knowledge base was last updated on 2023-10-01. When you're not sure about some information, you say that you don't have the information and don't make up anything. If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. "What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "Where do you travel from?")"""
-'''
-
-# mistral-small-3.1 https://ollama.com/library/mistral-small3.1:latest/blobs/6db27cd4e277
-mistral_small_31_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- range $index, $_ := .Messages }}
-{{- if eq .Role "system" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]
-{{- else if eq .Role "user" }}
-{{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]
-{{- end }}[INST]{{ .Content }}[/INST]
-{{- else if eq .Role "assistant" }}
-{{- if .Content }}{{ .Content }}
-{{- if not (eq (len (slice $.Messages $index)) 1) }}
-{{- end }}
-{{- else if .ToolCalls }}[TOOL_CALLS][
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
-{{- end }}]
-{{- end }}
-{{- else if eq .Role "tool" }}[TOOL_RESULTS]{"content": {{ .Content }}}[/TOOL_RESULTS]
-{{- end }}
-{{- end }}"""
-PARAMETER num_ctx 4096
-SYSTEM """You are Mistral Small 3.1, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.
-You power an AI assistant called Le Chat.
-Your knowledge base was last updated on 2023-10-01.
-
-When you're not sure about some information, you say that you don't have the information and don't make up anything.
-If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. "What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "Where do you travel from?").
-You are always very attentive to dates, in particular you try to resolve dates (e.g. "yesterday" is {yesterday}) and when asked about information at specific dates, you discard information that is at another date.
-You follow these instructions in all languages, and always respond to the user in the language they use or request.
-Next sections describe the capabilities that you have.
-
-# WEB BROWSING INSTRUCTIONS
-
-You cannot perform any web search or access internet to open URLs, links etc. If it seems like the user is expecting you to do so, you clarify the situation and ask the user to copy paste the text directly in the chat.
-
-# MULTI-MODAL INSTRUCTIONS
-
-You have the ability to read images, but you cannot generate images. You also cannot transcribe audio files or videos.
-You cannot read nor transcribe audio files or videos."""
-'''
-
-# mistral-small-3.2 https://ollama.com/library/mistral-small3.2:latest/blobs/706c4d1164f7
-mistral_small_32_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- range $index, $_ := .Messages }}
-{{- if eq .Role "system" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]
-{{- else if eq .Role "user" }}
-{{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]
-{{- end }}[INST]{{ .Content }}[/INST]
-{{- else if eq .Role "assistant" }}
-{{- if .Content }}{{ .Content }}
-{{- if not (eq (len (slice $.Messages $index)) 1) }}
-{{- end }}
-{{- else if .ToolCalls }}
-{{- range $i, $_ := .ToolCalls }}[TOOL_CALLS]{{ .Function.Name }}[CALL_ID]{{ $i }}[ARGS]{{ .Function.Arguments }}
-{{- end }}
-{{- end }}
-{{- else if eq .Role "tool" }}[TOOL_RESULTS]{"content": {{ .Content }}}[/TOOL_RESULTS]
-{{- end }}
-{{- end }}"""
-PARAMETER temperature 0.15
-SYSTEM """You are Mistral Small 3.2, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.
-You power an AI assistant called Le Chat.
-Your knowledge base was last updated on 2023-10-01.
-
-When you're not sure about some information or when the user's request requires up-to-date or specific data, you must use the available tools to fetch the information. Do not hesitate to use tools whenever they can provide a more accurate or complete response. If no relevant tools are available, then clearly state that you don't have the information and avoid making up anything.
-If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. "What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "Where do you travel from?").
-You are always very attentive to dates, in particular you try to resolve dates and when asked about information at specific dates, you discard information that is at another date.
-You follow these instructions in all languages, and always respond to the user in the language they use or request.
-Next sections describe the capabilities that you have.
-
-# WEB BROWSING INSTRUCTIONS
-
-You cannot perform any web search or access internet to open URLs, links etc. If it seems like the user is expecting you to do so, you clarify the situation and ask the user to copy paste the text directly in the chat.
-
-# MULTI-MODAL INSTRUCTIONS
-
-You have the ability to read images, but you cannot generate images. You also cannot transcribe audio files or videos.
-You cannot read nor transcribe audio files or videos.
-
-TOOL CALLING INSTRUCTIONS
-
-You may have access to tools that you can use to fetch information or perform actions. You must use these tools in the following situations:
-
-1. When the request requires up-to-date information.
-2. When the request requires specific data that you do not have in your knowledge base.
-3. When the request involves actions that you cannot perform without tools.
-
-Always prioritize using tools to provide the most accurate and helpful response. If tools are not available, inform the user that you cannot perform the requested action at the moment."""
-'''
-
-
-# https://ollama.com/library/mixtral:latest/blobs/53d74de0d84c
-mixtral_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST] {{ .Response }}"""
-PARAMETER stop "[INST]"
-PARAMETER stop "[/INST]"
-'''
-
-# https://registry.ollama.ai/library/mistral-nemo:latest/blobs/438402ddac75
-mistral_nemo_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """
-{{- range $i, $_ := .Messages }}
-{{- if eq .Role "user" }}
-{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]
-{{- end }}[INST]{{ if and $.System (eq (len (slice $.Messages $i)) 1) }}{{ $.System }}
-
-{{ end }}{{ .Content }}[/INST]
-{{- else if eq .Role "assistant" }}
-{{- if .Content }} {{ .Content }}{{ if not (eq (len (slice $.Messages $i)) 1) }}{{ end }}
-{{- else if .ToolCalls }}[TOOL_CALLS][
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
-{{- end }}]
-{{- end }}
-{{- else if eq .Role "tool" }}[TOOL_RESULTS]{"content": {{ .Content }}}[/TOOL_RESULTS]
-{{- end }}
-{{- end }}"""
-PARAMETER stop "[INST]"
-PARAMETER stop "[/INST]"
-'''
-
-# https://ollama.com/library/codestral:latest/blobs/51707752a87c
-codestral_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """
-{{- if .Suffix }}[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
-{{- else if .Messages }}
-{{- range $index, $_ := .Messages }}
-{{- if eq .Role "user" }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}
-
-{{ end }}{{ .Content }}[/INST]
-{{- else if eq .Role "assistant" }} {{ .Content }}
-{{- end }}
-{{- end }}
-{{- else }}[INST] {{ if .System }}{{ .System }}
-
-{{ end }}{{ .Prompt }} [/INST]
-{{- end }} {{ .Response }}
-{{- if .Response }}
-{{- end }}
-"""
-PARAMETER stop "[INST]"
-PARAMETER stop "[/INST]"
-PARAMETER stop "[PREFIX]"
-PARAMETER stop "[MIDDLE]"
-PARAMETER stop "[SUFFIX]"
-'''
-
-# https://ollama.com/library/devstral:latest/blobs/ea9ec42474e0
-devstral_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- $lastUserIndex := -1 }}
-{{- range $index, $_ := .Messages }}
-{{- if eq .Role "user" }}{{ $lastUserIndex = $index }}{{ end }}
-{{- end }}
-{{- range $index, $_ := .Messages }}
-{{- if eq .Role "system" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]
-{{- else if eq .Role "user" }}
-{{- if and (eq $lastUserIndex $index) $.Tools }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]
-{{- end }}[INST]{{ .Content }}[/INST]
-{{- else if eq .Role "assistant" }}
-{{- if .Content }}{{ .Content }}
-{{- if not (eq (len (slice $.Messages $index)) 1) }}
-{{- end }}
-{{- else if .ToolCalls }}[TOOL_CALLS][
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
-{{- end }}]
-{{- end }}
-{{- else if eq .Role "tool" }}[TOOL_RESULTS]{"content": {{ .Content }}}[/TOOL_RESULTS]
-{{- end }}
-{{- end }}"""
-SYSTEM """You are Devstral, a helpful agentic model trained by Mistral AI and using the OpenHands scaffold. You can interact with a computer to solve tasks.
-
-
-Your primary role is to assist users by executing commands, modifying code, and solving technical problems effectively. You should be thorough, methodical, and prioritize quality over speed.
-* If the user asks a question, like "why is X happening", don't try to fix the problem. Just give an answer to the question.
-
-
-
-* Each action you take is somewhat expensive. Wherever possible, combine multiple actions into a single action, e.g. combine multiple bash commands into one, using sed and grep to edit/view multiple files at once.
-* When exploring the codebase, use efficient tools like find, grep, and git commands with appropriate filters to minimize unnecessary operations.
-
-
-
-* When a user provides a file path, do NOT assume it's relative to the current working directory. First explore the file system to locate the file before working on it.
-* If asked to edit a file, edit the file directly, rather than creating a new file with a different filename.
-* For global search-and-replace operations, consider using `sed` instead of opening file editors multiple times.
-
-
-
-* Write clean, efficient code with minimal comments. Avoid redundancy in comments: Do not repeat information that can be easily inferred from the code itself.
-* When implementing solutions, focus on making the minimal changes needed to solve the problem.
-* Before implementing any changes, first thoroughly understand the codebase through exploration.
-* If you are adding a lot of code to a function or file, consider splitting the function or file into smaller pieces when appropriate.
-
-
-
-* When configuring git credentials, use "openhands" as the user.name and "openhands@all-hands.dev" as the user.email by default, unless explicitly instructed otherwise.
-* Exercise caution with git operations. Do NOT make potentially dangerous changes (e.g., pushing to main, deleting repositories) unless explicitly asked to do so.
-* When committing changes, use `git status` to see all modified files, and stage all files necessary for the commit. Use `git commit -a` whenever possible.
-* Do NOT commit files that typically shouldn't go into version control (e.g., node_modules/, .env files, build directories, cache files, large binaries) unless explicitly instructed by the user.
-* If unsure about committing certain files, check for the presence of .gitignore files or ask the user for clarification.
-
-
-
-* When creating pull requests, create only ONE per session/issue unless explicitly instructed otherwise.
-* When working with an existing PR, update it with new commits rather than creating additional PRs for the same issue.
-* When updating a PR, preserve the original PR title and purpose, updating description only when necessary.
-
-
-
-1. EXPLORATION: Thoroughly explore relevant files and understand the context before proposing solutions
-2. ANALYSIS: Consider multiple approaches and select the most promising one
-3. TESTING:
- * For bug fixes: Create tests to verify issues before implementing fixes
- * For new features: Consider test-driven development when appropriate
- * If the repository lacks testing infrastructure and implementing tests would require extensive setup, consult with the user before investing time in building testing infrastructure
- * If the environment is not set up to run tests, consult with the user first before investing time to install all dependencies
-4. IMPLEMENTATION: Make focused, minimal changes to address the problem
-5. VERIFICATION: If the environment is set up to run tests, test your implementation thoroughly, including edge cases. If the environment is not set up to run tests, consult with the user first before investing time to run tests.
-
-
-
-* Only use GITHUB_TOKEN and other credentials in ways the user has explicitly requested and would expect.
-* Use APIs to work with GitHub or other platforms, unless the user asks otherwise or your task requires browsing.
-
-
-
-* When user asks you to run an application, don't stop if the application is not installed. Instead, please install the application and run the command again.
-* If you encounter missing dependencies:
- 1. First, look around in the repository for existing dependency files (requirements.txt, pyproject.toml, package.json, Gemfile, etc.)
- 2. If dependency files exist, use them to install all dependencies at once (e.g., `pip install -r requirements.txt`, `npm install`, etc.)
- 3. Only install individual packages directly if no dependency files are found or if only specific packages are needed
-* Similarly, if you encounter missing dependencies for essential tools requested by the user, install them when possible.
-
-
-
-* If you've made repeated attempts to solve a problem but tests still fail or the user reports it's still broken:
- 1. Step back and reflect on 5-7 different possible sources of the problem
- 2. Assess the likelihood of each possible cause
- 3. Methodically address the most likely causes, starting with the highest probability
- 4. Document your reasoning process
-* When you run into any major issue while executing a plan from the user, please don't try to directly work around it. Instead, propose a new plan and confirm with the user before proceeding.
-"""
-'''
-
-# https://ollama.com/library/magistral:latest/blobs/35f7a1efc383
-magistral_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1}}
-{{- if eq .Role "system" }}[SYSTEM_PROMPT]{{ .Content }}[/SYSTEM_PROMPT]
-{{- else if eq .Role "user" }}
-{{- if and (le (len (slice $.Messages $i)) 2) $.Tools }}[AVAILABLE_TOOLS]{{ $.Tools }}[/AVAILABLE_TOOLS]
-{{- end }}[INST]{{ .Content }}[/INST]
-{{- else if eq .Role "assistant" }}
-{{- if and $.IsThinkSet (and $last .Thinking) -}}
-
-{{ .Thinking }}
-
-{{ end }}
-{{- if .Content }}{{ .Content }}
-{{- end }}
-{{- if .ToolCalls }}{{ range $i, $_ := .ToolCalls }}[TOOL_CALLS]{{ .Function.Name }}[CALL_ID]{{ $i }}[ARGS]{{ .Function.Arguments }}{{ end }}
-{{- end }}
-{{- if not (eq (len (slice $.Messages $i)) 1) }}
-{{- end }}
-{{- else if eq .Role "tool" }}[TOOL_RESULTS]0[TOOL_CONTENT]{{ .Content }}[/TOOL_RESULTS]
-{{- end }}
-{{- if and $last (ne .Role "assistant") }}{{ if and $.IsThinkSet (not $.Think) -}}
-
-{{ end }}
-{{- end }}
-{{- end }}"""
-PARAMETER temperature 0.7
-PARAMETER top_p 0.95
-SYSTEM """A user will ask you to solve a task. You should first draft your thinking process (inner monologue) until you have derived the final answer. Afterwards, write a self-contained summary of your thoughts (i.e. your summary should be succinct but contain all the critical steps you needed to reach the conclusion). You should use Markdown and Latex to format your response. Write both your thoughts and summary in the same language as the task posed by the user.
-
-Your thinking process must follow the template below:
-
-Your thoughts or/and draft, like working through an exercise on scratch paper. Be as casual and as long as you want until you are confident to generate a correct answer.
-
-
-Here, provide a concise summary that reflects your reasoning and presents a clear final answer to the user.
-
-Problem:"""
-'''
-
-OLLAMA_TEMPLATES["mistral"] = mistral_ollama
-OLLAMA_TEMPLATES["mistral-v03"] = mistral_v03_ollama
-OLLAMA_TEMPLATES["mistral-small"] = mistral_small_ollama
-OLLAMA_TEMPLATES["mistral-small-31"] = mistral_small_31_ollama
-OLLAMA_TEMPLATES["mistral-small-32"] = mistral_small_32_ollama
-OLLAMA_TEMPLATES["mixtral"] = mixtral_ollama
-OLLAMA_TEMPLATES["mistral-nemo"] = mistral_nemo_ollama
-OLLAMA_TEMPLATES["devstral"] = devstral_ollama
-OLLAMA_TEMPLATES["magistral"] = magistral_ollama
-OLLAMA_TEMPLATES["codestral"] = codestral_ollama
-
-
-# =========================================== Llama-2
-# Ollama from https://www.ollama.com/library/llama3
-llama_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """[INST] <>{{ .System }}<>
-
-{{ .Prompt }} [/INST]"""
-PARAMETER stop "{__EOS_TOKEN__}"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-'''
-
-OLLAMA_TEMPLATES["llama"] = llama_ollama
-
-# =========================================== Vicuna
-# Ollama from https://www.ollama.com/library/vicuna
-vicuna_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} {__EOS_TOKEN__}"""
-PARAMETER stop "{__EOS_TOKEN__}"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-'''
-
-OLLAMA_TEMPLATES["vicuna"] = vicuna_ollama
-
-# =========================================== Vicuna Old
-vicuna_old_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .System }}{{ .System }}
-{{ end }}{{ if .Prompt }}### Human: {{ .Prompt }}
-{{ end }}### Assistant: {{ .Response }}{__EOS_TOKEN__}
-"""
-PARAMETER stop "{__EOS_TOKEN__}"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-SYSTEM """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."""
-'''
-
-OLLAMA_TEMPLATES["vicuna_old"] = vicuna_old_ollama
-OLLAMA_TEMPLATES["vicuna old"] = OLLAMA_TEMPLATES["vicuna_old"]
-
-# =========================================== Alpaca multi turn
-alpaca_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .System }}{{ .System }}
-
-{{ end }}{{ if .Prompt }}### Instruction:
-{{ .Prompt }}{{ end }}
-
-### Response:
-{{ .Response }}{__EOS_TOKEN__}
-
-"""
-PARAMETER stop "{__EOS_TOKEN__}"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-SYSTEM """Below are some instructions that describe some tasks. Write responses that appropriately complete each request."""
-'''
-
-OLLAMA_TEMPLATES["alpaca"] = alpaca_ollama
-
-# =========================================== Gemma
-# Ollama from https://www.ollama.com/library/gemma
-gemma_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """user
-{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}
-model
-{{ .Response }}
-"""
-PARAMETER repeat_penalty 1
-PARAMETER stop ""
-PARAMETER stop ""
-PARAMETER penalize_newline false
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-'''
-
-OLLAMA_TEMPLATES["gemma"] = gemma_ollama
-
-# =========================================== Gemma with ChatML instead
-gemma_chatml_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .System }}<|im_start|>system
-{{ .System }}<|im_end|>
-{{ end }}{{ if .Prompt }}<|im_start|>user
-{{ .Prompt }}<|im_end|>
-{{ end }}<|im_start|>assistant
-{{ .Response }}<|im_end|>
-"""
-PARAMETER repeat_penalty 1
-PARAMETER stop "<|im_start|>"
-PARAMETER stop "<|im_end|>"
-PARAMETER penalize_newline false
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-'''
-
-OLLAMA_TEMPLATES["gemma_chatml"] = gemma_chatml_ollama
-
-# =========================================== Gemma 2
-# Same as Gemma 1, but with sliding window attention!
-# https://ollama.com/library/gemma2/blobs/6522ca797f47
-gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
-OLLAMA_TEMPLATES["gemma2"] = gemma2_ollama
-
-# =========================================== Gemma 2 with ChatML instead
-gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
-OLLAMA_TEMPLATES["gemma2_chatml"] = gemma2_chatml_ollama
-
-# =========================================== Llama-3
-# Ollama from https://www.ollama.com/library/llama3
-llama3_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
-
-{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
-
-{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
-
-{{ .Response }}<|eot_id|>"""
-PARAMETER num_keep 24
-PARAMETER stop "<|start_header_id|>"
-PARAMETER stop "<|end_header_id|>"
-PARAMETER stop "<|eot_id|>"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-'''
-
-OLLAMA_TEMPLATES["llama-3"] = llama3_ollama
-OLLAMA_TEMPLATES["llama3"] = llama3_ollama
-
-
-# =========================================== Phi-3
-# Ollama from https://www.ollama.com/library/phi3
-phi3_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .System }}<|system|>
-{{ .System }}<|end|>
-{{ end }}{{ if .Prompt }}<|user|>
-{{ .Prompt }}<|end|>
-{{ end }}<|assistant|>
-{{ .Response }}<|end|>
-"""
-PARAMETER stop "<|end|>"
-PARAMETER stop "<|user|>"
-PARAMETER stop "<|assistant|>"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-'''
-
-OLLAMA_TEMPLATES["phi-3"] = phi3_ollama
-OLLAMA_TEMPLATES["phi-35"] = OLLAMA_TEMPLATES["phi-3"]
-OLLAMA_TEMPLATES["phi-3.5"] = OLLAMA_TEMPLATES["phi-3"]
-
-# =========================================== Llama-3.1
-"""
-No trimming in Llama 3.1 Instruct!
-Also an extra newline for Cutting Knowledge Date
-See https://colab.research.google.com/drive/1Xpqq5xpIgO-B00MQ-UccYMwN2J8QFgBM?usp=sharing
-
-Also should be
-
-import datetime
-tokenizer.apply_chat_template(
- messages,
- add_generation_prompt = True,
- tokenize = False,
- date_string = datetime.today().strftime("%d %B %Y")),
-)
-"""
-
-# Ollama from https://ollama.com/library/llama3.1 (needs updating!)
-llama31_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .Messages }}
-{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
-{{- if .System }}
-
-{{ .System }}
-{{- end }}
-{{- if .Tools }}
-
-You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original use question.
-{{- end }}
-{{- end }}<|eot_id|>
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 }}
-{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
-{{- if and $.Tools $last }}
-
-Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
-
-Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
-
-{{ $.Tools }}
-{{- end }}
-
-{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
-
-{{ end }}
-{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
-{{- if .ToolCalls }}
-
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
-{{- else }}
-
-{{ .Content }}{{ if not $last }}<|eot_id|>{{ end }}
-{{- end }}
-{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
-
-{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
-
-{{ end }}
-{{- end }}
-{{- end }}
-{{- else }}
-{{- if .System }}<|start_header_id|>system<|end_header_id|>
-
-{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
-
-{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
-
-{{ end }}{{ .Response }}{{ if .Response }}<|eot_id|>{{ end }}"""
-PARAMETER stop "<|start_header_id|>"
-PARAMETER stop "<|end_header_id|>"
-PARAMETER stop "<|eot_id|>"
-PARAMETER stop "<|eom_id|>"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-'''
-
-# https://ollama.com/ajindal/llama3.1-storm:8b/blobs/1970553b62f4
-llama_31_storm_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """
-{{ if .Messages }}
-{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
-{{- if .System }}
-
-{{ .System }}
-{{- end }}
-{{- if .Tools }}
-
-You are a function calling AI model. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into function. The user may use the terms function calling or tool use interchangeably.
-
-Here are the available functions:
-{{ json .Tools }}
-
-For each function call return a json object with function name and arguments within XML tags in the format:
-{"tool_name": , "tool_arguments": }
-{{- end }}
-{{- end }}<|eot_id|>
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 }}
-{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
-
-{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
-{{ end }}
-{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
-{{- if .ToolCalls }}
-
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
-{{- else }}
-
-{{ .Content }}{{ if not $last }}<|eot_id|>{{ end }}
-{{- end }}
-{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
-
-{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
-{{ end }}
-{{- end }}
-{{- end }}
-{{- else }}
-{{- if .System }}<|start_header_id|>system<|end_header_id|>
-
-{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
-
-{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
-
-{{ end }}{{ .Response }}{{ if .Response }}<|eot_id|>{{ end }}
-"""
-PARAMETER stop "<|start_header_id|>"
-PARAMETER stop "<|end_header_id|>"
-PARAMETER stop "<|eot_id|>"
-'''
-
-# https://ollama.com/library/nemotron:latest/blobs/4863fe3335f3
-llama_31_nemotron_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """<|start_header_id|>system<|end_header_id|>
-
-{{ if .Tools }}You have access to the following functions. To call a function, please respond with JSON for a function call. Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
-
-{{ range .Tools }}{{ . }}
-
-{{ end }}
-{{- end }}{{ .System }}<|eot_id|>
-{{- range $i, $_ := .Messages }}
-{{- $isLastMessage := eq (len (slice $.Messages $i)) 1 -}}
-{{- if eq .Role "system" }}
-{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
-
-{{ if .Content }}{{ .Content }}
-{{- else if .ToolCalls }}
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }} }
-{{- end }}
-{{- end }}
-{{- if not $isLastMessage }}<|eot_id|>
-{{- end }}
-{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
-
-{{ .Content }}<|eot_id|>
-{{- if $isLastMessage }}<|start_header_id|>assistant<|end_header_id|>
-
-{{ end }}
-{{- else }}<|start_header_id|>{{ .Role }}<|end_header_id|>
-
-{{ .Content }}<|eot_id|>
-{{- if $isLastMessage }}<|start_header_id|>assistant<|end_header_id|>
-
-{{ end }}
-{{- end }}
-{{- end }}
-"""
-PARAMETER stop "<|start_header_id|>"
-PARAMETER stop "<|end_header_id|>"
-PARAMETER stop "<|eot_id|>"
-'''
-
-# https://ollama.com/library/llama3.2-vision:latest/blobs/715415638c895a1f8e8c6
-llama_32_vision_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- range $index, $_ := .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
-
-{{ .Content }}
-{{- if gt (len (slice $.Messages $index)) 1 }}<|eot_id|>
-{{- else if ne .Role "assistant" }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
-
-{{ end }}
-{{- end }}"""
-PARAMETER temperature 0.6
-PARAMETER top_p 0.9
-'''
-
-OLLAMA_TEMPLATES["llama-3.1"] = llama31_ollama
-OLLAMA_TEMPLATES["llama-31"] = llama31_ollama
-OLLAMA_TEMPLATES["llama-31-nemotron"] = llama_31_nemotron_ollama
-OLLAMA_TEMPLATES["llama-31-storm"] = llama_31_storm_ollama
-OLLAMA_TEMPLATES["llama-32-vision"] = llama_32_vision_ollama
-
-for version in ("llama-3.2", "llama-3.3", "llama-32", "llama-33"):
- OLLAMA_TEMPLATES[version] = OLLAMA_TEMPLATES["llama-3.1"]
-
-# =========================================== tinyllama
-# tinyllama-chat https://ollama.com/library/tinyllama:latest/blobs/af0ddbdaaa26
-tinyllama_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """<|system|>
-{{ .System }}
-<|user|>
-{{ .Prompt }}
-<|assistant|>"""
-PARAMETER stop "<|system|>"
-PARAMETER stop "<|user|>"
-PARAMETER stop "<|assistant|>"
-PARAMETER stop ""
-SYSTEM """You are a helpful AI assistant."""
-'''
-
-OLLAMA_TEMPLATES["tinyllama"] = tinyllama_ollama
-
-
-# =========================================== Qwen 2/2.5
-# Qwen2 https://ollama.com/library/qwen2:latest/blobs/77c91b422cc9
-# Qwen2.5 from https://ollama.com/library/qwen2.5/blobs/eb4402837c78
-qwen25_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- if .Messages }}
-{{- if or .System .Tools }}<|im_start|>system
-{{- if .System }}
-{{ .System }}
-{{- end }}
-{{- if .Tools }}
-
-# Tools
-
-You may call one or more functions to assist with the user query.
-
-You are provided with function signatures within XML tags:
-
-{{- range .Tools }}
-{"type": "function", "function": {{ .Function }}}
-{{- end }}
-
-
-For each function call, return a json object with function name and arguments within XML tags:
-
-{"name": , "arguments": }
-
-{{- end }}<|im_end|>
-{{ end }}
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 -}}
-{{- if eq .Role "user" }}<|im_start|>user
-{{ .Content }}<|im_end|>
-{{ else if eq .Role "assistant" }}<|im_start|>assistant
-{{ if .Content }}{{ .Content }}
-{{- else if .ToolCalls }}
-{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
-{{ end }}
-{{- end }}{{ if not $last }}<|im_end|>
-{{ end }}
-{{- else if eq .Role "tool" }}<|im_start|>user
-
-{{ .Content }}
-<|im_end|>
-{{ end }}
-{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
-{{ end }}
-{{- end }}
-{{- else }}
-{{- if .System }}<|im_start|>system
-{{ .System }}<|im_end|>
-{{ end }}{{ if .Prompt }}<|im_start|>user
-{{ .Prompt }}<|im_end|>
-{{ end }}<|im_start|>assistant
-{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}"""
-PARAMETER stop "<|im_end|>"
-PARAMETER stop "<|endoftext|>"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-SYSTEM """You are Qwen, created by Alibaba Cloud. You are a helpful assistant."""
-'''
-
-# https://ollama.com/library/qwen2.5-coder:latest/blobs/1e65450c3067
-qwen_25_coder_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
-{{- else if .Messages }}
-{{- if or .System .Tools }}<|im_start|>system
-{{- if .System }}
-{{ .System }}
-{{- end }}
-{{- if .Tools }}
-
-# Tools
-
-You may call one or more functions to assist with the user query.
-
-You are provided with function signatures within :
-
-{{- range .Tools }}
-{"type": "function", "function": {{ .Function }}}
-{{- end }}
-
-
-For each function call, return a json object with function name and arguments within with NO other text. Do not include any backticks or ```json.
-
-{"name": , "arguments": }
-
-{{- end }}<|im_end|>
-{{ end }}
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 -}}
-{{- if eq .Role "user" }}<|im_start|>user
-{{ .Content }}<|im_end|>
-{{ else if eq .Role "assistant" }}<|im_start|>assistant
-{{ if .Content }}{{ .Content }}
-{{- else if .ToolCalls }}
-{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
-{{ end }}
-{{- end }}{{ if not $last }}<|im_end|>
-{{ end }}
-{{- else if eq .Role "tool" }}<|im_start|>user
-
-{{ .Content }}
-<|im_end|>
-{{ end }}
-{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
-{{ end }}
-{{- end }}
-{{- else }}
-{{- if .System }}<|im_start|>system
-{{ .System }}<|im_end|>
-{{ end }}{{ if .Prompt }}<|im_start|>user
-{{ .Prompt }}<|im_end|>
-{{ end }}<|im_start|>assistant
-{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}"""
-SYSTEM """You are Qwen, created by Alibaba Cloud. You are a helpful assistant."""
-'''
-
-# https://ollama.com/library/qwen2.5vl:latest/blobs/a242d8dfdc8f
-qwen_25_vl_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- if .System -}}
-<|im_start|>system
-{{ .System }}<|im_end|>
-{{- end -}}
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 -}}
-{{- if eq .Role "user" }}
-<|im_start|>user
-{{ .Content }}<|im_end|>
-{{- else if eq .Role "assistant" }}
-<|im_start|>assistant
-{{ if .Content }}{{ .Content }}{{ if not $last }}<|im_end|>
-{{- else -}}<|im_end|>{{- end -}}
-{{- end -}}
-{{- end -}}
-{{- if and (ne .Role "assistant") $last }}
-<|im_start|>assistant
-{{ end -}}
-{{- end }}"""
-PARAMETER temperature 0.0001
-SYSTEM """You are a helpful assistant."""
-'''
-
-# https://ollama.com/library/openthinker:latest/blobs/32695b892af8
-openthinker_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 -}}
-<|im_start|>{{ .Role }}<|im_sep|>
-{{ .Content }}{{ if not $last }}<|im_end|>
-{{ end }}
-{{- if and (ne .Role "assistant") $last }}<|im_end|>
-<|im_start|>assistant<|im_sep|>
-{{ end }}
-{{- end }}"""
-'''
-
-
-OLLAMA_TEMPLATES["qwen-25"] = qwen25_ollama
-OLLAMA_TEMPLATES["qwen-2.5"] = qwen25_ollama
-OLLAMA_TEMPLATES["qwen-25-coder"] = qwen_25_coder_ollama
-OLLAMA_TEMPLATES["qwen-25-vl"] = qwen_25_vl_ollama
-OLLAMA_TEMPLATES["openthinker"] = openthinker_ollama
-OLLAMA_TEMPLATES["qwen-2"] = qwen25_ollama
-
-# =========================================== Phi-4
-_phi4_ollama_template = (
- "{{ if .System }}<|im_start|><|system|><|im_sep|>{{ .System }}<|im_end|>{{ end }}"
- "{{ if .Prompt }}<|im_start|><|user|><|im_sep|>{{ .Prompt }}<|im_end|>{{ end }}"
- "<|im_start|><|assistant|><|im_sep|>{{ .Response }}<|im_end|>"
-)
-
-# Ollama from https://www.ollama.com/library/phi4 is different
-phi_4_ollama = f'''
-FROM {{__FILE_LOCATION__}}
-TEMPLATE """{_phi4_ollama_template}"""
-PARAMETER stop "<|im_end|>"
-PARAMETER stop "<|im_start|>"
-PARAMETER stop "<|im_sep|>"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-'''
-
-# https://ollama.com/library/phi4-reasoning:latest/blobs/32695b892af8
-phi_4_reasoning_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 -}}
-<|im_start|>{{ .Role }}<|im_sep|>
-{{ .Content }}{{ if not $last }}<|im_end|>
-{{ end }}
-{{- if and (ne .Role "assistant") $last }}<|im_end|>
-<|im_start|>assistant<|im_sep|>
-{{ end }}
-{{- end }}"""
-PARAMETER stop "<|im_start|>"
-PARAMETER stop "<|im_end|>"
-PARAMETER stop "<|im_sep|>"
-'''
-
-# https://ollama.com/library/phi4-mini:latest/blobs/813f53fdc6e5
-phi_4_mini_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- if or .System .Tools }}<|system|>{{ if .System }}{{ .System }}{{ end }}
-{{- if .Tools }}{{ if not .System }}You are a helpful assistant with some tools.{{ end }}<|tool|>{{ .Tools }}<|/tool|><|end|>
-{{- end }}
-{{- end }}
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 -}}
-{{- if ne .Role "system" }}<|{{ .Role }}|>{{ .Content }}
-{{- if .ToolCalls }}<|tool_call|>[{{ range .ToolCalls }}{"name":"{{ .Function.Name }}","arguments":{{ .Function.Arguments }}{{ end }}]<|/tool_call|>
-{{- end }}
-{{- if not $last }}<|end|>
-{{- end }}
-{{- if and (ne .Role "assistant") $last }}<|end|><|assistant|>{{ end }}
-{{- end }}
-{{- end }}"""
-'''
-
-# https://ollama.com/library/phi4-mini-reasoning:latest/blobs/c895a1f8e8c6
-phi_4_mini_reasoning_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """
-{{- if .System }}<|system|>{{ .System }}
-{{- end }}
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 -}}
-{{- if ne .Role "system" }}<|{{ .Role }}|>{{ .Content }}
-{{- if not $last }}<|end|>
-{{- end }}
-{{- if and (ne .Role "assistant") $last }}<|end|><|assistant|>{{ end }}
-{{- end }}
-{{- end }}"""
-SYSTEM """Your name is Phi, an AI math expert developed by Microsoft."""
-'''
-OLLAMA_TEMPLATES["phi-4"] = phi_4_ollama
-OLLAMA_TEMPLATES["phi-4-reasoning"] = phi_4_reasoning_ollama
-OLLAMA_TEMPLATES["phi-4-mini"] = phi_4_mini_ollama
-OLLAMA_TEMPLATES["phi-4-mini-reasoning"] = phi_4_mini_reasoning_ollama
-
-
-# =========================================== Gemma-3
-# Ollama from https://ollama.com/library/gemma3/blobs/e0a42594d802
-gemma3_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 }}
-{{- if or (eq .Role "user") (eq .Role "system") }}user
-{{ .Content }}
-{{ if $last }}model
-{{ end }}
-{{- else if eq .Role "assistant" }}model
-{{ .Content }}{{ if not $last }}
-{{ end }}
-{{- end }}
-{{- end }}"""
-PARAMETER stop ""
-PARAMETER stop ""
-PARAMETER temperature 1.0
-PARAMETER min_p 0.0
-PARAMETER top_k 64
-PARAMETER top_p 0.95
-PARAMETER num_predict 32768
-'''
-
-# https://ollama.com/library/gemma3:270m/blobs/4b19ac7dd2fb
-gemma3_270m_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- $systemPromptAdded := false }}
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 }}
-{{- if eq .Role "user" }}user
-{{- if (and (not $systemPromptAdded) $.System) }}
-{{- $systemPromptAdded = true }}
-{{ $.System }}
-{{ end }}
-{{ .Content }}
-{{ if $last }}model
-{{ end }}
-{{- else if eq .Role "assistant" }}model
-{{ .Content }}{{ if not $last }}
-{{ end }}
-{{- end }}
-{{- end }}
-"""
-PARAMETER stop ""
-PARAMETER top_k 64
-PARAMETER top_p 0.95
-'''
-
-OLLAMA_TEMPLATES["gemma-3"] = gemma3_ollama
-OLLAMA_TEMPLATES["gemma3"] = gemma3_ollama
-OLLAMA_TEMPLATES["gemma3-270m"] = gemma3_270m_ollama
-
-
-# =========================================== Qwen-3
-# Ollama template for Qwen-3 (see https://ollama.com/library/qwen3/blobs/eb4402837c78)
-qwen3_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- if .Messages }}
-{{- if or .System .Tools }}<|im_start|>system
-{{- if .System }}
-{{ .System }}
-{{- end }}
-{{- if .Tools }}
-
-# Tools
-
-You may call one or more functions to assist with the user query.
-
-You are provided with function signatures within XML tags:
-
-{{- range .Tools }}
-{"type": "function", "function": {{ .Function }}}
-{{- end }}
-
-
-For each function call, return a json object with function name and arguments within XML tags:
-
-{"name": , "arguments": }
-
-{{- end }}<|im_end|>
-{{ end }}
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 -}}
-{{- if eq .Role "user" }}<|im_start|>user
-{{ .Content }}<|im_end|>
-{{ else if eq .Role "assistant" }}<|im_start|>assistant
-{{ if .Content }}{{ .Content }}
-{{- else if .ToolCalls }}
-{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
-{{ end }}
-{{- end }}{{ if not $last }}<|im_end|>
-{{ end }}
-{{- else if eq .Role "tool" }}<|im_start|>user
-
-{{ .Content }}
-<|im_end|>
-{{ end }}
-{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
-{{ end }}
-{{- end }}
-{{- else }}
-{{- if .System }}<|im_start|>system
-{{ .System }}<|im_end|>
-{{ end }}{{ if .Prompt }}<|im_start|>user
-{{ .Prompt }}<|im_end|>
-{{ end }}<|im_start|>assistant
-{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}"""
-PARAMETER stop "<|im_end|>"
-PARAMETER stop "<|im_start|>"
-PARAMETER temperature 0.6
-PARAMETER min_p 0.0
-PARAMETER top_k 20
-PARAMETER top_p 0.95
-PARAMETER repeat_penalty 1
-'''
-
-qwen3_template_eos_token = "<|im_end|>"
-OLLAMA_TEMPLATES["qwen-3"] = qwen3_ollama
-OLLAMA_TEMPLATES["qwen3"] = qwen3_ollama
-
-
-# =========================================== Gemma-3n
-# Ollama from https://ollama.com/library/gemma3n/blobs/e0a42594d802
-gemma3n_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 }}
-{{- if or (eq .Role "user") (eq .Role "system") }}user
-{{ .Content }}
-{{ if $last }}model
-{{ end }}
-{{- else if eq .Role "assistant" }}model
-{{ .Content }}{{ if not $last }}
-{{ end }}
-{{- end }}
-{{- end }}"""
-'''
-
-OLLAMA_TEMPLATES["gemma-3n"] = gemma3n_ollama
-OLLAMA_TEMPLATES["gemma3n"] = gemma3n_ollama
-
-# =========================================== GPT-OSS
-
-# Ollama from https://ollama.com/library/gpt-oss:latest/blobs/fa6710a93d78
-gptoss_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
-Knowledge cutoff: 2024-06
-Current date: {{ currentDate }}
-{{- if and .IsThinkSet .Think (ne .ThinkLevel "") }}
-
-Reasoning: {{ .ThinkLevel }}
-{{- else if or (not .IsThinkSet) (and .IsThinkSet .Think) }}
-
-Reasoning: medium
-{{- end }}
-
-{{- $hasNonBuiltinTools := false }}
-{{- if .Tools -}}
-{{- $hasBrowserSearch := false }}
-{{- $hasBrowserOpen := false }}
-{{- $hasBrowserFind := false }}
-{{- $hasPython := false }}
- {{- range .Tools }}
- {{- if eq .Function.Name "browser.search" -}}{{- $hasBrowserSearch = true -}}
- {{- else if eq .Function.Name "browser.open" -}}{{- $hasBrowserOpen = true -}}
- {{- else if eq .Function.Name "browser.find" -}}{{- $hasBrowserFind = true -}}
- {{- else if eq .Function.Name "python" -}}{{- $hasPython = true -}}
- {{- else }}{{ $hasNonBuiltinTools = true -}}
- {{- end }}
- {{- end }}
-{{- if or $hasBrowserSearch $hasBrowserOpen $hasBrowserFind $hasPython }}
-
-# Tools
-{{- if or $hasBrowserSearch $hasBrowserOpen $hasBrowserFind }}
-
-## browser
-
-// Tool for browsing.
-// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.
-// Cite information from the tool using the following format:
-// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.
-// Do not quote more than 10 words directly from the tool output.
-// sources=web (default: web)
-namespace browser {
-{{- if $hasBrowserSearch }}
-
-// Searches for information related to `query` and displays `topn` results.
-type search = (_: {
-query: string,
-topn?: number, // default: 10
-source?: string,
-}) => any;
-{{- end }}
-{{- if $hasBrowserOpen }}
-
-// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.
-// Valid link ids are displayed with the formatting: `【{id}†.*】`.
-// If `cursor` is not provided, the most recent page is implied.
-// If `id` is a string, it is treated as a fully qualified URL associated with `source`.
-// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.
-// Use this function without `id` to scroll to a new location of an opened page.
-type open = (_: {
-id?: number | string, // default: -1
-cursor?: number, // default: -1
-loc?: number, // default: -1
-num_lines?: number, // default: -1
-view_source?: boolean, // default: false
-source?: string,
-}) => any;
-{{- end }}
-{{- if $hasBrowserFind }}
-
-// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.
-type find = (_: {
-pattern: string,
-cursor?: number, // default: -1
-}) => any;
-{{- end }}
-
-} // namespace browser
-{{- end }}{{/* end if has browser tools */}}
-{{- if $hasPython }}
-
-## python
-
-Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).
-
-When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.
-{{- end }}{{/* end if hasPython */}}
-{{- end }}{{/* end if has any built-in tools */}}
-{{- end }}{{/* end if .Tools */}}
-
-# Valid channels: analysis, commentary, final. Channel must be included for every message.{{ if $hasNonBuiltinTools }}
-Calls to these tools must go to the commentary channel: 'functions'.
-{{- end -}}<|end|>{{/* end of system */ -}}
-{{- if or $hasNonBuiltinTools .System -}}
-<|start|>developer<|message|>{{- if $hasNonBuiltinTools }}# Tools
-
-## functions
-
-namespace functions {
-{{- range .Tools }}
-{{- if not (or (eq .Function.Name "browser.search") (eq .Function.Name "browser.open") (eq .Function.Name "browser.find") (eq .Function.Name "python")) }}
-{{if .Function.Description }}
-// {{ .Function.Description }}
-{{- end }}
-{{- if and .Function.Parameters.Properties (gt (len .Function.Parameters.Properties) 0) }}
-type {{ .Function.Name }} = (_: {
-{{- range $name, $prop := .Function.Parameters.Properties }}
-{{- if $prop.Description }}
- // {{ $prop.Description }}
-{{- end }}
- {{ $name }}: {{ if gt (len $prop.Type) 1 }}{{ range $i, $t := $prop.Type }}{{ if $i }} | {{ end }}{{ $t }}{{ end }}{{ else }}{{ index $prop.Type 0 }}{{ end }},
-{{- end }}
-}) => any;
-{{- else }}
-type {{ .Function.Name }} = () => any;
-{{- end }}
-{{- end }}{{/* end if not browser tool */}}
-{{- end }}{{/* end of range .Tools */}}
-
-} // namespace functions
-{{- end }}{{/* end if hasNonBuiltinTools */}}
-{{- if .System}}
-
-# Instructions
-
-{{ .System }}
-{{- end -}}
-<|end|>
-{{- end -}}
-{{- /* Find the index of the last user message */ -}}
-{{- $lastUserIdx := -1 }}
-{{- $prefillingContent := false }}
-{{- $prefillingThinkingOnly := false }}
-{{- range $i, $msg := .Messages }}
- {{- $last := eq (len (slice $.Messages $i)) 1 -}}
- {{- if eq $msg.Role "user" }}
- {{- $lastUserIdx = $i }}
- {{- end -}}
- {{- if and $last (eq $msg.Role "assistant") (gt (len $msg.Content) 0) }}
- {{- $prefillingContent = true }}
- {{- else if and $last (eq $msg.Role "assistant") (gt (len $msg.Thinking) 0) }}
- {{- $prefillingThinkingOnly = true }}
- {{- end }}
-{{- end -}}
-{{- /* Now render messages */ -}}
-{{- range $i, $msg := .Messages }}
- {{- $last := eq (len (slice $.Messages $i)) 1 -}}
- {{- if (ne $msg.Role "system") -}}
- {{- if eq $msg.Role "tool" -}}
- {{- if or (eq $msg.ToolName "python") (eq $msg.ToolName "browser.search") (eq $msg.ToolName "browser.open") (eq $msg.ToolName "browser.find") -}}
- <|start|>{{ $msg.ToolName }} to=assistant<|message|>{{ $msg.Content }}<|end|>
- {{- else -}}
- <|start|>functions.{{ $msg.ToolName }} to=assistant<|message|>{{ $msg.Content }}<|end|>
- {{- end -}}
- {{- else if eq $msg.Role "assistant" -}}
- {{- if and $msg.Thinking (gt $i $lastUserIdx) -}}{{- /* Show thinking only after last user message */ -}}
- <|start|>assistant<|channel|>analysis<|message|>{{ $msg.Thinking }}{{- if not $prefillingThinkingOnly -}}<|end|>{{- end -}}
- {{- end -}}
- {{- if gt (len $msg.Content) 0 -}}
- <|start|>assistant<|channel|>final<|message|>{{ $msg.Content }}{{- if not $prefillingContent -}}<|end|>{{- end -}}
- {{- end -}}
- {{- if gt (len $msg.ToolCalls) 0 -}}
- {{- range $j, $toolCall := $msg.ToolCalls -}}
- {{- $isBuiltin := or (eq $toolCall.Function.Name "python") (eq $toolCall.Function.Name "browser.search") (eq $toolCall.Function.Name "browser.open") (eq $toolCall.Function.Name "browser.find") -}}
- <|start|>assistant<|channel|>{{ if $isBuiltin }}analysis{{ else }}commentary{{ end }} to={{ if not $isBuiltin}}functions.{{end}}{{ $toolCall.Function.Name }} <|constrain|>json<|message|>{{ $toolCall.Function.Arguments }}<|call|>
- {{- end -}}
- {{- end -}}
- {{- else if eq $msg.Role "user" -}}
- <|start|>{{ $msg.Role }}<|message|>{{ $msg.Content }}<|end|>
- {{- end }}
- {{- else }}
- {{- end }}
-{{- end -}}
-{{- if not (or $prefillingContent $prefillingThinkingOnly) -}}
-<|start|>assistant
-{{- end -}}"""
-PARAMETER temperature 1.0
-PARAMETER top_k 0
-PARAMETER top_p 1.0
-'''
-
-OLLAMA_TEMPLATES["gpt-oss"] = gptoss_ollama
-OLLAMA_TEMPLATES["gptoss"] = gptoss_ollama
-
-
-# =========================================== Qwen3
-
-# Ollama from https://ollama.com/library/qwen3/blobs/53e4ea15e8f5
-qwen3_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """
-{{- $lastUserIdx := -1 -}}
-{{- range $idx, $msg := .Messages -}}
-{{- if eq $msg.Role "user" }}{{ $lastUserIdx = $idx }}{{ end -}}
-{{- end }}
-{{- if or .System .Tools }}<|im_start|>system
-{{ if .System }}
-{{ .System }}
-{{- end }}
-{{- if .Tools }}
-
-# Tools
-
-You may call one or more functions to assist with the user query.
-
-You are provided with function signatures within XML tags:
-
-{{- range .Tools }}
-{"type": "function", "function": {{ .Function }}}
-{{- end }}
-
-
-For each function call, return a json object with function name and arguments within XML tags:
-
-{"name": , "arguments": }
-
-{{- end -}}
-<|im_end|>
-{{ end }}
-{{- range $i, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $i)) 1 -}}
-{{- if eq .Role "user" }}<|im_start|>user
-{{ .Content }}<|im_end|>
-{{ else if eq .Role "assistant" }}<|im_start|>assistant
-{{ if (and $.IsThinkSet (and .Thinking (or $last (gt $i $lastUserIdx)))) -}}
-{{ .Thinking }}
-{{ end -}}
-{{ if .Content }}{{ .Content }}
-{{- else if .ToolCalls }}
-{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
-{{ end }}
-{{- end }}{{ if not $last }}<|im_end|>
-{{ end }}
-{{- else if eq .Role "tool" }}<|im_start|>user
-
-{{ .Content }}
-<|im_end|>
-{{ end }}
-{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
-{{ end }}
-{{- end }}
-"""
-'''
-
-OLLAMA_TEMPLATES["qwen3-instruct"] = qwen3_ollama
-OLLAMA_TEMPLATES["qwen3-thinking"] = qwen3_ollama
-
-
-# =========================================== Starling-LM
-
-
-# Ollama from https://ollama.com/library/starling-lm:7b/blobs/4b21bfc435b4
-starling_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>
-{{ end }}{{ if .Prompt }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>
-{{ end }}GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>"""
-PARAMETER stop "<|end_of_turn|>"
-PARAMETER stop "GPT4 Correct User:"
-PARAMETER stop "GPT4 Correct Assistant:"
-PARAMETER stop "GPT4 Correct System:"
-PARAMETER temperature 1.5
-PARAMETER min_p 0.1
-'''
-
-OLLAMA_TEMPLATES["starling"] = starling_ollama
-
-
-# =========================================== Yi-chat
-
-
-# Ollama from https://ollama.com/library/yi:34b-chat/blobs/62fbfd9ed093
-yi_chat_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{ if .System }}<|im_start|>system
-{{ .System }}<|im_end|>
-{{ end }}{{ if .Prompt }}<|im_start|>user
-{{ .Prompt }}<|im_end|>
-{{ end }}<|im_start|>assistant
-{{ .Response }}<|im_end|>"""
-'''
-
-OLLAMA_TEMPLATES["yi-chat"] = yi_chat_ollama
-
-# =========================================== Granite
-
-# Ollama from https://ollama.com/library/granite3.2:latest/blobs/3e7ca51acd6e
-granite_32_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- /*
-
------- MESSAGE PARSING ------
-
-*/}}
-{{- /*
-Declare the prompt structure variables to be filled in from messages
-*/}}
-{{- $system := "" }}
-{{- $documents := "" }}
-{{- $documentCounter := 0 }}
-{{- $thinking := false }}
-{{- $citations := false }}
-{{- $hallucinations := false }}
-{{- $length := "" }}
-
-{{- /*
-Loop over messages and look for a user-provided system message and documents
-*/ -}}
-{{- range .Messages }}
-
- {{- /* User defined system prompt(s) */}}
- {{- if (eq .Role "system")}}
- {{- if (ne $system "") }}
- {{- $system = print $system " " }}
- {{- end}}
- {{- $system = print $system .Content }}
- {{- end}}
-
- {{- /*
- NOTE: Since Ollama collates consecutive roles, for control and documents, we
- work around this by allowing the role to contain a qualifier after the
- role string.
- */ -}}
-
- {{- /* Role specified thinking */ -}}
- {{- if (and (ge (len .Role) 7) (eq (slice .Role 0 7) "control")) }}
- {{- if (eq .Content "thinking")}}{{- $thinking = true }}{{- end}}
- {{- if (eq .Content "citations")}}{{- $citations = true }}{{- end}}
- {{- if (eq .Content "hallucinations")}}{{- $hallucinations = true }}{{- end}}
- {{- if (and (ge (len .Content) 7) (eq (slice .Content 0 7) "length "))}}
- {{- $length = print ` {"length": "` (slice .Content 7) `"}` }}
- {{- end}}
- {{- end}}
-
- {{- /* Role specified document */ -}}
- {{- if (and (ge (len .Role) 8) (eq (slice .Role 0 8) "document")) }}
- {{- if (ne $documentCounter 0)}}
- {{- $documents = print $documents " "}}
- {{- end}}
- {{- $identifier := $documentCounter}}
- {{- if (ge (len .Role) 9) }}
- {{- $identifier = (slice .Role 8)}}
- {{- end}}
- {{- $documents = print $documents "Document " $identifier "" .Content}}
- {{- $documentCounter = len (printf "a%*s" $documentCounter "")}}
- {{- end}}
-{{- end}}
-
-{{- /*
-If no user message provided, build the default system message
-*/ -}}
-{{- if eq $system "" }}
- {{- $system = "Knowledge Cutoff Date: April 2024.You are Granite, developed by IBM."}}
-
- {{- /* Add Tools prompt */}}
- {{- if .Tools }}
- {{- $system = print $system " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." }}
- {{- end}}
-
- {{- /* Add documents prompt */}}
- {{- if $documents }}
- {{- if .Tools }}
- {{- $system = print $system " "}}
- {{- else }}
- {{- $system = print $system " "}}
- {{- end}}
- {{- $system = print $system "Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." }}
- {{- if $citations}}
- {{- $system = print $system " In your response, use the symbols and to indicate when a fact comes from a document in the search result, e.g 0 for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list."}}
- {{- end}}
- {{- if $hallucinations}}
- {{- $system = print $system "Finally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents."}}
- {{- end}}
- {{- end}}
-
- {{- /* Prompt without tools or documents */}}
- {{- if (and (not .Tools) (not $documents)) }}
- {{- $system = print $system " You are a helpful AI assistant."}}
- {{- if $thinking}}
- {{- $system = print $system "Respond to every user query in a comprehensive and detailed way. You can write down your thought process before responding. Write your thoughts after 'Here is my thought process:' and write your response after 'Here is my response:' for each user query."}}
- {{- end}}
- {{- end}}
-
- {{- /* Add thinking prompt if no tools or documents */}}
- {{- if (and $thinking (not .Tools) (not $documents)) }}
- {{- $system = print $system " You are a helpful AI assistant.Respond to every user query in a comprehensive and detailed way. You can write down your thought process before responding. Write your thoughts after 'Here is my thought process:' and write your response after 'Here is my response:' for each user query."}}
- {{- end}}
-
-{{- end}}
-{{- /*
-
------- TEMPLATE EXPANSION ------
-
-*/}}
-{{- /* System Prompt */ -}}
-<|start_of_role|>system<|end_of_role|>{{- $system }}<|end_of_text|>
-
-{{- /* Tools */ -}}
-{{- if .Tools }}
-<|start_of_role|>tools<|end_of_role|>[
-{{- range $index, $_ := .Tools }}
-{{ . }}
-{{- if and (ne (len (slice $.Tools $index)) 1) (gt (len $.Tools) 1) }},
-{{- end}}
-{{- end }}
-]
-{{- end}}
-
-{{- /* Documents */ -}}
-{{- if $documents }}
-<|start_of_role|>documents<|end_of_role|>
-{{ $documents }}<|end_of_text|>
-{{- end}}
-
-{{- /* Standard Messages */}}
-{{- range $index, $_ := .Messages }}
-{{- if (and
- (ne .Role "system")
- (or (lt (len .Role) 7) (ne (slice .Role 0 7) "control"))
- (or (lt (len .Role) 8) (ne (slice .Role 0 8) "document"))
-)}}
-<|start_of_role|>
-{{- if eq .Role "tool" }}tool_response
-{{- else }}{{ .Role }}
-{{- end }}<|end_of_role|>
-{{- if .Content }}{{ .Content }}
-{{- else if .ToolCalls }}<|tool_call|>
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
-{{- end }}
-{{- end }}
-{{- if eq (len (slice $.Messages $index)) 1 }}
-{{- if eq .Role "assistant" }}
-{{- else }}<|end_of_text|>
-<|start_of_role|>assistant<|end_of_role|>
-{{- end -}}
-{{- else }}<|end_of_text|>
-{{- end }}
-{{- end }}
-{{- end }}
-"""
-'''
-
-# granite-3.2-vision https://ollama.com/library/granite3.2-vision:latest/blobs/579046ba1157
-granite_32_vision_ollama = '''
-FROM {__FILE_LOCATION__}
-TEMPLATE """{{- /* Tools */ -}}
-{{- if .Tools -}}
-<|start_of_role|>available_tools<|end_of_role|>
-{{- range $index, $_ := .Tools }}
-{{- $last := eq (len (slice $.Tools $index)) 1 }}
-{{ . }}
-{{- if not $last }}
-{{ end}}
-{{- end -}}
-<|end_of_text|>
-{{ end }}
-
-{{- /* System Prompt */ -}}
-{{- if and (gt (len .Messages) 0) (eq (index .Messages 0).Role "system") -}}
-<|system|>
-{{(index .Messages 0).Content}}
-{{- else -}}
-<|system|>
-A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
-{{- end }}
-
-{{- /*Main message loop*/ -}}
-{{- range $index, $_ := .Messages }}
-{{- $last := eq (len (slice $.Messages $index)) 1 }}
-{{- if eq .Role "system" }}
-
-{{- else if eq .Role "user" }}
-<|user|>
-{{.Content}}
-
-{{- else if eq .Role "assistant" }}
-<|assistant|>
-{{- if .Content }}
-{{.Content}}
-<|end_of_text|>
-{{ end }}
-
-{{- else if eq .Role "assistant_tool_call" }}
-<|start_of_role|>assistant<|end_of_role|><|tool_call|>{{.Content}}<|end_of_text|>
-
-{{- else if eq .Role "tool_response" }}
-<|start_of_role|>tool_response<|end_of_role|>{{.Content}}<|end_of_text|>
-{{- end }}
-
-{{- /* Add generation prompt */ -}}
-{{ if $last }}
-{{- if eq .Role "assistant" }}
-{{- else }}
-<|assistant|>
-{{- end }}
-{{- end }}
-{{- end }}"""
-PARAMETER num_ctx 16384
-PARAMETER temperature 0
-SYSTEM """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
-'''
-
-OLLAMA_TEMPLATES["granite-32"] = granite_32_ollama
-OLLAMA_TEMPLATES["granite-32-vision"] = granite_32_vision_ollama
-
-
-OLLAMA_TEMPLATE_TO_MODEL_MAPPER = {
- "phi-3.5": (
- "unsloth/Phi-3.5-mini-instruct-bnb-4bit",
- "unsloth/Phi-3.5-mini-instruct",
- "microsoft/Phi-3.5-mini-instruct",
- ),
- "phi-3": (
- "unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
- "unsloth/Phi-3-mini-4k-instruct",
- "microsoft/Phi-3-mini-4k-instruct",
- "unsloth/Phi-3-medium-4k-instruct-bnb-4bit",
- "unsloth/Phi-3-medium-4k-instruct",
- "microsoft/Phi-3-medium-4k-instruct",
- "unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit",
- "unsloth/Phi-3-mini-4k-instruct-v0",
- ),
- "phi-4": (
- "unsloth/phi-4-unsloth-bnb-4bit",
- "unsloth/phi-4",
- "microsoft/phi-4",
- "unsloth/phi-4-bnb-4bit",
- ),
- "phi-4-reasoning": (
- "unsloth/phi-4-reasoning-unsloth-bnb-4bit",
- "unsloth/phi-4-reasoning",
- "microsoft/Phi-4-reasoning",
- "unsloth/phi-4-reasoning-bnb-4bit",
- "unsloth/phi-4-reasoning-plus-unsloth-bnb-4bit",
- "unsloth/phi-4-reasoning-plus",
- "microsoft/Phi-4-reasoning-plus",
- "unsloth/phi-4-reasoning-plus-bnb-4bit",
- ),
- "phi-4-mini": (
- "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit",
- "unsloth/Phi-4-mini-instruct",
- "microsoft/Phi-4-mini-instruct",
- "unsloth/Phi-4-mini-instruct-bnb-4bit",
- ),
- "phi-4-mini-reasoning": (
- "unsloth/phi-4-mini-reasoning-unsloth-bnb-4bit",
- "unsloth/phi-4-mini-reasoning",
- "microsoft/Phi-4-mini-reasoning",
- "unsloth/phi-4-mini-reasoning-bnb-4bit",
- ),
- "mistral": (
- "unsloth/mistral-7b-instruct-v0.1-bnb-4bit",
- "unsloth/mistral-7b-instruct-v0.1",
- "mistralai/Mistral-7B-Instruct-v0.1",
- "unsloth/mistral-7b-instruct-v0.2-bnb-4bit",
- "unsloth/mistral-7b-instruct-v0.2",
- "mistralai/Mistral-7B-Instruct-v0.2",
- ),
- "mistral-v03": (
- "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
- "unsloth/mistral-7b-instruct-v0.3",
- "mistralai/Mistral-7B-Instruct-v0.3",
- "unsloth/Mistral-Large-Instruct-2407-bnb-4bit",
- "mistralai/Mistral-Large-Instruct-2407",
- ),
- "mistral-small": (
- "unsloth/Mistral-Small-Instruct-2409-bnb-4bit",
- "unsloth/Mistral-Small-Instruct-2409",
- "mistralai/Mistral-Small-Instruct-2409",
- "unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit",
- "unsloth/Mistral-Small-24B-Instruct-2501",
- "mistralai/Mistral-Small-24B-Instruct-2501",
- "unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit",
- ),
- "mistral-small-31": (
- "unsloth/Mistral-Small-3.1-24B-Instruct-2503-unsloth-bnb-4bit",
- "unsloth/Mistral-Small-3.1-24B-Instruct-2503",
- "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
- "unsloth/Mistral-Small-3.1-24B-Instruct-2503-bnb-4bit",
- ),
- "mistral-small-32": (
- "unsloth/Mistral-Small-3.2-24B-Instruct-2506-unsloth-bnb-4bit",
- "unsloth/Mistral-Small-3.2-24B-Instruct-2506",
- "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
- "unsloth/Mistral-Small-3.2-24B-Instruct-2506-bnb-4bit",
- ),
- "mixtral": (
- "unsloth/Mixtral-8x7B-Instruct-v0.1-unsloth-bnb-4bit",
- "unsloth/Mixtral-8x7B-Instruct-v0.1",
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
- "unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit",
- ),
- "mistral-nemo": (
- "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit",
- "unsloth/Mistral-Nemo-Instruct-2407",
- "mistralai/Mistral-Nemo-Instruct-2407",
- ),
- "codestral": (
- "mistralai/Codestral-22B-v0.1",
- "mistral-community/Codestral-22B-v0.1",
- ),
- "devstral": (
- "unsloth/Devstral-Small-2505-unsloth-bnb-4bit",
- "unsloth/Devstral-Small-2505",
- "mistralai/Devstral-Small-2505",
- "unsloth/Devstral-Small-2505-bnb-4bit",
- "unsloth/Devstral-Small-2507-unsloth-bnb-4bit",
- "unsloth/Devstral-Small-2507",
- "mistralai/Devstral-Small-2507",
- "unsloth/Devstral-Small-2507-bnb-4bit",
- ),
- "magistral": (
- "unsloth/Magistral-Small-2506-unsloth-bnb-4bit",
- "unsloth/Magistral-Small-2506",
- "mistralai/Magistral-Small-2506",
- "unsloth/Magistral-Small-2506-bnb-4bit",
- "unsloth/Magistral-Small-2507-unsloth-bnb-4bit",
- "unsloth/Magistral-Small-2507",
- "mistralai/Magistral-Small-2507",
- "unsloth/Magistral-Small-2507-bnb-4bit",
- "unsloth/Magistral-Small-2509-unsloth-bnb-4bit",
- "unsloth/Magistral-Small-2509",
- "mistralai/Magistral-Small-2509",
- "unsloth/Magistral-Small-2509-bnb-4bit",
- ),
- "tinyllama": (
- "unsloth/tinyllama-chat-bnb-4bit",
- "unsloth/tinyllama-chat",
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
- ),
- "llama": (
- "unsloth/llama-2-7b-bnb-4bit",
- "unsloth/llama-2-7b",
- "meta-llama/Llama-2-7b-hf",
- "unsloth/llama-2-13b-bnb-4bit",
- "unsloth/llama-2-13b",
- "meta-llama/Llama-2-13b-hf",
- "unsloth/llama-2-7b-chat-bnb-4bit",
- "unsloth/llama-2-7b-chat",
- "meta-llama/Llama-2-7b-chat-hf",
- ),
- "llama3": (
- "unsloth/llama-3-8b-Instruct-bnb-4bit",
- "unsloth/llama-3-8b-Instruct",
- "meta-llama/Meta-Llama-3-8B-Instruct",
- "unsloth/llama-3-70b-Instruct-bnb-4bit",
- "meta-llama/Meta-Llama-3-70B-Instruct",
- ),
- "llama-3.1": (
- "unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
- "unsloth/Meta-Llama-3.1-8B-Instruct",
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
- "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
- "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit",
- "unsloth/Llama-3.1-8B-Instruct",
- "meta-llama/Llama-3.1-8B-Instruct",
- "unsloth/Llama-3.1-8B-Instruct-bnb-4bit",
- "unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit",
- "meta-llama/Meta-Llama-3.1-405B-Instruct",
- "unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit",
- "unsloth/Meta-Llama-3.1-70B-Instruct",
- "meta-llama/Meta-Llama-3.1-70B-Instruct",
- "unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit",
- "unsloth/Hermes-3-Llama-3.1-8B",
- "NousResearch/Hermes-3-Llama-3.1-8B",
- "unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit",
- "unsloth/Hermes-3-Llama-3.1-70B",
- "NousResearch/Hermes-3-Llama-3.1-70B",
- "unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit",
- "NousResearch/Hermes-3-Llama-3.1-405B",
- "unsloth/Llama-3.1-Tulu-3-8B-bnb-4bit",
- "unsloth/Llama-3.1-Tulu-3-8B",
- "allenai/Llama-3.1-Tulu-3-8B",
- "unsloth/Llama-3.1-Tulu-3-70B-bnb-4bit",
- "unsloth/Llama-3.1-Tulu-3-70B",
- "allenai/Llama-3.1-Tulu-3-70B",
- ),
- "llama-31-storm": (
- "unsloth/Llama-3.1-Storm-8B-bnb-4bit",
- "unsloth/Llama-3.1-Storm-8B",
- "akjindal53244/Llama-3.1-Storm-8B",
- ),
- "llama-31-nemotron": (
- "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit",
- "unsloth/Llama-3.1-Nemotron-70B-Instruct",
- "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
- ),
- "llama-3.2": (
- "unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit",
- "unsloth/Llama-3.2-1B-Instruct",
- "meta-llama/Llama-3.2-1B-Instruct",
- "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
- "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit",
- "unsloth/Llama-3.2-3B-Instruct",
- "meta-llama/Llama-3.2-3B-Instruct",
- "unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
- ),
- "llama-32-vision": (
- "unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
- "unsloth/Llama-3.2-11B-Vision-Instruct",
- "meta-llama/Llama-3.2-11B-Vision-Instruct",
- "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
- "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit",
- "unsloth/Llama-3.2-90B-Vision-Instruct",
- "meta-llama/Llama-3.2-90B-Vision-Instruct",
- ),
- "llama-3.3": (
- "unsloth/Llama-3.3-70B-Instruct-bnb-4bit",
- "unsloth/Llama-3.3-70B-Instruct",
- "meta-llama/Llama-3.3-70B-Instruct",
- ),
- "gemma": (
- "unsloth/gemma-7b-it-bnb-4bit",
- "unsloth/gemma-7b-it",
- "google/gemma-7b-it",
- "google/gemma-2b-it",
- "unsloth/gemma-1.1-2b-it-bnb-4bit",
- "unsloth/gemma-1.1-2b-it",
- "google/gemma-1.1-2b-it",
- "unsloth/gemma-1.1-7b-it-bnb-4bit",
- "unsloth/gemma-1.1-7b-it",
- "google/gemma-1.1-7b-it",
- ),
- "gemma2": (
- "unsloth/gemma-2-9b-it-bnb-4bit",
- "unsloth/gemma-2-9b-it",
- "google/gemma-2-9b-it",
- "unsloth/gemma-2-27b-it-bnb-4bit",
- "unsloth/gemma-2-27b-it",
- "google/gemma-2-27b-it",
- "unsloth/gemma-2-2b-it-bnb-4bit",
- "unsloth/gemma-2-2b-it",
- "google/gemma-2-2b-it",
- ),
- "gemma-3": (
- "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
- "unsloth/gemma-3-1b-it",
- "google/gemma-3-1b-it",
- "unsloth/gemma-3-1b-it-bnb-4bit",
- "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
- "unsloth/gemma-3-4b-it",
- "google/gemma-3-4b-it",
- "unsloth/gemma-3-4b-it-bnb-4bit",
- "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
- "unsloth/gemma-3-12b-it",
- "google/gemma-3-12b-it",
- "unsloth/gemma-3-12b-it-bnb-4bit",
- "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
- "unsloth/gemma-3-27b-it",
- "google/gemma-3-27b-it",
- "unsloth/gemma-3-27b-it-bnb-4bit",
- "unsloth/medgemma-4b-it-unsloth-bnb-4bit",
- "unsloth/medgemma-4b-it",
- "google/medgemma-4b-it",
- "unsloth/medgemma-4b-it-bnb-4bit",
- "unsloth/medgemma-27b-text-it-unsloth-bnb-4bit",
- "unsloth/medgemma-27b-text-it",
- "google/medgemma-27b-text-it",
- "unsloth/medgemma-27b-text-it-bnb-4bit",
- ),
- "gemma3n": (
- "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
- "unsloth/gemma-3n-E4B-it",
- "google/gemma-3n-E4B-it",
- "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
- "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
- "unsloth/gemma-3n-E2B-it",
- "google/gemma-3n-E2B-it",
- "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
- ),
- "gemma3-270m": (
- "unsloth/gemma-3-270m-it-unsloth-bnb-4bit",
- "unsloth/gemma-3-270m-it",
- "google/gemma-3-270m-it",
- "unsloth/gemma-3-270m-it-bnb-4bit",
- ),
- "qwen-25": (
- "unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit",
- "unsloth/Qwen2.5-0.5B-Instruct",
- "Qwen/Qwen2.5-0.5B-Instruct",
- "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit",
- "unsloth/Qwen2.5-1.5B-Instruct",
- "Qwen/Qwen2.5-1.5B-Instruct",
- "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit",
- "unsloth/Qwen2.5-3B-Instruct",
- "Qwen/Qwen2.5-3B-Instruct",
- "unsloth/Qwen2.5-3B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-7B-Instruct-unsloth-bnb-4bit",
- "unsloth/Qwen2.5-7B-Instruct",
- "Qwen/Qwen2.5-7B-Instruct",
- "unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-14B-Instruct-unsloth-bnb-4bit",
- "unsloth/Qwen2.5-14B-Instruct",
- "Qwen/Qwen2.5-14B-Instruct",
- "unsloth/Qwen2.5-14B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-32B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-32B-Instruct",
- "Qwen/Qwen2.5-32B-Instruct",
- "unsloth/Qwen2.5-72B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-72B-Instruct",
- "Qwen/Qwen2.5-72B-Instruct",
- "unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-Math-1.5B-Instruct",
- "Qwen/Qwen2.5-Math-1.5B-Instruct",
- "unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-Math-7B-Instruct",
- "Qwen/Qwen2.5-Math-7B-Instruct",
- "unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-Math-72B-Instruct",
- "Qwen/Qwen2.5-Math-72B-Instruct",
- ),
- "qwen-25-coder": (
- "unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-Coder-0.5B-Instruct",
- "Qwen/Qwen2.5-Coder-0.5B-Instruct",
- "unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-Coder-1.5B-Instruct",
- "Qwen/Qwen2.5-Coder-1.5B-Instruct",
- "unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-Coder-3B-Instruct",
- "Qwen/Qwen2.5-Coder-3B-Instruct",
- "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-Coder-7B-Instruct",
- "Qwen/Qwen2.5-Coder-7B-Instruct",
- "unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-Coder-14B-Instruct",
- "Qwen/Qwen2.5-Coder-14B-Instruct",
- "unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-Coder-32B-Instruct",
- "Qwen/Qwen2.5-Coder-32B-Instruct",
- ),
- "qwen-25-vl": (
- "unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit",
- "unsloth/Qwen2.5-VL-3B-Instruct",
- "Qwen/Qwen2.5-VL-3B-Instruct",
- "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit",
- "unsloth/Qwen2.5-VL-7B-Instruct",
- "Qwen/Qwen2.5-VL-7B-Instruct",
- "unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-VL-32B-Instruct-unsloth-bnb-4bit",
- "unsloth/Qwen2.5-VL-32B-Instruct",
- "Qwen/Qwen2.5-VL-32B-Instruct",
- "unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit",
- "unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit",
- "unsloth/Qwen2.5-VL-72B-Instruct",
- "Qwen/Qwen2.5-VL-72B-Instruct",
- "unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit",
- ),
- "openthinker": (
- "unsloth/OpenThinker-7B-unsloth-bnb-4bit",
- "unsloth/OpenThinker-7B",
- "open-thoughts/OpenThinker-7B",
- "unsloth/OpenThinker-7B-bnb-4bit",
- ),
- "qwen-2": (
- "unsloth/Qwen2-0.5B-Instruct-bnb-4bit",
- "unsloth/Qwen2-0.5B-Instruct",
- "Qwen/Qwen2-0.5B-Instruct",
- "unsloth/Qwen2-1.5B-Instruct-bnb-4bit",
- "unsloth/Qwen2-1.5B-Instruct",
- "Qwen/Qwen2-1.5B-Instruct",
- "unsloth/Qwen2-7B-Instruct-bnb-4bit",
- "unsloth/Qwen2-7B-Instruct",
- "Qwen/Qwen2-7B-Instruct",
- "unsloth/Qwen2-70B-Instruct-bnb-4bit",
- "Qwen/Qwen2-70B-Instruct",
- ),
- "qwen3": (
- "unsloth/Qwen3-0.6B-unsloth-bnb-4bit",
- "unsloth/Qwen3-0.6B",
- "Qwen/Qwen3-0.6B",
- "unsloth/Qwen3-0.6B-bnb-4bit",
- "unsloth/Qwen3-1.7B-unsloth-bnb-4bit",
- "unsloth/Qwen3-1.7B",
- "Qwen/Qwen3-1.7B",
- "unsloth/Qwen3-1.7B-bnb-4bit",
- "unsloth/Qwen3-4B-unsloth-bnb-4bit",
- "unsloth/Qwen3-4B",
- "Qwen/Qwen3-4B",
- "unsloth/Qwen3-4B-bnb-4bit",
- "unsloth/Qwen3-8B-unsloth-bnb-4bit",
- "unsloth/Qwen3-8B",
- "Qwen/Qwen3-8B",
- "unsloth/Qwen3-8B-bnb-4bit",
- "unsloth/Qwen3-14B-unsloth-bnb-4bit",
- "unsloth/Qwen3-14B",
- "Qwen/Qwen3-14B",
- "unsloth/Qwen3-14B-bnb-4bit",
- "unsloth/Qwen3-32B-unsloth-bnb-4bit",
- "unsloth/Qwen3-32B",
- "Qwen/Qwen3-32B",
- "unsloth/Qwen3-32B-bnb-4bit",
- "unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit",
- "unsloth/Qwen3-30B-A3B",
- "Qwen/Qwen3-30B-A3B",
- "unsloth/Qwen3-30B-A3B-bnb-4bit",
- ),
- "qwen3-instruct": (
- "unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit",
- "unsloth/Qwen3-4B-Instruct-2507",
- "Qwen/Qwen3-4B-Instruct-2507",
- "unsloth/Qwen3-4B-Instruct-2507-bnb-4bit",
- "unsloth/Qwen3-30B-A3B-Instruct-2507",
- "Qwen/Qwen3-30B-A3B-Instruct-2507",
- "unsloth/Qwen3-Coder-30B-A3B-Instruct",
- "Qwen/Qwen3-Coder-30B-A3B-Instruct",
- "unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit",
- "unsloth/Qwen3-4B-Instruct-2507",
- "Qwen/Qwen3-4B-Instruct-2507",
- "unsloth/Qwen3-4B-Instruct-2507-bnb-4bit",
- ),
- "qwen3-thinking": (
- "unsloth/QwQ-32B-Preview-bnb-4bit",
- "unsloth/QwQ-32B-Preview",
- "Qwen/QwQ-32B-Preview",
- "unsloth/QwQ-32B-unsloth-bnb-4bit",
- "unsloth/QwQ-32B",
- "Qwen/QwQ-32B",
- "unsloth/QwQ-32B-bnb-4bit",
- "unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit",
- "unsloth/Qwen3-4B-Thinking-2507",
- "Qwen/Qwen3-4B-Thinking-2507",
- "unsloth/Qwen3-4B-Thinking-2507-bnb-4bit",
- "unsloth/Qwen3-30B-A3B-Thinking-2507",
- "Qwen/Qwen3-30B-A3B-Thinking-2507",
- ),
- "zephyr": (
- "unsloth/zephyr-sft-bnb-4bit",
- "unsloth/zephyr-sft",
- "HuggingFaceH4/mistral-7b-sft-beta",
- ),
- "chatml": (
- "unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit",
- "unsloth/Hermes-2-Pro-Mistral-7B",
- "NousResearch/Hermes-2-Pro-Mistral-7B",
- "unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit",
- "unsloth/OpenHermes-2.5-Mistral-7B",
- "teknium/OpenHermes-2.5-Mistral-7B",
- ),
- "gpt-oss": (
- "unsloth/gpt-oss-20b-unsloth-bnb-4bit",
- "unsloth/gpt-oss-20b",
- "openai/gpt-oss-20b",
- "unsloth/gpt-oss-20b-unsloth-bnb-4bit",
- "unsloth/gpt-oss-120b-unsloth-bnb-4bit",
- "unsloth/gpt-oss-120b",
- "openai/gpt-oss-120b",
- "unsloth/gpt-oss-120b-unsloth-bnb-4bit",
- ),
- "starling": (
- "unsloth/Starling-LM-7B-beta-bnb-4bit",
- "unsloth/Starling-LM-7B-beta",
- "Nexusflow/Starling-LM-7B-beta",
- ),
- "yi-chat": (
- "unsloth/yi-34b-chat-bnb-4bit",
- "01-ai/Yi-6B-Chat",
- "01-ai/Yi-34B-Chat",
- ),
- "granite-32": (
- "unsloth/granite-3.2-2b-instruct-unsloth-bnb-4bit",
- "unsloth/granite-3.2-2b-instruct",
- "ibm-granite/granite-3.2-2b-instruct",
- "unsloth/granite-3.2-2b-instruct-bnb-4bit",
- "unsloth/granite-3.2-8b-instruct-unsloth-bnb-4bit",
- "unsloth/granite-3.2-8b-instruct",
- "ibm-granite/granite-3.2-8b-instruct",
- "unsloth/granite-3.2-8b-instruct-bnb-4bit",
- ),
- "granite-32-vision": (
- "unsloth/granite-vision-3.2-2b-unsloth-bnb-4bit",
- "unsloth/granite-vision-3.2-2b",
- "ibm-granite/granite-vision-3.2-2b",
- "unsloth/granite-vision-3.2-2b-bnb-4bit",
- ),
-}
-
-MODEL_TO_OLLAMA_TEMPLATE_MAPPER = {}
-
-for key, values in OLLAMA_TEMPLATE_TO_MODEL_MAPPER.items():
- for value in values:
- MODEL_TO_OLLAMA_TEMPLATE_MAPPER[value] = key
-
- # Get lowercased
- lowered_key = key.lower()
- for value in values:
- MODEL_TO_OLLAMA_TEMPLATE_MAPPER[value.lower()] = lowered_key
diff --git a/unsloth/registry/REGISTRY.md b/unsloth/registry/REGISTRY.md
deleted file mode 100644
index a0b3d96cad..0000000000
--- a/unsloth/registry/REGISTRY.md
+++ /dev/null
@@ -1,110 +0,0 @@
-## Model Registry
-
-### Structure
-```
-unsloth
- -registry
- __init__.py
- registry.py
- _llama.py
- _mistral.py
- _phi.py
- ...
-```
-
-Each model is registered in a separate file within the `registry` module (e.g. `registry/_llama.py`).
-
-Within each model registration file, a high-level `ModelMeta` is created for each model version, with the following structure:
-```python
-@dataclass
-class ModelMeta:
- org: str
- base_name: str
- model_version: str
- model_info_cls: type[ModelInfo]
- model_sizes: list[str] = field(default_factory=list)
- instruct_tags: list[str] = field(default_factory=list)
- quant_types: list[QuantType] | dict[str, list[QuantType]] = field(default_factory=list)
- is_multimodal: bool = False
-```
-
-Each model then instantiates a global `ModelMeta` for its specific model version, defining how the model path (e.g. `unsloth/Llama-3.1-8B-Instruct`) is constructed since each model type has a different naming convention.
-```python
-LlamaMeta_3_1 = ModelMeta(
- org="meta-llama",
- base_name="Llama",
- instruct_tags=[None, "Instruct"],
- model_version="3.1",
- model_sizes=["8"],
- model_info_cls=LlamaModelInfo,
- is_multimodal=False,
- quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
-)
-```
-
-`LlamaModelInfo` is a subclass of `ModelInfo` that defines the model path for each model size and quant type.
-```python
-class LlamaModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- key = f"{base_name}-{version}-{size}B"
- return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key)
-```
-
-Once these constructs are defined, the model is registered by writing a register_xx_models function.
-```python
-def register_llama_3_1_models(include_original_model: bool = False):
- global _IS_LLAMA_3_1_REGISTERED
- if _IS_LLAMA_3_1_REGISTERED:
- return
- _register_models(LlamaMeta_3_1, include_original_model=include_original_model)
- _IS_LLAMA_3_1_REGISTERED = True
-```
-
-`_register_models` is a helper function that registers the model with the registry. The global `_IS_XX_REGISTERED` is used to prevent duplicate registration.
-
-Once a model is registered, registry.registry.MODEL_REGISTRY is updated with the model info and can be searched with `registry.search_models`.
-
-### Tests
-
-The `tests/test_model_registry.py` file contains tests for the model registry.
-
-Also, each model registration file is an executable module that checks that all registered models are available on `huggingface_hub`.
-```python
-python unsloth.registry._llama.py
-```
-
-Prints the following (abridged) output:
-```bash
-✓ unsloth/Llama-3.1-8B
-✓ unsloth/Llama-3.1-8B-bnb-4bit
-✓ unsloth/Llama-3.1-8B-unsloth-bnb-4bit
-✓ meta-llama/Llama-3.1-8B
-✓ unsloth/Llama-3.1-8B-Instruct
-✓ unsloth/Llama-3.1-8B-Instruct-bnb-4bit
-✓ unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit
-✓ meta-llama/Llama-3.1-8B-Instruct
-✓ unsloth/Llama-3.2-1B
-✓ unsloth/Llama-3.2-1B-bnb-4bit
-✓ unsloth/Llama-3.2-1B-unsloth-bnb-4bit
-✓ meta-llama/Llama-3.2-1B
-...
-```
-
-### TODO
-- Model Collections
- - [x] Gemma3
- - [ ] Llama3.1
- - [x] Llama3.2
- - [x] MistralSmall
- - [x] Qwen2.5
- - [x] Qwen2.5-VL
- - [ ] Qwen2.5 Coder
- - [x] QwenQwQ-32B
- - [x] Deepseek v3
- - [x] Deepseek R1
- - [x] Phi-4
- - [ ] Unsloth 4-bit Dynamic Quants
- - [ ] Vision/multimodal models
-- Sync model uploads with registry
-- Add utility methods for tracking model stats
\ No newline at end of file
diff --git a/unsloth/registry/__init__.py b/unsloth/registry/__init__.py
deleted file mode 100644
index 52e6f52431..0000000000
--- a/unsloth/registry/__init__.py
+++ /dev/null
@@ -1,78 +0,0 @@
-from ._deepseek import register_deepseek_models as _register_deepseek_models
-from ._gemma import register_gemma_models as _register_gemma_models
-from ._llama import register_llama_models as _register_llama_models
-from ._mistral import register_mistral_models as _register_mistral_models
-from ._phi import register_phi_models as _register_phi_models
-from ._qwen import register_qwen_models as _register_qwen_models
-from .registry import MODEL_REGISTRY, ModelInfo, QuantType
-
-_ARE_MODELS_REGISTERED = False
-
-
-def register_models():
- global _ARE_MODELS_REGISTERED
-
- if _ARE_MODELS_REGISTERED:
- return
- _register_deepseek_models()
- _register_gemma_models()
- _register_llama_models()
- _register_mistral_models()
- _register_phi_models()
- _register_qwen_models()
-
- _ARE_MODELS_REGISTERED = True
-
-
-def search_models(
- org: str = None,
- base_name: str = None,
- version: str = None,
- size: str = None,
- quant_types: list[QuantType] = None,
- search_pattern: str = None,
-) -> list[ModelInfo]:
- """
- Get model info from the registry.
-
- See registry.ModelInfo for more fields.
-
- If search_pattern is provided, the full model path will be matched against the pattern, where the model path is the model_id on huggingface hub.
-
- """
- if not _ARE_MODELS_REGISTERED:
- register_models()
-
- model_infos = MODEL_REGISTRY.values()
- if org:
- model_infos = [
- model_info for model_info in model_infos if model_info.org == org
- ]
- if base_name:
- model_infos = [
- model_info
- for model_info in model_infos
- if model_info.base_name == base_name
- ]
- if version:
- model_infos = [
- model_info for model_info in model_infos if model_info.version == version
- ]
- if size:
- model_infos = [
- model_info for model_info in model_infos if model_info.size == size
- ]
- if quant_types:
- model_infos = [
- model_info
- for model_info in model_infos
- if any(model_info.quant_type == quant_type for quant_type in quant_types)
- ]
- if search_pattern:
- model_infos = [
- model_info
- for model_info in model_infos
- if search_pattern in model_info.model_path
- ]
-
- return model_infos
diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py
deleted file mode 100644
index 4bbc852cd7..0000000000
--- a/unsloth/registry/_deepseek.py
+++ /dev/null
@@ -1,206 +0,0 @@
-from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models
-
-_IS_DEEPSEEK_V3_REGISTERED = False
-_IS_DEEPSEEK_V3_0324_REGISTERED = False
-_IS_DEEPSEEK_R1_REGISTERED = False
-_IS_DEEPSEEK_R1_ZERO_REGISTERED = False
-_IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = False
-_IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = False
-
-
-class DeepseekV3ModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- key = f"{base_name}-V{version}"
- return super().construct_model_name(
- base_name, version, size, quant_type, instruct_tag, key
- )
-
-
-class DeepseekR1ModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- key = f"{base_name}-{version}" if version else base_name
- if size:
- key = f"{key}-{size}B"
- return super().construct_model_name(
- base_name, version, size, quant_type, instruct_tag, key
- )
-
-
-# Deepseek V3 Model Meta
-DeepseekV3Meta = ModelMeta(
- org = "deepseek-ai",
- base_name = "DeepSeek",
- instruct_tags = [None],
- model_version = "3",
- model_sizes = [""],
- model_info_cls = DeepseekV3ModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.BF16],
-)
-
-DeepseekV3_0324Meta = ModelMeta(
- org = "deepseek-ai",
- base_name = "DeepSeek",
- instruct_tags = [None],
- model_version = "3-0324",
- model_sizes = [""],
- model_info_cls = DeepseekV3ModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.GGUF],
-)
-
-DeepseekR1Meta = ModelMeta(
- org = "deepseek-ai",
- base_name = "DeepSeek-R1",
- instruct_tags = [None],
- model_version = "",
- model_sizes = [""],
- model_info_cls = DeepseekR1ModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.BF16, QuantType.GGUF],
-)
-
-DeepseekR1ZeroMeta = ModelMeta(
- org = "deepseek-ai",
- base_name = "DeepSeek-R1",
- instruct_tags = [None],
- model_version = "Zero",
- model_sizes = [""],
- model_info_cls = DeepseekR1ModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.GGUF],
-)
-
-DeepseekR1DistillLlamaMeta = ModelMeta(
- org = "deepseek-ai",
- base_name = "DeepSeek-R1-Distill",
- instruct_tags = [None],
- model_version = "Llama",
- model_sizes = ["8", "70"],
- model_info_cls = DeepseekR1ModelInfo,
- is_multimodal = False,
- quant_types = {"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]},
-)
-
-# Deepseek R1 Distill Qwen Model Meta
-DeepseekR1DistillQwenMeta = ModelMeta(
- org = "deepseek-ai",
- base_name = "DeepSeek-R1-Distill",
- instruct_tags = [None],
- model_version = "Qwen",
- model_sizes = ["1.5", "7", "14", "32"],
- model_info_cls = DeepseekR1ModelInfo,
- is_multimodal = False,
- quant_types = {
- "1.5": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],
- "7": [QuantType.UNSLOTH, QuantType.BNB],
- "14": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF],
- "32": [QuantType.GGUF, QuantType.BNB],
- },
-)
-
-
-def register_deepseek_v3_models(include_original_model: bool = False):
- global _IS_DEEPSEEK_V3_REGISTERED
- if _IS_DEEPSEEK_V3_REGISTERED:
- return
- _register_models(DeepseekV3Meta, include_original_model = include_original_model)
- _IS_DEEPSEEK_V3_REGISTERED = True
-
-
-def register_deepseek_v3_0324_models(include_original_model: bool = False):
- global _IS_DEEPSEEK_V3_0324_REGISTERED
- if _IS_DEEPSEEK_V3_0324_REGISTERED:
- return
- _register_models(DeepseekV3_0324Meta, include_original_model = include_original_model)
- _IS_DEEPSEEK_V3_0324_REGISTERED = True
-
-
-def register_deepseek_r1_models(include_original_model: bool = False):
- global _IS_DEEPSEEK_R1_REGISTERED
- if _IS_DEEPSEEK_R1_REGISTERED:
- return
- _register_models(DeepseekR1Meta, include_original_model = include_original_model)
- _IS_DEEPSEEK_R1_REGISTERED = True
-
-
-def register_deepseek_r1_zero_models(include_original_model: bool = False):
- global _IS_DEEPSEEK_R1_ZERO_REGISTERED
- if _IS_DEEPSEEK_R1_ZERO_REGISTERED:
- return
- _register_models(DeepseekR1ZeroMeta, include_original_model = include_original_model)
- _IS_DEEPSEEK_R1_ZERO_REGISTERED = True
-
-
-def register_deepseek_r1_distill_llama_models(include_original_model: bool = False):
- global _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED
- if _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED:
- return
- _register_models(
- DeepseekR1DistillLlamaMeta, include_original_model = include_original_model
- )
- _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = True
-
-
-def register_deepseek_r1_distill_qwen_models(include_original_model: bool = False):
- global _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED
- if _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED:
- return
- _register_models(
- DeepseekR1DistillQwenMeta, include_original_model = include_original_model
- )
- _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = True
-
-
-def register_deepseek_models(include_original_model: bool = False):
- register_deepseek_v3_models(include_original_model = include_original_model)
- register_deepseek_v3_0324_models(include_original_model = include_original_model)
- register_deepseek_r1_models(include_original_model = include_original_model)
- register_deepseek_r1_zero_models(include_original_model = include_original_model)
- register_deepseek_r1_distill_llama_models(
- include_original_model = include_original_model
- )
- register_deepseek_r1_distill_qwen_models(
- include_original_model = include_original_model
- )
-
-
-def _list_deepseek_r1_distill_models():
- from unsloth.utils.hf_hub import ModelInfo as HfModelInfo
- from unsloth.utils.hf_hub import list_models
-
- models: list[HfModelInfo] = list_models(
- author = "unsloth", search = "Distill", limit = 1000
- )
- distill_models = []
- for model in models:
- model_id = model.id
- model_name = model_id.split("/")[-1]
- # parse out only the version
- version = model_name.removeprefix("DeepSeek-R1-Distill-")
- distill_models.append(version)
-
- return distill_models
-
-
-register_deepseek_models(include_original_model = True)
-
-if __name__ == "__main__":
- from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
-
- MODEL_REGISTRY.clear()
-
- register_deepseek_models(include_original_model = True)
-
- for model_id, model_info in MODEL_REGISTRY.items():
- model_info = _check_model_info(model_id)
- if model_info is None:
- print(f"\u2718 {model_id}")
- else:
- print(f"\u2713 {model_id}")
- # distill_models = _list_deepseek_r1_distill_models()
- # for model in sorted(distill_models):
- # if "qwen" in model.lower():
- # print(model)
diff --git a/unsloth/registry/_gemma.py b/unsloth/registry/_gemma.py
deleted file mode 100644
index c338128bc6..0000000000
--- a/unsloth/registry/_gemma.py
+++ /dev/null
@@ -1,74 +0,0 @@
-from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models
-
-_IS_GEMMA_3_BASE_REGISTERED = False
-_IS_GEMMA_3_INSTRUCT_REGISTERED = False
-
-
-class GemmaModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- key = f"{base_name}-{version}-{size}B"
- return super().construct_model_name(
- base_name, version, size, quant_type, instruct_tag, key
- )
-
-
-# Gemma3 Base Model Meta
-GemmaMeta3Base = ModelMeta(
- org = "google",
- base_name = "gemma",
- instruct_tags = ["pt"], # pt = base
- model_version = "3",
- model_sizes = ["1", "4", "12", "27"],
- model_info_cls = GemmaModelInfo,
- is_multimodal = True,
- quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
-)
-
-# Gemma3 Instruct Model Meta
-GemmaMeta3Instruct = ModelMeta(
- org = "google",
- base_name = "gemma",
- instruct_tags = ["it"], # it = instruction tuned
- model_version = "3",
- model_sizes = ["1", "4", "12", "27"],
- model_info_cls = GemmaModelInfo,
- is_multimodal = True,
- quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
-)
-
-
-def register_gemma_3_base_models(include_original_model: bool = False):
- global _IS_GEMMA_3_BASE_REGISTERED
- if _IS_GEMMA_3_BASE_REGISTERED:
- return
- _register_models(GemmaMeta3Base, include_original_model = include_original_model)
- _IS_GEMMA_3_BASE_REGISTERED = True
-
-
-def register_gemma_3_instruct_models(include_original_model: bool = False):
- global _IS_GEMMA_3_INSTRUCT_REGISTERED
- if _IS_GEMMA_3_INSTRUCT_REGISTERED:
- return
- _register_models(GemmaMeta3Instruct, include_original_model = include_original_model)
- _IS_GEMMA_3_INSTRUCT_REGISTERED = True
-
-
-def register_gemma_models(include_original_model: bool = False):
- register_gemma_3_base_models(include_original_model = include_original_model)
- register_gemma_3_instruct_models(include_original_model = include_original_model)
-
-
-if __name__ == "__main__":
- from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
-
- MODEL_REGISTRY.clear()
-
- register_gemma_models(include_original_model = True)
-
- for model_id, model_info in MODEL_REGISTRY.items():
- model_info = _check_model_info(model_id)
- if model_info is None:
- print(f"\u2718 {model_id}")
- else:
- print(f"\u2713 {model_id}")
diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py
deleted file mode 100644
index f5f82372bf..0000000000
--- a/unsloth/registry/_llama.py
+++ /dev/null
@@ -1,125 +0,0 @@
-from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models
-
-_IS_LLAMA_3_1_REGISTERED = False
-_IS_LLAMA_3_2_REGISTERED = False
-_IS_LLAMA_3_2_VISION_REGISTERED = False
-
-
-class LlamaModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- key = f"{base_name}-{version}-{size}B"
- return super().construct_model_name(
- base_name, version, size, quant_type, instruct_tag, key
- )
-
-
-class LlamaVisionModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- key = f"{base_name}-{version}-{size}B-Vision"
- return super().construct_model_name(
- base_name, version, size, quant_type, instruct_tag, key
- )
-
-
-# Llama 3.1
-LlamaMeta_3_1 = ModelMeta(
- org = "meta-llama",
- base_name = "Llama",
- instruct_tags = [None, "Instruct"],
- model_version = "3.1",
- model_sizes = ["8"],
- model_info_cls = LlamaModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
-)
-
-# Llama 3.2 Base Models
-LlamaMeta_3_2_Base = ModelMeta(
- org = "meta-llama",
- base_name = "Llama",
- instruct_tags = [None],
- model_version = "3.2",
- model_sizes = ["1", "3"],
- model_info_cls = LlamaModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
-)
-
-# Llama 3.2 Instruction Tuned Models
-LlamaMeta_3_2_Instruct = ModelMeta(
- org = "meta-llama",
- base_name = "Llama",
- instruct_tags = ["Instruct"],
- model_version = "3.2",
- model_sizes = ["1", "3"],
- model_info_cls = LlamaModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
-)
-
-# Llama 3.2 Vision
-LlamaMeta_3_2_Vision = ModelMeta(
- org = "meta-llama",
- base_name = "Llama",
- instruct_tags = [None, "Instruct"],
- model_version = "3.2",
- model_sizes = ["11", "90"],
- model_info_cls = LlamaVisionModelInfo,
- is_multimodal = True,
- quant_types = {
- "11": [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
- "90": [QuantType.NONE],
- },
-)
-
-
-def register_llama_3_1_models(include_original_model: bool = False):
- global _IS_LLAMA_3_1_REGISTERED
- if _IS_LLAMA_3_1_REGISTERED:
- return
- _register_models(LlamaMeta_3_1, include_original_model = include_original_model)
- _IS_LLAMA_3_1_REGISTERED = True
-
-
-def register_llama_3_2_models(include_original_model: bool = False):
- global _IS_LLAMA_3_2_REGISTERED
- if _IS_LLAMA_3_2_REGISTERED:
- return
- _register_models(LlamaMeta_3_2_Base, include_original_model = include_original_model)
- _register_models(
- LlamaMeta_3_2_Instruct, include_original_model = include_original_model
- )
- _IS_LLAMA_3_2_REGISTERED = True
-
-
-def register_llama_3_2_vision_models(include_original_model: bool = False):
- global _IS_LLAMA_3_2_VISION_REGISTERED
- if _IS_LLAMA_3_2_VISION_REGISTERED:
- return
- _register_models(
- LlamaMeta_3_2_Vision, include_original_model = include_original_model
- )
- _IS_LLAMA_3_2_VISION_REGISTERED = True
-
-
-def register_llama_models(include_original_model: bool = False):
- register_llama_3_1_models(include_original_model = include_original_model)
- register_llama_3_2_models(include_original_model = include_original_model)
- register_llama_3_2_vision_models(include_original_model = include_original_model)
-
-
-if __name__ == "__main__":
- from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
-
- MODEL_REGISTRY.clear()
-
- register_llama_models(include_original_model = True)
-
- for model_id, model_info in MODEL_REGISTRY.items():
- model_info = _check_model_info(model_id)
- if model_info is None:
- print(f"\u2718 {model_id}")
- else:
- print(f"\u2713 {model_id}")
diff --git a/unsloth/registry/_mistral.py b/unsloth/registry/_mistral.py
deleted file mode 100644
index 173d6cfdef..0000000000
--- a/unsloth/registry/_mistral.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import copy
-
-from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models
-
-_IS_MISTRAL_SMALL_REGISTERED = False
-
-_MISTRAL_SMALL_03_25_VERSION = "2503"
-_MISTRAL_SMALL_01_25_VERSION = "2501"
-_MISTRAL_SMALL_09_24_VERSION = "2409" # Not uploaded to unsloth
-
-
-class MistralSmallModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- if version == _MISTRAL_SMALL_03_25_VERSION:
- key = f"{base_name}-3.1-{size}B-{instruct_tag}"
- else:
- key = f"{base_name}-{size}B-{instruct_tag}"
- key += f"-{version}"
- key = cls.append_quant_type(key, quant_type)
-
- return key
-
-
-MistralSmall_2503_Base_Meta = ModelMeta(
- org = "mistralai",
- base_name = "Mistral-Small",
- instruct_tags = ["Base"],
- model_version = _MISTRAL_SMALL_03_25_VERSION,
- model_sizes = ["24"],
- model_info_cls = MistralSmallModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB],
-)
-
-MistralSmall_2503_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta)
-MistralSmall_2503_Instruct_Meta.instruct_tags = ["Instruct"]
-MistralSmall_2503_Instruct_Meta.quant_types = [
- QuantType.NONE,
- QuantType.UNSLOTH,
- QuantType.BNB,
- QuantType.GGUF,
-]
-
-MistralSmall_2501_Base_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta)
-MistralSmall_2501_Base_Meta.model_version = _MISTRAL_SMALL_01_25_VERSION
-
-MistralSmall_2501_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Instruct_Meta)
-MistralSmall_2501_Instruct_Meta.model_version = _MISTRAL_SMALL_01_25_VERSION
-
-
-def register_mistral_small_models(include_original_model: bool = False):
- global _IS_MISTRAL_SMALL_REGISTERED
- if _IS_MISTRAL_SMALL_REGISTERED:
- return
- _register_models(
- MistralSmall_2503_Base_Meta, include_original_model = include_original_model
- )
- _register_models(
- MistralSmall_2503_Instruct_Meta, include_original_model = include_original_model
- )
- _register_models(
- MistralSmall_2501_Base_Meta, include_original_model = include_original_model
- )
- _register_models(
- MistralSmall_2501_Instruct_Meta, include_original_model = include_original_model
- )
-
- _IS_MISTRAL_SMALL_REGISTERED = True
-
-
-def register_mistral_models(include_original_model: bool = False):
- register_mistral_small_models(include_original_model = include_original_model)
-
-
-if __name__ == "__main__":
- from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
-
- MODEL_REGISTRY.clear()
-
- register_mistral_models(include_original_model = True)
-
- for model_id, model_info in MODEL_REGISTRY.items():
- model_info = _check_model_info(model_id)
- if model_info is None:
- print(f"\u2718 {model_id}")
- else:
- print(f"\u2713 {model_id}")
diff --git a/unsloth/registry/_phi.py b/unsloth/registry/_phi.py
deleted file mode 100644
index a6f773c48e..0000000000
--- a/unsloth/registry/_phi.py
+++ /dev/null
@@ -1,74 +0,0 @@
-from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models
-
-_IS_PHI_4_REGISTERED = False
-_IS_PHI_4_INSTRUCT_REGISTERED = False
-
-
-class PhiModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- key = f"{base_name}-{version}"
- return super().construct_model_name(
- base_name, version, size, quant_type, instruct_tag, key
- )
-
-
-# Phi Model Meta
-PhiMeta4 = ModelMeta(
- org = "microsoft",
- base_name = "phi",
- instruct_tags = [None],
- model_version = "4",
- model_sizes = ["1"], # Assuming only one size
- model_info_cls = PhiModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
-)
-
-# Phi Instruct Model Meta
-PhiInstructMeta4 = ModelMeta(
- org = "microsoft",
- base_name = "phi",
- instruct_tags = ["mini-instruct"],
- model_version = "4",
- model_sizes = ["1"], # Assuming only one size
- model_info_cls = PhiModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
-)
-
-
-def register_phi_4_models(include_original_model: bool = False):
- global _IS_PHI_4_REGISTERED
- if _IS_PHI_4_REGISTERED:
- return
- _register_models(PhiMeta4, include_original_model = include_original_model)
- _IS_PHI_4_REGISTERED = True
-
-
-def register_phi_4_instruct_models(include_original_model: bool = False):
- global _IS_PHI_4_INSTRUCT_REGISTERED
- if _IS_PHI_4_INSTRUCT_REGISTERED:
- return
- _register_models(PhiInstructMeta4, include_original_model = include_original_model)
- _IS_PHI_4_INSTRUCT_REGISTERED = True
-
-
-def register_phi_models(include_original_model: bool = False):
- register_phi_4_models(include_original_model = include_original_model)
- register_phi_4_instruct_models(include_original_model = include_original_model)
-
-
-if __name__ == "__main__":
- from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
-
- MODEL_REGISTRY.clear()
-
- register_phi_models(include_original_model = True)
-
- for model_id, model_info in MODEL_REGISTRY.items():
- model_info = _check_model_info(model_id)
- if model_info is None:
- print(f"\u2718 {model_id}")
- else:
- print(f"\u2713 {model_id}")
diff --git a/unsloth/registry/_qwen.py b/unsloth/registry/_qwen.py
deleted file mode 100644
index f852cb8762..0000000000
--- a/unsloth/registry/_qwen.py
+++ /dev/null
@@ -1,136 +0,0 @@
-from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models
-
-_IS_QWEN_2_5_REGISTERED = False
-_IS_QWEN_2_5_VL_REGISTERED = False
-_IS_QWEN_QWQ_REGISTERED = False
-
-
-class QwenModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- key = f"{base_name}{version}-{size}B"
- return super().construct_model_name(
- base_name, version, size, quant_type, instruct_tag, key
- )
-
-
-class QwenVLModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- key = f"{base_name}{version}-VL-{size}B"
- return super().construct_model_name(
- base_name, version, size, quant_type, instruct_tag, key
- )
-
-
-class QwenQwQModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- key = f"{base_name}-{size}B"
- return super().construct_model_name(
- base_name, version, size, quant_type, instruct_tag, key
- )
-
-
-class QwenQVQPreviewModelInfo(ModelInfo):
- @classmethod
- def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag):
- key = f"{base_name}-{size}B-Preview"
- return super().construct_model_name(
- base_name, version, size, quant_type, instruct_tag, key
- )
-
-
-# Qwen2.5 Model Meta
-Qwen_2_5_Meta = ModelMeta(
- org = "Qwen",
- base_name = "Qwen",
- instruct_tags = [None, "Instruct"],
- model_version = "2.5",
- model_sizes = ["3", "7"],
- model_info_cls = QwenModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
-)
-
-# Qwen2.5 VL Model Meta
-Qwen_2_5_VLMeta = ModelMeta(
- org = "Qwen",
- base_name = "Qwen",
- instruct_tags = ["Instruct"], # No base, only instruction tuned
- model_version = "2.5",
- model_sizes = ["3", "7", "32", "72"],
- model_info_cls = QwenVLModelInfo,
- is_multimodal = True,
- quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH],
-)
-
-# Qwen QwQ Model Meta
-QwenQwQMeta = ModelMeta(
- org = "Qwen",
- base_name = "QwQ",
- instruct_tags = [None],
- model_version = "",
- model_sizes = ["32"],
- model_info_cls = QwenQwQModelInfo,
- is_multimodal = False,
- quant_types = [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF],
-)
-
-# Qwen QVQ Preview Model Meta
-QwenQVQPreviewMeta = ModelMeta(
- org = "Qwen",
- base_name = "QVQ",
- instruct_tags = [None],
- model_version = "",
- model_sizes = ["72"],
- model_info_cls = QwenQVQPreviewModelInfo,
- is_multimodal = True,
- quant_types = [QuantType.NONE, QuantType.BNB],
-)
-
-
-def register_qwen_2_5_models(include_original_model: bool = False):
- global _IS_QWEN_2_5_REGISTERED
- if _IS_QWEN_2_5_REGISTERED:
- return
- _register_models(Qwen_2_5_Meta, include_original_model = include_original_model)
- _IS_QWEN_2_5_REGISTERED = True
-
-
-def register_qwen_2_5_vl_models(include_original_model: bool = False):
- global _IS_QWEN_2_5_VL_REGISTERED
- if _IS_QWEN_2_5_VL_REGISTERED:
- return
- _register_models(Qwen_2_5_VLMeta, include_original_model = include_original_model)
- _IS_QWEN_2_5_VL_REGISTERED = True
-
-
-def register_qwen_qwq_models(include_original_model: bool = False):
- global _IS_QWEN_QWQ_REGISTERED
- if _IS_QWEN_QWQ_REGISTERED:
- return
- _register_models(QwenQwQMeta, include_original_model = include_original_model)
- _register_models(QwenQVQPreviewMeta, include_original_model = include_original_model)
- _IS_QWEN_QWQ_REGISTERED = True
-
-
-def register_qwen_models(include_original_model: bool = False):
- register_qwen_2_5_models(include_original_model = include_original_model)
- register_qwen_2_5_vl_models(include_original_model = include_original_model)
- register_qwen_qwq_models(include_original_model = include_original_model)
-
-
-if __name__ == "__main__":
- from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info
-
- MODEL_REGISTRY.clear()
-
- register_qwen_models(include_original_model = True)
-
- for model_id, model_info in MODEL_REGISTRY.items():
- model_info = _check_model_info(model_id)
- if model_info is None:
- print(f"\u2718 {model_id}")
- else:
- print(f"\u2713 {model_id}")
diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py
deleted file mode 100644
index 945301420a..0000000000
--- a/unsloth/registry/registry.py
+++ /dev/null
@@ -1,191 +0,0 @@
-import warnings
-from dataclasses import dataclass, field
-from enum import Enum
-
-
-class QuantType(Enum):
- BNB = "bnb"
- UNSLOTH = "unsloth" # dynamic 4-bit quantization
- GGUF = "GGUF"
- NONE = "none"
- BF16 = "bf16" # only for Deepseek V3
-
-
-# Tags for Hugging Face model paths
-BNB_QUANTIZED_TAG = "bnb-4bit"
-UNSLOTH_DYNAMIC_QUANT_TAG = "unsloth" + "-" + BNB_QUANTIZED_TAG
-GGUF_TAG = "GGUF"
-BF16_TAG = "bf16"
-
-QUANT_TAG_MAP = {
- QuantType.BNB: BNB_QUANTIZED_TAG,
- QuantType.UNSLOTH: UNSLOTH_DYNAMIC_QUANT_TAG,
- QuantType.GGUF: GGUF_TAG,
- QuantType.NONE: None,
- QuantType.BF16: BF16_TAG,
-}
-
-
-# NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH
-@dataclass
-class ModelInfo:
- org: str
- base_name: str
- version: str
- size: int
- name: str = None # full model name, constructed from base_name, version, and size unless provided
- is_multimodal: bool = False
- instruct_tag: str = None
- quant_type: QuantType = None
- description: str = None
-
- def __post_init__(self):
- self.name = self.name or self.construct_model_name(
- self.base_name,
- self.version,
- self.size,
- self.quant_type,
- self.instruct_tag,
- )
-
- @staticmethod
- def append_instruct_tag(key: str, instruct_tag: str = None):
- if instruct_tag:
- key = "-".join([key, instruct_tag])
- return key
-
- @staticmethod
- def append_quant_type(key: str, quant_type: QuantType = None):
- if quant_type != QuantType.NONE:
- key = "-".join([key, QUANT_TAG_MAP[quant_type]])
- return key
-
- @classmethod
- def construct_model_name(
- cls, base_name, version, size, quant_type, instruct_tag, key = ""
- ):
- key = cls.append_instruct_tag(key, instruct_tag)
- key = cls.append_quant_type(key, quant_type)
- return key
-
- @property
- def model_path(
- self,
- ) -> str:
- return f"{self.org}/{self.name}"
-
-
-@dataclass
-class ModelMeta:
- org: str
- base_name: str
- model_version: str
- model_info_cls: type[ModelInfo]
- model_sizes: list[str] = field(default_factory = list)
- instruct_tags: list[str] = field(default_factory = list)
- quant_types: list[QuantType] | dict[str, list[QuantType]] = field(
- default_factory = list
- )
- is_multimodal: bool = False
-
-
-MODEL_REGISTRY: dict[str, ModelInfo] = {}
-
-
-def register_model(
- model_info_cls: ModelInfo,
- org: str,
- base_name: str,
- version: str,
- size: int,
- instruct_tag: str = None,
- quant_type: QuantType = None,
- is_multimodal: bool = False,
- name: str = None,
-):
- name = name or model_info_cls.construct_model_name(
- base_name = base_name,
- version = version,
- size = size,
- quant_type = quant_type,
- instruct_tag = instruct_tag,
- )
- key = f"{org}/{name}"
-
- if key in MODEL_REGISTRY:
- raise ValueError(
- f"Model {key} already registered, current keys: {MODEL_REGISTRY.keys()}"
- )
-
- MODEL_REGISTRY[key] = model_info_cls(
- org = org,
- base_name = base_name,
- version = version,
- size = size,
- is_multimodal = is_multimodal,
- instruct_tag = instruct_tag,
- quant_type = quant_type,
- name = name,
- )
-
-
-def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]):
- from huggingface_hub import HfApi
- from huggingface_hub import ModelInfo as HfModelInfo
- from huggingface_hub.utils import RepositoryNotFoundError
-
- api = HfApi()
-
- try:
- model_info: HfModelInfo = api.model_info(model_id, expand = properties)
- except Exception as e:
- if isinstance(e, RepositoryNotFoundError):
- warnings.warn(f"{model_id} not found on Hugging Face")
- model_info = None
- else:
- raise e
- return model_info
-
-
-def _register_models(model_meta: ModelMeta, include_original_model: bool = False):
- org = model_meta.org
- base_name = model_meta.base_name
- instruct_tags = model_meta.instruct_tags
- model_version = model_meta.model_version
- model_sizes = model_meta.model_sizes
- is_multimodal = model_meta.is_multimodal
- quant_types = model_meta.quant_types
- model_info_cls = model_meta.model_info_cls
-
- for size in model_sizes:
- for instruct_tag in instruct_tags:
- # Handle quant types per model size
- if isinstance(quant_types, dict):
- _quant_types = quant_types[size]
- else:
- _quant_types = quant_types
- for quant_type in _quant_types:
- # NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH
- _org = "unsloth" # unsloth models -- these are all quantized versions of the original model
- register_model(
- model_info_cls = model_info_cls,
- org = _org,
- base_name = base_name,
- version = model_version,
- size = size,
- instruct_tag = instruct_tag,
- quant_type = quant_type,
- is_multimodal = is_multimodal,
- )
- # include original model from releasing organization
- if include_original_model:
- register_model(
- model_info_cls = model_info_cls,
- org = org,
- base_name = base_name,
- version = model_version,
- size = size,
- instruct_tag = instruct_tag,
- quant_type = QuantType.NONE,
- is_multimodal = is_multimodal,
- )
diff --git a/unsloth/save.py b/unsloth/save.py
index 6e38d1e952..c4ec69c7cf 100644
--- a/unsloth/save.py
+++ b/unsloth/save.py
@@ -13,26 +13,6 @@
# limitations under the License.
from unsloth_zoo.utils import Version
-from importlib.metadata import version as importlib_version
-from unsloth_zoo.hf_utils import dtype_from_config, HAS_TORCH_DTYPE
-from unsloth_zoo.llama_cpp import (
- convert_to_gguf,
- quantize_gguf,
- use_local_gguf,
- install_llama_cpp,
- check_llama_cpp,
- _download_convert_hf_to_gguf,
-)
-
-# H4: Defensive imports -- these were added in unsloth-zoo PR #526
-# and may not exist on older versions
-try:
- from unsloth_zoo.llama_cpp import LLAMA_CPP_DEFAULT_DIR, IS_WINDOWS
-except ImportError:
- import sys
-
- IS_WINDOWS = sys.platform == "win32"
- LLAMA_CPP_DEFAULT_DIR = "llama.cpp"
from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit
from peft.tuners.lora import Linear4bit as Peft_Linear4bit
from peft.tuners.lora import Linear as Peft_Linear
@@ -51,12 +31,7 @@
import re
from transformers.models.llama.modeling_llama import logger
from .tokenizer_utils import fix_sentencepiece_gguf
-from .models.loader_utils import get_model_name
-from .models._utils import _convert_torchao_model
-from .ollama_template_mappers import OLLAMA_TEMPLATES, MODEL_TO_OLLAMA_TEMPLATE_MAPPER
-from transformers import ProcessorMixin
from huggingface_hub import HfApi
-
try:
from huggingface_hub import get_token
except:
@@ -65,8 +40,9 @@
except:
# For older versions of huggingface_hub
from huggingface_hub.utils._token import get_token
+ pass
+pass
from pathlib import Path
-from peft import PeftModelForCausalLM, PeftModel
__all__ = [
"print_quantization_methods",
@@ -77,87 +53,66 @@
]
# llama.cpp specific targets - all takes 90s. Below takes 60s
-LLAMA_CPP_TARGETS = [
- "llama-quantize",
- "llama-cli",
- "llama-server",
-]
+LLAMA_CPP_TARGETS = ["llama-quantize", "llama-export-lora", "llama-cli",]
# Check environments
keynames = "\n" + "\n".join(os.environ.keys())
-IS_COLAB_ENVIRONMENT = "\nCOLAB_" in keynames
+IS_COLAB_ENVIRONMENT = "\nCOLAB_" in keynames
IS_KAGGLE_ENVIRONMENT = "\nKAGGLE_" in keynames
KAGGLE_TMP = "/tmp"
del keynames
# Weights
LLAMA_WEIGHTS = (
- "self_attn.q_proj",
- "self_attn.k_proj",
- "self_attn.v_proj",
- "self_attn.o_proj",
- "mlp.gate_proj",
- "mlp.up_proj",
- "mlp.down_proj",
+ "self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.o_proj",
+ "mlp.gate_proj", "mlp.up_proj", "mlp.down_proj",
)
LLAMA_LAYERNORMS = (
- "input_layernorm",
- "post_attention_layernorm",
- "pre_feedforward_layernorm",
- "post_feedforward_layernorm",
- "self_attn.q_norm",
- "self_attn.k_norm",
+ "input_layernorm", "post_attention_layernorm",
+ "pre_feedforward_layernorm", "post_feedforward_layernorm",
)
# https://github.com/ggerganov/llama.cpp/blob/master/examples/quantize/quantize.cpp#L19
# From https://mlabonne.github.io/blog/posts/Quantize_Llama_2_models_using_ggml.html
-ALLOWED_QUANTS = {
- "not_quantized": "Recommended. Fast conversion. Slow inference, big files.",
- "fast_quantized": "Recommended. Fast conversion. OK inference, OK file size.",
- "quantized": "Recommended. Slow conversion. Fast inference, small files.",
- "f32": "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
- "bf16": "Bfloat16 - Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
- "f16": "Float16 - Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
- "q8_0": "Fast conversion. High resource use, but generally acceptable.",
- "q4_k_m": "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
- "q5_k_m": "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
- "q2_k": "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
- "q3_k_l": "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
- "q3_k_m": "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
- "q3_k_s": "Uses Q3_K for all tensors",
- "q4_0": "Original quant method, 4-bit.",
- "q4_1": "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
- "q4_k_s": "Uses Q4_K for all tensors",
- "q4_k": "alias for q4_k_m",
- "q5_k": "alias for q5_k_m",
- "q5_0": "Higher accuracy, higher resource usage and slower inference.",
- "q5_1": "Even higher accuracy, resource usage and slower inference.",
- "q5_k_s": "Uses Q5_K for all tensors",
- "q6_k": "Uses Q8_K for all tensors",
+ALLOWED_QUANTS = \
+{
+ "not_quantized" : "Recommended. Fast conversion. Slow inference, big files.",
+ "fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
+ "quantized" : "Recommended. Slow conversion. Fast inference, small files.",
+ "f32" : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
+ "bf16" : "Bfloat16 - Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
+ "f16" : "Float16 - Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
+ "q8_0" : "Fast conversion. High resource use, but generally acceptable.",
+ "q4_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
+ "q5_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
+ "q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
+ "q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+ "q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+ "q3_k_s" : "Uses Q3_K for all tensors",
+ "q4_0" : "Original quant method, 4-bit.",
+ "q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
+ "q4_k_s" : "Uses Q4_K for all tensors",
+ "q4_k" : "alias for q4_k_m",
+ "q5_k" : "alias for q5_k_m",
+ "q5_0" : "Higher accuracy, higher resource usage and slower inference.",
+ "q5_1" : "Even higher accuracy, resource usage and slower inference.",
+ "q5_k_s" : "Uses Q5_K for all tensors",
+ "q6_k" : "Uses Q8_K for all tensors",
# "iq2_xxs" : "2.06 bpw quantization", # Not supported sadly
# "iq2_xs" : "2.31 bpw quantization",
# "iq3_xxs" : "3.06 bpw quantization",
- "q3_k_xs": "3-bit extra small quantization",
+ "q3_k_xs" : "3-bit extra small quantization",
}
-
-def has_curl():
- return shutil.which("curl") is not None
-
-
-CURL_FLAG = "-DLLAMA_CURL=ON" if has_curl() else "-DLLAMA_CURL=OFF"
-
-
def print_quantization_methods():
for key, value in ALLOWED_QUANTS.items():
print(f'"{key}" ==> {value}')
+ pass
+pass
-def check_if_sentencepiece_model(
- model, temporary_location = "_unsloth_sentencepiece_temp"
-):
- if not hasattr(model, "_saved_temp_tokenizer"):
- return False
+def check_if_sentencepiece_model(model, temporary_location = "_unsloth_sentencepiece_temp"):
+ if not hasattr(model, "_saved_temp_tokenizer"): return False
temp_tokenizer = model._saved_temp_tokenizer
sentencepiece_model = False
@@ -166,17 +121,19 @@ def check_if_sentencepiece_model(
if not os.path.exists(file_location):
created_folder = True
os.makedirs(file_location)
+ pass
temp_tokenizer.save_pretrained(file_location)
if os.path.isfile(f"{file_location}/tokenizer.model"):
sentencepiece_model = True
+ pass
if created_folder:
shutil.rmtree(file_location, ignore_errors = True)
return sentencepiece_model
+pass
def _free_cached_model(model):
from huggingface_hub import scan_cache_dir
-
cached_repos = list(scan_cache_dir().repos)
# Go through every cached repo, and delete the one that matches the model we want to save.
@@ -184,27 +141,27 @@ def _free_cached_model(model):
for cached_repo in cached_repos:
if cached_repo.repo_id == model.config._name_or_path:
remove_cache_commit = list(cached_repo.revisions)[0].commit_hash
- delete_strategy = scan_cache_dir().delete_revisions(
- remove_cache_commit,
- )
+ delete_strategy = scan_cache_dir().delete_revisions(remove_cache_commit,)
logger.warning_once(
- "Unsloth: Will remove a cached repo with size "
- + delete_strategy.expected_freed_size_str,
+ "Unsloth: Will remove a cached repo with size " + \
+ delete_strategy.expected_freed_size_str,
)
delete_strategy.execute()
+ pass
+ pass
+pass
def _merge_lora(layer, name):
+
bias = getattr(layer, "bias", None)
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)):
# Is LoRA so we need to merge!
W, quant_state, A, B, s, bias = get_lora_parameters_bias(layer)
if quant_state is not None:
- dtype = (
- quant_state.dtype if type(quant_state) is not list else quant_state[2]
- )
+ dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2]
W = fast_dequantize(W, quant_state)
else:
dtype = W.dtype
@@ -219,13 +176,13 @@ def _merge_lora(layer, name):
# if not torch.isfinite(W).all():
maximum_element = torch.max(W.min().abs(), W.max())
if not torch.isfinite(maximum_element).item():
- raise ValueError(
- f"Unsloth: Merge failed.\n{name} has some elements = infinity."
- )
+ raise ValueError(f"Unsloth: Merge failed.\n{name} has some elements = infinity.")
+ pass
W = W.t().to(dtype)
else:
W = layer.weight
return W, bias
+pass
def fast_save_pickle(shard, name):
@@ -239,41 +196,41 @@ def fast_save_pickle(shard, name):
# pickle_protocol = pickle.HIGHEST_PROTOCOL,
)
return
+pass
@torch.inference_mode
def unsloth_save_model(
model,
tokenizer,
- save_directory: Union[str, os.PathLike],
- save_method: str = "lora", # ["lora", "merged_16bit", "merged_4bit"]
- push_to_hub: bool = False,
- token: Optional[Union[str, bool]] = None,
- is_main_process: bool = True,
- state_dict: Optional[dict] = None,
- save_function: Callable = torch.save,
- max_shard_size: Union[int, str] = "5GB",
- safe_serialization: bool = True,
- variant: Optional[str] = None,
- save_peft_format: bool = True,
+ save_directory : Union[str, os.PathLike],
+ save_method : str = "lora", # ["lora", "merged_16bit", "merged_4bit"]
+ push_to_hub : bool = False,
+ token : Optional[Union[str, bool]] = None,
+ is_main_process : bool = True,
+ state_dict : Optional[dict] = None,
+ save_function : Callable = torch.save,
+ max_shard_size : Union[int, str] = "5GB",
+ safe_serialization : bool = True,
+ variant : Optional[str] = None,
+ save_peft_format : bool = True,
+
# Push to hub
- use_temp_dir: Optional[bool] = None,
- commit_message: Optional[str] = "Trained with Unsloth",
- private: Optional[bool] = None,
- create_pr: bool = False,
- revision: str = None,
- commit_description: str = "Upload model trained with Unsloth 2x faster",
- tags: List[str] = None,
+ use_temp_dir : Optional[bool] = None,
+ commit_message : Optional[str] = "Trained with Unsloth",
+ private : Optional[bool] = None,
+ create_pr : bool = False,
+ revision : str = None,
+ commit_description : str = "Upload model trained with Unsloth 2x faster",
+ tags : List[str] = None,
+
# Our functions
- temporary_location: str = "_unsloth_temporary_saved_buffers",
- maximum_memory_usage: float = 0.9,
- datasets: Optional[List[str]] = None,
+ temporary_location : str = "_unsloth_temporary_saved_buffers",
+ maximum_memory_usage : float = 0.9,
):
- if token is None:
- token = get_token()
+ if token is None: token = get_token()
- if commit_message is None:
- commit_message = ""
+ if commit_message is None: commit_message = ""
if "Unsloth" not in commit_message:
commit_message += " (Trained with Unsloth)"
commit_message = commit_message.lstrip()
@@ -282,214 +239,185 @@ def unsloth_save_model(
commit_description = "Upload model trained with Unsloth 2x faster"
elif "Unsloth 2x faster" not in commit_description:
commit_description += " (Trained with Unsloth 2x faster)"
+ pass
if save_method == "merged_4bit":
raise RuntimeError(
- "Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\n"
- "to merge to GGUF or others later on. I suggest you to do this as a final step\n"
- "if you're planning to do multiple saves.\n"
+ "Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\n"\
+ "to merge to GGUF or others later on. I suggest you to do this as a final step\n"\
+ "if you're planning to do multiple saves.\n"\
"If you are certain, change `save_method` to `merged_4bit_forced`."
)
elif save_method == "merged_4bit_forced":
save_method = "merged_4bit"
+ pass
save_pretrained_settings = dict(locals())
- for deletion in (
- "model",
- "tokenizer",
- "save_method",
- "temporary_location",
- "maximum_memory_usage",
- "datasets",
- ):
+ for deletion in ("model", "tokenizer", "save_method", "temporary_location", "maximum_memory_usage"):
del save_pretrained_settings[deletion]
+ pass
# First check for a token!
if push_to_hub:
from huggingface_hub import whoami
-
try:
username = whoami(token = token)["name"]
except:
raise RuntimeError(
- "Unsloth: Please supply a token!\n"
+ "Unsloth: Please supply a token!\n"\
"Go to https://huggingface.co/settings/tokens"
)
+ pass
+ pass
- assert maximum_memory_usage > 0 and maximum_memory_usage <= 0.95
+ assert(maximum_memory_usage > 0 and maximum_memory_usage <= 0.95)
# Clean memory up first
for _ in range(3):
torch.cuda.empty_cache()
gc.collect()
+ pass
save_method = save_method.lower().replace(" ", "_")
- if (
- save_method != "lora"
- and save_method != "merged_16bit"
- and save_method != "merged_4bit"
- ):
+ if save_method != "lora" and save_method != "merged_16bit" and save_method != "merged_4bit":
raise RuntimeError(
- "Unsloth: You must select one of 3 options when saving models:\n"
- '"lora" ==> This is the fastest and easiet. Just saves LoRA modules.\n'
- '"merged_16bit" ==> This merges LoRA weights and saves to float16. Needed for llama.cpp / GGUF.\n'
+ "Unsloth: You must select one of 3 options when saving models:\n"\
+ '"lora" ==> This is the fastest and easiet. Just saves LoRA modules.\n'\
+ '"merged_16bit" ==> This merges LoRA weights and saves to float16. Needed for llama.cpp / GGUF.\n'\
'"merged_4bit" ==> This merges LoRA weights and saves to 4bit. Useful for DPO / inference.'
)
+ pass
if save_method == "merged_4bit":
+
print("Unsloth: Merging 4bit and LoRA weights to 4bit...")
print("This might take 5 minutes...")
# Counteract no LoRA adapters!
if hasattr(model, "merge_and_unload"):
model = model.merge_and_unload()
+ pass
print("Done.")
+ pass
if tags is not None:
- assert isinstance(tags, (list, tuple))
- tags = list(tags) + [
- "unsloth",
- ]
+ assert(isinstance(tags, (list, tuple)))
+ tags = list(tags) + ["unsloth",]
else:
- tags = [
- "unsloth",
- ]
+ tags = ["unsloth",]
+ pass
save_pretrained_settings["tags"] = tags
if ((save_method == "lora") or (save_method == "merged_4bit")) and push_to_hub:
if token is None:
raise RuntimeError(
- "Unsloth: Pushing to HF requires a token. Pass `token = 'hf_....'`\n"
+ "Unsloth: Pushing to HF requires a token. Pass `token = 'hf_....'`\n"\
"Go to https://huggingface.co/settings/tokens."
)
+ pass
if save_method == "lora":
print("Unsloth: Saving LoRA adapters. Please wait...")
elif save_method == "merged_4bit":
print("Unsloth: Saving 4bit Bitsandbytes model. Please wait...")
+ pass
# Update model tag
_ = upload_to_huggingface(
- model,
- save_directory,
- token,
- "finetuned",
- "trl",
- file_location = None,
- old_username = None,
- private = private,
- datasets = datasets,
+ model, save_directory, token,
+ "finetuned", "trl", file_location = None,
+ old_username = None, private = private,
)
- getattr(model, "original_push_to_hub", model.push_to_hub)(
- repo_id = save_directory,
- use_temp_dir = use_temp_dir,
- commit_message = commit_message,
- private = private,
- token = token,
- max_shard_size = max_shard_size,
- create_pr = create_pr,
+ getattr(model, "original_push_to_hub", tokenizer.push_to_hub)\
+ (
+ repo_id = save_directory,
+ use_temp_dir = use_temp_dir,
+ commit_message = commit_message,
+ private = private,
+ token = token,
+ max_shard_size = max_shard_size,
+ create_pr = create_pr,
safe_serialization = safe_serialization,
- revision = revision,
+ revision = revision,
commit_description = commit_description,
- tags = tags,
+ tags = tags,
)
if tokenizer is not None:
# Set padding side to left for inference
old_padding_side = tokenizer.padding_side
tokenizer.padding_side = "left"
- getattr(tokenizer, "original_push_to_hub", tokenizer.push_to_hub)(
- repo_id = save_directory,
- use_temp_dir = use_temp_dir,
- commit_message = commit_message,
- private = private,
- token = token,
- max_shard_size = max_shard_size,
- create_pr = create_pr,
+ getattr(tokenizer, "original_push_to_hub", tokenizer.push_to_hub)\
+ (
+ repo_id = save_directory,
+ use_temp_dir = use_temp_dir,
+ commit_message = commit_message,
+ private = private,
+ token = token,
+ max_shard_size = max_shard_size,
+ create_pr = create_pr,
safe_serialization = safe_serialization,
- revision = revision,
+ revision = revision,
commit_description = commit_description,
- tags = tags,
+ tags = tags,
)
# Revert back padding side
tokenizer.padding_side = old_padding_side
+ pass
if hasattr(model, "config"):
- print(
- f"Saved {save_method} model to https://huggingface.co/" + save_directory
- )
+ print(f"Saved {save_method} model to https://huggingface.co/" + save_directory)
+ pass
return save_directory, None
+ pass
# Tokenizer has different saving arguments
- tokenizer_save_settings = {
- "save_directory": save_pretrained_settings["save_directory"],
- "legacy_format": None,
- "filename_prefix": None,
- "push_to_hub": save_pretrained_settings["push_to_hub"],
- "private": save_pretrained_settings["private"],
- "token": save_pretrained_settings["token"],
+ tokenizer_save_settings = \
+ {
+ "save_directory" : save_pretrained_settings["save_directory"],
+ "legacy_format" : None,
+ "filename_prefix" : None,
+ "push_to_hub" : save_pretrained_settings["push_to_hub"],
+ "private" : save_pretrained_settings["private"],
+ "token" : save_pretrained_settings["token"],
}
# Check if PEFT Model or not - if yes, 3 levels. If not 2 levels.
from peft import PeftModelForCausalLM
-
if isinstance(model, PeftModelForCausalLM):
internal_model = model.model
else:
internal_model = model
+ pass
# Cannot be converted properly!
- if (
- (save_method == "merged_4bit")
- or (save_method == "lora")
- or (not hasattr(model, "model") or not hasattr(internal_model.model, "layers"))
+ if (save_method == "merged_4bit") or (save_method == "lora") or (
+ not hasattr(model, "model") or \
+ not hasattr(internal_model.model, "layers")
):
# Do general saving
# Edit save_pretrained_settings
# [TODO] _create_repo has errors due to **kwargs getting accepted
# commit_description does not seem to work?
- what_to_delete = (
- (
- "use_temp_dir",
- "commit_message",
- "create_pr",
- "revision",
- "commit_description",
- "tags",
- )
- if save_pretrained_settings["push_to_hub"] is False
- else (
- "use_temp_dir",
- "create_pr",
- "revision",
- "tags",
- "commit_description",
- )
- )
+ what_to_delete = ("use_temp_dir", "commit_message", "create_pr", "revision", "commit_description", "tags",) \
+ if save_pretrained_settings["push_to_hub"] is False else \
+ ("use_temp_dir", "create_pr", "revision", "tags", "commit_description",)
for deletion in what_to_delete:
del save_pretrained_settings[deletion]
+ pass
if hasattr(model, "add_model_tags"):
- model.add_model_tags(
- [
- "unsloth",
- ]
- )
+ model.add_model_tags(["unsloth",])
# Update model tag
if push_to_hub:
- _ = upload_to_huggingface(
- model,
- save_pretrained_settings["save_directory"],
- token,
- "finetuned",
- "trl",
- file_location = None,
- old_username = None,
- private = private,
- datasets = datasets,
+ _ = upload_to_huggingface(
+ model, save_pretrained_settings["save_directory"], token,
+ "finetuned", "trl", file_location = None,
+ old_username = None, private = private,
)
+ pass
if tokenizer is not None:
print("Unsloth: Saving tokenizer...", end = "")
@@ -508,48 +436,47 @@ def unsloth_save_model(
print()
print("Unsloth: Saving model...", end = "")
- if save_method != "lora":
- print(" This might take 10 minutes for Llama-7b...", end = "")
+ if save_method != "lora": print(" This might take 10 minutes for Llama-7b...", end = "")
# [TODO] Is this correct?
if save_method == "lora":
save_pretrained_settings["selected_adapters"] = None
+ pass
model.save_pretrained(**save_pretrained_settings)
if push_to_hub and hasattr(model, "config"):
- print(
- "Saved to https://huggingface.co/"
- + save_pretrained_settings["save_directory"]
- )
+ print("Saved to https://huggingface.co/" + save_pretrained_settings["save_directory"])
+ pass
print(" Done.")
return save_directory, None
+ pass
# If push_to_hub, we must remove the .../ part of a repo
username = None
if push_to_hub and "/" in save_directory:
+
# +1 solves absolute path issues
new_save_directory = save_directory
- username = new_save_directory[: new_save_directory.find("/")]
- new_save_directory = new_save_directory[new_save_directory.find("/") + 1 :]
+ username = new_save_directory[:new_save_directory.find("/")]
+ new_save_directory = new_save_directory[new_save_directory.find("/")+1:]
if IS_KAGGLE_ENVIRONMENT:
- new_save_directory = os.path.join(
- KAGGLE_TMP, new_save_directory[new_save_directory.find("/") + 1 :]
- )
+ new_save_directory = os.path.join(KAGGLE_TMP, new_save_directory[new_save_directory.find("/")+1:])
logger.warning_once(
- "Unsloth: You are pushing to hub in Kaggle environment.\n"
+ "Unsloth: You are pushing to hub in Kaggle environment.\n"\
f"To save memory, we shall move {save_directory} to {new_save_directory}"
)
else:
logger.warning_once(
- f"Unsloth: You are pushing to hub, but you passed your HF username = {username}.\n"
+ f"Unsloth: You are pushing to hub, but you passed your HF username = {username}.\n"\
f"We shall truncate {save_directory} to {new_save_directory}"
)
save_pretrained_settings["save_directory"] = new_save_directory
- tokenizer_save_settings["save_directory"] = new_save_directory
+ tokenizer_save_settings ["save_directory"] = new_save_directory
save_directory = new_save_directory
+ pass
print("Unsloth: Merging 4bit and LoRA weights to 16bit...")
@@ -557,25 +484,18 @@ def unsloth_save_model(
max_ram = psutil.virtual_memory().available
sharded_ram_usage = 5 * 1024 * 1024 * 1024
if type(max_shard_size) is str:
- gb_found = re.match(
- r"([0-9]{1,})[\s]{0,}GB", max_shard_size, flags = re.IGNORECASE
- )
- mb_found = re.match(
- r"([0-9]{1,})[\s]{0,}MB", max_shard_size, flags = re.IGNORECASE
- )
- if gb_found:
- sharded_ram_usage = int(gb_found.group(1)) * 1024 * 1024 * 1024
- elif mb_found:
- sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024
+ gb_found = re.match(r"([0-9]{1,})[\s]{0,}GB", max_shard_size, flags = re.IGNORECASE)
+ mb_found = re.match(r"([0-9]{1,})[\s]{0,}MB", max_shard_size, flags = re.IGNORECASE)
+ if gb_found: sharded_ram_usage = int(gb_found.group(1)) * 1024 * 1024 * 1024
+ elif mb_found: sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024
elif type(max_shard_size) is int:
- sharded_ram_usage = max_shard_size
+ sharded_ram_usage = sharded_ram_usage
+ pass
# Switch to our fast saving modules if it's a slow PC!
n_cpus = psutil.cpu_count(logical = False)
- if n_cpus is None:
- n_cpus = psutil.cpu_count()
- if n_cpus is None:
- n_cpus = 1
+ if n_cpus is None: n_cpus = psutil.cpu_count()
+ if n_cpus is None: n_cpus = 1
if safe_serialization is None:
safe_serialization = True
@@ -583,27 +503,27 @@ def unsloth_save_model(
elif safe_serialization and (n_cpus <= 2):
logger.warning_once(
- f"Unsloth: You have {n_cpus} CPUs. Using `safe_serialization` is 10x slower.\n"
- f"We shall switch to Pytorch saving, which might take 3 minutes and not 30 minutes.\n"
+ f"Unsloth: You have {n_cpus} CPUs. Using `safe_serialization` is 10x slower.\n"\
+ f"We shall switch to Pytorch saving, which might take 3 minutes and not 30 minutes.\n"\
f"To force `safe_serialization`, set it to `None` instead.",
)
safe_serialization = False
save_function = fast_save_pickle
save_pretrained_settings["safe_serialization"] = safe_serialization
- save_pretrained_settings["save_function"] = save_function
+ save_pretrained_settings["save_function"] = save_function
+ pass
# Only safe_serialization uses more RAM
if safe_serialization:
max_ram -= sharded_ram_usage
else:
- max_ram -= sharded_ram_usage * 0.25 # Uses much less
+ max_ram -= sharded_ram_usage*0.25 # Uses much less
+ pass
max_ram = int(max(0, max_ram) * maximum_memory_usage)
- print(
- f"Unsloth: Will use up to "
- f"{round(max_ram/1024/1024/1024, 2)} out of "
- f"{round(psutil.virtual_memory().total/1024/1024/1024, 2)} RAM for saving."
- )
+ print(f"Unsloth: Will use up to "\
+ f"{round(max_ram/1024/1024/1024, 2)} out of "\
+ f"{round(psutil.virtual_memory().total/1024/1024/1024, 2)} RAM for saving.")
# Move temporary_location to /tmp in Kaggle
if IS_KAGGLE_ENVIRONMENT:
@@ -612,41 +532,36 @@ def unsloth_save_model(
# Max directory for disk saving
if not os.path.exists(temporary_location):
os.makedirs(temporary_location)
+ pass
# Check if Kaggle or Colab, since only 20GB of Disk space allowed.
if IS_KAGGLE_ENVIRONMENT or IS_COLAB_ENVIRONMENT:
# We free up 4GB of space
logger.warning_once(
- "Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded\n"
+ "Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded\n"\
"model which will save 4-16GB of disk space, allowing you to save on Kaggle/Colab."
)
_free_cached_model(internal_model)
+ pass
# HF also uses a OrderedDict
from collections import OrderedDict
-
state_dict = OrderedDict()
- torch_dtype = dtype_from_config(internal_model.config)
+ torch_dtype = internal_model.config.torch_dtype
if type(torch_dtype) is str:
- if torch_dtype == "float16":
- torch_dtype = torch.float16
- elif torch_dtype == "bfloat16":
- torch_dtype = torch.bfloat16
+ if torch_dtype == "float16": torch_dtype = torch.float16
+ elif torch_dtype == "bfloat16": torch_dtype = torch.bfloat16
+ pass
# Check modules to save float32 dtype
- state_dict["model.embed_tokens.weight"] = (
- internal_model.model.embed_tokens.weight.data.to(torch_dtype)
- )
+ state_dict["model.embed_tokens.weight"] = internal_model.model.embed_tokens.weight.data.to(torch_dtype)
- max_vram = int(
- torch.cuda.get_device_properties(0).total_memory * maximum_memory_usage
- )
+ max_vram = int(torch.cuda.get_device_properties(0).total_memory * maximum_memory_usage)
print("Unsloth: Saving model... This might take 5 minutes ...")
from tqdm import tqdm as ProgressBar
-
for j, layer in enumerate(ProgressBar(internal_model.model.layers)):
for item in LLAMA_WEIGHTS:
proj = eval(f"layer.{item}")
@@ -656,6 +571,7 @@ def unsloth_save_model(
# Bias term
if bias is not None:
state_dict[f"model.layers.{j}.{item}.bias"] = bias
+ pass
if (torch.cuda.memory_allocated() + W.nbytes) < max_vram:
# Save to GPU memory
@@ -670,104 +586,70 @@ def unsloth_save_model(
# Save to Disk
logger.warning_once("\nWe will save to Disk and not RAM now.")
filename = os.path.join(temporary_location, f"{name}.pt")
- torch.save(
- W,
- filename,
- pickle_module = pickle,
- pickle_protocol = pickle.HIGHEST_PROTOCOL,
- )
+ torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
# weights_only = True weirdly fails?
- state_dict[name] = torch.load(
- filename, map_location = "cpu", mmap = True, weights_only = False
- )
+ state_dict[name] = torch.load(filename, map_location = "cpu", mmap = True, weights_only = False)
+ pass
for item in LLAMA_LAYERNORMS:
try:
# Skip for Gemma 2
- state_dict[f"model.layers.{j}.{item}.weight"] = eval(
- f"layer.{item}.weight.data"
- )
+ state_dict[f"model.layers.{j}.{item}.weight"] = eval(f"layer.{item}.weight.data")
except:
continue
+ pass
+ pass
state_dict["model.norm.weight"] = internal_model.model.norm.weight.data
# Check for modules_to_save float32 dtype
# Check for tied weights
- if (
- internal_model.model.embed_tokens.weight.data_ptr()
- != internal_model.lm_head.weight.data_ptr()
- ):
- state_dict["lm_head.weight"] = internal_model.lm_head.weight.data.to(
- torch_dtype
- )
+ if internal_model.model.embed_tokens.weight.data_ptr() != internal_model.lm_head.weight.data_ptr():
+ state_dict["lm_head.weight"] = internal_model.lm_head.weight.data.to(torch_dtype)
+ pass
# All tensors MUST be type torch.Tensor and not torch.nn.parameter.Parameter
for key, value in state_dict.items():
- if hasattr(value, "data"):
- state_dict[key] = value = value.data
+ if hasattr(value, "data"): state_dict[key] = value = value.data
if type(value) is not torch.Tensor:
logger.warning_once(f"Unsloth: {key} is not a Tensor but a {type(value)}.")
+ pass
+ pass
# Edit save_pretrained_settings
# [TODO] _create_repo has errors due to **kwargs getting accepted
save_pretrained_settings["state_dict"] = state_dict
# commit_description does not seem to work?
- what_to_delete = (
- (
- "use_temp_dir",
- "commit_message",
- "create_pr",
- "revision",
- "commit_description",
- "tags",
- )
- if not push_to_hub
- else (
- "use_temp_dir",
- "create_pr",
- "revision",
- "tags",
- "commit_description",
- )
- )
+ what_to_delete = ("use_temp_dir", "commit_message", "create_pr", "revision", "commit_description", "tags",) \
+ if not push_to_hub else \
+ ("use_temp_dir", "create_pr", "revision", "tags", "commit_description",)
for deletion in what_to_delete:
del save_pretrained_settings[deletion]
+ pass
if hasattr(model, "add_model_tags"):
- model.add_model_tags(
- [
- "unsloth",
- ]
- )
+ model.add_model_tags(["unsloth",])
# Update model tag
if push_to_hub:
_ = upload_to_huggingface(
- model,
- save_pretrained_settings["save_directory"],
- token,
- "finetuned",
- "trl",
- file_location = None,
- old_username = username,
- private = private,
- datasets = datasets,
+ model, save_pretrained_settings["save_directory"], token,
+ "finetuned", "trl", file_location = None,
+ old_username = username, private = private,
)
+ pass
# First check if we're pushing to an organization!
save_directory = save_pretrained_settings["save_directory"]
if save_pretrained_settings["push_to_hub"]:
- new_save_directory, new_username = _determine_username(
- save_directory, username, token
- )
+ new_save_directory, new_username = _determine_username(save_directory, username, token)
if token is not None:
from huggingface_hub import whoami
-
actual_username = whoami(token = token)["name"]
else:
actual_username = username
+ pass
# Check if pushing to an organization
if save_pretrained_settings["push_to_hub"] and (username != actual_username):
@@ -775,6 +657,7 @@ def unsloth_save_model(
# We upload everything at the end!
tokenizer_save_settings["push_to_hub"] = False
tokenizer_save_settings["save_directory"] = new_save_directory
+ pass
# Save tokenizer
if tokenizer is not None:
@@ -792,6 +675,7 @@ def unsloth_save_model(
print(" Done.")
else:
print()
+ pass
# Since merged, edit quantization_config
old_config = model.config
@@ -830,11 +714,12 @@ def unsloth_save_model(
path_in_repo = ".",
repo_id = new_save_directory,
repo_type = "model",
- commit_message = "(Trained with Unsloth)",
+ commit_message = "(Trained with Unsloth)",
ignore_patterns = "*.md",
)
else:
internal_model.save_pretrained(**save_pretrained_settings)
+ pass
# Revert config back
original_model = model
@@ -845,9 +730,8 @@ def unsloth_save_model(
print("Done.")
if push_to_hub and hasattr(model, "config"):
- print(
- f"Saved merged model to https://huggingface.co/{username}/{save_directory.lstrip('/').split('/')[-1]}"
- )
+ print(f"Saved merged model to https://huggingface.co/{username}/{save_directory.lstrip('/').split('/')[-1]}")
+ pass
save_pretrained_settings["state_dict"] = None
@@ -856,6 +740,8 @@ def unsloth_save_model(
if j % 10 == 0:
torch.cuda.empty_cache()
gc.collect()
+ pass
+ pass
state_dict = None
del state_dict
torch.cuda.empty_cache()
@@ -863,26 +749,20 @@ def unsloth_save_model(
# Remove temporary location
import shutil
-
shutil.rmtree(temporary_location, ignore_errors = True)
for _ in range(3):
torch.cuda.empty_cache()
gc.collect()
return save_directory, username
+pass
def install_llama_cpp_clone_non_blocking():
- full_command = [
- "git",
- "clone",
- "--recursive",
- "https://github.com/ggerganov/llama.cpp",
- ]
- run_installer = subprocess.Popen(
- full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT
- )
+ full_command = ["git", "clone", "--recursive", "https://github.com/ggerganov/llama.cpp"]
+ run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
return run_installer
+pass
def install_llama_cpp_make_non_blocking():
@@ -894,90 +774,71 @@ def install_llama_cpp_make_non_blocking():
IS_CMAKE = False
if check == 0:
# Uses old MAKE
- n_jobs = max(int((psutil.cpu_count() or 1) * 1.5), 1)
- full_command = ["make", "all", "-j" + str(n_jobs), "-C", "llama.cpp"]
+ n_jobs = max(int(psutil.cpu_count()*1.5), 1)
+ full_command = ["make", "all", "-j"+str(n_jobs), "-C", "llama.cpp"]
IS_CMAKE = False
else:
# Uses new CMAKE
- n_jobs = max(int(psutil.cpu_count() or 1), 1) # Use less CPUs since 1.5x faster
- check = os.system(
- f"cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF {CURL_FLAG}"
- )
-
+ n_jobs = max(int(psutil.cpu_count()), 1) # Use less CPUs since 1.5x faster
+ check = os.system("cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON")
if check != 0:
- raise RuntimeError(
- f"*** Unsloth: Failed compiling llama.cpp using os.system(...) with error {check}. Please report this ASAP!"
- )
+ raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp using os.system(...) with error {check}. Please report this ASAP!")
+ pass
# f"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count()*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}",
full_command = [
- "cmake",
- "--build",
- "llama.cpp/build",
- "--config",
- "Release",
- "-j" + str(n_jobs),
+ "cmake", "--build", "llama.cpp/build",
+ "--config", "Release",
+ "-j"+str(n_jobs),
"--clean-first",
"--target",
] + LLAMA_CPP_TARGETS
IS_CMAKE = True
+ pass
# https://github.com/ggerganov/llama.cpp/issues/7062
# Weirdly GPU conversion for GGUF breaks??
# run_installer = subprocess.Popen(full_command, env = env, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
- run_installer = subprocess.Popen(
- full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT
- )
+ run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
return run_installer, IS_CMAKE
+pass
def install_python_non_blocking(packages = []):
full_command = ["pip", "install"] + packages
- run_installer = subprocess.Popen(
- full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT
- )
+ run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT)
return run_installer
+pass
def try_execute(commands, force_complete = False):
for command in commands:
- with subprocess.Popen(
- command,
- shell = True,
- stdout = subprocess.PIPE,
- stderr = subprocess.STDOUT,
- bufsize = 1,
- ) as sp:
+ with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
for line in sp.stdout:
line = line.decode("utf-8", errors = "replace")
if "undefined reference" in line:
- raise RuntimeError(
- f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!"
- )
+ raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!")
elif "deprecated" in line:
return "CMAKE"
elif "Unknown argument" in line:
- raise RuntimeError(
- f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!"
- )
+ raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!")
elif "***" in line:
- raise RuntimeError(
- f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!"
- )
+ raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!")
print(line, flush = True, end = "")
+ pass
if force_complete and sp.returncode is not None and sp.returncode != 0:
raise subprocess.CalledProcessError(sp.returncode, sp.args)
+ pass
+ pass
return None
+pass
def install_llama_cpp_old(version = -10):
# Download the 10th latest release since the latest might be broken!
# FALLBACK mechanism
- releases = subprocess.check_output(
- ["git", "ls-remote", "--tags", "https://github.com/ggerganov/llama.cpp.git"]
- )
+ releases = subprocess.check_output(["git", "ls-remote", "--tags", "https://github.com/ggerganov/llama.cpp.git"])
releases = releases.decode("utf-8").replace("\t", " ").split("\n")
for i, x in enumerate(releases):
- if "refs/tags/b" not in x:
- break
+ if "refs/tags/b" not in x: break
releases = releases[:i]
latest = releases[-1]
version = releases[version].split(" ")[0]
@@ -985,18 +846,17 @@ def install_llama_cpp_old(version = -10):
# Check if the llama.cpp exists
if os.path.exists("llama.cpp"):
print(
- "**[WARNING]** You have a llama.cpp directory which is broken.\n"
- "Unsloth will DELETE the broken directory and install a new one.\n"
+ "**[WARNING]** You have a llama.cpp directory which is broken.\n"\
+ "Unsloth will DELETE the broken directory and install a new one.\n"\
"Press CTRL + C / cancel this if this is wrong. We shall wait 30 seconds.\n"
)
import time
-
for i in range(30):
print(f"**[WARNING]** Deleting llama.cpp directory... {30-i} seconds left.")
time.sleep(1)
import shutil
-
shutil.rmtree("llama.cpp", ignore_errors = True)
+ pass
# Clone a specific commit
# Also don't use the GPU!
@@ -1009,33 +869,32 @@ def install_llama_cpp_old(version = -10):
# Try using MAKE
commands = [
"make clean -C llama.cpp",
- f"make all -j{(psutil.cpu_count() or 1)*2} -C llama.cpp",
+ f"make all -j{psutil.cpu_count()*2} -C llama.cpp",
]
if try_execute(commands) == "CMAKE":
# Instead use CMAKE
commands = [
- f"cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF {CURL_FLAG}",
- f"cmake --build llama.cpp/build --config Release -j{(psutil.cpu_count() or 1)*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}",
+ "cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON",
+ f"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count()*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}",
"cp llama.cpp/build/bin/llama-* llama.cpp",
"rm -rf llama.cpp/build",
]
-
try_execute(commands)
+ pass
# Check if successful
if not (
- os.path.exists("llama.cpp/llama-quantize.exe")
- or os.path.exists("llama.cpp/llama-quantize")
- or os.path.exists("llama.cpp/quantize.exe")
- or os.path.exists("llama.cpp/quantize")
- or os.path.exists("llama.cpp/build/bin/llama-quantize")
- or os.path.exists("llama.cpp/build/bin/quantize")
+ os.path.exists("llama.cpp/llama-quantize.exe") or
+ os.path.exists("llama.cpp/llama-quantize") or
+ os.path.exists("llama.cpp/quantize.exe") or
+ os.path.exists("llama.cpp/quantize")
):
raise RuntimeError(
- "Unsloth: The file 'llama.cpp/llama-quantize' or `llama.cpp/quantize` does not exist.\n"
- "We've also double checked the building directory under 'llama.cpp/build/bin/'.\n"
- "But we expect this file to exist! Check if the file exists under llama.cpp and investigate the building process of llama.cpp (make/cmake)!"
+ "Unsloth: The file 'llama.cpp/llama-quantize' or `llama.cpp/quantize` does not exist.\n"\
+ "But we expect this file to exist! Maybe the llama.cpp developers changed the name or check extension of the llama-quantize file."
)
+ pass
+pass
def install_llama_cpp_blocking(use_cuda = False):
@@ -1047,110 +906,101 @@ def install_llama_cpp_blocking(use_cuda = False):
"git clone --recursive https://github.com/ggerganov/llama.cpp",
"pip install gguf protobuf",
]
- if os.path.exists("llama.cpp"):
- return
+ if os.path.exists("llama.cpp"): return
try_execute(commands)
commands = [
"make clean -C llama.cpp",
# https://github.com/ggerganov/llama.cpp/issues/7062
# Weirdly GPU conversion for GGUF breaks??
- # f"{use_cuda} make all -j{(psutil.cpu_count() or 1)*2} -C llama.cpp",
- f"make all -j{(psutil.cpu_count() or 1)*2} -C llama.cpp",
+ # f"{use_cuda} make all -j{psutil.cpu_count()*2} -C llama.cpp",
+ f"make all -j{psutil.cpu_count()*2} -C llama.cpp",
]
if try_execute(commands) == "CMAKE":
# Instead use CMAKE
commands = [
- f"cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF {CURL_FLAG}",
- f"cmake --build llama.cpp/build --config Release -j{(psutil.cpu_count() or 1)*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}",
+ "cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON",
+ f"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count()*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}",
"cp llama.cpp/build/bin/llama-* llama.cpp",
"rm -rf llama.cpp/build",
]
try_execute(commands)
+ pass
+pass
def get_executable(executables):
# Get system locations (System Path).split(system separator)
system_directories = os.environ.get("PATH").split(os.pathsep)
- for directory in system_directories:
- for executable in executables:
- path = os.path.join(directory, executable)
- # Check if the executable exists and is executable
- if os.path.exists(path) and os.access(path, os.X_OK):
- return path
- return None
+ found_executables = [os.path.join(path, executable)
+ for path in system_directories
+ for executable in executables
+ if os.path.exists(os.path.join(path, executable)) and os.access(os.path.join(path, executable), os.X_OK)]
+ return found_executables[0] if found_executables else None
+pass
def save_to_gguf(
- model_name: str,
- model_type: str,
- model_dtype: str,
- is_sentencepiece: bool = False,
- model_directory: str = "unsloth_finetuned_model",
- quantization_method = "fast_quantized", # Can be a list of options! ["q4_k_m", "q8_0", "q5_k_m"]
- first_conversion: str = None,
- is_vlm: bool = False,
- is_gpt_oss: bool = False,
+ model_type : str,
+ model_dtype : str,
+ is_sentencepiece : bool = False,
+ model_directory : str = "unsloth_finetuned_model",
+ quantization_method = "fast_quantized", # Can be a list of options! ["q4_k_m", "q8_0", "q5_k_m"]
+ first_conversion : str = None,
+ _run_installer = None, # Non blocking install of llama.cpp
):
- """
- Orchestrates the complete GGUF conversion process.
- Handles installation, conversion, and quantization.
- """
- # print_output True only if UNSLOTH_ENABLE_LOGGING=1
- if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1":
- print_output = True
- else:
- print_output = False
-
- # Validate model dtype
- assert model_dtype == "float16" or model_dtype == "bfloat16"
+ # logger.warning(
+ # "NOTICE: llama.cpp GGUF conversion is currently unstable, since llama.cpp is\n"\
+ # "undergoing some major bug fixes as at 5th of May 2024. This is not an Unsloth issue.\n"\
+ # "Please be patient - GGUF saving should still work, but might not work as well."
+ # )
+ assert(model_dtype == "float16" or model_dtype == "bfloat16")
model_dtype = "f16" if model_dtype == "float16" else "bf16"
# Convert quantization_method to list
- if isinstance(quantization_method, list):
- pass
- elif isinstance(quantization_method, str):
- quantization_method = [
- quantization_method,
- ]
- elif isinstance(quantization_method, tuple):
- quantization_method = list(quantization_method)
+ if isinstance(quantization_method, list): pass
+ elif isinstance(quantization_method, str): quantization_method = [ quantization_method, ]
+ elif isinstance(quantization_method, tuple): quantization_method = list(quantization_method)
else:
- raise TypeError(
- "Unsloth: quantization_method can only be a string or a list of strings"
- )
+ raise TypeError("Unsloth: quantization_method can only be a string or a list of strings")
+ pass
# Check if bfloat16 is supported
if model_dtype == "bf16" and not torch.cuda.is_bf16_supported():
logger.warning(
- "Unsloth: Cannot convert to bf16 GGUF since your computer doesn't support it.\n"
+ "Unsloth: Cannot convert to bf16 GGUF since your computer doesn't support it.\n"\
"We shall switch instead to f16."
)
model_dtype = "f16"
+ pass
# Check first_conversion as well
if first_conversion is None:
first_conversion = model_dtype
+ pass
# Check I quants
for quant_method in quantization_method:
if quant_method.startswith("iq2"):
- raise RuntimeError(
- "Unsloth: Currently iq2 type quantizations aren't supported yet - sorry!"
- )
+ raise RuntimeError("Unsloth: Currently iq2 type quantizations aren't supported yet - sorry!")
+ pass
+
+ # Careful convert.py is only for Llama / Mistral based archs
+ use_fast_convert = False
+ if not is_sentencepiece: use_fast_convert = False # Llama-3
+ elif model_type == "llama": use_fast_convert = True
+ elif model_type == "mistral": use_fast_convert = True
+ pass
+ logger.warning_once(f"Unsloth: Converting {model_type} model. Can use fast conversion = {use_fast_convert}.")
# Map quant methods
- new_quantization_methods = []
+ new_quantization_method = []
for quant_method in quantization_method:
- if quant_method == "not_quantized":
- quant_method = model_dtype
- elif quant_method == "fast_quantized":
- quant_method = "q8_0"
- elif quant_method == "quantized":
- quant_method = "q4_k_m"
- elif quant_method is None:
- quant_method = "q8_0"
+ if quant_method == "not_quantized": quant_method = model_dtype
+ elif quant_method == "fast_quantized": quant_method = "q8_0"
+ elif quant_method == "quantized": quant_method = "q4_k_m"
+ elif quant_method is None: quant_method = "q8_0"
# Check if wrong method
if quant_method not in ALLOWED_QUANTS.keys():
@@ -1158,249 +1008,301 @@ def save_to_gguf(
for key, value in ALLOWED_QUANTS.items():
error += f"[{key}] => {value}\n"
raise RuntimeError(error)
+ pass
- new_quantization_methods.append(quant_method)
- quantization_method = new_quantization_methods
-
- # Determine optimal first_conversion
- if is_gpt_oss:
- print("Unsloth: GPT-OSS model detected - using special conversion settings")
- first_conversion = "None" # No quantization for GPT-OSS
- # Only keep one conversion method since GPT-OSS doesn't quantize
- quantization_method = ["None"]
- else:
- if first_conversion is None:
- # Check if q8_0 is the ONLY quantization method requested
- if len(quantization_method) == 1 and quantization_method[0] == "q8_0":
- first_conversion = "None" # Let llama-quantize do the direct conversion
- else:
- # For all other cases, choose the highest precision format
- # that can be requantized to all requested formats
- strength = 0
- for quant_method in quantization_method:
- if quant_method == "f32":
- strength = max(strength, 3)
- elif quant_method == "f16":
- strength = max(strength, 2)
- elif quant_method == "bf16":
- strength = max(strength, 1)
- # Note: we don't set strength for q8_0 here since we handle it above
-
- if strength >= 3:
- first_conversion = "f32"
- elif strength >= 2:
- first_conversion = "f16"
- elif strength >= 1:
- first_conversion = "bf16"
- else:
- first_conversion = "bf16" # requantizing from q8_0 disallowed in new llama.cpp default to bf16.
-
- # Check bfloat16 support again for first_conversion
- if first_conversion == "bf16" and not torch.cuda.is_bf16_supported():
- logger.warning("Unsloth: Switching bf16 to f16 due to hardware limitations")
- first_conversion = "f16"
+ new_quantization_method.append(quant_method)
+ pass
+ quantization_method = new_quantization_method
- first_conversion_dtype = "" if first_conversion == "None" else first_conversion
- # Print conversion info
- print_info = (
- f"==((====))== Unsloth: Conversion from HF to GGUF information\n"
- f" {chr(92)}{chr(92)} /| [0] Installing llama.cpp might take 3 minutes.\n"
- f"O^O/ {chr(92)}_/ {chr(92)} [1] Converting HF to GGUF {first_conversion_dtype} might take 3 minutes.\n"
- f"{chr(92)} / [2] Converting GGUF {first_conversion_dtype} to {quantization_method} might take 10 minutes each.\n"
+ print_info = \
+ f"==((====))== Unsloth: Conversion from QLoRA to GGUF information\n"\
+ f" {chr(92)}{chr(92)} /| [0] Installing llama.cpp might take 3 minutes.\n"\
+ f"O^O/ {chr(92)}_/ {chr(92)} [1] Converting HF to GGUF 16bits might take 3 minutes.\n"\
+ f"{chr(92)} / [2] Converting GGUF 16bits to {quantization_method} might take 10 minutes each.\n"\
f' "-____-" In total, you will have to wait at least 16 minutes.\n'
- )
print(print_info)
- # Step 1: Ensure llama.cpp is installed
- try:
- quantizer_location, converter_location = check_llama_cpp()
- print("Unsloth: llama.cpp found in the system. Skipping installation.")
- except:
+ # Check first_conversion format
+ if first_conversion == "f16" : pass
+ elif first_conversion == "bf16" : pass
+ elif first_conversion == "f32" : pass
+ elif first_conversion == "q8_0" : pass
+ else:
+ raise RuntimeError(
+ f"Unsloth: `first_conversion` can only be one of ['f16', 'bf16', 'f32', 'q8_0'] and not `{first_conversion}`."
+ )
+ pass
+
+ # Determine whether the system already has llama.cpp installed and the scripts are executable
+ quantize_location = get_executable(["llama-quantize", "quantize", "llama-quantize.exe", "quantize.exe"])
+ convert_location = get_executable(["convert-hf-to-gguf.py", "convert_hf_to_gguf.py"])
+
+ error = 0
+ if quantize_location is not None and convert_location is not None:
+ print("Unsloth: llama.cpp found in the system. We shall skip installation.")
+ else:
print("Unsloth: Installing llama.cpp. This might take 3 minutes...")
- if IS_KAGGLE_ENVIRONMENT:
- # Kaggle: no CUDA support due to environment limitations
- quantizer_location, converter_location = install_llama_cpp(
- gpu_support = False, print_output = print_output
+ if _run_installer is not None:
+ _run_installer, IS_CMAKE = _run_installer
+
+ error = _run_installer.wait()
+ # Check if successful
+ if error != 0:
+ print(f"Unsloth: llama.cpp error code = {error}.")
+ install_llama_cpp_old(-10)
+ pass
+
+ if IS_CMAKE:
+ # CMAKE needs to do some extra steps
+ print("Unsloth: CMAKE detected. Finalizing some steps for installation.")
+
+ check = os.system("cp llama.cpp/build/bin/llama-* llama.cpp")
+ if check != 0: raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!")
+ check = os.system("rm -rf llama.cpp/build")
+ if check != 0: raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!")
+ pass
+ else:
+ error = 0
+ install_llama_cpp_blocking()
+ pass
+
+ # Careful llama.cpp/quantize changed to llama.cpp/llama-quantize
+ # and llama.cpp/main changed to llama.cpp/llama-cli
+ # See https://github.com/ggerganov/llama.cpp/pull/7809
+ quantize_location = None
+ if os.path.exists("llama.cpp/quantize.exe"):
+ quantize_location = "llama.cpp/quantize.exe"
+ elif os.path.exists("llama.cpp/quantize"):
+ quantize_location = "llama.cpp/quantize"
+ elif os.path.exists("llama.cpp/llama-quantize.exe"):
+ quantize_location = "llama.cpp/llama-quantize.exe"
+ elif os.path.exists("llama.cpp/llama-quantize"):
+ quantize_location = "llama.cpp/llama-quantize"
+ else:
+ raise RuntimeError(
+ "Unsloth: The file ('llama.cpp/llama-quantize' or 'llama.cpp/llama-quantize.exe' if you are on Windows WSL) or 'llama.cpp/quantize' does not exist.\n"\
+ "But we expect this file to exist! Maybe the llama.cpp developers changed the name or check extension of the llama-quantize file."
)
+ pass
+
+ # See https://github.com/unslothai/unsloth/pull/730
+ # Filenames changed again!
+ convert_location = None
+ if os.path.exists("llama.cpp/convert-hf-to-gguf.py"):
+ convert_location = "llama.cpp/convert-hf-to-gguf.py"
+ elif os.path.exists("llama.cpp/convert_hf_to_gguf.py"):
+ convert_location = "llama.cpp/convert_hf_to_gguf.py"
else:
- quantizer_location, converter_location = install_llama_cpp(
- gpu_support = False, # GGUF conversion doesn't need CUDA
- print_output = print_output,
+ raise RuntimeError(
+ "Unsloth: The file 'llama.cpp/convert-hf-to-gguf.py' or 'llama.cpp/convert_hf_to_gguf.py' does not exist.\n"\
+ "But we expect this file to exist! Maybe the llama.cpp developers changed the name?"
)
+ pass
+ pass
- # Step 2: Download and patch converter script
- print("Unsloth: Preparing converter script...")
- with use_local_gguf():
- converter_path, supported_text_archs, supported_vision_archs = (
- _download_convert_hf_to_gguf()
- )
+ # Determine maximum first_conversion state
+ if first_conversion == "f32" : strength = 3
+ elif first_conversion == "f16" : strength = 2
+ elif first_conversion == "bf16" : strength = 1
+ elif first_conversion == "q8_0" : strength = 0
- # Step 3: Initial GGUF conversion
- print(
- f"Unsloth: [1] Converting model into {first_conversion_dtype} GGUF format."
- )
- print(f"This might take 3 minutes...")
-
- initial_files, is_vlm_update = convert_to_gguf(
- model_name = model_name,
- input_folder = model_directory,
- model_dtype = model_dtype,
- quantization_type = first_conversion,
- converter_location = converter_path,
- supported_text_archs = supported_text_archs,
- supported_vision_archs = supported_vision_archs,
- is_vlm = is_vlm,
- is_gpt_oss = is_gpt_oss,
- max_shard_size = "50GB",
- print_output = print_output,
- )
- # update is_vlm switch
- is_vlm = is_vlm_update
- # Check conversion success
- for file in initial_files:
- if not os.path.exists(file):
- if IS_KAGGLE_ENVIRONMENT:
- raise RuntimeError(
- f"Unsloth: Conversion failed for {file}\n"
- "You are in a Kaggle environment with limited disk space (20GB).\n"
- "Try saving to /tmp for more space or use a smaller model.\n"
- "Alternatively, save the 16bit model first, then convert manually."
- )
- else:
- raise RuntimeError(
- f"Unsloth: Conversion failed for {file}\n"
- "Please check disk space and try again."
+ for quant_method in quantization_method:
+ if quant_method == "f32": strength = max(strength, 3)
+ elif quant_method == "f16": strength = max(strength, 2)
+ elif quant_method == "bf16": strength = max(strength, 1)
+ elif quant_method == "q8_0": strength = max(strength, 0)
+ else:
+ # Quantized models must have f16 as the default argument
+ if first_conversion == "f32" : pass
+ elif first_conversion == "f16" : pass
+ elif first_conversion == "bf16" : pass
+ elif first_conversion == "q8_0":
+ logger.warning_once(
+ "Unsloth: Using q8_0 for the `first_conversion` will lose a bit of accuracy, "\
+ "but saves disk space!"
)
+ # first_conversion = "f16"
+ pass
+ pass
+ pass
+
+ # If only q8_0:
+ if len(quantization_method) == 1 and quantization_method[0] == "q8_0":
+ strength = 0
+ pass
- # Move initial GGUF files into a dedicated _gguf directory
- gguf_directory = f"{model_directory}_gguf"
- os.makedirs(gguf_directory, exist_ok = True)
- moved_files = []
- for fpath in initial_files:
- dst = os.path.join(gguf_directory, os.path.basename(fpath))
- shutil.move(fpath, dst)
- moved_files.append(dst)
- initial_files = moved_files
+ if strength >= 3: first_conversion = "f32"
+ elif strength >= 2: first_conversion = "f16"
+ elif strength >= 1: first_conversion = "bf16"
+ else: first_conversion = "q8_0"
- print(f"Unsloth: Initial conversion completed! Files: {initial_files}")
+ # Non llama/mistral needs can only use f32 or f16
+ if not use_fast_convert and \
+ (first_conversion != "f16" or first_conversion != "bf16" or first_conversion != "f32"):
+
+ pass
+ # Latest llama.cpp works for all models for q8_0!
+
+ # logger.warning_once("Unsloth: We must use f16 for non Llama and Mistral models.")
+ # first_conversion = "f16"
+ pass
- # Step 4: Additional quantizations using llama-quantize
- all_saved_locations = initial_files.copy()
+ # Check if bfloat16 is supported
+ if first_conversion == "bf16" and not torch.cuda.is_bf16_supported():
+ logger.warning(
+ "Unsloth: Cannot convert to bf16 GGUF since your computer doesn't support it.\n"\
+ "We shall switch instead to f16."
+ )
+ first_conversion = "f16"
+ pass
- # Get CPU count for quantization
n_cpus = psutil.cpu_count()
- if n_cpus is None:
- n_cpus = 1
+ if n_cpus is None: n_cpus = 1
n_cpus *= 2
+ # Concurrency from https://rentry.org/llama-cpp-conversions#merging-loras-into-a-model
- if not is_gpt_oss:
- base_gguf = initial_files[0]
- quants_created = False
- for quant_method in quantization_method:
- if quant_method != first_conversion:
- print(
- f"Unsloth: [2] Converting GGUF {first_conversion_dtype} into {quant_method}. This might take 10 minutes..."
- )
- output_location = os.path.join(
- gguf_directory, f"{model_name}.{quant_method.upper()}.gguf"
- )
- try:
- # Use the quantize_gguf function we created
- quantized_file = quantize_gguf(
- input_gguf = base_gguf,
- output_gguf = output_location,
- quant_type = quant_method,
- quantizer_location = quantizer_location,
- print_output = print_output,
- )
- all_saved_locations.append(quantized_file)
- quants_created = True
- except Exception as e:
- if IS_KAGGLE_ENVIRONMENT:
- raise RuntimeError(
- f"Unsloth: Quantization failed for {output_location}\n"
- "You are in a Kaggle environment, which might be the reason this is failing.\n"
- "Kaggle only provides 20GB of disk space in the working directory.\n"
- "Merging to 16bit for 7b models use 16GB of space.\n"
- "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"
- "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"
- "You can try saving it to the `/tmp` directory for larger disk space.\n"
- "I suggest you to save the 16bit model first, then use manual llama.cpp conversion.\n"
- f"Error: {e}"
- )
- else:
- if IS_WINDOWS:
- build_instructions = (
- f'cd "{LLAMA_CPP_DEFAULT_DIR}"\n'
- f"cmake -S . -B build -DBUILD_SHARED_LIBS=OFF\n"
- f"cmake --build build --config Release"
- )
- else:
- build_instructions = f'cd "{LLAMA_CPP_DEFAULT_DIR}" && make clean && make all -j'
+ final_location = str((Path(model_directory) / f"unsloth.{first_conversion.upper()}.gguf").absolute())
- raise RuntimeError(
- f"Unsloth: Quantization failed for {output_location}\n"
- "You might have to compile llama.cpp yourself, then run this again.\n"
- "You do not need to close this Python program. Run the following commands in a new terminal:\n"
- f'git clone --recursive https://github.com/ggerganov/llama.cpp "{LLAMA_CPP_DEFAULT_DIR}"\n'
- f"{build_instructions}\n"
- "Once that's done, redo the quantization.\n"
- f"Error: {e}"
- )
- print("Unsloth: Model files cleanup...")
- if quants_created:
- all_saved_locations.remove(base_gguf)
- Path(base_gguf).unlink(missing_ok = True)
+ print(f"Unsloth: [1] Converting model at {model_directory} into {first_conversion} GGUF format.\n"\
+ f"The output location will be {final_location}\n"\
+ "This might take 3 minutes...")
- # flip the list to get [text_model, mmproj] order. for text models stays the same.
- all_saved_locations.reverse()
+ # We first check if tokenizer.model exists in the model_directory
+ if os.path.exists(f"{model_directory}/tokenizer.model"):
+ vocab_type = "spm,hfft,bpe"
+ # Fix Sentencepiece model as well!
+ fix_sentencepiece_gguf(model_directory)
else:
- print("Unsloth: GPT-OSS model - skipping additional quantizations")
+ vocab_type = "bpe"
+ pass
- if is_gpt_oss:
- want_full_precision = True
+ # convert.py is deprecated!
+ use_fast_convert = False
+ if use_fast_convert:
+ command = f"python llama.cpp/convert.py {model_directory} "\
+ f"--outfile {final_location} --vocab-type {vocab_type} "\
+ f"--outtype {first_conversion} --concurrency {n_cpus} --pad-vocab"
else:
- want_full_precision = first_conversion in frozenset(quantization_method)
+ command = f"python {convert_location} {model_directory} "\
+ f"--outfile {final_location} "\
+ f"--outtype {first_conversion}"
+ pass
+
+ try_execute([command,], force_complete = True)
+
+ # Check if quantization succeeded!
+ if not os.path.isfile(final_location):
+ if IS_KAGGLE_ENVIRONMENT:
+ if not Path(final_location).resolve().is_relative_to(Path('/tmp').resolve()):
+ raise RuntimeError(
+ f"Unsloth: Quantization failed for {final_location}\n"\
+ "You are in a Kaggle environment, which might be the reason this is failing.\n"\
+ "Kaggle only provides 20GB of disk space in the working directory.\n"\
+ "Merging to 16bit for 7b models use 16GB of space.\n"\
+ "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\
+ "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\
+ "You can try saving it to the `/tmp` directory for larger disk space.\n"\
+ "I suggest you to save the 16bit model first, then use manual llama.cpp conversion."
+ )
+ else:
+ raise RuntimeError(
+ f"Unsloth: Quantization failed for {final_location}\n"\
+ "You might have to compile llama.cpp yourself, then run this again.\n"\
+ "You do not need to close this Python program. Run the following commands in a new terminal:\n"\
+ "You must run this in the same folder as you're saving your model.\n"\
+ "git clone --recursive https://github.com/ggerganov/llama.cpp\n"\
+ "cd llama.cpp && make clean && make all -j\n"\
+ "Once that's done, redo the quantization."
+ )
+ pass
+ pass
+ print(f"Unsloth: Conversion completed! Output location: {final_location}")
+
+ full_precision_location = final_location
+
+ all_saved_locations = [full_precision_location,]
+ # Convert each type!
+ for quant_method in quantization_method:
+ if quant_method != first_conversion:
+ print(f"Unsloth: [2] Converting GGUF 16bit into {quant_method}. This might take 20 minutes...")
+ final_location = str((Path(model_directory) / f"unsloth.{quant_method.upper()}.gguf").absolute())
+
+ command = f"./{quantize_location} {full_precision_location} "\
+ f"{final_location} {quant_method} {n_cpus}"
+
+ try_execute([command,], force_complete = True)
+
+ # Check if quantization succeeded!
+ if not os.path.isfile(final_location):
+ if IS_KAGGLE_ENVIRONMENT:
+ if not Path(final_location).resolve().is_relative_to(Path('/tmp').resolve()):
+ raise RuntimeError(
+ f"Unsloth: Quantization failed for {final_location}\n"\
+ "You are in a Kaggle environment, which might be the reason this is failing.\n"\
+ "Kaggle only provides 20GB of disk space in the working directory.\n"\
+ "Merging to 16bit for 7b models use 16GB of space.\n"\
+ "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\
+ "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\
+ "You can try saving it to the `/tmp` directory for larger disk space.\n"\
+ "I suggest you to save the 16bit model first, then use manual llama.cpp conversion."
+ )
+ else:
+ raise RuntimeError(
+ "Unsloth: Quantization failed! You might have to compile llama.cpp yourself, then run this again.\n"\
+ "You do not need to close this Python program. Run the following commands in a new terminal:\n"\
+ "You must run this in the same folder as you're saving your model.\n"\
+ "git clone --recursive https://github.com/ggerganov/llama.cpp\n"\
+ "cd llama.cpp && make clean && make all -j\n"\
+ "Once that's done, redo the quantization."
+ )
+ pass
+ pass
- print(f"Unsloth: All GGUF conversions completed successfully!")
- print(f"Generated files: {all_saved_locations}")
+ print(f"Unsloth: Conversion completed! Output location: {final_location}")
+ all_saved_locations.append(final_location)
+ pass
+ pass
- return all_saved_locations, want_full_precision, is_vlm
+ # Finally check if first_conversion (f16, bf16 etc) was in the list of actual quant methods
+ full_precision_seen = first_conversion in frozenset(quantization_method)
+
+ return all_saved_locations, full_precision_seen
+pass
def unsloth_save_pretrained_merged(
self,
- save_directory: Union[str, os.PathLike],
- tokenizer = None,
- save_method: str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
- push_to_hub: bool = False,
- token: Optional[Union[str, bool]] = None,
- is_main_process: bool = True,
- state_dict: Optional[dict] = None,
- save_function: Callable = torch.save,
- max_shard_size: Union[int, str] = "5GB",
- safe_serialization: bool = True,
- variant: Optional[str] = None,
- save_peft_format: bool = True,
- tags: List[str] = None,
- temporary_location: str = "_unsloth_temporary_saved_buffers",
- maximum_memory_usage: float = 0.75,
- datasets: Optional[List[str]] = None,
+ save_directory : Union[str, os.PathLike],
+ tokenizer = None,
+ save_method : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
+ push_to_hub : bool = False,
+ token : Optional[Union[str, bool]] = None,
+ is_main_process : bool = True,
+ state_dict : Optional[dict] = None,
+ save_function : Callable = torch.save,
+ max_shard_size : Union[int, str] = "5GB",
+ safe_serialization : bool = True,
+ variant : Optional[str] = None,
+ save_peft_format : bool = True,
+ tags : List[str] = None,
+ temporary_location : str = "_unsloth_temporary_saved_buffers",
+ maximum_memory_usage : float = 0.75,
):
"""
- Same as .save_pretrained(...) except 4bit weights are auto
- converted to float16 with as few overhead as possible.
+ Same as .save_pretrained(...) except 4bit weights are auto
+ converted to float16 with as few overhead as possible.
- Choose for `save_method` to be either:
- 1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
- 2. `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
- 3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
+ Choose for `save_method` to be either:
+ 1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
+ 2. `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
+ 3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
"""
if tokenizer is None:
logger.warning_once(
- "Unsloth: You're not saving a tokenizer as well?\n"
+ "Unsloth: You're not saving a tokenizer as well?\n"\
"You can do it separately via `tokenizer.save_pretrained(...)`"
)
+ pass
arguments = dict(locals())
arguments["model"] = self
@@ -1408,54 +1310,57 @@ def unsloth_save_pretrained_merged(
unsloth_save_model(**arguments)
for _ in range(3):
gc.collect()
+pass
def unsloth_push_to_hub_merged(
self,
- repo_id: str,
- tokenizer = None,
- save_method: str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
- use_temp_dir: Optional[bool] = None,
- commit_message: Optional[str] = "Trained with Unsloth",
- private: Optional[bool] = None,
- token: Union[bool, str, None] = None,
- max_shard_size: Union[int, str, None] = "5GB",
- create_pr: bool = False,
- safe_serialization: bool = True,
- revision: str = None,
- commit_description: str = "Upload model trained with Unsloth 2x faster",
- tags: Optional[List[str]] = None,
- temporary_location: str = "_unsloth_temporary_saved_buffers",
- maximum_memory_usage: float = 0.75,
- datasets: Optional[List[str]] = None,
+ repo_id : str,
+ tokenizer = None,
+ save_method : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
+ use_temp_dir : Optional[bool] = None,
+ commit_message : Optional[str] = "Trained with Unsloth",
+ private : Optional[bool] = None,
+ token : Union[bool, str, None] = None,
+ max_shard_size : Union[int, str, None] = "5GB",
+ create_pr : bool = False,
+ safe_serialization : bool = True,
+ revision : str = None,
+ commit_description : str = "Upload model trained with Unsloth 2x faster",
+ tags : Optional[List[str]] = None,
+ temporary_location : str = "_unsloth_temporary_saved_buffers",
+ maximum_memory_usage : float = 0.75,
):
"""
- Same as .push_to_hub(...) except 4bit weights are auto
- converted to float16 with as few overhead as possible.
+ Same as .push_to_hub(...) except 4bit weights are auto
+ converted to float16 with as few overhead as possible.
- Choose for `save_method` to be either:
- 1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
- 2. `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
- 3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
+ Choose for `save_method` to be either:
+ 1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
+ 2. `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
+ 3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
"""
if tokenizer is None:
logger.warning_once(
- "Unsloth: You're not saving a tokenizer as well?\n"
+ "Unsloth: You're not saving a tokenizer as well?\n"\
"You can do it separately via `tokenizer.push_to_hub(...)`"
)
+ pass
arguments = dict(locals())
- arguments["model"] = self
+ arguments["model"] = self
arguments["save_directory"] = repo_id
- arguments["push_to_hub"] = True
+ arguments["push_to_hub"] = True
del arguments["self"]
del arguments["repo_id"]
unsloth_save_model(**arguments)
for _ in range(3):
gc.collect()
+pass
-MODEL_CARD = """---
+MODEL_CARD = \
+"""---
base_model: {base_model}
tags:
- text-generation-inference
@@ -1474,7 +1379,7 @@ def unsloth_push_to_hub_merged(
- **License:** apache-2.0
- **Finetuned from model :** {base_model}
-This {model_type} model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth)
+This {model_type} model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
[
](https://github.com/unslothai/unsloth)
"""
@@ -1485,19 +1390,19 @@ def _determine_username(save_directory, old_username, token):
save_directory = save_directory.lstrip("./")
if "/" not in save_directory:
from huggingface_hub import whoami
-
try:
username = whoami(token = token)["name"]
if type(old_username) is str and username != old_username:
username = old_username
+ pass
save_directory = f"{username}/{save_directory}"
except:
- raise RuntimeError(
- f"Unsloth: {save_directory} is not a Huggingface directory."
- )
+ raise RuntimeError(f"Unsloth: {save_directory} is not a Huggingface directory.")
else:
username = save_directory.split("/")[0]
+ pass
return save_directory, username
+pass
def create_huggingface_repo(
@@ -1505,52 +1410,38 @@ def create_huggingface_repo(
save_directory,
token = None,
private = False,
- datasets = None,
):
- if token is None:
+ if token is None :
token = get_token()
- save_directory, username = _determine_username(save_directory, None, token)
+ pass
+ save_directory, username = _determine_username(save_directory, "", token)
from huggingface_hub import create_repo
-
try:
create_repo(
- repo_id = save_directory,
- token = token,
+ repo_id = save_directory,
+ token = token,
repo_type = "model",
- exist_ok = False,
- private = private,
+ exist_ok = False,
+ private = private,
)
# Create model card
from huggingface_hub import ModelCard
-
content = MODEL_CARD.format(
- username = username,
+ username = username,
base_model = model.config._name_or_path,
model_type = model.config.model_type,
- method = "",
- extra = "unsloth",
+ method = "",
+ extra = "unsloth",
)
card = ModelCard(content)
- if datasets:
- card.data.datasets = datasets
card.push_to_hub(save_directory, token = token)
except:
- # Repo already exists — update datasets metadata separately
- if datasets:
- try:
- from huggingface_hub import metadata_update
-
- metadata_update(
- save_directory, {"datasets": datasets}, overwrite = True, token = token
- )
- except Exception as e:
- logger.warning_once(
- f"Unsloth: Could not update datasets metadata for {save_directory}: {e}"
- )
+ pass
hf_api = HfApi(token = token)
return save_directory, hf_api
+pass
def upload_to_huggingface(
@@ -1563,99 +1454,85 @@ def upload_to_huggingface(
old_username = None,
private = None,
create_config = True,
- datasets = None,
):
save_directory, username = _determine_username(save_directory, old_username, token)
from huggingface_hub import create_repo
-
try:
create_repo(
- repo_id = save_directory,
- token = token,
+ repo_id = save_directory,
+ token = token,
repo_type = "model",
- exist_ok = False,
- private = private,
+ exist_ok = False,
+ private = private,
)
# Create model card
from huggingface_hub import ModelCard
-
content = MODEL_CARD.format(
- username = username,
+ username = username,
base_model = model.config._name_or_path,
model_type = model.config.model_type,
- method = "",
- extra = extra,
+ method = "",
+ extra = extra,
)
card = ModelCard(content)
- if datasets:
- card.data.datasets = datasets
card.push_to_hub(save_directory, token = token)
except:
- # Repo already exists — update datasets metadata separately
- if datasets:
- try:
- from huggingface_hub import metadata_update
-
- metadata_update(
- save_directory, {"datasets": datasets}, overwrite = True, token = token
- )
- except Exception as e:
- logger.warning_once(
- f"Unsloth: Could not update datasets metadata for {save_directory}: {e}"
- )
+ pass
if file_location is not None:
# Now upload file
hf_api = HfApi(token = token)
if "/" in file_location:
- uploaded_location = file_location[file_location.rfind("/") + 1 :]
+ uploaded_location = file_location[file_location.rfind("/")+1:]
else:
uploaded_location = file_location
+ pass
# find ftevent file from tensorboard and upload it
import glob
-
ftevent_files = glob.glob("*out.tfevents*", recursive = True)
if len(ftevent_files) > 0:
- print(
- "Unsloth: Uploading tensorboard files... Please wait...",
- file_location + "*out.tfevents*",
- )
+ print("Unsloth: Uploading tensorboard files... Please wait...", file_location + "*out.tfevents*")
for ftevent_file in ftevent_files:
hf_api.upload_file(
path_or_fileobj = ftevent_file,
- path_in_repo = ftevent_file.replace(file_location, ""),
- repo_id = save_directory,
- repo_type = "model",
- commit_message = "(Trained with Unsloth)",
+ path_in_repo = ftevent_file.replace(file_location, ""),
+ repo_id = save_directory,
+ repo_type = "model",
+ commit_message = "(Trained with Unsloth)",
)
+ pass
+ pass
hf_api.upload_file(
path_or_fileobj = file_location,
- path_in_repo = uploaded_location,
- repo_id = save_directory,
- repo_type = "model",
- commit_message = "(Trained with Unsloth)",
+ path_in_repo = uploaded_location,
+ repo_id = save_directory,
+ repo_type = "model",
+ commit_message = "(Trained with Unsloth)",
)
# We also upload a config.json file
if create_config:
import json
-
- with open("_temporary_unsloth_config.json", "w", encoding = "utf-8") as file:
- json.dump({"model_type": model.config.model_type}, file, indent = 4)
+ with open("_temporary_unsloth_config.json", "w") as file:
+ json.dump({"model_type" : model.config.model_type}, file, indent = 4)
+ pass
hf_api.upload_file(
path_or_fileobj = "_temporary_unsloth_config.json",
- path_in_repo = "config.json",
- repo_id = save_directory,
- repo_type = "model",
- commit_message = "(Trained with Unsloth)",
+ path_in_repo = "config.json",
+ repo_id = save_directory,
+ repo_type = "model",
+ commit_message = "(Trained with Unsloth)",
)
os.remove("_temporary_unsloth_config.json")
+ pass
+ pass
return username
+pass
def fix_tokenizer_bos_token(tokenizer):
@@ -1663,97 +1540,87 @@ def fix_tokenizer_bos_token(tokenizer):
fix_bos_token = False
chat_template = getattr(tokenizer, "chat_template", None)
- if tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None):
- if chat_template is not None and (
- tokenizer.bos_token in chat_template
- or "{bos_token}" in chat_template.replace(" ", "")
- or "{bos_token+" in chat_template.replace(" ", "")
- ):
+ if (tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None)):
+ if chat_template is not None and \
+ (
+ tokenizer.bos_token in chat_template or \
+ "{bos_token}" in chat_template.replace(" ", "") or \
+ "{bos_token+" in chat_template.replace(" ", "")
+ ):
+
fix_bos_token = True
logger.warning(
- "Unsloth: ##### The current model auto adds a BOS token.\n"
+ "Unsloth: ##### The current model auto adds a BOS token.\n"\
"Unsloth: ##### Your chat template has a BOS token. We shall remove it temporarily."
)
# Remove {{bos_token}}
- new_chat_template = re.sub(
- r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\}[\s]{0,}\}", "", chat_template
- )
+ new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\}[\s]{0,}\}", "", chat_template)
# Remove {{bos_token +
- new_chat_template = re.sub(
- r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\+[\s]{0,}",
- "",
- new_chat_template,
- )
+ new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\+[\s]{0,}", "", new_chat_template)
tokenizer.chat_template = new_chat_template
+ pass
+ pass
return fix_bos_token, chat_template
+pass
-def create_ollama_modelfile(tokenizer, base_model_name, model_location):
+def create_ollama_modelfile(tokenizer, gguf_location):
"""
- Creates an Ollama Modelfile.
- Use ollama.create(model = "new_ollama_model", modelfile = modelfile)
+ Creates an Ollama Modelfile.
+ Use ollama.create(model = "new_ollama_model", modelfile = modelfile)
"""
- ollama_template_name = MODEL_TO_OLLAMA_TEMPLATE_MAPPER.get(base_model_name)
- if not ollama_template_name:
- print(
- f"Unsloth: No Ollama template mapping found for model '{base_model_name}'. Skipping Ollama Modelfile"
- )
- return None
- ollama_modelfile = OLLAMA_TEMPLATES.get(ollama_template_name)
- if not ollama_modelfile:
- print(
- f"Unsloth: No Ollama template mapping found for model '{base_model_name}'. Skipping Ollama Modelfile"
- )
- return None
- tokenizer._ollama_modelfile = (
- ollama_modelfile # This comes from the unpacking above
- )
- modelfile = ollama_modelfile
+ modelfile = getattr(tokenizer, "_ollama_modelfile", None)
+ if modelfile is None: return None
FILE_LOCATION_REPLACER = "⚫@✅#🦥__FILE_LOCATION__⚡@🦥#⛵"
- EOS_TOKEN_REPLACER = "⚫@✅#🦥__EOS_TOKEN__⚡@🦥#⛵"
- LEFT_BRACKET_REPLACER = "⚫@✅#🦥"
+ EOS_TOKEN_REPLACER = "⚫@✅#🦥__EOS_TOKEN__⚡@🦥#⛵"
+ LEFT_BRACKET_REPLACER = "⚫@✅#🦥"
RIGHT_BRACKET_REPLACER = "⚡@🦥#⛵"
# Fixes https://github.com/unslothai/unsloth/issues/1087
# We must convert all {'s and }'s but keep {__FILE_LOCATION__} intact
- modelfile = (
- modelfile.replace("{__FILE_LOCATION__}", FILE_LOCATION_REPLACER)
- .replace("{__EOS_TOKEN__}", EOS_TOKEN_REPLACER)
- .replace("{", LEFT_BRACKET_REPLACER)
+ modelfile = modelfile\
+ .replace("{__FILE_LOCATION__}", FILE_LOCATION_REPLACER)\
+ .replace("{__EOS_TOKEN__}", EOS_TOKEN_REPLACER)\
+ .replace("{", LEFT_BRACKET_REPLACER)\
.replace("}", RIGHT_BRACKET_REPLACER)
- )
# Revert {__FILE_LOCATION__} back
- modelfile = modelfile.replace(
- FILE_LOCATION_REPLACER, "{__FILE_LOCATION__}"
- ).replace(EOS_TOKEN_REPLACER, "{__EOS_TOKEN__}")
+ modelfile = modelfile\
+ .replace(FILE_LOCATION_REPLACER, "{__FILE_LOCATION__}")\
+ .replace(EOS_TOKEN_REPLACER, "{__EOS_TOKEN__}")
if "__EOS_TOKEN__" in modelfile:
modelfile = modelfile.format(
- __FILE_LOCATION__ = model_location,
- __EOS_TOKEN__ = tokenizer.eos_token,
+ __FILE_LOCATION__ = gguf_location,
+ __EOS_TOKEN__ = tokenizer.eos_token,
)
else:
modelfile = modelfile.format(
- __FILE_LOCATION__ = model_location,
+ __FILE_LOCATION__ = gguf_location,
)
+ pass
- modelfile = modelfile.replace("⚫@✅#🦥", "{").replace("⚡@🦥#⛵", "}").rstrip()
+ modelfile = modelfile\
+ .replace("⚫@✅#🦥", "{")\
+ .replace("⚡@🦥#⛵", "}")\
+ .rstrip()
return modelfile
+pass
-
-def create_ollama_model(username: str, model_name: str, tag: str, modelfile_path: str):
+def create_ollama_model(
+ username: str,
+ model_name: str,
+ tag: str,
+ modelfile_path: str
+):
try:
init_check = subprocess.run(
- ["curl", "http://localhost:11434"],
- capture_output = True,
- text = True,
- timeout = 3,
+ ['curl', 'http://localhost:11434'], capture_output=True, text=True, timeout=3
)
if init_check.returncode == 0:
print(init_check.stdout.strip())
@@ -1763,22 +1630,16 @@ def create_ollama_model(username: str, model_name: str, tag: str, modelfile_path
return "Ollama Request Timeout"
process = subprocess.Popen(
- [
- "ollama",
- "create",
- f"{username}/{model_name}:{tag}",
- "-f",
- f"{modelfile_path}",
- ],
- stdout = subprocess.PIPE,
- stderr = subprocess.STDOUT,
- text = True,
- bufsize = 1,
- universal_newlines = True,
+ ['ollama', 'create', f'{username}/{model_name}:{tag}', '-f', f'{modelfile_path}'],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ bufsize=1,
+ universal_newlines=True
)
- for line in iter(process.stdout.readline, ""):
- print(line, end = "")
+ for line in iter(process.stdout.readline, ''):
+ print(line, end='')
sys.stdout.flush()
return_code = process.wait()
@@ -1787,15 +1648,13 @@ def create_ollama_model(username: str, model_name: str, tag: str, modelfile_path
print(f"\nMODEL CREATED FAILED WITH RETURN CODE {return_code}")
else:
print("\nMODEL CREATED SUCCESSFULLY")
+pass
def push_to_ollama_hub(username: str, model_name: str, tag: str):
try:
init_check = subprocess.run(
- ["curl", "http://localhost:11434"],
- capture_output = True,
- text = True,
- timeout = 3,
+ ['curl', 'http://localhost:11434'], capture_output=True, text=True, timeout=3
)
if init_check.returncode == 0:
print(init_check.stdout.strip())
@@ -1805,16 +1664,16 @@ def push_to_ollama_hub(username: str, model_name: str, tag: str):
return "Ollama Request Timeout"
process = subprocess.Popen(
- ["ollama", "push", f"{username}/{model_name}:{tag}"],
- stdout = subprocess.PIPE,
- stderr = subprocess.STDOUT,
- text = True,
- bufsize = 1,
- universal_newlines = True,
+ ['ollama', 'push', f'{username}/{model_name}:{tag}'],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ bufsize=1,
+ universal_newlines=True
)
- for line in iter(process.stdout.readline, ""):
- print(line, end = "")
+ for line in iter(process.stdout.readline, ''):
+ print(line, end='')
sys.stdout.flush()
return_code = process.wait()
@@ -1825,639 +1684,413 @@ def push_to_ollama_hub(username: str, model_name: str, tag: str):
print("\nMODEL PUBLISHED SUCCESSFULLY")
-def push_to_ollama(tokenizer, gguf_location, username: str, model_name: str, tag: str):
+def push_to_ollama(
+ tokenizer,
+ gguf_location,
+ username: str,
+ model_name: str,
+ tag: str
+):
model_file = create_ollama_modelfile(
- tokenizer = tokenizer, gguf_location = gguf_location
+ tokenizer=tokenizer,
+ gguf_location=gguf_location
)
- with open(f"Modelfile_{model_name}", "w", encoding = "utf-8") as f:
+ with open(f"Modelfile_{model_name}", "w") as f:
f.write(model_file)
f.close()
-
+
create_ollama_model(
- username = username,
- model_name = model_name,
- tag = tag,
- modelfile_path = f"Modelfile_{model_name}",
+ username=username,
+ model_name=model_name,
+ tag=tag,
+ modelfile_path=f"Modelfile_{model_name}"
+ )
+
+ push_to_ollama_hub(
+ username=username,
+ model_name=model_name,
+ tag=tag
)
- push_to_ollama_hub(username = username, model_name = model_name, tag = tag)
+ print("Succesfully pushed to ollama")
+
+
- print("Successfully pushed to ollama")
def unsloth_save_pretrained_gguf(
self,
- save_directory: Union[str, os.PathLike],
- tokenizer = None,
- quantization_method = "fast_quantized",
- first_conversion: str = None,
- push_to_hub: bool = False,
- token: Optional[Union[str, bool]] = None,
- private: Optional[bool] = None,
- is_main_process: bool = True,
- state_dict: Optional[dict] = None,
- save_function: Callable = torch.save,
- max_shard_size: Union[int, str] = "5GB",
- safe_serialization: bool = True,
- variant: Optional[str] = None,
- save_peft_format: bool = True,
- tags: List[str] = None,
- temporary_location: str = "_unsloth_temporary_saved_buffers",
- maximum_memory_usage: float = 0.85,
+ save_directory : Union[str, os.PathLike],
+ tokenizer = None,
+ quantization_method : str = "fast_quantized",
+ first_conversion : str = None,
+ push_to_hub : bool = False,
+ token : Optional[Union[str, bool]] = None,
+ private : Optional[bool] = None,
+ is_main_process : bool = True,
+ state_dict : Optional[dict] = None,
+ save_function : Callable = torch.save,
+ max_shard_size : Union[int, str] = "5GB",
+ safe_serialization : bool = True,
+ variant : Optional[str] = None,
+ save_peft_format : bool = True,
+ tags : List[str] = None,
+ temporary_location : str = "_unsloth_temporary_saved_buffers",
+ maximum_memory_usage : float = 0.85,
):
"""
- Same as .save_pretrained(...) except 4bit weights are auto
- converted to float16 then converted to GGUF / llama.cpp format.
-
- Choose for `quantization_method` to be:
- "not_quantized" : "Recommended. Fast conversion. Slow inference, big files.",
- "fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
- "quantized" : "Recommended. Slow conversion. Fast inference, small files.",
- "f32" : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
- "f16" : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
- "q8_0" : "Fast conversion. High resource use, but generally acceptable.",
- "q4_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
- "q5_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
- "q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
- "q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
- "q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
- "q3_k_s" : "Uses Q3_K for all tensors",
- "q4_0" : "Original quant method, 4-bit.",
- "q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
- "q4_k_s" : "Uses Q4_K for all tensors",
- "q4_k" : "alias for q4_k_m",
- "q5_k" : "alias for q5_k_m",
- "q5_0" : "Higher accuracy, higher resource usage and slower inference.",
- "q5_1" : "Even higher accuracy, resource usage and slower inference.",
- "q5_k_s" : "Uses Q5_K for all tensors",
- "q6_k" : "Uses Q8_K for all tensors",
- "iq2_xxs" : "2.06 bpw quantization",
- "iq2_xs" : "2.31 bpw quantization",
- "iq3_xxs" : "3.06 bpw quantization",
- "q3_k_xs" : "3-bit extra small quantization",
+ Same as .save_pretrained(...) except 4bit weights are auto
+ converted to float16 then converted to GGUF / llama.cpp format.
+
+ Choose for `quantization_method` to be:
+ "not_quantized" : "Recommended. Fast conversion. Slow inference, big files.",
+ "fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
+ "quantized" : "Recommended. Slow conversion. Fast inference, small files.",
+ "f32" : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
+ "f16" : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
+ "q8_0" : "Fast conversion. High resource use, but generally acceptable.",
+ "q4_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
+ "q5_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
+ "q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
+ "q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+ "q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+ "q3_k_s" : "Uses Q3_K for all tensors",
+ "q4_0" : "Original quant method, 4-bit.",
+ "q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
+ "q4_k_s" : "Uses Q4_K for all tensors",
+ "q4_k" : "alias for q4_k_m",
+ "q5_k" : "alias for q5_k_m",
+ "q5_0" : "Higher accuracy, higher resource usage and slower inference.",
+ "q5_1" : "Even higher accuracy, resource usage and slower inference.",
+ "q5_k_s" : "Uses Q5_K for all tensors",
+ "q6_k" : "Uses Q8_K for all tensors",
+ "iq2_xxs" : "2.06 bpw quantization",
+ "iq2_xs" : "2.31 bpw quantization",
+ "iq3_xxs" : "3.06 bpw quantization",
+ "q3_k_xs" : "3-bit extra small quantization",
"""
if tokenizer is None:
raise ValueError("Unsloth: Saving to GGUF must have a tokenizer.")
- try:
- base_model_name = get_model_name(self.config._name_or_path, load_in_4bit = False)
- model_name = base_model_name.split("/")[-1]
- except:
- base_model_name = self.config._name_or_path
- model_name = base_model_name.split("/")[-1]
-
- # Check if push_to_hub is requested
- if push_to_hub:
- raise ValueError(
- "Unsloth: Please use .push_to_hub_gguf() instead of .save_pretrained_gguf() with push_to_hub=True"
- )
-
- # Step 1: Check if this is a VLM (Vision-Language Model) and check if gpt-oss
- is_vlm = False
- if hasattr(self, "config") and hasattr(self.config, "architectures"):
- is_vlm = any(
- x.endswith(("ForConditionalGeneration", "ForVisionText2Text"))
- for x in self.config.architectures
- )
- is_vlm = is_vlm or hasattr(self.config, "vision_config")
-
- is_processor = is_vlm and isinstance(tokenizer, ProcessorMixin)
-
- is_gpt_oss = (
- True
- if (
- hasattr(self.config, "architectures")
- and self.config.architectures == "GptOssForCausalLM"
- )
- or (
- hasattr(self.config, "model_type")
- and self.config.model_type in ["gpt-oss", "gpt_oss"]
- )
- else False
- )
- # Step 2: Prepare arguments for model saving
arguments = dict(locals())
- arguments["model"] = self
- arguments["tokenizer"] = tokenizer
- arguments["push_to_hub"] = False # We handle upload ourselves
- # GPT-OSS needs mxfp4 save method
- if is_gpt_oss:
- if quantization_method is not None:
- _qm = (
- quantization_method
- if isinstance(quantization_method, (list, tuple))
- else [quantization_method]
- )
- _ignored = [q for q in _qm if str(q).lower() != "mxfp4"]
- if _ignored:
- logger.warning_once(
- f"Unsloth: GPT-OSS does not support GGUF quantization "
- f"(requested: {', '.join(str(q) for q in _ignored)}). "
- f"Overriding to MXFP4 format. "
- f"Pass quantization_method=None to suppress this warning."
- )
- arguments["save_method"] = "mxfp4"
- else:
- arguments["save_method"] = "merged_16bit"
+ arguments["model"] = self
+ arguments["tokenizer"] = tokenizer
+ arguments["push_to_hub"] = False # We save ourselves
+ arguments["save_method"] = "merged_16bit" # Must be 16bit
del arguments["self"]
del arguments["quantization_method"]
del arguments["first_conversion"]
- del arguments["is_vlm"]
- del arguments["is_gpt_oss"]
- del arguments["model_name"]
- del arguments["base_model_name"]
- del arguments["is_processor"]
-
- # Step 3: Fix tokenizer BOS token if needed
- if is_processor:
- fix_bos_token, old_chat_template = fix_tokenizer_bos_token(tokenizer.tokenizer)
- else:
- fix_bos_token, old_chat_template = fix_tokenizer_bos_token(tokenizer)
- # Step 4: Save/merge model to 16-bit format
- print(
- f'Unsloth: Merging model weights to {"mxfp4" if is_gpt_oss else "16-bit"} format...'
- )
- try:
- # Call unsloth_generic_save directly (it's in the same file)
- unsloth_generic_save(**arguments)
+ # Fix tokenizer adding an extra BOS token at the front
+ fix_bos_token, old_chat_template = fix_tokenizer_bos_token(tokenizer)
- except Exception as e:
- raise RuntimeError(f"Failed to save/merge model: {e}")
+ # Non blocking install GGUF first
+ if not os.path.exists("llama.cpp"):
- if is_processor:
- tokenizer = tokenizer.tokenizer
+ if IS_KAGGLE_ENVIRONMENT:
+ # Kaggle is weird - no blocking installs, and no CUDA?
+ python_install = install_python_non_blocking(["gguf", "protobuf"])
+ python_install.wait()
+ install_llama_cpp_blocking(use_cuda = False)
+ new_save_directory, old_username = unsloth_save_model(**arguments)
+ makefile = None
+ else:
+ git_clone = install_llama_cpp_clone_non_blocking()
+ python_install = install_python_non_blocking(["gguf", "protobuf"])
+ git_clone.wait()
+ makefile = install_llama_cpp_make_non_blocking()
+ new_save_directory, old_username = unsloth_save_model(**arguments)
+ python_install.wait()
+ pass
+ else:
+ try:
+ new_save_directory, old_username = unsloth_save_model(**arguments)
+ makefile = None
+ except:
+ # Retry by recloning llama.cpp
+ if IS_KAGGLE_ENVIRONMENT:
+ # Kaggle is weird - no blocking installs, and no CUDA?
+ python_install = install_python_non_blocking(["gguf", "protobuf"])
+ python_install.wait()
+ install_llama_cpp_blocking(use_cuda = False)
+ new_save_directory, old_username = unsloth_save_model(**arguments)
+ makefile = None
+ else:
+ git_clone = install_llama_cpp_clone_non_blocking()
+ python_install = install_python_non_blocking(["gguf", "protobuf"])
+ git_clone.wait()
+ makefile = install_llama_cpp_make_non_blocking()
+ new_save_directory, old_username = unsloth_save_model(**arguments)
+ python_install.wait()
+ pass
+ pass
+ pass
# Use old chat template if the bos is removed
if fix_bos_token:
tokenizer.chat_template = old_chat_template
+ pass
- # Step 6: Clean up memory
for _ in range(3):
- import gc
-
gc.collect()
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- # Step 7: Get model dtype and type
- try:
- model_dtype = dtype_from_config(self.config)
- model_type = self.config.model_type
- if type(model_dtype) is str:
- assert model_dtype == "float16" or model_dtype == "bfloat16"
- elif model_dtype == torch.float16:
- model_dtype = "float16"
- elif model_dtype == torch.bfloat16:
- model_dtype = "bfloat16"
- else:
- raise TypeError("Unsloth: Model dtype can only be float16 or bfloat16")
- except Exception as e:
- # Fallback if dtype_from_config fails
- print(f"Unsloth: Could not determine dtype ({e}), defaulting to float16")
+ model_dtype = self.config.torch_dtype
+ model_type = self.config.model_type
+ if type(model_dtype) is str:
+ assert(model_dtype == "float16" or model_dtype == "bfloat16")
+ elif model_dtype == torch.float16:
model_dtype = "float16"
+ elif model_dtype == torch.bfloat16:
+ model_dtype = "bfloat16"
+ else:
+ raise TypeError("Unsloth: Model dtype can only be float16 or bfloat16")
+ pass
- # Step 8: Convert to GGUF format
- print("Unsloth: Converting to GGUF format...")
-
- # Convert quantization_method to list if string
- # Use old style quantization_method
- quantization_methods = []
- if quantization_method is not None:
- # Convert quantization_method to list
- if isinstance(quantization_method, list):
- pass
- elif isinstance(quantization_method, str):
- quantization_method = [
- quantization_method,
- ]
- elif isinstance(quantization_method, tuple):
- quantization_method = list(quantization_method)
- else:
- raise TypeError(
- "Unsloth: quantization_method can only be a string or a list of strings"
- )
- for i, quant_method in enumerate(quantization_method):
- quant_method = quant_method.lower()
- if quant_method == "not_quantized":
- quant_method = "f16"
- elif quant_method == "fast_quantized":
- quant_method = "q8_0"
- elif quant_method == "quantized":
- quant_method = "q4_k_m"
- elif quant_method is None:
- quant_method = "q8_0"
- quantization_methods.append(quant_method.lower())
+ is_sentencepiece_model = check_if_sentencepiece_model(self)
- try:
- all_file_locations, want_full_precision, is_vlm_update = save_to_gguf(
- model_name = model_name,
- model_type = model_type,
- model_dtype = model_dtype,
- is_sentencepiece = False,
- model_directory = save_directory,
- quantization_method = quantization_methods,
- first_conversion = first_conversion,
- is_vlm = is_vlm, # Pass VLM flag
- is_gpt_oss = is_gpt_oss, # Pass gpt_oss Flag
- )
- except Exception as e:
- if IS_KAGGLE_ENVIRONMENT:
- raise RuntimeError(
- f"Unsloth: GGUF conversion failed in Kaggle environment.\n"
- f"This is likely due to the 20GB disk space limit.\n"
- f"Try saving to /tmp directory or use a smaller model.\n"
- f"Error: {e}"
- )
- else:
- raise RuntimeError(f"Unsloth: GGUF conversion failed: {e}")
+ # Save to GGUF
+ all_file_locations, want_full_precision = save_to_gguf(
+ model_type, model_dtype, is_sentencepiece_model,
+ new_save_directory, quantization_method, first_conversion, makefile,
+ )
- # Step 9: Create Ollama modelfile
- gguf_directory = f"{save_directory}_gguf"
+ # Save Ollama modelfile
+ modelfile = create_ollama_modelfile(tokenizer, all_file_locations[0])
modelfile_location = None
- ollama_success = False
- if all_file_locations:
- try:
- if is_vlm_update:
- modelfile = create_ollama_modelfile(tokenizer, base_model_name, ".")
- else:
- modelfile = create_ollama_modelfile(
- tokenizer,
- base_model_name,
- os.path.basename(all_file_locations[0]),
- )
- if modelfile is not None:
- modelfile_location = os.path.join(gguf_directory, "Modelfile")
- with open(modelfile_location, "w", encoding = "utf-8") as file:
- file.write(modelfile)
- ollama_success = True
- except Exception as e:
- print(f"Warning: Could not create Ollama modelfile: {e}")
-
- # Step 10: Show BOS token warning if applicable
+ if modelfile is not None:
+ modelfile_location = os.path.join(new_save_directory, "Modelfile")
+ with open(modelfile_location, "w") as file:
+ file.write(modelfile)
+ pass
+ print(f"Unsloth: Saved Ollama Modelfile to {modelfile_location}")
+ pass
+
if fix_bos_token:
logger.warning(
- "Unsloth: ##### The current model auto adds a BOS token.\n"
+ "Unsloth: ##### The current model auto adds a BOS token.\n"\
"Unsloth: ##### We removed it in GGUF's chat template for you."
)
+ pass
- _exe = ".exe" if IS_WINDOWS else ""
- if IS_WINDOWS:
- _bin_dir = os.path.join(LLAMA_CPP_DEFAULT_DIR, "build", "bin", "Release")
- else:
- _bin_dir = LLAMA_CPP_DEFAULT_DIR
+ if push_to_hub:
+ print("Unsloth: Uploading GGUF to Huggingface Hub...")
- if is_vlm_update:
- print("\n")
- print(
- f"Unsloth: example usage for Multimodal LLMs: {os.path.join(_bin_dir, 'llama-mtmd-cli' + _exe)} -m {all_file_locations[0]} --mmproj {all_file_locations[-1]}"
- )
- print("Unsloth: load image inside llama.cpp runner: /image test_image.jpg")
- print("Unsloth: Prompt model to describe the image")
- else:
- print(
- f'Unsloth: example usage for text only LLMs: {os.path.join(_bin_dir, "llama-cli" + _exe)} --model {all_file_locations[0]} -p "why is the sky blue?"'
- )
+ # If not needing full precision, skip the first
+ if not want_full_precision: all_file_locations = all_file_locations[1:]
- if ollama_success:
- print(f"Unsloth: Saved Ollama Modelfile to {modelfile_location}")
- print(
- f"Unsloth: convert model to ollama format by running - ollama create model_name -f {modelfile_location}"
- )
+ for file_location in all_file_locations:
+ username = upload_to_huggingface(
+ self, save_directory, token,
+ "GGUF converted", "gguf", file_location, old_username, private,
+ )
+ link = f"{username}/{new_save_directory.lstrip('/.')}" \
+ if username not in new_save_directory else \
+ new_save_directory.lstrip('/.')
+ print(f"Saved GGUF to https://huggingface.co/{link}")
+ pass
- # Return a dict with all needed info for push_to_hub
- return {
- "save_directory": save_directory,
- "gguf_directory": gguf_directory,
- "gguf_files": all_file_locations,
- "modelfile_location": modelfile_location,
- "want_full_precision": want_full_precision,
- "is_vlm": is_vlm_update,
- "fix_bos_token": fix_bos_token,
- }
+ # Save modelfile
+ if modelfile_location is not None:
+ username = upload_to_huggingface(
+ self, save_directory, token,
+ "GGUF converted", "gguf", modelfile_location, old_username, private,
+ )
+ print(f"Saved Ollama Modelfile to https://huggingface.co/{link}")
+ pass
+ pass
+pass
def unsloth_push_to_hub_gguf(
self,
- repo_id: str,
- tokenizer = None,
- quantization_method = "fast_quantized",
- first_conversion: str = None,
- use_temp_dir: Optional[bool] = None,
- commit_message: Optional[str] = "Trained with Unsloth",
- private: Optional[bool] = None,
- token: Union[bool, str, None] = None,
- max_shard_size: Union[int, str, None] = "5GB",
- create_pr: bool = False,
- safe_serialization: bool = True,
- revision: str = None,
- commit_description: str = "Upload model trained with Unsloth 2x faster",
- tags: Optional[List[str]] = None,
- temporary_location: str = "_unsloth_temporary_saved_buffers",
- maximum_memory_usage: float = 0.85,
- datasets: Optional[List[str]] = None,
+ repo_id : str,
+ tokenizer = None,
+ quantization_method : str = "fast_quantized",
+ first_conversion : str = None,
+ use_temp_dir : Optional[bool] = None,
+ commit_message : Optional[str] = "Trained with Unsloth",
+ private : Optional[bool] = None,
+ token : Union[bool, str, None] = None,
+ max_shard_size : Union[int, str, None] = "5GB",
+ create_pr : bool = False,
+ safe_serialization : bool = True,
+ revision : str = None,
+ commit_description : str = "Upload model trained with Unsloth 2x faster",
+ tags : Optional[List[str]] = None,
+ temporary_location : str = "_unsloth_temporary_saved_buffers",
+ maximum_memory_usage : float = 0.85,
):
"""
- Same as .push_to_hub(...) except 4bit weights are auto
- converted to float16 then converted to GGUF / llama.cpp format.
-
- Choose for `quantization_method` to be:
- "not_quantized" : "Recommended. Fast conversion. Slow inference, big files.",
- "fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
- "quantized" : "Recommended. Slow conversion. Fast inference, small files.",
- "f32" : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
- "f16" : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
- "q8_0" : "Fast conversion. High resource use, but generally acceptable.",
- "q4_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
- "q5_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
- "q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
- "q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
- "q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
- "q3_k_s" : "Uses Q3_K for all tensors",
- "q4_0" : "Original quant method, 4-bit.",
- "q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
- "q4_k_s" : "Uses Q4_K for all tensors",
- "q5_0" : "Higher accuracy, higher resource usage and slower inference.",
- "q5_1" : "Even higher accuracy, resource usage and slower inference.",
- "q5_k_s" : "Uses Q5_K for all tensors",
- "q6_k" : "Uses Q8_K for all tensors",
+ Same as .push_to_hub(...) except 4bit weights are auto
+ converted to float16 then converted to GGUF / llama.cpp format.
+
+ Choose for `quantization_method` to be:
+ "not_quantized" : "Recommended. Fast conversion. Slow inference, big files.",
+ "fast_quantized" : "Recommended. Fast conversion. OK inference, OK file size.",
+ "quantized" : "Recommended. Slow conversion. Fast inference, small files.",
+ "f32" : "Not recommended. Retains 100% accuracy, but super slow and memory hungry.",
+ "f16" : "Fastest conversion + retains 100% accuracy. Slow and memory hungry.",
+ "q8_0" : "Fast conversion. High resource use, but generally acceptable.",
+ "q4_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K",
+ "q5_k_m" : "Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K",
+ "q2_k" : "Uses Q4_K for the attention.vw and feed_forward.w2 tensors, Q2_K for the other tensors.",
+ "q3_k_l" : "Uses Q5_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+ "q3_k_m" : "Uses Q4_K for the attention.wv, attention.wo, and feed_forward.w2 tensors, else Q3_K",
+ "q3_k_s" : "Uses Q3_K for all tensors",
+ "q4_0" : "Original quant method, 4-bit.",
+ "q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
+ "q4_k_s" : "Uses Q4_K for all tensors",
+ "q5_0" : "Higher accuracy, higher resource usage and slower inference.",
+ "q5_1" : "Even higher accuracy, resource usage and slower inference.",
+ "q5_k_s" : "Uses Q5_K for all tensors",
+ "q6_k" : "Uses Q8_K for all tensors",
"""
if tokenizer is None:
raise ValueError("Unsloth: Saving to GGUF must have a tokenizer.")
- # Step 1: Determine save directory
- model_name = repo_id.split("/")[-1] if "/" in repo_id else repo_id
-
- if use_temp_dir or use_temp_dir is None:
- import tempfile
-
- temp_dir = tempfile.mkdtemp(prefix = "unsloth_gguf_")
- save_directory = temp_dir
- cleanup_temp = True
- else:
- save_directory = model_name # Use model name, not repo_id
- cleanup_temp = False
-
- # Step 2: Call save_pretrained_gguf to do the conversion
- print(f"Unsloth: Converting model to GGUF format...")
-
- try:
- # Call save_pretrained_gguf - it returns all the info we need
- result = unsloth_save_pretrained_gguf(
- self = self,
- save_directory = save_directory,
- tokenizer = tokenizer,
- quantization_method = quantization_method,
- first_conversion = first_conversion,
- push_to_hub = False, # Never push from here
- token = None, # Don't need token for local save
- max_shard_size = max_shard_size,
- safe_serialization = safe_serialization,
- temporary_location = temporary_location,
- maximum_memory_usage = maximum_memory_usage,
- )
-
- # Extract results
- all_file_locations = result["gguf_files"]
- modelfile_location = result["modelfile_location"]
- want_full_precision = result["want_full_precision"]
- is_vlm = result["is_vlm"]
- fix_bos_token = result["fix_bos_token"]
- actual_save_directory = result["save_directory"]
-
- except Exception as e:
- if cleanup_temp:
- import shutil
-
- for d in [save_directory, f"{save_directory}_gguf"]:
- try:
- shutil.rmtree(d)
- except:
- pass
- raise RuntimeError(f"Failed to convert model to GGUF: {e}")
-
- # Step 3: Upload to HuggingFace Hub
- print("Unsloth: Uploading GGUF to Huggingface Hub...")
+ arguments = dict(locals())
+ arguments["model"] = self
+ arguments["tokenizer"] = tokenizer
+ arguments["save_directory"] = repo_id
+ arguments["push_to_hub"] = False # We save ourselves
+ arguments["save_method"] = "merged_16bit" # Must be 16bit
+ del arguments["self"]
+ del arguments["repo_id"]
+ del arguments["quantization_method"]
+ del arguments["first_conversion"]
- try:
- from huggingface_hub import HfApi
+ # Fix tokenizer adding an extra BOS token at the front
+ fix_bos_token, old_chat_template = fix_tokenizer_bos_token(tokenizer)
- api = HfApi(token = token)
+ # Non blocking install GGUF first
+ if not os.path.exists("llama.cpp"):
- # Get full repo id
- if "/" not in repo_id:
- username = api.whoami()["name"]
- full_repo_id = f"{username}/{repo_id}"
+ if IS_KAGGLE_ENVIRONMENT:
+ # Kaggle is weird - no blocking installs, and no CUDA?
+ python_install = install_python_non_blocking(["gguf", "protobuf"])
+ python_install.wait()
+ install_llama_cpp_blocking(use_cuda = False)
+ new_save_directory, old_username = unsloth_save_model(**arguments)
+ makefile = None
else:
- full_repo_id = repo_id
-
- # Create repo
- api.create_repo(
- repo_id = full_repo_id,
- repo_type = "model",
- private = private,
- exist_ok = True,
- )
-
- # Upload GGUF files
- for file_location in all_file_locations:
- original_name = os.path.basename(file_location)
- # Replace temp directory name with proper model name
- if cleanup_temp and "unsloth_gguf_" in original_name:
- # Extract the quantization part (e.g., ".Q8_0.gguf" or ".Q8_0-mmproj.gguf")
- quant_suffix = (
- original_name.split(".", 1)[1]
- if "." in original_name
- else original_name
- )
- proper_name = f"{model_name}.{quant_suffix}"
+ git_clone = install_llama_cpp_clone_non_blocking()
+ python_install = install_python_non_blocking(["gguf", "protobuf"])
+ git_clone.wait()
+ makefile = install_llama_cpp_make_non_blocking()
+ new_save_directory, old_username = unsloth_save_model(**arguments)
+ python_install.wait()
+ pass
+ else:
+ try:
+ new_save_directory, old_username = unsloth_save_model(**arguments)
+ makefile = None
+ except:
+ # Retry by recloning llama.cpp
+ if IS_KAGGLE_ENVIRONMENT:
+ # Kaggle is weird - no blocking installs, and no CUDA?
+ python_install = install_python_non_blocking(["gguf", "protobuf"])
+ python_install.wait()
+ install_llama_cpp_blocking(use_cuda = False)
+ new_save_directory, old_username = unsloth_save_model(**arguments)
+ makefile = None
else:
- proper_name = original_name.replace(
- os.path.basename(save_directory), model_name
- )
-
- print(f"Uploading {proper_name}...")
-
- api.upload_file(
- path_or_fileobj = file_location,
- path_in_repo = proper_name,
- repo_id = full_repo_id,
- repo_type = "model",
- commit_message = commit_message,
- commit_description = commit_description,
- create_pr = create_pr,
- revision = revision,
- )
+ git_clone = install_llama_cpp_clone_non_blocking()
+ python_install = install_python_non_blocking(["gguf", "protobuf"])
+ git_clone.wait()
+ makefile = install_llama_cpp_make_non_blocking()
+ new_save_directory, old_username = unsloth_save_model(**arguments)
+ python_install.wait()
+ pass
+ pass
+ pass
- # Upload config.json if exists
- config_path = os.path.join(actual_save_directory, "config.json")
- if os.path.exists(config_path):
- print("Uploading config.json...")
- api.upload_file(
- path_or_fileobj = config_path,
- path_in_repo = "config.json",
- repo_id = full_repo_id,
- repo_type = "model",
- commit_message = f"{commit_message} - config",
- create_pr = create_pr,
- revision = revision,
- )
+ # Use old chat template if the bos is removed
+ if fix_bos_token:
+ tokenizer.chat_template = old_chat_template
+ pass
- # Upload Modelfile if exists
- if modelfile_location and os.path.exists(modelfile_location):
- print("Uploading Ollama Modelfile...")
- api.upload_file(
- path_or_fileobj = modelfile_location,
- path_in_repo = "Modelfile",
- repo_id = full_repo_id,
- repo_type = "model",
- commit_message = f"{commit_message} - Ollama Modelfile",
- create_pr = create_pr,
- revision = revision,
- )
+ for _ in range(3):
+ gc.collect()
- # Create and upload README
- readme_content = f"""---
-tags:
-- gguf
-- llama.cpp
-- unsloth
-{"- vision-language-model" if is_vlm else ""}
----
+ model_dtype = self.config.torch_dtype
+ model_type = self.config.model_type
+ if type(model_dtype) is str:
+ assert(model_dtype == "float16" or model_dtype == "bfloat16")
+ elif model_dtype == torch.float16:
+ model_dtype = "float16"
+ elif model_dtype == torch.bfloat16:
+ model_dtype = "bfloat16"
+ else:
+ raise TypeError("Unsloth: Model dtype can only be float16 or bfloat16")
+ pass
-# {repo_id.split("/")[-1]} : GGUF
+ is_sentencepiece_model = check_if_sentencepiece_model(self)
-This model was finetuned and converted to GGUF format using [Unsloth](https://github.com/unslothai/unsloth).
+ # Save to GGUF
+ all_file_locations, want_full_precision = save_to_gguf(
+ model_type, model_dtype, is_sentencepiece_model,
+ new_save_directory, quantization_method, first_conversion, makefile,
+ )
-**Example usage**:
-- For text only LLMs: `llama-cli -hf {repo_id} --jinja`
-- For multimodal models: `llama-mtmd-cli -hf {repo_id} --jinja`
+ # Save Ollama modelfile
+ modelfile = create_ollama_modelfile(tokenizer, all_file_locations[0])
+ modelfile_location = None
+ if modelfile is not None:
+ modelfile_location = os.path.join(new_save_directory, "Modelfile")
+ with open(modelfile_location, "w") as file:
+ file.write(modelfile)
+ pass
+ print(f"Unsloth: Saved Ollama Modelfile to {modelfile_location}")
+ pass
-## Available Model files:
-"""
- for file in all_file_locations:
- # Fix filename in README too
- original_name = os.path.basename(file)
- if cleanup_temp and "unsloth_gguf_" in original_name:
- quant_suffix = (
- original_name.split(".", 1)[1]
- if "." in original_name
- else original_name
- )
- proper_name = f"{model_name}.{quant_suffix}"
- else:
- proper_name = original_name.replace(
- os.path.basename(save_directory), model_name
- )
- readme_content += f"- `{proper_name}`\n"
-
- # Special note for VLM with Modelfile
- if is_vlm and modelfile_location:
- readme_content += "\n## ⚠️ Ollama Note for Vision Models\n"
- readme_content += "**Important:** Ollama currently does not support separate mmproj files for vision models.\n\n"
- readme_content += "To create an Ollama model from this vision model:\n"
- readme_content += "1. Place the `Modelfile` in the same directory as the finetuned bf16 merged model\n"
- readme_content += "3. Run: `ollama create model_name -f ./Modelfile`\n"
- readme_content += " (Replace `model_name` with your desired name)\n\n"
- readme_content += (
- "This will create a unified bf16 model that Ollama can use.\n"
- )
- elif modelfile_location:
- readme_content += "\n## Ollama\n"
- readme_content += "An Ollama Modelfile is included for easy deployment.\n"
-
- if fix_bos_token:
- readme_content += "\n## Note\n"
- readme_content += (
- "The model's BOS token behavior was adjusted for GGUF compatibility.\n"
- )
+ # If not needing full precision, skip the first
+ if not want_full_precision: all_file_locations = all_file_locations[1:]
- readme_content += (
- "This was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth)\n"
- '[
](https://github.com/unslothai/unsloth)\n'
+ for file_location in all_file_locations:
+ print("Unsloth: Uploading GGUF to Huggingface Hub...")
+ username = upload_to_huggingface(
+ self, repo_id, token,
+ "GGUF converted", "gguf", file_location, old_username, private,
)
+ link = f"{username}/{new_save_directory.lstrip('/.')}" \
+ if username not in new_save_directory else \
+ new_save_directory.lstrip('/.')
- readme_path = os.path.join(actual_save_directory, "README.md")
- with open(readme_path, "w") as f:
- f.write(readme_content)
+ print(f"Saved GGUF to https://huggingface.co/{link}")
+ pass
- api.upload_file(
- path_or_fileobj = readme_path,
- path_in_repo = "README.md",
- repo_id = full_repo_id,
- repo_type = "model",
- commit_message = "Add README",
- create_pr = create_pr,
- revision = revision,
+ # Save modelfile
+ if modelfile_location is not None:
+ username = upload_to_huggingface(
+ self, repo_id, token,
+ "GGUF converted", "gguf", modelfile_location, old_username, private,
)
+ print(f"Saved Ollama Modelfile to https://huggingface.co/{link}")
+ pass
- print(
- f"Unsloth: Successfully uploaded GGUF to https://huggingface.co/{full_repo_id}"
+ if fix_bos_token:
+ logger.warning(
+ "Unsloth: ##### The current model auto adds a BOS token.\n"\
+ "Unsloth: ##### We removed it in GGUF's chat template for you."
)
-
- # Add tags
- if tags is None:
- tags = []
- tags.extend(["gguf", "llama-cpp", "unsloth"])
- if is_vlm:
- tags.append("vision-language-model")
-
- try:
- api.add_tags(
- repo_id = full_repo_id,
- tags = tags,
- repo_type = "model",
- )
- except:
- pass
-
- if datasets:
- try:
- from huggingface_hub import metadata_update
-
- metadata_update(
- full_repo_id, {"datasets": datasets}, overwrite = True, token = token
- )
- except Exception as e:
- logger.warning_once(
- f"Unsloth: Could not update datasets metadata for {full_repo_id}: {e}"
- )
-
- except Exception as e:
- raise RuntimeError(f"Failed to upload to Hugging Face Hub: {e}")
-
- finally:
- # Clean up temporary directory
- if cleanup_temp:
- print("Unsloth: Cleaning up temporary files...")
- import shutil
-
- for d in [save_directory, f"{save_directory}_gguf"]:
- if os.path.exists(d):
- try:
- shutil.rmtree(d)
- except:
- pass
-
- return full_repo_id
-
+ pass
+pass
# Corrected function to save LoRA to a custom directory
def save_lora_to_custom_dir(model, tokenizer, save_directory):
# Create the custom directory if it doesn't exist
- os.makedirs(save_directory, exist_ok = True)
+ os.makedirs(save_directory, exist_ok=True)
# Call the unsloth_save_model function with the custom directory
unsloth_save_model(
model,
tokenizer,
- save_directory = save_directory,
- save_method = "lora",
- push_to_hub = False,
+ save_directory=save_directory,
+ save_method="lora",
+ push_to_hub=False,
)
-
# Corrected method within the model class to convert LoRA to GGML and push to Hugging Face Hub
def unsloth_convert_lora_to_ggml_and_push_to_hub(
self,
@@ -2477,7 +2110,7 @@ def unsloth_convert_lora_to_ggml_and_push_to_hub(
if IS_KAGGLE_ENVIRONMENT:
python_install = install_python_non_blocking(["protobuf"])
python_install.wait()
- install_llama_cpp_blocking(use_cuda = False)
+ install_llama_cpp_blocking(use_cuda=False)
makefile = None
else:
git_clone = install_llama_cpp_clone_non_blocking()
@@ -2497,26 +2130,17 @@ def unsloth_convert_lora_to_ggml_and_push_to_hub(
model_type = self.config.model_type
output_file = os.path.join(lora_directory_push, "ggml-adapter-model.bin")
- print(
- f"Unsloth: Converting auto-saved LoRA adapters at {lora_directory_push} to GGML format."
- )
+ print(f"Unsloth: Converting auto-saved LoRA adapters at {lora_directory_push} to GGML format.")
print(f"The output file will be {output_file}")
command = f"python3 llama.cpp/convert-lora-to-ggml.py {lora_directory_push} {output_file} llama"
try:
- with subprocess.Popen(
- command,
- shell = True,
- stdout = subprocess.PIPE,
- stderr = subprocess.PIPE,
- bufsize = 1,
- universal_newlines = True,
- ) as sp:
+ with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp:
for line in sp.stdout:
- print(line, end = "", flush = True)
+ print(line, end="", flush=True)
for line in sp.stderr:
- print(line, end = "", flush = True)
+ print(line, end="", flush=True)
sp.wait()
if sp.returncode != 0:
raise subprocess.CalledProcessError(sp.returncode, command)
@@ -2528,26 +2152,17 @@ def unsloth_convert_lora_to_ggml_and_push_to_hub(
print("Unsloth: Uploading GGML file to Hugging Face Hub...")
username = upload_to_huggingface(
- self,
- repo_id,
- token,
- "GGML converted LoRA",
- "ggml",
- output_file,
- None,
- private,
+ self, repo_id, token,
+ "GGML converted LoRA", "ggml", output_file, None, private,
)
link = f"{repo_id.lstrip('/')}"
print("Unsloth: Done.")
print(f"Converted LoRA to GGML and uploaded to https://huggingface.co/{link}")
- print(
- "\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!"
- )
-
+ print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!")
def unsloth_convert_lora_to_ggml_and_save_locally(
self,
- save_directory: str, # Added parameter for the folder name
+ save_directory: str, # Added parameter for the folder name
tokenizer,
temporary_location: str = "_unsloth_temporary_saved_buffers",
maximum_memory_usage: float = 0.85,
@@ -2556,7 +2171,7 @@ def unsloth_convert_lora_to_ggml_and_save_locally(
if IS_KAGGLE_ENVIRONMENT:
python_install = install_python_non_blocking(["protobuf"])
python_install.wait()
- install_llama_cpp_blocking(use_cuda = False)
+ install_llama_cpp_blocking(use_cuda=False)
makefile = None
else:
git_clone = install_llama_cpp_clone_non_blocking()
@@ -2576,26 +2191,17 @@ def unsloth_convert_lora_to_ggml_and_save_locally(
model_type = self.config.model_type
output_file = os.path.join(save_directory, "ggml-adapter-model.bin")
- print(
- f"Unsloth: Converting auto-saved LoRA adapters at {save_directory} to GGML format."
- )
+ print(f"Unsloth: Converting auto-saved LoRA adapters at {save_directory} to GGML format.")
print(f"The output file will be {output_file}")
command = f"python3 llama.cpp/convert-lora-to-ggml.py {save_directory} {output_file} llama"
try:
- with subprocess.Popen(
- command,
- shell = True,
- stdout = subprocess.PIPE,
- stderr = subprocess.PIPE,
- bufsize = 1,
- universal_newlines = True,
- ) as sp:
+ with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp:
for line in sp.stdout:
- print(line, end = "", flush = True)
+ print(line, end="", flush=True)
for line in sp.stderr:
- print(line, end = "", flush = True)
+ print(line, end="", flush=True)
sp.wait()
if sp.returncode != 0:
raise subprocess.CalledProcessError(sp.returncode, command)
@@ -2604,9 +2210,8 @@ def unsloth_convert_lora_to_ggml_and_save_locally(
return
print("Unsloth: Done.")
print(f"Unsloth: Conversion completed! Output file: {output_file}")
- print(
- "\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!"
- )
+ print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!")
+pass
from .models.loader_utils import get_model_name
@@ -2619,196 +2224,129 @@ def unsloth_convert_lora_to_ggml_and_save_locally(
convert_to_gguf as _convert_to_gguf,
)
-
@torch.inference_mode
def save_to_gguf_generic(
model,
save_directory,
- tokenizer,
- quantization_method = None,
quantization_type = "Q8_0",
repo_id = None,
token = None,
):
- if token is None and repo_id is not None:
- token = get_token()
+ if token is None and repo_id is not None: token = get_token()
if repo_id is not None and token is None:
raise RuntimeError("Unsloth: Please specify a token for uploading!")
if not os.path.exists(os.path.join("llama.cpp", "unsloth_convert_hf_to_gguf.py")):
install_llama_cpp(just_clone_repo = True)
+ pass
- # Use old style quantization_method
- new_quantization_methods = []
- if quantization_method is not None:
- # Convert quantization_method to list
- if isinstance(quantization_method, list):
- pass
- elif isinstance(quantization_method, str):
- quantization_method = [
- quantization_method,
- ]
- elif isinstance(quantization_method, tuple):
- quantization_method = list(quantization_method)
- else:
- raise TypeError(
- "Unsloth: quantization_method can only be a string or a list of strings"
- )
- for i, quant_method in enumerate(quantization_method):
- quant_method = quant_method.lower()
- if quant_method == "not_quantized":
- quant_method = "f16"
- elif quant_method == "fast_quantized":
- quant_method = "q8_0"
- elif quant_method == "quantized":
- quant_method = "q4_k_m"
- elif quant_method is None:
- quant_method = "q8_0"
- new_quantization_methods.append(quant_method.lower())
- else:
- new_quantization_methods.append(quantization_type.lower())
- # Check if wrong method
- for quant_method in new_quantization_methods:
- if quant_method not in ALLOWED_QUANTS.keys():
- error = f"Unsloth: Quant method = [{quant_method}] not supported. Choose from below:\n"
- for key, value in ALLOWED_QUANTS.items():
- error += f"[{key}] => {value}\n"
- raise RuntimeError(error)
-
- # Go through all types and save individually - somewhat inefficient
- # since we save F16 / BF16 multiple times
- for quantization_type in new_quantization_methods:
- metadata = _convert_to_gguf(
- save_directory,
- print_output = True,
- quantization_type = quantization_type,
+ metadata = _convert_to_gguf(
+ save_directory,
+ print_output = True,
+ quantization_type = quantization_type,
+ )
+ if repo_id is not None:
+ prepare_saving(
+ model,
+ repo_id,
+ push_to_hub = True,
+ max_shard_size = "50GB",
+ private = True,
+ token = token,
)
- if repo_id is not None:
- prepare_saving(
- model,
- repo_id,
- push_to_hub = True,
- max_shard_size = "50GB",
- private = True,
- token = token,
- )
-
- from huggingface_hub import HfApi
- api = HfApi(token = token)
- api.upload_folder(
- folder_path = save_directory,
- repo_id = repo_id,
- repo_type = "model",
- allow_patterns = ["*.gguf"],
- )
+ from huggingface_hub import HfApi
+ api = HfApi(token = token)
+ api.upload_folder(
+ folder_path = save_directory,
+ repo_id = repo_id,
+ repo_type = "model",
+ allow_patterns = ["*.gguf"],
+ )
+ pass
return metadata
+pass
@torch.inference_mode
def unsloth_generic_save(
model,
tokenizer,
- save_directory: Union[str, os.PathLike] = "unsloth_finetuned_merge",
- save_method: str = "lora", # ["lora", "merged_16bit", "merged_4bit"]
- push_to_hub: bool = False,
- token: Optional[Union[str, bool]] = None,
- is_main_process: bool = True,
- state_dict: Optional[dict] = None,
- save_function: Callable = torch.save,
- max_shard_size: Union[int, str] = "5GB",
- safe_serialization: bool = True,
- variant: Optional[str] = None,
- save_peft_format: bool = True,
+ save_directory : Union[str, os.PathLike] = "unsloth_finetuned_merge",
+ save_method : str = "lora", # ["lora", "merged_16bit", "merged_4bit"]
+ push_to_hub : bool = False,
+ token : Optional[Union[str, bool]] = None,
+ is_main_process : bool = True,
+ state_dict : Optional[dict] = None,
+ save_function : Callable = torch.save,
+ max_shard_size : Union[int, str] = "5GB",
+ safe_serialization : bool = True,
+ variant : Optional[str] = None,
+ save_peft_format : bool = True,
+
# Push to hub
- use_temp_dir: Optional[bool] = None,
- commit_message: Optional[str] = "Trained with Unsloth",
- private: Optional[bool] = None,
- create_pr: bool = False,
- revision: str = None,
- commit_description: str = "Upload model trained with Unsloth 2x faster",
- tags: List[str] = None,
+ use_temp_dir : Optional[bool] = None,
+ commit_message : Optional[str] = "Trained with Unsloth",
+ private : Optional[bool] = None,
+ create_pr : bool = False,
+ revision : str = None,
+ commit_description : str = "Upload model trained with Unsloth 2x faster",
+ tags : List[str] = None,
+
# Our functions
- temporary_location: str = "_unsloth_temporary_saved_buffers",
- maximum_memory_usage: float = 0.9,
- datasets: Optional[List[str]] = None,
+ temporary_location : str = "_unsloth_temporary_saved_buffers",
+ maximum_memory_usage : float = 0.9,
):
- if token is None and push_to_hub:
- token = get_token()
-
- if save_method == "merged_4bit":
- raise RuntimeError(
- "Unsloth: Merging into 4bit will cause your model to lose accuracy if you plan\n"
- "to merge to GGUF or others later on. I suggest you to do this as a final step\n"
- "if you're planning to do multiple saves.\n"
- "If you are certain, change `save_method` to `merged_4bit_forced`."
- )
- elif save_method == "merged_4bit_forced":
- save_method = "merged_4bit"
-
+ if token is None and push_to_hub: token = get_token()
merge_and_overwrite_lora(
get_model_name,
- model = model,
- tokenizer = tokenizer,
- save_directory = save_directory,
- push_to_hub = push_to_hub,
- private = private,
- token = token,
- save_method = save_method,
- output_dtype = None,
+ model = model,
+ tokenizer = tokenizer,
+ save_directory = save_directory,
+ push_to_hub = push_to_hub,
+ private = private,
+ token = token,
+ output_dtype = None,
low_disk_space_usage = True,
- use_temp_file = False,
+ use_temp_file = False,
)
-
- if push_to_hub and datasets:
- try:
- from huggingface_hub import metadata_update
-
- save_dir, _ = _determine_username(save_directory, None, token)
- metadata_update(
- save_dir, {"datasets": datasets}, overwrite = True, token = token
- )
- except Exception as e:
- logger.warning_once(
- f"Unsloth: Could not update datasets metadata for {save_directory}: {e}"
- )
-
return
+pass
def unsloth_generic_save_pretrained_merged(
self,
- save_directory: Union[str, os.PathLike],
- tokenizer = None,
- save_method: str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
- push_to_hub: bool = False,
- token: Optional[Union[str, bool]] = None,
- is_main_process: bool = True,
- state_dict: Optional[dict] = None,
- save_function: Callable = torch.save,
- max_shard_size: Union[int, str] = "5GB",
- safe_serialization: bool = True,
- variant: Optional[str] = None,
- save_peft_format: bool = True,
- tags: List[str] = None,
- temporary_location: str = "_unsloth_temporary_saved_buffers",
- maximum_memory_usage: float = 0.75,
- datasets: Optional[List[str]] = None,
+ save_directory : Union[str, os.PathLike],
+ tokenizer = None,
+ save_method : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
+ push_to_hub : bool = False,
+ token : Optional[Union[str, bool]] = None,
+ is_main_process : bool = True,
+ state_dict : Optional[dict] = None,
+ save_function : Callable = torch.save,
+ max_shard_size : Union[int, str] = "5GB",
+ safe_serialization : bool = True,
+ variant : Optional[str] = None,
+ save_peft_format : bool = True,
+ tags : List[str] = None,
+ temporary_location : str = "_unsloth_temporary_saved_buffers",
+ maximum_memory_usage : float = 0.75,
):
"""
- Same as .push_to_hub(...) except 4bit weights are auto
- converted to float16 with as few overhead as possible.
+ Same as .push_to_hub(...) except 4bit weights are auto
+ converted to float16 with as few overhead as possible.
- Choose for `save_method` to be either:
- 1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
- 2. `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
- 3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
+ Choose for `save_method` to be either:
+ 1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
+ 2. `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
+ 3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
"""
if tokenizer is None:
logger.warning_once(
- "Unsloth: You're not saving a tokenizer as well?\n"
+ "Unsloth: You're not saving a tokenizer as well?\n"\
"You can do it separately via `tokenizer.save_pretrained(...)`"
)
+ pass
arguments = dict(locals())
arguments["model"] = self
@@ -2816,266 +2354,58 @@ def unsloth_generic_save_pretrained_merged(
unsloth_generic_save(**arguments)
for _ in range(3):
gc.collect()
+pass
def unsloth_generic_push_to_hub_merged(
self,
- repo_id: str,
- tokenizer = None,
- save_method: str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
- use_temp_dir: Optional[bool] = None,
- commit_message: Optional[str] = "Trained with Unsloth",
- private: Optional[bool] = None,
- token: Union[bool, str, None] = None,
- max_shard_size: Union[int, str, None] = "5GB",
- create_pr: bool = False,
- safe_serialization: bool = True,
- revision: str = None,
- commit_description: str = "Upload model trained with Unsloth 2x faster",
- tags: Optional[List[str]] = None,
- temporary_location: str = "_unsloth_temporary_saved_buffers",
- maximum_memory_usage: float = 0.75,
- datasets: Optional[List[str]] = None,
+ repo_id : str,
+ tokenizer = None,
+ save_method : str = "merged_16bit", # ["lora", "merged_16bit", "merged_4bit"]
+ use_temp_dir : Optional[bool] = None,
+ commit_message : Optional[str] = "Trained with Unsloth",
+ private : Optional[bool] = None,
+ token : Union[bool, str, None] = None,
+ max_shard_size : Union[int, str, None] = "5GB",
+ create_pr : bool = False,
+ safe_serialization : bool = True,
+ revision : str = None,
+ commit_description : str = "Upload model trained with Unsloth 2x faster",
+ tags : Optional[List[str]] = None,
+ temporary_location : str = "_unsloth_temporary_saved_buffers",
+ maximum_memory_usage : float = 0.75,
):
"""
- Same as .push_to_hub(...) except 4bit weights are auto
- converted to float16 with as few overhead as possible.
+ Same as .push_to_hub(...) except 4bit weights are auto
+ converted to float16 with as few overhead as possible.
- Choose for `save_method` to be either:
- 1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
- 2. `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
- 3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
+ Choose for `save_method` to be either:
+ 1. `16bit`: Merge LoRA into float16 weights. Useful for GGUF / llama.cpp.
+ 2. `4bit`: Merge LoRA into int4 weights. Useful for DPO / HF inference.
+ 3. `lora`: Save LoRA adapters with no merging. Useful for HF inference.
"""
if tokenizer is None:
logger.warning_once(
- "Unsloth: You're not saving a tokenizer as well?\n"
+ "Unsloth: You're not saving a tokenizer as well?\n"\
"You can do it separately via `tokenizer.push_to_hub(...)`"
)
+ pass
arguments = dict(locals())
- arguments["model"] = self
+ arguments["model"] = self
arguments["save_directory"] = repo_id
- arguments["push_to_hub"] = True
+ arguments["push_to_hub"] = True
del arguments["self"]
del arguments["repo_id"]
unsloth_generic_save(**arguments)
for _ in range(3):
gc.collect()
-
-
-def _unsloth_save_torchao_with_attached_config(
- model,
- save_directory: Union[str, os.PathLike],
- tokenizer,
- push_to_hub: bool = False,
- token: Optional[Union[str, bool]] = None,
-):
- """Save a QAT-trained model by converting fake-quantized weights to real quantized weights."""
- # Convert QAT fake-quantized weights to real quantized weights
- _convert_torchao_model(model)
- # PEFT models also might come here, so parse it
- if isinstance(model, PeftModelForCausalLM):
- _unsloth_save_torchao_with_given_config(
- model = model,
- save_directory = save_directory,
- tokenizer = tokenizer,
- torchao_config = model.config.quantization_config,
- push_to_hub = push_to_hub,
- token = token,
- )
- return
-
- # TorchAO does not support safe_serialization reliably
- safe_serialization = False
-
- if push_to_hub:
- model.push_to_hub(
- save_directory, safe_serialization = safe_serialization, token = token
- )
- tokenizer.push_to_hub(save_directory, token = token)
- else:
- model.save_pretrained(save_directory, safe_serialization = safe_serialization)
- tokenizer.save_pretrained(save_directory)
-
-
-def _unsloth_save_torchao_with_given_config(
- model,
- save_directory: Union[str, os.PathLike],
- tokenizer,
- torchao_config,
- push_to_hub: bool = False,
- token: Optional[Union[str, bool]] = None,
-):
- """Quantizes the model with torchao and saves a torchao quantized checkpoint
-
- Args
- `save_directory`: local folder path or huggingface hub ID when `push_to_hub` is set to True, e.g. `my_model`
- `torchao_config` (TorchAOBaseConfig): configuration for torchao quantization, full list: https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize
- `push_to_hub` (bool): whether to push the checkpoint to huggingface hub or save locally
- """
-
- if push_to_hub:
- assert token is not None, "Unsloth: Please specify a token for uploading!"
-
- assert (
- torchao_config is not None
- ), "Unsloth: Please specify a torchao_config for post-training quantization!"
-
- # first merge the lora weights
- arguments = dict(locals())
- arguments["push_to_hub"] = False # We save ourselves
- arguments["save_method"] = "merged_16bit" # Must be 16bit
- del arguments["torchao_config"]
-
- if not isinstance(model, PeftModelForCausalLM) and not isinstance(model, PeftModel):
- model.save_pretrained(save_directory)
- tokenizer.save_pretrained(save_directory)
- else:
- unsloth_generic_save(**arguments)
-
- for _ in range(3):
- gc.collect()
-
- from transformers import (
- AutoModelForCausalLM,
- AutoTokenizer,
- TorchAoConfig,
- AutoModelForImageTextToText,
- AutoProcessor,
- )
- from torchao import quantize_
-
- if isinstance(torchao_config, TorchAoConfig):
- quantization_config = torchao_config
- else:
- quantization_config = TorchAoConfig(quant_type = torchao_config)
-
- # Determine if this is a VLM
- is_vlm = False
- if hasattr(model, "config") and hasattr(model.config, "architectures"):
- is_vlm = any(
- x.endswith(("ForConditionalGeneration", "ForVisionText2Text"))
- for x in model.config.architectures
- )
- is_vlm = is_vlm or hasattr(model.config, "vision_config")
- auto_model = AutoModelForImageTextToText if is_vlm else AutoModelForCausalLM
- auto_processor = AutoProcessor if is_vlm else AutoTokenizer
-
- tokenizer = auto_processor.from_pretrained(save_directory)
-
- # TorchAO must only use bfloat16 for loading (float16 fails)
- if HAS_TORCH_DTYPE:
- kwargs = {"torch_dtype": torch.bfloat16}
- else:
- kwargs = {"dtype": torch.bfloat16}
-
- # Reload with quantization applied
- quantized_model = auto_model.from_pretrained(
- save_directory,
- device_map = "auto",
- quantization_config = quantization_config,
- **kwargs,
- )
-
- torchao_save_directory = save_directory + "-torchao"
-
- # TorchAO does not support safe_serialization right now 0.14.0 seems broken!
- safe_serialization = Version(importlib_version("torchao")) > Version("0.14.0")
- safe_serialization = False
-
- if push_to_hub:
- quantized_model.push_to_hub(
- torchao_save_directory, safe_serialization = safe_serialization, token = token
- )
- tokenizer.push_to_hub(torchao_save_directory, token = token)
- else:
- quantized_model.save_pretrained(
- torchao_save_directory, safe_serialization = safe_serialization
- )
- tokenizer.save_pretrained(torchao_save_directory)
-
- # Clean up the intermediate unquantized model
- if os.path.exists(save_directory):
- try:
- shutil.rmtree(save_directory)
- except:
- pass
-
-
-def unsloth_save_pretrained_torchao(
- self,
- save_directory: Union[str, os.PathLike],
- tokenizer = None,
- torchao_config = None,
- push_to_hub: bool = False,
- token: Optional[Union[str, bool]] = None,
-):
- """Saves a torchao quantized model checkpoint.
-
- This function handles two mutually exclusive workflows:
-
- 1. **QAT (Quantization-Aware Training)**: If the model was trained with `qat_scheme`
- parameter, do NOT pass `torchao_config`. The function will convert the QAT
- fake-quantized weights to real quantized weights and save directly.
-
- 2. **PTQ (Post-Training Quantization)**: If you want to apply quantization to a
- regular model, pass a `torchao_config`. The model must NOT have been trained
- with `qat_scheme`.
-
- Args:
- `save_directory`: local folder path or huggingface hub ID when `push_to_hub` is True
- `tokenizer`: the tokenizer to save alongside the model
- `torchao_config` (TorchAOBaseConfig): configuration for torchao quantization.
- Required for PTQ, must be None for QAT models.
- Options: https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize
- `push_to_hub` (bool): whether to push to huggingface hub or save locally
- `token`: HuggingFace token for pushing to hub
- """
- if token is None and push_to_hub:
- token = get_token()
-
- has_qat_config = (
- hasattr(self, "_torchao_config") and self._torchao_config is not None
- )
-
- if torchao_config is not None:
- # PTQ path: user provided a config, model must NOT have QAT config unless PEFT
- assert not has_qat_config, (
- "Unsloth: You passed `torchao_config` but this model was trained with `qat_scheme`. "
- "For QAT models, do not pass `torchao_config` - the quantization config is already "
- "attached to the model from training."
- )
- _unsloth_save_torchao_with_given_config(
- model = self,
- save_directory = save_directory,
- tokenizer = tokenizer,
- torchao_config = torchao_config,
- push_to_hub = push_to_hub,
- token = token,
- )
- else:
- # QAT path: no config provided, model must have QAT config
- assert has_qat_config, (
- "Unsloth: No `torchao_config` provided and model was not trained with `qat_scheme`. "
- "Either train with `qat_scheme` parameter, or provide a `torchao_config` for "
- "post-training quantization."
- )
- _unsloth_save_torchao_with_attached_config(
- model = self,
- save_directory = save_directory,
- tokenizer = tokenizer,
- push_to_hub = push_to_hub,
- token = token,
- )
-
- for _ in range(3):
- gc.collect()
+pass
def not_implemented_save(*args, **kwargs):
- raise NotImplementedError(
- "Unsloth: Sorry GGUF is currently not supported for vision models!"
- )
+ raise NotImplementedError("Unsloth: Sorry GGUF is currently not supported for vision models!")
+pass
def patch_saving_functions(model, vision = False):
@@ -3088,6 +2418,7 @@ def patch_saving_functions(model, vision = False):
original_push_to_hub = model.original_push_to_hub
else:
original_push_to_hub = model.push_to_hub
+ pass
signature = str(inspect.signature(original_push_to_hub)).replace("NoneType", "None")
signature = signature[1:]
@@ -3152,63 +2483,36 @@ def patch_saving_functions(model, vision = False):
original_model = model
while True:
- # Check if push_to_hub exists before accessing its __name__
- if (
- hasattr(original_model, "push_to_hub")
- and original_model.push_to_hub.__name__ != "unsloth_push_to_hub"
- ):
+
+ if original_model.push_to_hub.__name__ != "unsloth_push_to_hub":
original_model.original_push_to_hub = original_model.push_to_hub
- original_model.push_to_hub = types.MethodType(
- unsloth_push_to_hub, original_model
- )
+ original_model.push_to_hub = types.MethodType(unsloth_push_to_hub, original_model)
if hasattr(original_model, "add_model_tags"):
- original_model.add_model_tags(
- [
- "unsloth",
- ]
- )
+ original_model.add_model_tags(["unsloth",])
+ pass
+ pass
- if hasattr(original_model, "model"):
- original_model = original_model.model
- else:
- break
+ if hasattr(original_model, "model"): original_model = original_model.model
+ else: break
+ pass
# Add saving methods to top level model
if not vision:
if hasattr(model, "config"):
# Counteract tokenizers
- model.push_to_hub_merged = types.MethodType(
- unsloth_generic_push_to_hub_merged, model
- )
- model.save_pretrained_merged = types.MethodType(
- unsloth_generic_save_pretrained_merged, model
- )
- model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model)
- model.save_pretrained_gguf = types.MethodType(
- unsloth_save_pretrained_gguf, model
- )
- model.save_pretrained_torchao = types.MethodType(
- unsloth_save_pretrained_torchao, model
- )
- model.push_to_hub_ggml = types.MethodType(
- unsloth_convert_lora_to_ggml_and_push_to_hub, model
- )
- model.save_pretrained_ggml = types.MethodType(
- unsloth_convert_lora_to_ggml_and_save_locally, model
- )
+ model.push_to_hub_merged = types.MethodType(unsloth_push_to_hub_merged, model)
+ model.save_pretrained_merged = types.MethodType(unsloth_save_pretrained_merged, model)
+ model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model)
+ model.save_pretrained_gguf = types.MethodType(unsloth_save_pretrained_gguf, model)
+ model.push_to_hub_ggml = types.MethodType(unsloth_convert_lora_to_ggml_and_push_to_hub, model)
+ model.save_pretrained_ggml = types.MethodType(unsloth_convert_lora_to_ggml_and_save_locally, model)
+ pass
else:
# Vision only 1 option
- model.push_to_hub_merged = types.MethodType(
- unsloth_generic_push_to_hub_merged, model
- )
- model.save_pretrained_merged = types.MethodType(
- unsloth_generic_save_pretrained_merged, model
- )
- model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model)
- model.save_pretrained_gguf = types.MethodType(
- unsloth_save_pretrained_gguf, model
- )
- model.save_pretrained_torchao = types.MethodType(
- unsloth_save_pretrained_torchao, model
- )
+ model.push_to_hub_merged = types.MethodType(unsloth_generic_push_to_hub_merged, model)
+ model.save_pretrained_merged = types.MethodType(unsloth_generic_save_pretrained_merged, model)
+ model.push_to_hub_gguf = types.MethodType(save_to_gguf_generic, model)
+ model.save_pretrained_gguf = types.MethodType(save_to_gguf_generic, model)
+ pass
return model
+pass
diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py
index c445879df7..26669127d7 100644
--- a/unsloth/tokenizer_utils.py
+++ b/unsloth/tokenizer_utils.py
@@ -25,7 +25,6 @@
import numpy as np
import gc
import subprocess
-import psutil
from unsloth_zoo.tokenizer_utils import (
mean_of_trained_tokens,
@@ -45,12 +44,10 @@
]
-IGNORED_TOKENIZER_CHECKING = frozenset(
- (
- "CodeLlamaTokenizerFast",
- "CodeLlamaTokenizer",
- )
-)
+IGNORED_TOKENIZER_CHECKING = frozenset((
+ "CodeLlamaTokenizerFast",
+ "CodeLlamaTokenizer",
+))
IGNORED_TOKENIZER_NAMES = [
@@ -59,24 +56,26 @@
"unsloth/Qwen2.5-Coder-7B-Instruct",
]
IGNORED_TOKENIZER_NAMES = frozenset(
- [x.lower() for x in IGNORED_TOKENIZER_NAMES]
- + [x.lower() + "-bnb-4bit" for x in IGNORED_TOKENIZER_NAMES]
+ [x.lower() for x in IGNORED_TOKENIZER_NAMES] + \
+ [x.lower()+"-bnb-4bit" for x in IGNORED_TOKENIZER_NAMES]
)
os.environ["UNSLOTH_IGNORED_TOKENIZER_NAMES"] = "\n".join(IGNORED_TOKENIZER_NAMES)
# Check environments
keynames = "\n" + "\n".join(os.environ.keys())
-IS_COLAB_ENVIRONMENT = "\nCOLAB_" in keynames
+IS_COLAB_ENVIRONMENT = "\nCOLAB_" in keynames
IS_KAGGLE_ENVIRONMENT = "\nKAGGLE_" in keynames
KAGGLE_TMP = "/tmp"
del keynames
def try_fix_tokenizer(tokenizer, prepend = True):
+
if hasattr(tokenizer, "_tokenizer"):
converted_tokenizer = tokenizer._tokenizer
else:
converted_tokenizer = convert_slow_tokenizer(tokenizer)
+ pass
tokenizer_string = converted_tokenizer.to_str()
@@ -84,6 +83,7 @@ def try_fix_tokenizer(tokenizer, prepend = True):
prepend_text = '{"type":"Prepend","prepend":"▁"},'
if not prepend and prepend_text in tokenizer_string:
tokenizer_string = tokenizer_string.replace(prepend_text, "", 1)
+ pass
dir_names = dir(tokenizer)
# Get eos_token, bos_token etc
@@ -91,21 +91,19 @@ def try_fix_tokenizer(tokenizer, prepend = True):
for token_name in token_names:
token = getattr(tokenizer, token_name, None)
- if token is None:
- continue
+ if token is None: continue
token_id = getattr(tokenizer, token_name + "_id", None)
# Locate the token's id mapping in the string
find_text = f'"id":{token_id},"content":"'
start = tokenizer_string.find(find_text) + len(find_text)
- if start == -1:
- continue
- end = tokenizer_string.find('",', start)
+ if start == -1: continue
+ end = tokenizer_string.find('",', start)
- bad_token = tokenizer_string[start:end]
+ bad_token = tokenizer_string[start : end]
# Check if token is the actual same one - if not, edit it
if bad_token != token:
- bad_text = f'{find_text}{bad_token}",'
+ bad_text = f'{find_text}{bad_token}",'
good_text = f'{find_text}{token}",'
tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)
@@ -113,20 +111,24 @@ def try_fix_tokenizer(tokenizer, prepend = True):
bad_text = f'"{bad_token}":{token_id},'
good_text = f'"{token}":{token_id},'
tokenizer_string = tokenizer_string.replace(bad_text, good_text, 1)
+ pass
+ pass
fixed_tokenizer = converted_tokenizer.from_str(tokenizer_string)
return fixed_tokenizer
+pass
def get_sorted_dict(dictionary):
sorted_keys = sorted(dictionary.values())
- inverted_dictionary = {value: key for key, value in dictionary.items()}
+ inverted_dictionary = { value : key for key, value in dictionary.items() }
sorted_dictionary = {}
for key in sorted_keys:
value = inverted_dictionary[key]
sorted_dictionary[value] = key
return sorted_dictionary
+pass
def convert_to_fast_tokenizer(
@@ -134,14 +136,13 @@ def convert_to_fast_tokenizer(
temporary_location = "_unsloth_sentencepiece_temp",
):
is_fast = getattr(slow_tokenizer, "is_fast", False)
- if is_fast:
- return slow_tokenizer
-
+ if is_fast: return slow_tokenizer
+
try:
tokenizer_name = slow_tokenizer.__class__.__name__
lowered_tokenizer_name = tokenizer_name.lower()
if lowered_tokenizer_name.endswith("tokenizer"):
- class_name = lowered_tokenizer_name[: -len("tokenizer")]
+ class_name = lowered_tokenizer_name[:-len("tokenizer")]
FastTokenizer = eval(
f'__import__(f"transformers.models.{class_name}").{tokenizer_name}Fast'
)
@@ -149,52 +150,52 @@ def convert_to_fast_tokenizer(
FastTokenizer = PreTrainedTokenizerFast
except:
FastTokenizer = PreTrainedTokenizerFast
+ pass
# Get all arguments (bos_token, etc)
docs = FastTokenizer.__doc__
- docs = docs[docs.find("Args:") :]
+ docs = docs[docs.find("Args:"):]
args = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
args = [x for x in args if not x.endswith("_file")]
# Also some missing maybe!
docs = PreTrainedTokenizerFast.__doc__
- docs = docs[docs.find("Args:") :]
+ docs = docs[docs.find("Args:"):]
args2 = re.findall(r"\n[\s]+([^\s]{1,}) \(", docs, flags = re.MULTILINE)
args2 = [x for x in args2 if not x.endswith("_file")]
args = list(set(args + args2))
kwargs = {}
- for arg in args:
- kwargs[arg] = getattr(slow_tokenizer, arg, None)
+ for arg in args: kwargs[arg] = getattr(slow_tokenizer, arg, None)
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = True)
- fast_tokenizer = FastTokenizer(**kwargs)
+ fast_tokenizer = FastTokenizer( **kwargs )
# Check if they're similar!
sorted_slow_tokenizer = get_sorted_dict(slow_tokenizer.get_vocab())
sorted_fast_tokenizer = get_sorted_dict(fast_tokenizer.get_vocab())
- check_vocab = sorted_slow_tokenizer == sorted_fast_tokenizer
- check_special = (
- slow_tokenizer.all_special_tokens == fast_tokenizer.all_special_tokens
- )
+ check_vocab = (sorted_slow_tokenizer == sorted_fast_tokenizer)
+ check_special = (slow_tokenizer.all_special_tokens == fast_tokenizer.all_special_tokens)
# Failure so return slow_tokenizer
- if not check_vocab or not check_special:
- return slow_tokenizer
+ if not check_vocab or not check_special: return slow_tokenizer
# Now confirm if they match
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Maybe remove prepending of __apple?
kwargs["tokenizer_object"] = try_fix_tokenizer(slow_tokenizer, prepend = False)
- fast_tokenizer = FastTokenizer(**kwargs)
+ fast_tokenizer = FastTokenizer( **kwargs )
if not assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Failure :(
return slow_tokenizer
+ pass
+ pass
# Also tokenizer.model is missing!
name = slow_tokenizer.name_or_path.replace("/", "_")
if not os.path.exists(temporary_location):
os.makedirs(temporary_location)
+ pass
new_location = f"{temporary_location}/{name}"
slow_tokenizer.save_pretrained(new_location)
fast_tokenizer.save_pretrained(new_location)
@@ -204,72 +205,66 @@ def convert_to_fast_tokenizer(
if assert_same_tokenization(slow_tokenizer, fast_tokenizer):
return fast_tokenizer
return slow_tokenizer
+pass
# Check Mistral chat template without BOS / EOS
-mistral_template = (
- "{% if messages[0]['role'] == 'system' %}"
- "{% if messages[1]['role'] == 'user' %}"
- "{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"
- "{% set loop_messages = messages[2:] %}"
- "{% else %}"
- "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"
- "{% set loop_messages = messages[1:] %}"
- "{% endif %}"
- "{% else %}"
- "{% set loop_messages = messages %}"
- "{% endif %}"
- "{% for message in loop_messages %}"
- "{% if message['role'] == 'user' %}"
- "{{ '[INST] ' + message['content'] + ' [/INST]' }}"
- "{% elif message['role'] == 'assistant' %}"
- "{{ message['content'] }}"
- "{% else %}"
- "{{ raise_exception('Only user and assistant roles are supported!') }}"
- "{% endif %}"
+mistral_template = \
+ "{% if messages[0]['role'] == 'system' %}"\
+ "{% if messages[1]['role'] == 'user' %}"\
+ "{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
+ "{% set loop_messages = messages[2:] %}"\
+ "{% else %}"\
+ "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
+ "{% set loop_messages = messages[1:] %}"\
+ "{% endif %}"\
+ "{% else %}"\
+ "{% set loop_messages = messages %}"\
+ "{% endif %}"\
+ "{% for message in loop_messages %}"\
+ "{% if message['role'] == 'user' %}"\
+ "{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
+ "{% elif message['role'] == 'assistant' %}"\
+ "{{ message['content'] }}"\
+ "{% else %}"\
+ "{{ raise_exception('Only user and assistant roles are supported!') }}"\
+ "{% endif %}"\
"{% endfor %}"
-)
+pass
# Check Llama chat template without BOS / EOS
-llama_template = (
- "{% if messages[0]['role'] == 'system' %}"
- "{% if messages[1]['role'] == 'user' %}"
- "{{ '[INST] <>\n' + messages[0]['content'] + '\n<>\n\n' + messages[1]['content'] + ' [/INST]' }}"
- "{% set loop_messages = messages[2:] %}"
- "{% else %}"
- "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"
- "{% set loop_messages = messages[1:] %}"
- "{% endif %}"
- "{% else %}"
- "{% set loop_messages = messages %}"
- "{% endif %}"
- "{% for message in loop_messages %}"
- "{% if message['role'] == 'user' %}"
- "{{ '[INST] ' + message['content'].strip() + ' [/INST]' }}"
- "{% elif message['role'] == 'assistant' %}"
- "{{ ' ' + message['content'].strip() + ' ' }}"
- "{% else %}"
- "{{ raise_exception('Only user and assistant roles are supported!') }}"
- "{% endif %}"
+llama_template = \
+ "{% if messages[0]['role'] == 'system' %}"\
+ "{% if messages[1]['role'] == 'user' %}"\
+ "{{ '[INST] <>\n' + messages[0]['content'] + '\n<>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
+ "{% set loop_messages = messages[2:] %}"\
+ "{% else %}"\
+ "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
+ "{% set loop_messages = messages[1:] %}"\
+ "{% endif %}"\
+ "{% else %}"\
+ "{% set loop_messages = messages %}"\
+ "{% endif %}"\
+ "{% for message in loop_messages %}"\
+ "{% if message['role'] == 'user' %}"\
+ "{{ '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
+ "{% elif message['role'] == 'assistant' %}"\
+ "{{ ' ' + message['content'].strip() + ' ' }}"\
+ "{% else %}"\
+ "{{ raise_exception('Only user and assistant roles are supported!') }}"\
+ "{% endif %}"\
"{% endfor %}"
-)
+pass
def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
# Get eos_token, bos_token etc
- if not hasattr(slow_tokenizer, "all_special_tokens"):
- return True
+ if not hasattr(slow_tokenizer, "all_special_tokens"): return True
dir_names = dir(slow_tokenizer)
- special_tokens = list(
- filter(
- None,
- (
- getattr(slow_tokenizer, x)
- for x in dir_names
- if x.endswith("_token") and x.count("_") == 1
- ),
- )
- )
+ special_tokens = list(filter(None, (
+ getattr(slow_tokenizer, x) for x in dir_names
+ if x.endswith("_token") and x.count("_") == 1
+ )))
all_special_tokens = list(set(special_tokens + slow_tokenizer.all_special_tokens))
# Remove replacement char for false positive
@@ -280,7 +275,7 @@ def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
check_chat_template1 = True
check_chat_template2 = True
check_chat_template3 = True
-
+
"""
Weirdly Mistral tokenizers are actually correct??
Ie below will actually load mistral v1 and v3 incorrectly!
@@ -318,20 +313,16 @@ def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
slow_tokenizer.chat_template = slow_chat_template
fast_tokenizer.chat_template = fast_chat_template
"""
- check_chat_template = (
- check_chat_template1 and check_chat_template2 and check_chat_template3
- )
+ check_chat_template = check_chat_template1 and check_chat_template2 and check_chat_template3
# Try special tokens
try:
- string = (
- "\n".join(all_special_tokens)
- + "A quick brown fox jumps over the lazy dog!!\n\nHi\n\n"
- + "".join(all_special_tokens)
- )
- check_special_tokens = (
- slow_tokenizer(string).input_ids == fast_tokenizer(string).input_ids
- )
+ string = "\n".join(all_special_tokens) + \
+ "A quick brown fox jumps over the lazy dog!!\n\nHi\n\n" + \
+ "".join(all_special_tokens)
+ check_special_tokens = \
+ slow_tokenizer(string).input_ids == \
+ fast_tokenizer(string).input_ids
return check_chat_template and check_special_tokens
except:
@@ -342,6 +333,8 @@ def assert_same_tokenization(slow_tokenizer, fast_tokenizer):
return check_chat_template
else:
return False
+ pass
+pass
def fix_sentencepiece_tokenizer(
@@ -352,39 +345,22 @@ def fix_sentencepiece_tokenizer(
):
# From https://github.com/google/sentencepiece/issues/121
# We need to manually edit the sentencepiece tokenizer!
- try:
- from transformers.convert_slow_tokenizer import import_protobuf
-
- sentencepiece_model_pb2 = import_protobuf()
- except Exception as e:
- try:
- import google.protobuf
- from unsloth_zoo.utils import Version
-
- protobuf_version = Version(google.protobuf.__version__)
- if protobuf_version > Version("3.20.3"):
- raise RuntimeError(
- f"Unsloth: Your protobuf version = {protobuf_version} is too new.\n"
- f"Please downgrade via `pip install --force-reinstall protobuf==3.20.3`"
- )
- except:
- # This will only work for older SentencePiece versions <= 3.20.3
- from transformers.utils import sentencepiece_model_pb2
+ from transformers.utils import sentencepiece_model_pb2
if not os.path.exists(temporary_location):
os.makedirs(temporary_location)
+ pass
# Check if tokenizer.model exists
if not os.path.isfile(f"{temporary_location}/tokenizer.model"):
return new_tokenizer
+ pass
# First save the old tokenizer
old_tokenizer.save_pretrained(temporary_location)
tokenizer_file = sentencepiece_model_pb2.ModelProto()
- tokenizer_file.ParseFromString(
- open(f"{temporary_location}/tokenizer.model", "rb").read()
- )
+ tokenizer_file.ParseFromString(open(f"{temporary_location}/tokenizer.model", "rb").read())
# Now save the new tokenizer
new_tokenizer.save_pretrained(temporary_location)
@@ -393,47 +369,48 @@ def fix_sentencepiece_tokenizer(
for old_token, new_token in token_mapping.items():
ids = old_tokenizer([old_token], add_special_tokens = False).input_ids
ids = ids[0]
- if len(ids) != 1:
+ if (len(ids) != 1):
# Skip this token!
- print(
- f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!"
- )
+ print(f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!")
continue
+ pass
ids = ids[0]
# [TODO] Hack for Starling - try except
try:
tokenizer_piece = tokenizer_file.pieces[ids]
except:
continue
- assert tokenizer_piece.piece == old_token
+ assert(tokenizer_piece.piece == old_token)
tokenizer_piece.piece = new_token
+ pass
# And now write it
with open(f"{temporary_location}/tokenizer.model", "wb") as file:
file.write(tokenizer_file.SerializeToString())
+ pass
# And load it!
from transformers import AutoTokenizer
-
tokenizer = AutoTokenizer.from_pretrained(
temporary_location,
eos_token = new_tokenizer.eos_token,
pad_token = new_tokenizer.pad_token,
)
return tokenizer
+pass
def fix_sentencepiece_gguf(saved_location):
"""
- Fixes sentencepiece tokenizers which did not extend the vocabulary with
- user defined tokens.
- Inspiration from https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py
+ Fixes sentencepiece tokenizers which did not extend the vocabulary with
+ user defined tokens.
+ Inspiration from https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py
"""
from copy import deepcopy
from transformers.utils import sentencepiece_model_pb2
import json
from enum import IntEnum
-
+
class SentencePieceTokenTypes(IntEnum):
NORMAL = 1
UNKNOWN = 2
@@ -441,58 +418,54 @@ class SentencePieceTokenTypes(IntEnum):
USER_DEFINED = 4
UNUSED = 5
BYTE = 6
+ pass
# Load tokenizer.model
tokenizer_file = sentencepiece_model_pb2.ModelProto()
- if not os.path.isfile(f"{saved_location}/tokenizer.model"):
- return
- tokenizer_file.ParseFromString(
- open(f"{saved_location}/tokenizer.model", "rb").read()
- )
+ if not os.path.isfile(f"{saved_location}/tokenizer.model"): return
+ tokenizer_file.ParseFromString(open(f"{saved_location}/tokenizer.model", "rb").read())
sentence_piece_size = len(tokenizer_file.pieces)
# Load added_tokens_json
- if not os.path.isfile(f"{saved_location}/added_tokens.json"):
- return
+ if not os.path.isfile(f"{saved_location}/added_tokens.json"): return
with open(f"{saved_location}/added_tokens.json", "r", encoding = "utf-8") as file:
added_tokens_json = json.load(file)
- if len(added_tokens_json) == 0:
- return
+ pass
+ if len(added_tokens_json) == 0: return
- added_tokens_json = dict(
- sorted(added_tokens_json.items(), key = lambda item: item[1])
- )
+ added_tokens_json = dict(sorted(added_tokens_json.items(), key = lambda item: item[1]))
new_size = sentence_piece_size + len(added_tokens_json)
# Confirm added_tokens_json is correct
added_tokens_ids = np.array(list(added_tokens_json.values()))
diff = np.diff(added_tokens_ids)
- if diff.min() != 1 or diff.max() != 1:
- return
- if added_tokens_ids.min() != sentence_piece_size:
- return
+ if (diff.min() != 1 or diff.max() != 1): return
+ if (added_tokens_ids.min() != sentence_piece_size): return
# Edit sentence piece tokens with added_tokens_json
logger.warning(
- f"Unsloth: Extending {saved_location}/tokenizer.model with added_tokens.json.\n"
- f"Originally tokenizer.model is of size ({sentence_piece_size}).\n"
+ f"Unsloth: Extending {saved_location}/tokenizer.model with added_tokens.json.\n"\
+ f"Originally tokenizer.model is of size ({sentence_piece_size}).\n"\
f"But we need to extend to sentencepiece vocab size ({new_size})."
)
- new_tokens = deepcopy(tokenizer_file.pieces[-len(added_tokens_ids) :])
+ new_tokens = deepcopy(tokenizer_file.pieces[-len(added_tokens_ids):])
for new_token, added_token in zip(new_tokens, added_tokens_json.keys()):
new_token.piece = added_token.encode("utf-8")
new_token.score = -1000.0
- new_token.type = SentencePieceTokenTypes.USER_DEFINED
+ new_token.type = SentencePieceTokenTypes.USER_DEFINED
+ pass
tokenizer_file.pieces.extend(new_tokens)
with open(f"{saved_location}/tokenizer.model", "wb") as file:
file.write(tokenizer_file.SerializeToString())
+ pass
# Add padding tokens
# actual_vocab_size = model.config.vocab_size
# padding = actual_vocab_size - len(tokenizer_file.pieces)
return
+pass
def _load_correct_tokenizer(
@@ -512,6 +485,7 @@ def _load_correct_tokenizer(
cache_dir = os.path.join(KAGGLE_TMP, cache_dir)
else:
cache_dir = None
+ pass
# Try loading the slow tokenizer. If it fails, then try Fast only
# Mainly to solve Deepseek models with no tokenizer.model file
@@ -519,15 +493,15 @@ def _load_correct_tokenizer(
try:
slow_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
- model_max_length = model_max_length,
- padding_side = padding_side,
- token = token,
+ model_max_length = model_max_length,
+ padding_side = padding_side,
+ token = token,
trust_remote_code = trust_remote_code,
# Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373
- use_fast = False,
- legacy = False,
- from_slow = True,
- cache_dir = cache_dir,
+ use_fast = False,
+ legacy = False,
+ from_slow = True,
+ cache_dir = cache_dir,
)
except:
slow_tokenizer = None
@@ -535,17 +509,17 @@ def _load_correct_tokenizer(
# f"Unsloth: {tokenizer_name} has no tokenizer.model file.\n"\
# "Just informing you about this - this is not a critical error."
# )
+ pass
# Unsure why this occurs!
- if type(slow_tokenizer) is bool:
- slow_tokenizer = None
+ if type(slow_tokenizer) is bool: slow_tokenizer = None
fast_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
- model_max_length = model_max_length,
- padding_side = padding_side,
- token = token,
+ model_max_length = model_max_length,
+ padding_side = padding_side,
+ token = token,
trust_remote_code = trust_remote_code,
- cache_dir = cache_dir,
+ cache_dir = cache_dir,
)
if not fix_tokenizer or tokenizer_name in IGNORED_TOKENIZER_NAMES:
@@ -557,26 +531,22 @@ def _load_correct_tokenizer(
elif "phi-4" in tokenizer_name.lower():
return fast_tokenizer
elif slow_tokenizer is not None:
- if hasattr(fast_tokenizer, "add_bos_token") and hasattr(
- slow_tokenizer, "add_bos_token"
- ):
+ if hasattr(fast_tokenizer, "add_bos_token") and hasattr(slow_tokenizer, "add_bos_token"):
fast_tokenizer.add_bos_token = slow_tokenizer.add_bos_token
- if hasattr(fast_tokenizer, "add_eos_token") and hasattr(
- slow_tokenizer, "add_eos_token"
- ):
+ if hasattr(fast_tokenizer, "add_eos_token") and hasattr(slow_tokenizer, "add_eos_token"):
fast_tokenizer.add_eos_token = slow_tokenizer.add_eos_token
-
+
# Confirm if slow and fast are equivalent!
if assert_same_tokenization(slow_tokenizer, fast_tokenizer):
return fast_tokenizer
else:
- logger.warning(
- f"Unsloth: Will load {tokenizer_name} as a legacy tokenizer."
- )
+ logger.warning(f"Unsloth: Will load {tokenizer_name} as a legacy tokenizer.")
return convert_to_fast_tokenizer(slow_tokenizer)
pass
else:
return fast_tokenizer
+ pass
+pass
def load_correct_tokenizer(
@@ -601,21 +571,15 @@ def load_correct_tokenizer(
### 1. Fixup tokenizer's chat_template
old_chat_template = getattr(tokenizer, "chat_template", None)
- # Ignore mistral type models since they don't have an add_generation_prompt
- if any(
- s in str(getattr(tokenizer, "name_or_path", "")).lower()
- for s in ["mistral", "qwen3guard"]
- ):
+ # Ignore mistral type models since they don't have a add_generation_prompt
+ if "mistral" in str(getattr(tokenizer, "name_or_path", "")).lower():
chat_template = old_chat_template
# Also check Llama-2 old style models
- elif (
- old_chat_template is not None
- and "[/INST]" in old_chat_template
- and "[INST]" in old_chat_template
- and "bos_token" in old_chat_template
- and "eos_token" in old_chat_template
- ):
+ elif old_chat_template is not None and \
+ "[/INST]" in old_chat_template and "[INST]" in old_chat_template and \
+ "bos_token" in old_chat_template and "eos_token" in old_chat_template:
+
chat_template = old_chat_template
else:
@@ -625,9 +589,11 @@ def load_correct_tokenizer(
"Unsloth: Fixing chat template failed - please file a report immediately!"
)
pass
+ pass
tokenizer.chat_template = chat_template
return tokenizer
+pass
def _find_end_position(template, endfor, endif):
@@ -639,6 +605,8 @@ def _find_end_position(template, endfor, endif):
return endfor
else:
return endif
+ pass
+pass
def _fix_chat_template(chat_template):
@@ -651,33 +619,28 @@ def _fix_chat_template(chat_template):
chosen_end = _find_end_position(chat_template, endfor, endif)
if chosen_end is None:
return chat_template
-
+
where = chat_template.find(chosen_end)
- after_endfor = chat_template[where + len(chosen_end) :]
+ after_endfor = chat_template[where + len(chosen_end):]
dash = "-" if chosen_end.startswith("{%-") else ""
- if (
- "{%" + dash + " if" not in after_endfor
- and "{%" + dash + " set " not in after_endfor
- and after_endfor.startswith("{{")
- and after_endfor.endswith("}}")
- and after_endfor.count("{{") == 1
- and after_endfor.count("}}") == 1
- ):
- after_endfor = (
- "{%" + dash + " if add_generation_prompt %}" + after_endfor + endif
- )
+ if "{%" + dash + " if" not in after_endfor and "{%" + dash + " set " not in after_endfor and \
+ after_endfor.startswith("{{") and after_endfor.endswith("}}") and \
+ after_endfor.count("{{") == 1 and after_endfor.count("}}") == 1:
+
+ after_endfor = "{%" + dash + " if add_generation_prompt %}" + after_endfor + endif
- chat_template = chat_template[: where + len(chosen_end)] + after_endfor
+ chat_template = chat_template[:where + len(chosen_end)] + after_endfor
+ pass
return chat_template
+pass
def fix_chat_template(tokenizer):
chat_template = getattr(tokenizer, "chat_template", None)
- if chat_template is None:
- return None
+ if chat_template is None: return None
### 1. Check if add_generation_prompt works
# Check for ShareGPT style first
@@ -686,69 +649,62 @@ def fix_chat_template(tokenizer):
messages = [
{"role": "user", "content": "Who are you?"},
]
- tokenizer.apply_chat_template(
- messages, add_generation_prompt = False, tokenize = False
- )
+ tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
is_sharegpt = False
except:
try:
messages = [
{"from": "human", "value": "Who are you?"},
]
- tokenizer.apply_chat_template(
- messages, add_generation_prompt = False, tokenize = False
- )
+ tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
is_sharegpt = True
except:
is_sharegpt = None
+ pass
+ pass
# Not ShareGPT or HF style - just return
- if is_sharegpt is None:
- return chat_template
+ if is_sharegpt is None: return chat_template
# Tokenize
messages = [
- {"role": "user", "content": "Who are you?"}
- if not is_sharegpt
- else {"from": "human", "value": "Who are you?"}
+ {"role": "user", "content": "Who are you?"} \
+ if not is_sharegpt else \
+ {"from": "human", "value": "Who are you?"}
]
- no = tokenizer.apply_chat_template(
- messages, add_generation_prompt = False, tokenize = False
- )
- yes = tokenizer.apply_chat_template(
- messages, add_generation_prompt = True, tokenize = False
- )
+ no = tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
+ yes = tokenizer.apply_chat_template(messages, add_generation_prompt = True, tokenize = False)
if no == yes:
# SAME?! That's not good! We check for add_generation_prompt
- if (
- "{% if add_generation_prompt %}" not in chat_template
- and "{%- if add_generation_prompt %}" not in chat_template
- ):
+ if "{% if add_generation_prompt %}" not in chat_template and \
+ "{%- if add_generation_prompt %}" not in chat_template:
# Try fixing it by adding it
new_chat_template = _fix_chat_template(chat_template)
- if (
- "{% if add_generation_prompt %}" not in new_chat_template
- and "{%- if add_generation_prompt %}" not in new_chat_template
- ):
+ if "{% if add_generation_prompt %}" not in new_chat_template and \
+ "{%- if add_generation_prompt %}" not in new_chat_template:
raise RuntimeError(
- f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"
- "does not have a {% if add_generation_prompt %} for generation purposes.\n"
- f"Please file a bug report to the maintainers of `{tokenizer.name_or_path}` - thanks!"
+ f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\
+ "does not have a {% if add_generation_prompt %} for generation purposes.\n"\
+ "Please file a bug report immediately - thanks!"
)
else:
logger.warning_once(
- "Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n"
- f"This is not a bug, but please notify the maintainers of `{tokenizer.name_or_path}` - thanks!"
+ "Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n"\
+ "This is not a bug, but please notify the Unsloth maintainers - thanks!"
)
chat_template = new_chat_template
+ pass
else:
raise RuntimeError(
- f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"
- "has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\n"
+ f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\
+ "has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\n"\
"Please file a bug report immediately - thanks!"
)
+ pass
+ pass
return chat_template
+pass
def check_tokenizer(
@@ -769,64 +725,59 @@ def check_tokenizer(
# We ignore some of them!
if tokenizer.__repr__().split("(", 1)[0] in IGNORED_TOKENIZER_CHECKING:
return tokenizer
+ pass
max_embedding_size = model.model.embed_tokens.weight.shape[0]
added_tokens_fast = tokenizer.added_tokens_decoder
- added_tokens_fast = {
- index: str(value) for index, value in added_tokens_fast.items()
- }
+ added_tokens_fast = {index : str(value) for index, value in added_tokens_fast.items()}
sorted_keys = sorted(added_tokens_fast)
- added_tokens_fast = {key: added_tokens_fast[key] for key in sorted_keys}
+ added_tokens_fast = {key : added_tokens_fast[key] for key in sorted_keys}
for j, index in enumerate(added_tokens_fast.keys()):
if index >= max_embedding_size:
- bad_indices = list(added_tokens_fast.keys())[j:]
- bad_tokens = list(added_tokens_fast.values())[j:]
+ bad_indices = list(added_tokens_fast.keys ())[j:]
+ bad_tokens = list(added_tokens_fast.values())[j:]
if not _reload:
# Try removing the token
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
special_tokens = tokenizer.special_tokens_map
import itertools
-
special_tokens = frozenset(
itertools.chain.from_iterable(
[x] if type(x) is str else x for x in special_tokens.values()
)
)
can_be_removed1 = [x for x in bad_tokens if x not in special_tokens]
- can_be_removed2 = [
- x
- for x in can_be_removed1
- if x in tokenizer._added_tokens_encoder.keys()
- ]
+ can_be_removed2 = [x for x in can_be_removed1 if x in tokenizer._added_tokens_encoder.keys()]
# Check of extra tokens can in fact we removed!
- can_be_removed = (len(can_be_removed1) == len(bad_tokens)) and (
- len(can_be_removed2) == len(bad_tokens)
- )
+ can_be_removed = \
+ (len(can_be_removed1) == len(bad_tokens)) and \
+ (len(can_be_removed2) == len(bad_tokens))
# Check if sep_token or other generic types
remove_generic = False
try_mapper = []
if not can_be_removed:
names = dir(tokenizer)
- names = (
- x for x in names if x.endswith("_token") and x.count("_") == 1
- )
+ names = (x for x in names if x.endswith("_token") and x.count("_") == 1)
generic_tokens = [(x, getattr(tokenizer, x, None)) for x in names]
try_removal = []
for token in bad_tokens:
- for name_token, check_token in generic_tokens:
+ for (name_token, check_token) in generic_tokens:
if check_token == token:
try_removal.append(token)
try_mapper.append(name_token)
+ pass
+ pass
+ pass
# Recheck!
- can_be_removed = len(try_removal) == len(bad_tokens)
- if can_be_removed:
- remove_generic = True
+ can_be_removed = (len(try_removal) == len(bad_tokens))
+ if can_be_removed: remove_generic = True
can_be_removed1 = bad_tokens
+ pass
if can_be_removed:
# Yes it can be fixed!
@@ -839,26 +790,32 @@ def check_tokenizer(
# Remove sep token for example
setattr(tokenizer, try_mapper[j], None)
setattr(tokenizer, try_mapper[j] + "_id", None)
+ pass
+ pass
# Confirm 1 more time!
if max(tokenizer.added_tokens_decoder.keys()) < max_embedding_size:
logger.warning_once(
- f"Unsloth loaded a broken tokenizer `{model_name}`, but managed to repair it!\n"
- f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"
+ f"Unsloth loaded a broken tokenizer `{model_name}`, but managed to repair it!\n"\
+ f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
"We removed these bad tokens. If you think this is incorrect, fix your tokenizer first."
)
return convert_to_fast_tokenizer(tokenizer)
+ pass
+ pass
# :( Failure
raise RuntimeError(
- f"Unsloth tried to load `{model_name}`, but cannot succeed.\n"
- f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"
+ f"Unsloth tried to load `{model_name}`, but cannot succeed.\n"\
+ f"Tokens {bad_tokens} with ids {bad_indices} exceeds the max vocab size of {max_embedding_size}.\n"\
f"Fix your tokenizer since it'll perform out of bounds memory accesses."
)
-
+ pass
+
if IS_COLAB_ENVIRONMENT or IS_KAGGLE_ENVIRONMENT:
cache_dir = "huggingface_tokenizers_cache"
else:
cache_dir = None
+ pass
# Sometimes slow tokenizer does not work like Deepseek
try:
@@ -888,12 +845,16 @@ def check_tokenizer(
# Tokenizer has out of bounds issues and we can't
# load the slow tokenizer version :(
logger.warning_once(
- "Unsloth: Tokenizer is most likely buggy, and Unsloth failed to repair it.\n"
- "It will still work, but beware of out of bounds memory accesses.\n"
+ "Unsloth: Tokenizer is most likely buggy, and Unsloth failed to repair it.\n"\
+ "It will still work, but beware of out of bounds memory accesses.\n"\
"Please file an issue on the model owner's repo about this issue."
)
return tokenizer
+ pass
+ pass
+ pass
return convert_to_fast_tokenizer(tokenizer)
+pass
import inspect
@@ -902,11 +863,9 @@ def check_tokenizer(
import trl.trainer.sft_trainer
from trl.trainer.sft_trainer import *
from transformers.trainer import *
-
try:
from trl.trainer.sft_trainer import neftune_post_forward_hook
except:
-
def neftune_post_forward_hook(module, input, output):
"""
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for
@@ -934,11 +893,13 @@ def neftune_post_forward_hook(module, input, output):
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
+ pass
+pass
def patch_sft_trainer_tokenizer():
"""
- Patches the trainer with changes
+ Patches the trainer with changes
"""
try:
sft_trainer = eval(f"trl.trainer.sft_trainer.SFTTrainer")
@@ -946,50 +907,39 @@ def patch_sft_trainer_tokenizer():
return
all_imports = dir(trl.trainer.sft_trainer)
- for (
- function_name,
- replacer,
- ) in (
+ for (function_name, replacer,) in (
# ("_prepare_non_packed_dataloader", "def tokenize(element):",),
- (
- "_prepare_non_packed_dataloader",
- None,
- ),
- (
- "_prepare_dataset",
- None,
- ),
+ ("_prepare_non_packed_dataloader", None,),
+ ("_prepare_dataset", None,),
# ("_prepare_packed_dataloader", "if dataset_text_field is not None",),
):
- if not hasattr(sft_trainer, function_name):
- continue
+ if not hasattr(sft_trainer, function_name): continue
function = getsource(eval(f"sft_trainer.{function_name}"))
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
- check_text = (
- "\n"
- "if 'tokenizer' not in locals(): tokenizer = processing_class\n"
- "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"
- "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"
- "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"
- "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"
- "chat_template = getattr(tokenizer, 'chat_template', None)\n"
- "chat_template = '' if chat_template is None else chat_template\n"
- "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "
- "if getattr(tokenizer, 'bos_token', None) is not None else False\n"
- "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"
- " from functools import partial\n"
- " tokenizer = partial(tokenizer, add_special_tokens = False)\n"
- " processing_class = tokenizer\n"
- "else:\n"
- " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n"
- )
+ check_text = \
+ "\n"\
+ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\
+ "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\
+ "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"\
+ "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"\
+ "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"\
+ "chat_template = getattr(tokenizer, 'chat_template', None)\n"\
+ "chat_template = '' if chat_template is None else chat_template\n"\
+ "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\
+ "if getattr(tokenizer, 'bos_token', None) is not None else False\n"\
+ "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"\
+ " from functools import partial\n"\
+ " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\
+ " processing_class = tokenizer\n"\
+ "else:\n"\
+ " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n"
check_text = check_text.split("\n")
- check_text = "\n".join(" " * where + x for x in check_text)
+ check_text = "\n".join(" "*where + x for x in check_text)
check_text = check_text.rstrip() + "\n"
if replacer is None:
@@ -999,121 +949,101 @@ def patch_sft_trainer_tokenizer():
function,
flags = re.MULTILINE | re.DOTALL,
)
- if len(replacer) == 0:
- continue
+ if len(replacer) == 0: continue
replacer = replacer[0]
function = function.replace(replacer, replacer + check_text)
else:
function = function.replace(replacer, check_text + replacer)
+ pass
x = [x for x in all_imports if x in function]
- try:
- exec(f"from trl.trainer.sft_trainer import ({','.join(x)})", locals())
- except ImportError:
- for _item in x:
- try:
- exec(f"from trl.trainer.sft_trainer import {_item}", locals())
- except ImportError:
- pass
+ exec(f"from trl.trainer.sft_trainer import ({','.join(x)})", locals())
exec(function, locals(), globals())
- exec(
- f"trl.trainer.sft_trainer.SFTTrainer.{function_name} = {function_name}",
- globals(),
- )
+ exec(f"trl.trainer.sft_trainer.SFTTrainer.{function_name} = {function_name}", globals())
+ pass
# Patch train with fix_untrained_tokens
- for path_to_trainer in (
- "sft_trainer.SFTTrainer",
- "dpo_trainer.DPOTrainer",
- "kto_trainer.KTOTrainer",
- ):
+ for path_to_trainer in \
+ ("sft_trainer.SFTTrainer", "dpo_trainer.DPOTrainer", "kto_trainer.KTOTrainer"):
+
function_name, replacer = "train", "if resume_from_checkpoint is False:"
- try:
- function = getsource(eval(f"trl.trainer.{path_to_trainer}.{function_name}"))
- except Exception:
- continue
+ function = getsource(eval(f"trl.trainer.{path_to_trainer}.{function_name}"))
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
- check_text = (
- "\n"
- "import subprocess, re, gc, numpy as np\n"
- "a = np.array([0,])\n"
- "try:\n"
- " a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\n"
- " a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)\n"
- " a = np.array([int(x.decode('utf-8'))/1024 for x in a])\n"
- "except:\n"
- " if not torch.cuda.is_available():\n"
- " raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')\n"
- "if ((a - PRE_CHECK) >= 1).sum() > 1:\n"
- " raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')\n"
- "for _ in range(3):\n"
- " gc.collect()\n"
- " torch.cuda.empty_cache()\n"
- "pass\n"
- "\n"
- "tokenizer = self.processing_class if hasattr(self, 'processing_class') else self.tokenizer\n"
- "fix_untrained_tokens(self.model, tokenizer, self.train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n\n"
- "fix_zero_training_loss(self.model, tokenizer, self.train_dataset)\n\n"
- )
+ check_text = \
+ "\n"\
+ "import subprocess, re, gc, numpy as np\n"\
+ "a = np.array([0,])\n"\
+ "try:\n"\
+ " a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True)\n"\
+ " a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a)\n"\
+ " a = np.array([int(x.decode('utf-8'))/1024 for x in a])\n"\
+ "except:\n"\
+ " if not torch.cuda.is_available():\n"\
+ " raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!')\n"\
+ "if ((a - PRE_CHECK) >= 1).sum() > 1:\n"\
+ " raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!')\n"\
+ "for _ in range(3):\n"\
+ " gc.collect()\n"\
+ " torch.cuda.empty_cache()\n"\
+ "pass\n"\
+ "\n"\
+ "tokenizer = self.processing_class if hasattr(self, 'processing_class') else self.tokenizer\n"\
+ "fix_untrained_tokens(self.model, tokenizer, self.train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n\n"\
+ "fix_zero_training_loss(self.model, tokenizer, self.train_dataset)\n\n"
# Warn on gradient accumulation steps if it's used
- check_text += (
- "\n"
- "try:\n"
- " gradient_accumulation_steps = self.args.gradient_accumulation_steps\n"
- " if type(gradient_accumulation_steps) is int and gradient_accumulation_steps > 1:\n"
- " from transformers import __version__ as transformers_version\n"
- " from packaging.version import Version\n"
- " if Version(transformers_version) <= Version('4.45.2'):\n"
- " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\\\n"
- " '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`')\n"
- "except:\n"
- " pass\n"
- "\n\n"
- )
+ check_text += \
+ "\n"\
+ "try:\n"\
+ " gradient_accumulation_steps = self.args.gradient_accumulation_steps\n"\
+ " if type(gradient_accumulation_steps) is int and gradient_accumulation_steps > 1:\n"\
+ " from transformers import __version__ as transformers_version\n"\
+ " from packaging.version import Version\n"\
+ " if Version(transformers_version) <= Version('4.45.2'):\n"\
+ " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\\\n"\
+ " '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`')\n"\
+ "except:\n"\
+ " pass\n"\
+ "\n\n"
# Add NEFTune since it doesn't seem to work?? We need to manually inject it
- check_text += (
- "\n"
- "if hasattr(self, 'neftune_hook_handle'):\n"
- " self.neftune_hook_handle.remove()\n"
- " if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"
- "\n"
- "if getattr(self, 'neftune_noise_alpha', None) is not None:\n"
- " self.model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"
- " self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\n"
- "pass\n"
- "\n"
- )
+ check_text += \
+ "\n"\
+ "if hasattr(self, 'neftune_hook_handle'):\n"\
+ " self.neftune_hook_handle.remove()\n"\
+ " if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"\
+ "\n"\
+ "if getattr(self, 'neftune_noise_alpha', None) is not None:\n"\
+ " self.model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"\
+ " self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\n"\
+ "pass\n"\
+ "\n"
# Also DPO weirdly tokenizes non numeric columns? Delete them!
- check_text += (
- "\n"
- "if hasattr(self.train_dataset, 'column_names'):\n"
- " column_names = set(self.train_dataset.column_names)\n"
- " check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"
- " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"
- " 'prompt_input_ids', 'prompt_attention_mask']\n"
- " if all(x in column_names for x in check):\n"
- " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"
- " del check, column_names\n"
- "\n"
- )
+ check_text += \
+ "\n"\
+ "if hasattr(self.train_dataset, 'column_names'):\n"\
+ " column_names = set(self.train_dataset.column_names)\n"\
+ " check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\
+ " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\
+ " 'prompt_input_ids', 'prompt_attention_mask']\n"\
+ " if all(x in column_names for x in check):\n"\
+ " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\
+ " del check, column_names\n"\
+ "\n"
check_text = check_text.split("\n")
- check_text = "\n".join(" " * where + x for x in check_text)
+ check_text = "\n".join(" "*where + x for x in check_text)
function = function.replace(replacer, check_text + replacer)
exec(function, globals())
- exec(
- f"trl.trainer.{path_to_trainer}.{function_name} = {function_name}",
- globals(),
- )
-
+ exec(f"trl.trainer.{path_to_trainer}.{function_name} = {function_name}", globals())
+ pass
+pass
# Finally patch TRL tokenizer things -> moved to RL
# patch_sft_trainer_tokenizer()
diff --git a/unsloth/trainer.py b/unsloth/trainer.py
index 65abe6801f..012be4b0cb 100644
--- a/unsloth/trainer.py
+++ b/unsloth/trainer.py
@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import os
-import psutil
import warnings
from dataclasses import dataclass, field
from typing import Optional
@@ -24,20 +21,13 @@
import inspect
from trl import SFTTrainer
from . import is_bfloat16_supported
-from unsloth.utils import (
- configure_padding_free,
- configure_sample_packing,
- enable_padding_free_metadata,
- enable_sample_packing,
-)
from unsloth_zoo.training_utils import (
unsloth_train as _unsloth_train,
)
from unsloth_zoo.vision_utils import (
UnslothVisionDataCollator,
)
-from unsloth_zoo.hf_utils import get_transformers_model_type
-from unsloth_zoo.utils import Version
+from packaging.version import Version
import dataclasses
__all__ = [
@@ -48,92 +38,41 @@
"UnslothVisionDataCollator",
]
-logger = logging.getLogger(__name__)
-
-_AUTO_PADDING_FREE_ENV_DISABLED = os.environ.get(
- "UNSLOTH_DISABLE_AUTO_PADDING_FREE", ""
-).strip().lower() in {"1", "true", "yes", "on"}
-
-PADDING_FREE_BLOCKLIST = {
- "gemma2", # - gemma2: Uses slow_attention_softcapping which has torch.compile issues
- "gpt_oss", # - gpt_oss: Uses Flex Attention which doesn't handle padding_free correctly
-}
-
-
-def _should_pack(config) -> bool:
- if config is None or not getattr(config, "packing", False):
- return False
- return not getattr(config, "_unsloth_disable_auto_packing", False)
-
-
-def _should_auto_padding_free(config) -> bool:
- if (
- config is None
- or _AUTO_PADDING_FREE_ENV_DISABLED
- or getattr(config, "packing", False)
- ):
- return False
- return getattr(config, "padding_free", None) is None
-
-
-def _disable_sample_packing(config):
- if config is None:
- return
- for attr, value in (("packing", False), ("padding_free", False)):
- if hasattr(config, attr):
- setattr(config, attr, value)
- if hasattr(config, "remove_unused_columns"):
- setattr(config, "remove_unused_columns", True)
- setattr(config, "_unsloth_disable_auto_packing", True)
-
-
-_AUTO_PACK_SKIP_MESSAGES = (
- "packing is not supported",
- "padding-free training",
- "passing a custom data collator",
-)
-
-
-def _should_skip_auto_packing_error(exc: Exception) -> bool:
- message = str(exc).lower()
- return any(msg in message for msg in _AUTO_PACK_SKIP_MESSAGES)
-
-
# Unsloth gradient accumulation fix:
-from transformers import __version__ as transformers_version, ProcessorMixin
-
+from transformers import __version__ as transformers_version
if Version(transformers_version) > Version("4.45.2"):
-
def unsloth_train(trainer, *args, **kwargs):
return trainer.train(*args, **kwargs)
-
+ pass
else:
-
def unsloth_train(trainer, *args, **kwargs):
if len(args) != 0 or len(kwargs) != 0:
raise RuntimeError(
- "Unsloth: Our custom gradient accumulation fixed trainer does not support other arguments.\n"
- "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"
- "`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`"
+ "Unsloth: Our custom gradient accumulation fixed trainer does not support other arguments.\n"\
+ "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\
+ '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`'
)
print(
- "Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"
- "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"
- "`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`"
+ "Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"\
+ "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\
+ '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`'
)
return _unsloth_train(trainer)
-
+ pass
+pass
try:
from trl import SFTConfig as TrainingArguments
except:
from transformers import TrainingArguments
-
-
+pass
+@dataclass
class UnslothTrainingArguments(TrainingArguments):
- def __init__(self, embedding_learning_rate: float = None, *args, **kwargs):
- embedding_learning_rate = embedding_learning_rate
- super().__init__(*args, **kwargs)
+ embedding_learning_rate : Optional[float] = field(
+ default = None,
+ metadata = {"help" : "Different learning rates for embeddings and lm_head."}
+ )
+pass
def _create_unsloth_optimizer(
@@ -145,130 +84,86 @@ def _create_unsloth_optimizer(
lr = optimizer_kwargs["lr"]
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
- param_groups = {
- "non_embeddings": {},
- "embeddings": {},
+ param_groups = \
+ {
+ "non_embeddings" : {},
+ "embeddings" : {},
}
for name, param in model.named_parameters():
- if not param.requires_grad:
- continue
+ if not param.requires_grad: continue
if name.endswith("modules_to_save.default.weight"):
- partial_name = name[: -len(".modules_to_save.default.weight")]
- partial_name = partial_name[partial_name.rfind(".") + 1 :]
- print(
- f"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}."
- )
- param_groups["embeddings"][name] = param
+ partial_name = name[:-len(".modules_to_save.default.weight")]
+ partial_name = partial_name[partial_name.rfind(".")+1:]
+ print(f"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}.")
+ param_groups["embeddings"] [name] = param
else:
param_groups["non_embeddings"][name] = param
+ pass
+ pass
optimizer_grouped_parameters = [
{
- "params": list(param_groups["non_embeddings"].values()),
- "weight_decay": weight_decay,
- "lr": lr,
+ "params" : list(param_groups["non_embeddings"].values()),
+ "weight_decay" : weight_decay,
+ "lr" : lr,
},
{
- "params": list(param_groups["embeddings"].values()),
- "weight_decay": weight_decay,
- "lr": embedding_lr,
+ "params" : list(param_groups["embeddings"].values()),
+ "weight_decay" : weight_decay,
+ "lr" : embedding_lr,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return optimizer
+pass
class UnslothTrainer(SFTTrainer):
def create_optimizer(self):
embedding_learning_rate = getattr(self.args, "embedding_learning_rate", None)
- if embedding_learning_rate is None:
- return super().create_optimizer()
+ if embedding_learning_rate is None: return super().create_optimizer()
if self.optimizer is None:
- optimizer_cls, optimizer_kwargs = SFTTrainer.get_optimizer_cls_and_kwargs(
- self.args
- )
+ optimizer_cls, optimizer_kwargs = SFTTrainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = _create_unsloth_optimizer(
self.model,
optimizer_cls,
optimizer_kwargs,
embedding_learning_rate,
)
+ pass
return self.optimizer
-
+ pass
+pass
# From `trl>=0.13.0`, they changed how to pass several params to the trainer
# We need to patch to make the transition smooth
-def _resolve_trainer_params(trainer_class, init_fn):
- """Resolve the real named parameters for a trainer __init__.
-
- Some TRL trainers (e.g., ORPOTrainer in TRL 0.27.1) are thin wrappers
- with only ``def __init__(self, *args, **kwargs)``. For those, walk the
- MRO and return the first parent class that has real named parameters.
- """
- params = inspect.signature(init_fn).parameters
- named = {
- k
- for k, v in params.items()
- if v.kind
- in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
- and k != "self"
- }
- if named:
- return set(params.keys())
-
- # Thin wrapper detected - walk MRO for real signature
- for cls in trainer_class.__mro__[1:]:
- if cls is object:
- continue
- parent_init = cls.__dict__.get("__init__")
- if parent_init is None:
- continue
- try:
- parent_params = inspect.signature(parent_init).parameters
- parent_named = {
- k
- for k, v in parent_params.items()
- if v.kind
- in (
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- inspect.Parameter.KEYWORD_ONLY,
- )
- and k != "self"
- }
- if parent_named:
- return set(parent_params.keys())
- except (ValueError, TypeError):
- continue
- return set(params.keys())
-
-
def _backwards_compatible_trainer(trainer_class, config_class):
original_init = trainer_class.__init__
-
+
@wraps(original_init)
def new_init(self, *args, **kwargs):
# All Trainer tokenizer are now called processing_class
- trainer_params = _resolve_trainer_params(trainer_class, original_init)
+ trainer_params = set(inspect.signature(original_init).parameters.keys())
if "processing_class" in trainer_params and "tokenizer" in kwargs:
kwargs["processing_class"] = kwargs.pop("tokenizer")
+ pass
- if ("args" in kwargs) and (Version(trl) >= Version("0.13.0.dev0")):
+ if ("args" in kwargs) and (Version(trl.__version__) >= Version("0.13.0.dev0")):
training_args = kwargs.pop("args", None)
# Get parameters that Trainer.__init__ actually expects
- trainer_params.remove("self")
- trainer_params.remove("args")
+ trainer_params.remove('self')
+ trainer_params.remove('args')
# Get fields that should be passed to Config init
config_fields = {
- field.name: field
- for field in dataclasses.fields(config_class)
+ field.name: field for field in dataclasses.fields(config_class)
if field.init
}
-
+
# Create config dict with valid fields from training_args
config_dict = {
name: getattr(training_args, name)
@@ -278,206 +173,54 @@ def new_init(self, *args, **kwargs):
# Get parameters that exist in Config but not in TrainingArguments
from transformers import TrainingArguments
-
- moved_params = set(inspect.signature(config_class).parameters.keys()) - set(
- inspect.signature(TrainingArguments).parameters.keys()
- )
-
+ moved_params = \
+ set(inspect.signature(config_class) .parameters.keys()) - \
+ set(inspect.signature(TrainingArguments).parameters.keys())
+
# Separate kwargs into trainer kwargs and config kwargs
trainer_kwargs = {}
additional_config_kwargs = {}
for key, value in kwargs.items():
- if key in trainer_params:
- trainer_kwargs[key] = value
+ if key in trainer_params: trainer_kwargs[key] = value
elif key in moved_params or key in config_fields:
additional_config_kwargs[key] = value
else:
additional_config_kwargs[key] = value
+ pass
+ pass
# Update config_dict with additional kwargs
config_dict.update(additional_config_kwargs)
# Create Config with all the collected parameters
- # Reinitialising config class with parameters (that were none initially but populated on first init)
- # causes the 2nd init to fail as there are mutual exclusive checks on pairs of parameters.
- # Refer: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_config.py#L499-L502 for example
- # So we only create config class if the previous init was not TrainingArguments
- if not isinstance(training_args, TrainingArguments):
- config = config_class(**config_dict)
- else:
- config = training_args
-
+ config = config_class(**config_dict)
+
# Reconstruct kwargs for Trainer
kwargs = trainer_kwargs
kwargs["args"] = config
+ pass
original_init(self, *args, **kwargs)
-
+ pass
return new_init
-
-
-def _patch_sft_trainer_auto_packing(trl_module):
- sft_trainer = getattr(trl_module, "SFTTrainer", None)
- if sft_trainer is None:
- return
- if getattr(sft_trainer, "_unsloth_auto_packing_wrapped", False):
- return
-
- original_init = sft_trainer.__init__
-
- @wraps(original_init)
- def new_init(self, *args, **kwargs):
- config_arg = None
- if len(args) >= 2:
- config_arg = args[1]
- else:
- config_arg = kwargs.get("args")
-
- # Check if model type is unsupported for padding_free
- model = kwargs.get("model")
- is_unsupported_model = False
- is_vlm = False
- if model is not None:
- model_config = getattr(model, "config", None)
- if model_config is not None:
- model_types = get_transformers_model_type(model_config)
- # Blocklist: models that don't work correctly with padding_free
- is_unsupported_model = any(
- x in PADDING_FREE_BLOCKLIST for x in model_types
- )
-
- # Check if VLM
- architectures = getattr(model_config, "architectures", None)
- if architectures is None:
- architectures = []
- is_vlm = any(
- x.endswith("ForConditionalGeneration") for x in architectures
- )
- is_vlm = is_vlm or hasattr(model_config, "vision_config")
-
- processing_class = kwargs.get("processing_class") or kwargs.get("tokenizer")
- data_collator = kwargs.get("data_collator")
-
- # We also disable vision language models for padding free collators
- blocked = (
- (data_collator is not None)
- or isinstance(processing_class, ProcessorMixin)
- or is_vlm
- or is_unsupported_model
- or (
- os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
- ) # Disable padding free on forced logits
- )
- requested_pack = bool(getattr(config_arg, "packing", False))
- if blocked:
- if hasattr(config_arg, "packing"):
- setattr(config_arg, "packing", False)
- if hasattr(config_arg, "padding_free"):
- setattr(config_arg, "padding_free", False)
-
- if blocked and requested_pack:
- reason = "custom data collator"
- if data_collator is None and isinstance(processing_class, ProcessorMixin):
- reason = "processor-based model"
- elif is_vlm:
- reason = "vision-language model"
- elif is_unsupported_model:
- reason = f"unsupported model type(s): {', '.join(model_types)}"
- message = "Unsloth: Sample packing skipped " f"({reason} detected)."
- print(message)
-
- packing_active = False
- if _should_pack(config_arg) and not blocked:
- configure_sample_packing(config_arg)
- packing_active = True
- logger.info("Unsloth: Sample packing enabled for SFTTrainer instance.")
-
- # Resolve padding_free: None (default) = auto-enable unless env-disabled or packing
- auto_padding_free_active = False
- padding_free_requested = getattr(config_arg, "padding_free", None) is True
- if not blocked:
- if padding_free_requested:
- configure_padding_free(config_arg)
- elif _should_auto_padding_free(config_arg):
- configure_padding_free(config_arg)
- auto_padding_free_active = True
- logger.info(
- "Unsloth: Padding-free batching auto-enabled for SFTTrainer instance."
- )
-
- try:
- original_init(self, *args, **kwargs)
- except ValueError as exc:
- if packing_active and _should_skip_auto_packing_error(exc):
- logger.info(
- "Unsloth: Auto sample packing failed because trainer reported an incompatible setup (%s).",
- exc,
- )
- _disable_sample_packing(config_arg)
- packing_active = False
- original_init(self, *args, **kwargs)
- else:
- raise
-
- trainer_args = getattr(self, "args", None)
- trainer_packing = bool(trainer_args and getattr(trainer_args, "packing", False))
- trainer_padding_free = bool(
- trainer_args and getattr(trainer_args, "padding_free", False)
- )
-
- if blocked and trainer_args is not None:
- # Mirror the block on the trainer args to avoid re-enabling later
- setattr(trainer_args, "packing", False)
- setattr(trainer_args, "padding_free", False)
-
- if (
- not blocked
- and trainer_packing
- and (packing_active or _should_pack(trainer_args))
- ):
- enable_sample_packing(self.model, self)
- print(
- "🦥 Unsloth: Packing enabled - training is >2x faster and uses less VRAM!"
- )
- elif not blocked and trainer_padding_free:
- enable_padding_free_metadata(self.model, self)
- message = (
- "🦥 Unsloth: Padding-free auto-enabled, enabling faster training."
- if auto_padding_free_active
- else "🦥 Unsloth: Padding-free enabled, enabling faster training."
- )
- print(message)
-
- sft_trainer.__init__ = new_init
- sft_trainer._unsloth_auto_packing_wrapped = True
+pass
def _patch_trl_trainer():
import trl
-
- if hasattr(trl, "__UNSLOTH_BACKWARDS_COMPATIBLE__"):
- return
- if Version(trl) <= Version("0.11.0"):
- return
+ if hasattr(trl, "__UNSLOTH_BACKWARDS_COMPATIBLE__"): return
+ if Version(trl.__version__) <= Version("0.11.0"): return
import trl.trainer
-
trl_classes = dir(trl.trainer)
- trl_trainers = set(
- x[: -len("Trainer")] for x in trl_classes if x.endswith("Trainer")
- )
- trl_configs = set(x[: -len("Config")] for x in trl_classes if x.endswith("Config"))
+ trl_trainers = set(x[:-len("Trainer")] for x in trl_classes if x.endswith("Trainer"))
+ trl_configs = set(x[:-len("Config")] for x in trl_classes if x.endswith("Config"))
trl_classes = list(trl_trainers & trl_configs)
for x in trl_classes:
- try:
- exec(
- f"trl.{x}Trainer.__init__ = _backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)",
- globals(),
- )
- except:
- continue
-
- _patch_sft_trainer_auto_packing(trl)
+ try: exec(f"trl.{x}Trainer.__init__ = _backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)", globals())
+ except: continue
+ pass
trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ = True
+pass
diff --git a/unsloth/utils/__init__.py b/unsloth/utils/__init__.py
deleted file mode 100644
index 9a093fedd7..0000000000
--- a/unsloth/utils/__init__.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from .packing import (
- configure_padding_free,
- configure_sample_packing,
- enable_padding_free_metadata,
- enable_sample_packing,
- mark_allow_overlength,
-)
-from .attention_dispatch import (
- AttentionConfig,
- AttentionContext,
- FLASH_DENSE,
- FLASH_VARLEN,
- SDPA,
- XFORMERS,
- run_attention,
- select_attention_backend,
-)
-
-__all__ = [
- "configure_sample_packing",
- "configure_padding_free",
- "enable_sample_packing",
- "enable_padding_free_metadata",
- "mark_allow_overlength",
- "AttentionConfig",
- "AttentionContext",
- "FLASH_VARLEN",
- "FLASH_DENSE",
- "XFORMERS",
- "SDPA",
- "run_attention",
- "select_attention_backend",
-]
diff --git a/unsloth/utils/attention_dispatch.py b/unsloth/utils/attention_dispatch.py
deleted file mode 100644
index 72d52ab376..0000000000
--- a/unsloth/utils/attention_dispatch.py
+++ /dev/null
@@ -1,353 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-"""Shared helpers for attention backend selection and execution."""
-
-from __future__ import annotations
-
-from dataclasses import dataclass
-from typing import Any, Optional, Tuple
-
-import torch
-from torch import Tensor
-from torch.nn.functional import scaled_dot_product_attention
-
-from ..models._utils import *
-from ..utils.packing import (
- build_sdpa_packed_attention_mask,
- build_xformers_block_causal_mask,
-)
-
-if HAS_FLASH_ATTENTION:
- from flash_attn import flash_attn_func, flash_attn_varlen_func
-HAS_XFORMERS = xformers is not None
-SDPA_HAS_GQA = "enable_gqa" in (scaled_dot_product_attention.__doc__ or "")
-
-FLASH_VARLEN = "flash_varlen"
-FLASH_DENSE = "flash_dense"
-XFORMERS = "xformers"
-SDPA = "sdpa"
-
-
-XFORMERS_BLOCK_DIAG_CLS = (
- xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None
-)
-
-
-@dataclass
-class AttentionConfig:
- """
- Per-layer attention metadata.
-
- NOTE(djsaunde): I had originally intended this to be populated once per layer, but
- we're currently constructing it on every forward pass since it can possibly be
- invalid from one forward pass to the next (e.g., switching from training to
- inference). For now, I'm keeping separate from AttentionContext for the sake of
- better grouping of params.
- """
-
- backend: str
- n_kv_heads: int
- n_groups: int
- flash_dense_kwargs: Optional[dict[str, Any]] = None
- flash_varlen_kwargs: Optional[dict[str, Any]] = None
- sdpa_kwargs: Optional[dict[str, Any]] = None
- xformers_kwargs: Optional[dict[str, Any]] = None
-
-
-@dataclass
-class AttentionContext:
- """Per-call info required to run attention."""
-
- bsz: int
- q_len: int
- kv_seq_len: int
- n_heads: int
- head_dim: int
- requires_grad: bool
- seq_info: Optional[Tuple[Tensor, Tensor, int]]
- attention_mask: Optional[Tensor]
- causal_mask: Optional[Any]
- sliding_window: Optional[int] = None
-
-
-def select_attention_backend(use_varlen: bool = False) -> str:
- """Return attention backend based on availability / priority order."""
-
- if HAS_FLASH_ATTENTION:
- if use_varlen:
- return FLASH_VARLEN
- else:
- return FLASH_DENSE
- if HAS_XFORMERS:
- return XFORMERS
- return SDPA
-
-
-def run_attention(
- *,
- config: AttentionConfig,
- context: AttentionContext,
- Q: Tensor,
- K: Tensor,
- V: Tensor,
-) -> Tensor:
- """
- Run attention using config / context info.
-
- Backend choice is prioritized for speed: FlashAttention when installed
- (`flash_varlen` for packed/variable-length inputs with `seq_info`, otherwise dense
- flash), then xFormers if flash is unavailable, with PyTorch SDPA as the final
- fallback (e.g., CPU or no fused kernels).
-
- Varlen flash is preferred when packing metadata is present because it avoids padding
- and keeps peak memory low. xFormers and SDPA can also handle packed batches (we
- pass a block-diagonal mask into each).
- """
-
- backend = config.backend
- if backend == FLASH_VARLEN and context.seq_info is None:
- backend = FLASH_DENSE if HAS_FLASH_ATTENTION else SDPA
-
- # [TODO] Flash attention does not support arbitrary attention masks (only
- # causal via flag). When a padding mask is present (e.g. left-padded
- # batched generation), fall back to SDPA which consumes attn_mask.
- # xFormers also does not thread context.attention_mask through, so the
- # same fallback applies.
- if context.attention_mask is not None and backend in (
- FLASH_DENSE,
- FLASH_VARLEN,
- XFORMERS,
- ):
- backend = SDPA
-
- flash_dense_kwargs = config.flash_dense_kwargs or {}
- flash_varlen_kwargs = config.flash_varlen_kwargs or {}
- sdpa_kwargs = config.sdpa_kwargs or {}
- xformers_kwargs = config.xformers_kwargs or {}
-
- bsz = context.bsz
- n_heads = context.n_heads
- q_len = context.q_len
- head_dim = context.head_dim
- kv_seq_len = context.kv_seq_len
- requires_grad = context.requires_grad
- sliding_window = context.sliding_window
-
- if backend == FLASH_VARLEN:
- Q_f = Q.transpose(1, 2).reshape(bsz * q_len, n_heads, head_dim)
- K_f = K.transpose(1, 2).reshape(bsz * q_len, config.n_kv_heads, head_dim)
- V_f = V.transpose(1, 2).reshape(bsz * q_len, config.n_kv_heads, head_dim)
- _, cu_seqlens, max_seqlen = context.seq_info
- return flash_attn_varlen_func(
- Q_f,
- K_f,
- V_f,
- cu_seqlens,
- cu_seqlens,
- max_seqlen,
- max_seqlen,
- **flash_varlen_kwargs,
- ).view(bsz, q_len, n_heads, head_dim)
- elif backend == FLASH_DENSE:
- Q_t = Q.transpose(1, 2)
- K_t = K.transpose(1, 2)
- V_t = V.transpose(1, 2)
- return flash_attn_func(Q_t, K_t, V_t, **flash_dense_kwargs).reshape(
- bsz, q_len, n_heads, head_dim
- )
- elif backend == XFORMERS:
- attn_bias = build_xformers_block_causal_mask(
- context.seq_info,
- sliding_window = sliding_window,
- base_mask = context.causal_mask,
- )
-
- Q_t = Q.transpose(1, 2)
- K_t = K.transpose(1, 2)
- V_t = V.transpose(1, 2)
-
- K_mod = K_t
- V_mod = V_t
- Q_mod = Q_t
-
- if config.n_groups != 1:
- K_mod = K_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim)
- V_mod = V_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim)
- K_mod = K_mod.expand(
- bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
- )
- V_mod = V_mod.expand(
- bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
- )
-
- if requires_grad:
- K_mod = K_mod.reshape(bsz, kv_seq_len, n_heads, head_dim)
- V_mod = V_mod.reshape(bsz, kv_seq_len, n_heads, head_dim)
- else:
- Q_mod = Q_t.view(
- bsz, q_len, config.n_kv_heads, config.n_groups, head_dim
- )
-
- has_block = XFORMERS_BLOCK_DIAG_CLS is not None and isinstance(
- attn_bias, XFORMERS_BLOCK_DIAG_CLS
- )
-
- if config.n_groups != 1 and has_block:
- if not requires_grad:
- Q_mod = Q_mod.view(
- 1, bsz * q_len, config.n_kv_heads, config.n_groups, head_dim
- )
- K_mod = K_mod.view(
- 1, bsz * kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
- )
- V_mod = V_mod.view(
- 1, bsz * kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
- )
- else:
- Q_mod = Q_mod.view(1, bsz * q_len, n_heads, head_dim)
- K_mod = K_mod.view(1, bsz * kv_seq_len, n_heads, head_dim)
- V_mod = V_mod.view(1, bsz * kv_seq_len, n_heads, head_dim)
-
- out = xformers_attention(
- Q_mod,
- K_mod,
- V_mod,
- attn_bias = attn_bias,
- **xformers_kwargs,
- )
-
- if config.n_groups != 1 and not requires_grad:
- out = out.view(bsz, q_len, config.n_kv_heads, config.n_groups, head_dim)
- out = out.reshape(bsz, q_len, n_heads, head_dim)
- else:
- out = out.view(bsz, q_len, n_heads, head_dim)
- return out
- else:
- local_mask = context.attention_mask
- is_causal_local = False
- if context.seq_info is not None and local_mask is None:
- local_mask = build_sdpa_packed_attention_mask(
- context.seq_info,
- dtype = Q.dtype,
- device = Q.device,
- sliding_window = sliding_window,
- )
- else:
- q_len_local = Q.shape[-2]
- k_len_local = K.shape[-2]
- # ---- SDPA mask normalization for left padding / 2D masks ----
- if local_mask is not None and isinstance(local_mask, torch.Tensor):
- local_mask = local_mask.to(device = Q.device)
-
- if local_mask.dim() == 2:
- # key padding keep mask: (bsz, k_len), 1/True = real token
- if local_mask.dtype == torch.bool:
- key_keep = local_mask
- else:
- # tokenizer attention_mask is typically int 0/1
- key_keep = local_mask != 0
-
- past_len = (
- k_len_local - q_len_local
- ) # works for prefill (0) and decode
- q_pos = torch.arange(
- past_len, past_len + q_len_local, device = Q.device
- )
- k_pos = torch.arange(k_len_local, device = Q.device)
-
- causal_keep = (
- k_pos[None, :] <= q_pos[:, None]
- ) # True = allowed (SDPA)
- if sliding_window is not None:
- causal_keep &= k_pos[None, :] >= (
- q_pos[:, None] - (sliding_window - 1)
- )
-
- # (bsz, 1, q_len, k_len) boolean keep mask
- local_mask = (
- causal_keep[None, None, :, :] & key_keep[:, None, None, :]
- )
-
- elif local_mask.dim() == 3:
- # (bsz, q_len, k_len) -> (bsz, 1, q_len, k_len)
- local_mask = local_mask[:, None, :, :]
-
- elif local_mask.dim() == 4:
- if local_mask.dtype != torch.bool:
- # Use boolean keep masks for better SDPA stability.
- local_mask = local_mask.eq(0)
- else:
- raise ValueError(
- f"Unsupported SDPA attention_mask rank: {local_mask.dim()}"
- )
-
- # Avoid NaNs from fully-masked rows (common with left padding).
- if local_mask.dtype == torch.bool:
- no_allowed = ~local_mask.any(
- dim = -1, keepdim = True
- ) # (bsz,1,q_len,1)
- local_mask = local_mask | no_allowed
-
- is_causal_local = local_mask is None and q_len_local == k_len_local
-
- kwargs = dict(sdpa_kwargs)
- kwargs.setdefault("attn_mask", local_mask)
- kwargs.setdefault("is_causal", is_causal_local)
-
- use_sdpa_gqa = SDPA_HAS_GQA and config.n_groups != 1
- if (
- use_sdpa_gqa
- and (not requires_grad)
- and isinstance(local_mask, torch.Tensor)
- and local_mask.dim() >= 3
- and local_mask.shape[0] > 1
- ):
- # Batched masked inference has shown row-coupled drift with SDPA GQA.
- # Fall back to explicit KV expansion for deterministic row-wise behavior.
- use_sdpa_gqa = False
-
- if use_sdpa_gqa:
- kwargs.setdefault("enable_gqa", True)
- out = scaled_dot_product_attention(Q, K, V, **kwargs)
- return out.transpose(1, 2)
-
- K_mod = K
- V_mod = V
- if config.n_groups != 1:
- K_mod = K[:, :, None, :, :].expand(
- bsz, config.n_kv_heads, config.n_groups, kv_seq_len, head_dim
- )
- V_mod = V[:, :, None, :, :].expand(
- bsz, config.n_kv_heads, config.n_groups, kv_seq_len, head_dim
- )
- K_mod = K_mod.reshape(bsz, n_heads, kv_seq_len, head_dim)
- V_mod = V_mod.reshape(bsz, n_heads, kv_seq_len, head_dim)
-
- out = scaled_dot_product_attention(
- Q.contiguous(),
- K_mod.contiguous(),
- V_mod.contiguous(),
- **kwargs,
- )
- return out.transpose(1, 2).contiguous()
-
-
-__all__ = [
- "AttentionConfig",
- "AttentionContext",
- "select_attention_backend",
- "run_attention",
-]
diff --git a/unsloth/utils/hf_hub.py b/unsloth/utils/hf_hub.py
deleted file mode 100644
index e3960ba0ce..0000000000
--- a/unsloth/utils/hf_hub.py
+++ /dev/null
@@ -1,80 +0,0 @@
-from huggingface_hub import HfApi, ModelInfo
-
-_HFAPI: HfApi = None
-
-POPULARITY_PROPERTIES = [
- "downloads",
- "downloadsAllTime",
- "trendingScore",
- "likes",
-]
-THOUSAND = 1000
-MILLION = 1000000
-BILLION = 1000000000
-
-
-def formatted_int(value: int) -> str:
- if value < THOUSAND:
- return str(value)
- elif value < MILLION:
- return f"{float(value) / 1000:,.1f}K"
- elif value < BILLION:
- return f"{float(value) / 1000000:,.1f}M"
- else:
- return f"{float(value) / 1000000000:,.1f}B"
-
-
-def get_model_info(
- model_id: str, properties: list[str] = ["safetensors", "lastModified"]
-) -> ModelInfo:
- """
- Get the model info for a specific model.
-
- properties: list[str] = See https://huggingface.co/docs/huggingface_hub/api-ref/hf_hub/hf_api/model_info
- Default properties: ["safetensors", "lastModified"], only retrieves minimal information.
- Set to None to retrieve the full model information.
- """
- global _HFAPI
- if _HFAPI is None:
- _HFAPI = HfApi()
- try:
- model_info: ModelInfo = _HFAPI.model_info(model_id, expand = properties)
- except Exception as e:
- print(f"Error getting model info for {model_id}: {e}")
- model_info = None
- return model_info
-
-
-def list_models(
- properties: list[str] = None,
- full: bool = False,
- sort: str = "downloads",
- author: str = "unsloth",
- search: str = None,
- limit: int = 10,
-) -> list[ModelInfo]:
- """
- Retrieve model information from the Hugging Face Hub.
-
- properties: list[str] = See https://huggingface.co/docs/huggingface_hub/api-ref/hf_hub/hf_api/list_models
- full: bool = Whether to retrieve the full model information, if True properties will be ignored.
- sort: str = The sort order.
- author: str = The author of the model.
- search: str = The search query for filtering models.
-
- """
- global _HFAPI
- if _HFAPI is None:
- _HFAPI = HfApi()
- if full:
- properties = None
-
- models: list[ModelInfo] = _HFAPI.list_models(
- author = author,
- search = search,
- sort = sort,
- limit = limit,
- expand = properties,
- full = full,
- )
- return models
diff --git a/unsloth/utils/packing.py b/unsloth/utils/packing.py
deleted file mode 100644
index 63a57c04da..0000000000
--- a/unsloth/utils/packing.py
+++ /dev/null
@@ -1,354 +0,0 @@
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-"""Utilities for enabling packed (padding-free) batches across Unsloth."""
-
-from __future__ import annotations
-
-import logging
-from collections import OrderedDict
-from typing import Any, Iterable, Optional, Sequence, Tuple
-
-import torch
-
-try:
- from xformers.ops.fmha.attn_bias import (
- BlockDiagonalCausalMask as _XFormersBlockMask,
- )
-except Exception:
- try:
- from xformers.attn_bias import BlockDiagonalCausalMask as _XFormersBlockMask
- except Exception:
- _XFormersBlockMask = None
-
-_XFORMERS_MASK_CACHE_MAXSIZE = 32
-_XFORMERS_MASK_CACHE: OrderedDict[Tuple[Tuple[int, ...], int], Any] = OrderedDict()
-
-
-def _window_cache_key(sliding_window: Optional[int]) -> int:
- if sliding_window is None or sliding_window <= 0:
- return 0
- return int(sliding_window)
-
-
-def _get_cached_block_mask(
- lengths: Tuple[int, ...],
- sliding_window: Optional[int],
-):
- if _XFormersBlockMask is None:
- return None
-
- window_key = _window_cache_key(sliding_window)
- cache_key = (lengths, window_key)
- cached = _XFORMERS_MASK_CACHE.get(cache_key)
- if cached is not None:
- _XFORMERS_MASK_CACHE.move_to_end(cache_key)
- return cached
-
- mask = _XFormersBlockMask.from_seqlens(list(lengths))
- if window_key and mask is not None and hasattr(mask, "make_local_attention"):
- mask = mask.make_local_attention(window_size = window_key)
-
- _XFORMERS_MASK_CACHE[cache_key] = mask
- if len(_XFORMERS_MASK_CACHE) > _XFORMERS_MASK_CACHE_MAXSIZE:
- _XFORMERS_MASK_CACHE.popitem(last = False)
- return mask
-
-
-class _TrlPackingWarningFilter(logging.Filter):
- to_filter = (
- "attention implementation is not",
- "kernels-community",
- )
-
- def filter(self, record: logging.LogRecord) -> bool:
- message = record.getMessage()
- return not any(substring in message for substring in self.to_filter)
-
-
-_TRL_FILTER_INSTALLED = False
-
-
-def _ensure_trl_warning_filter():
- global _TRL_FILTER_INSTALLED
- if _TRL_FILTER_INSTALLED:
- return
- logging.getLogger("trl.trainer.sft_trainer").addFilter(_TrlPackingWarningFilter())
- _TRL_FILTER_INSTALLED = True
-
-
-def mark_allow_overlength(module):
- """Mark a module hierarchy so padding-free batches can exceed max_seq_length."""
- if module is None:
- return
- if hasattr(module, "max_seq_length"):
- setattr(module, "_unsloth_allow_packed_overlength", True)
- children = getattr(module, "children", None)
- if children is None:
- return
- for child in children():
- mark_allow_overlength(child)
-
-
-def configure_sample_packing(config):
- """Mutate an ``SFTConfig`` so TRL prepares packed batches."""
- _ensure_trl_warning_filter()
- setattr(config, "packing", True)
- setattr(config, "padding_free", True)
- setattr(config, "remove_unused_columns", False)
-
-
-def configure_padding_free(config):
- """Mutate an ``SFTConfig`` so TRL enables padding-free batching without packing."""
- _ensure_trl_warning_filter()
- setattr(config, "padding_free", True)
- setattr(config, "remove_unused_columns", False)
-
-
-def enable_sample_packing(
- model,
- trainer,
- *,
- sequence_lengths_key: str = "seq_lengths",
-) -> None:
- """Enable runtime support for packed batches on an existing trainer."""
- if model is None or trainer is None:
- raise ValueError("model and trainer must not be None")
-
- mark_allow_overlength(model)
-
- if hasattr(trainer, "args") and hasattr(trainer.args, "remove_unused_columns"):
- trainer.args.remove_unused_columns = False
-
- collator = getattr(trainer, "data_collator", None)
- if collator is None or not hasattr(collator, "torch_call"):
- return
- if getattr(collator, "_unsloth_packing_wrapped", False):
- return
-
- if hasattr(collator, "padding_free"):
- collator.padding_free = True
- if hasattr(collator, "return_position_ids"):
- collator.return_position_ids = True
-
- original_torch_call = collator.torch_call
-
- def torch_call_with_lengths(examples: Sequence[dict]):
- batch = original_torch_call(examples)
- if examples and isinstance(examples[0], dict):
- seq_lengths: list[int] = []
- for example in examples:
- lengths = example.get(sequence_lengths_key)
- if isinstance(lengths, Iterable):
- seq_lengths.extend(int(length) for length in lengths)
- # Fallback: infer lengths from tokenized inputs when metadata is absent
- if not seq_lengths:
- for example in examples:
- ids = example.get("input_ids")
- if isinstance(ids, Iterable):
- seq_lengths.append(len(ids))
- if seq_lengths:
- batch["packed_seq_lengths"] = torch.tensor(
- seq_lengths, dtype = torch.int32
- )
- if "attention_mask" in batch:
- batch.pop("attention_mask")
- return batch
-
- collator.torch_call = torch_call_with_lengths
- collator._unsloth_packing_wrapped = True
-
-
-def enable_padding_free_metadata(model, trainer):
- """Inject seq-length metadata when padding-free batching is enabled without packing."""
- collator = getattr(trainer, "data_collator", None)
- if (
- collator is None
- or getattr(collator, "_unsloth_padding_free_lengths_wrapped", False)
- or not getattr(collator, "padding_free", False)
- ):
- return
-
- mark_allow_overlength(model)
- if hasattr(collator, "return_position_ids"):
- collator.return_position_ids = True
- if hasattr(trainer, "args") and hasattr(trainer.args, "remove_unused_columns"):
- trainer.args.remove_unused_columns = False
-
- original_torch_call = collator.torch_call
-
- def torch_call_with_padding_free_metadata(examples: Sequence[dict]):
- seq_lengths: list[int] = []
- if examples and isinstance(examples[0], dict):
- for example in examples:
- lengths = example.get("seq_lengths")
- if lengths is None:
- ids = example.get("input_ids")
- if ids is None:
- continue
- lengths = [len(ids)]
- example["seq_lengths"] = lengths
- seq_lengths.extend(lengths)
-
- batch = original_torch_call(examples)
- if seq_lengths:
- batch["packed_seq_lengths"] = torch.tensor(
- seq_lengths,
- dtype = torch.int32,
- )
- return batch
-
- collator.torch_call = torch_call_with_padding_free_metadata
- collator._unsloth_padding_free_lengths_wrapped = True
-
-
-def get_packed_info_from_kwargs(
- kwargs: dict,
- device: torch.device,
-) -> Optional[Tuple[torch.Tensor, torch.Tensor, int]]:
- """Return packed sequence metadata expected by the attention kernels."""
-
- seq_lengths = kwargs.get("packed_seq_lengths")
- if seq_lengths is None:
- return None
-
- lengths = seq_lengths.to(device = device, dtype = torch.int32, non_blocking = True)
- cu_seqlens = torch.empty(lengths.numel() + 1, dtype = torch.int32, device = device)
- cu_seqlens[0] = 0
- torch.cumsum(lengths, dim = 0, dtype = torch.int32, out = cu_seqlens[1:])
-
- max_seqlen = int(lengths.max().item())
- return lengths, cu_seqlens, max_seqlen
-
-
-def build_xformers_block_causal_mask(
- seq_info: Optional[Tuple[torch.Tensor, torch.Tensor, int]],
- *,
- sliding_window: Optional[int] = None,
- base_mask: Optional[Any] = None,
-):
- if _XFormersBlockMask is None:
- return None
- if seq_info is not None:
- seq_lengths, _, _ = seq_info
- lengths_tensor = seq_lengths.to("cpu", torch.int32)
- if lengths_tensor.numel() == 0:
- return None
- lengths = tuple(int(x) for x in lengths_tensor.tolist())
- mask = _get_cached_block_mask(lengths, sliding_window)
- else:
- mask = base_mask
-
- if (
- sliding_window is not None
- and sliding_window > 0
- and mask is not None
- and hasattr(mask, "make_local_attention")
- ):
- mask = mask.make_local_attention(window_size = sliding_window)
- return mask
-
-
-def build_sdpa_packed_attention_mask(
- seq_info: Tuple[torch.Tensor, torch.Tensor, int],
- *,
- dtype: torch.dtype,
- device: torch.device,
- sliding_window: Optional[int] = None,
-) -> torch.Tensor:
- seq_lengths, _, _ = seq_info
- total_tokens = int(seq_lengths.sum().item())
- mask = torch.full(
- (total_tokens, total_tokens),
- float("-inf"),
- dtype = dtype,
- device = device,
- )
- offset = 0
- for length in seq_lengths.tolist():
- length = int(length)
- if length <= 0:
- continue
- block = torch.zeros((length, length), dtype = dtype, device = device)
- upper = torch.triu(
- torch.ones((length, length), device = device), diagonal = 1
- ).bool()
- block = block.masked_fill(upper, float("-inf"))
- if (
- sliding_window is not None
- and sliding_window > 0
- and length > sliding_window
- ):
- idx = torch.arange(length, device = device)
- dist = idx.unsqueeze(1) - idx.unsqueeze(0)
- window_mask = dist >= sliding_window
- block = block.masked_fill(window_mask, float("-inf"))
- mask[offset : offset + length, offset : offset + length] = block
- offset += length
- return mask.unsqueeze(0).unsqueeze(0)
-
-
-def _normalize_packed_lengths(
- seq_lengths: Any,
- *,
- device: torch.device,
-) -> Optional[torch.Tensor]:
- if seq_lengths is None:
- return None
- if isinstance(seq_lengths, torch.Tensor):
- lengths = seq_lengths.to(device = device, dtype = torch.int64)
- else:
- lengths = torch.tensor(seq_lengths, device = device, dtype = torch.int64)
- if lengths.ndim != 1:
- lengths = lengths.reshape(-1)
- if lengths.numel() == 0:
- return None
- return lengths
-
-
-def mask_packed_sequence_boundaries(
- shift_labels: torch.Tensor,
- seq_lengths: Any,
- *,
- ignore_index: int = -100,
-) -> bool:
- """Mark final token of every packed sample so CE ignores boundary predictions."""
- lengths = _normalize_packed_lengths(seq_lengths, device = shift_labels.device)
- if lengths is None:
- return False
-
- flat = shift_labels.reshape(-1)
- total_tokens = flat.shape[0]
- boundary_positions = torch.cumsum(lengths, dim = 0) - 1
- valid = boundary_positions < total_tokens
- if not torch.all(valid):
- boundary_positions = boundary_positions[valid]
- if boundary_positions.numel() == 0:
- return False
- flat[boundary_positions] = ignore_index
- return True
-
-
-__all__ = [
- "configure_sample_packing",
- "configure_padding_free",
- "enable_sample_packing",
- "enable_padding_free_metadata",
- "mark_allow_overlength",
- "get_packed_info_from_kwargs",
- "build_xformers_block_causal_mask",
- "build_sdpa_packed_attention_mask",
- "mask_packed_sequence_boundaries",
-]