Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

torchchat is a small codebase showcasing the ability to run large language models (LLMs) seamlessly. With torchchat, you can run LLMs using Python, within your own (C/C++) application (desktop or server) and on iOS and Android.

> [!IMPORTANT]
> Update September 25, 2024: torchchat has multimodal support for **Llama3.2 11B**!!
>
> To try it out, finish the [Installation](#Installation) section below, then hop
> over to our [multimodal guide](docs/multimodal.md) to learn more.


## What can you do with torchchat?
- [Run models via PyTorch / Python](#running-via-pytorch--python)
Expand All @@ -18,6 +24,7 @@ torchchat is a small codebase showcasing the ability to run large language model


## Highlights

- Command line interaction with popular LLMs such as Llama 3, Llama 2, Stories, Mistral and more
- PyTorch-native execution with performance
- Supports popular hardware and OS
Expand Down Expand Up @@ -514,6 +521,13 @@ aliases.

| Model | Mobile Friendly | Notes |
|------------------|---|---------------------|
|[meta-llama/Meta-Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct)|✅|Tuned for `chat` . Alias to `llama3.2-3b`.|
|[meta-llama/Meta-Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B)|✅|Best for `generate`. Alias to `llama3.2-3b-base`.|
|[meta-llama/Llama-Guard-3-1B](https://huggingface.co/meta-llama/Llama-Guard-3-1B)|✅|Tuned for classification . Alias to `llama3-1b-guard`.|
|[meta-llama/Meta-Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct)|✅|Tuned for `chat` . Alias to `llama3.2-1b`.|
|[meta-llama/Meta-Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)|✅|Best for `generate`. Alias to `llama3.2-1b-base`.|
|[meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)||Multimodal (Image + Text). Tuned for `chat` . Alias to `llama3.2-11B`.|
|[meta-llama/Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision)||Multimodal (Image + Text). Tuned for `generate` . Alias to `llama3.2-11B-base`.|
|[meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)|✅|Tuned for `chat` . Alias to `llama3.1`.|
|[meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B)|✅|Best for `generate`. Alias to `llama3.1-base`.|
|[meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)|✅|Tuned for `chat` . Alias to `llama3`.|
Expand Down
Binary file added assets/dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
73 changes: 73 additions & 0 deletions docs/multimodal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Multimodal Models

Released on September 25th, 2024, **Llama3.2 11B Vision** is torchchat's first multimodal model.

This page goes over the different commands you can run with LLama 3.2 11B Vision.

## Model Access

> [!NOTE]
> While the commands refer to the model as some variant of "Llama 3.2 11B Vision",
> the underlying checkpoint used is based off the "Instruct" variant of the model.

**Llama3.2 11B Vision** is available via both [Hugging Face](https://huggingface.co/meta-llama) and [directly from Meta](https://www.llama.com/).

While we strongly encourage you to use the Hugging Face checkpoint (which is the default for torchchat when utilizing the commands with the argument `llama3.2-11B`), we also provide support for manually providing the checkpoint. This can be done by replacing the `llama3.2-11B` argument in the commands below with the following:

```
--checkpoint-path <file.pth> --tokenizer-path <tokenizer.model> --params-path torchchat/model_params/Llama-3.2-11B-Vision.json
```

## Generation

**We are currently debugging Multimodal Inference on MPS and will have updates soon. In the meantime, when testing on Mac, please set `--device cpu`**

This generates text output based on a text prompt and (optional) image prompt.

```
python torchchat.py generate llama3.2-11B --prompt "What's in this image?" --image-prompt assets/dog.jpg
```

## Server
This mode exposes a REST API for interacting with a model.
The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions.

To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request.
In one terminal, start the server

[skip default]: begin

```bash
python3 torchchat.py server llama3.2-11B
```
[skip default]: end

In another terminal, query the server using `curl`. This query might take a few minutes to respond.

**We are currently debugging the server integration and will have updated examples shortly.**

## Browser

This command opens a basic browser interface for local chat by querying a local server.

First, follow the steps in the Server section above to start a local server. Then, in another terminal, launch the interface. Running the following will open a tab in your browser.

[skip default]: begin

```
streamlit run torchchat/usages/browser.py
```

**We are currently debugging the browser integration and will have updated examples shortly.**

---

# Future Work

One of the goals of torchchat is to support various execution modes for every model. The following are execution modes that will be supported for **Llama3.2 11B Vision** in the near future:

- **[torch.compile](https://pytorch.org/docs/stable/torch.compiler.html)**: Optimize inference via JIT Compilation
- **[AOTI](https://pytorch.org/blog/pytorch2-2/)**: Enable pre-compiled and C++ inference
- **[ExecuTorch](https://github.com/pytorch/executorch)**: On-device (Edge) inference

In addition, we are in the process of integrating with [lm_evaluation_harness](https://github.com/EleutherAI/lm-evaluation-harness) for multimodal model evaluation.
7 changes: 2 additions & 5 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import torch._inductor.config
import torch.nn as nn

try:
from _torchchat_test_script import flamingo_meta_to_tune
except ImportError:
pass
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune

from distributed import launch_distributed, ParallelDims, parallelize_llama

Expand Down Expand Up @@ -404,7 +401,7 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
for submodule in model.modules():
if isinstance(submodule, Llama3ScaledRoPE):
submodule.__init__(head_dim, max_seq_len, rope_base)
state_dict = flamingo_meta_to_tune(checkpoint)
state_dict = llama3_vision_meta_to_tune(checkpoint)
model.model.load_state_dict(state_dict, assign=True, strict=False)
else:
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
Expand Down
22 changes: 22 additions & 0 deletions torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def convert_hf_checkpoint(
if model_name is None:
model_name = model_dir.name

# TODO: This is an incongruent way of resolving config_args
# See https://github.com/pytorch/torchchat/issues/1179
config_args = ModelArgs.from_name(model_name).transformer_args['text']
config = TransformerArgs.from_params(config_args)
print(f"Model config {config.__dict__}")
Expand Down Expand Up @@ -132,6 +134,26 @@ def permute(w, n_heads):
os.remove(file)


@torch.inference_mode()
def convert_hf_checkpoint_to_tune(
*,
model_dir: Optional[Path] = None,
model_name: str,
) -> None:
assert model_dir is not None

consolidated_pth = model_dir / "original" / "consolidated.pth"
tokenizer_pth = model_dir / "original" / "tokenizer.model"
if consolidated_pth.is_file() and tokenizer_pth.is_file():
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
os.rename(consolidated_pth, model_dir / "model.pth")
print(f"Moving tokenizer to {model_dir / 'tokenizer.model'}.")
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
print("Done.")
else:
raise RuntimeError(f"Could not find {consolidated_pth}")


if __name__ == "__main__":
import argparse

Expand Down
18 changes: 12 additions & 6 deletions torchchat/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pathlib import Path
from typing import Optional

from torchchat.cli.convert_hf_checkpoint import convert_hf_checkpoint
from torchchat.cli.convert_hf_checkpoint import convert_hf_checkpoint, convert_hf_checkpoint_to_tune
from torchchat.model_config.model_config import (
load_model_configs,
ModelConfig,
Expand Down Expand Up @@ -50,11 +50,17 @@ def _download_hf_snapshot(
else:
raise e

# Convert the model to the torchchat format.
print(f"Converting {model_config.name} to torchchat format...", file=sys.stderr)
convert_hf_checkpoint(
model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True
)
# Convert the Multimodal Llama model to the torchtune format.
if model_config.name in {"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision"}:
print(f"Converting {model_config.name} to torchtune format...", file=sys.stderr)
convert_hf_checkpoint_to_tune( model_dir=artifact_dir, model_name=model_config.name)

else:
# Convert the model to the torchchat format.
print(f"Converting {model_config.name} to torchchat format...", file=sys.stderr)
convert_hf_checkpoint(
model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True
)


def _download_direct(
Expand Down
7 changes: 2 additions & 5 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
import torch._dynamo.config
import torch._inductor.config

try:
from _torchchat_test_script import flamingo_transform
except ImportError:
pass
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform

from PIL import Image

Expand Down Expand Up @@ -753,7 +750,7 @@ def chat(
Message(role="assistant", content=""),
]

transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))

with torch.device(device=self.builder_args.device), set_default_dtype(self.dtype):
data = transform({"messages": messages}, inference=True)
Expand Down
38 changes: 38 additions & 0 deletions torchchat/model_config/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,44 @@
"distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"transformer_params_key": "Meta-Llama-3.1-70B-Tune"
},
"meta-llama/Meta-Llama-3.2-1B": {
"aliases": ["llama3.2-1b-base"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Llama-3.2-1B"
},
"meta-llama/Meta-Llama-3.2-1B-Instruct": {
"aliases": ["llama3.2-1b", "llama3.2-1b-chat", "llama3.2-1b-instruct"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Llama-3.2-1B-Instruct",
"transformer_params_key": "Meta-Llama-3.2-1B"
},
"meta-llama/Llama-Guard-3-1B": {
"aliases": ["llama3-1b-guard", "llama3.2-1b-guard"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Llama-Guard-3-1B"
},
"meta-llama/Meta-Llama-3.2-3B": {
"aliases": ["llama3.2-3b-base"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Llama-3.2-3B"
},
"meta-llama/Meta-Llama-3.2-3B-Instruct": {
"aliases": ["llama3.2-3b", "llama3.2-3b-chat", "llama3.2-3b-instruct"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Llama-3.2-3B-Instruct",
"transformer_params_key": "Meta-Llama-3.2-3B"
},
"meta-llama/Llama-3.2-11B-Vision": {
"aliases": ["llama3.2-11B-base", "Llama-3.2-11B-Vision-base"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Llama-3.2-11B-Vision"
},
"meta-llama/Llama-3.2-11B-Vision-Instruct": {
"aliases": ["llama3.2-11B", "Llama-3.2-11B-Vision", "Llama-3.2-mm"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"transformer_params_key": "Llama-3.2-11B-Vision"
},
"meta-llama/CodeLlama-7b-Python-hf": {
"aliases": ["codellama", "codellama-7b"],
"distribution_channel": "HuggingFaceSnapshot",
Expand Down
29 changes: 29 additions & 0 deletions torchchat/model_params/Llama-3.2-11B-Vision.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"model_type": "flamingo",
"use_tiktoken": true,
"encoder": {
"patch_size": 14,
"num_heads": 16,
"clip_embed_dim": 1280,
"clip_num_layers": 32,
"clip_hidden_states": [3, 7, 15, 23, 30],
"decoder_embed_dim": 4096,
"num_layers_projection": 8,
"tile_size": 560,
"max_num_tiles": 4,
"in_channels": 3
},
"decoder": {
"vocab_size": 128256,
"num_layers": 32,
"fusion_interval": 4,
"num_special_tokens": 8,
"num_heads": 32,
"num_kv_heads": 8,
"embed_dim": 4096,
"max_seq_len": 131072,
"encoder_max_seq_len": 128080,
"rope_base": 500000.0,
"intermediate_dim": 14336
}
}
20 changes: 20 additions & 0 deletions torchchat/model_params/Llama-Guard-3-1B-INT4.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"block_size": 131072,
"dim": 2048,
"hidden_dim": 6400,
"n_layers": 12,
"n_heads": 32,
"n_kv_heads": 8,
"vocab_size": 128256,
"ffn_dim_multiplier": 1.5,
"multiple_of": 256,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"rope_scaling": {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192
},
"use_tiktoken": true
}
19 changes: 19 additions & 0 deletions torchchat/model_params/Llama-Guard-3-1B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"block_size": 131072,
"dim": 2048,
"n_layers": 16,
"n_heads": 32,
"n_kv_heads": 8,
"vocab_size": 128256,
"ffn_dim_multiplier": 1.5,
"multiple_of": 256,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"rope_scaling": {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192
},
"use_tiktoken": true
}
19 changes: 19 additions & 0 deletions torchchat/model_params/Meta-Llama-3.2-1B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"block_size": 131072,
"dim": 2048,
"n_layers": 16,
"n_heads": 32,
"n_kv_heads": 8,
"vocab_size": 128256,
"ffn_dim_multiplier": 1.5,
"multiple_of": 256,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"rope_scaling": {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192
},
"use_tiktoken": true
}
19 changes: 19 additions & 0 deletions torchchat/model_params/Meta-Llama-3.2-3B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"block_size": 131072,
"dim": 3072,
"n_layers": 28,
"n_heads": 24,
"n_kv_heads": 8,
"vocab_size": 128256,
"ffn_dim_multiplier": 1.0,
"multiple_of": 256,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"rope_scaling": {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192
},
"use_tiktoken": true
}
8 changes: 3 additions & 5 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

import torch

try:
from _torchchat_test_script import flamingo_transform, padded_collate
except ImportError:
pass
from torchtune.models.llama3_2_vision._convert_weights import padded_collate
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform

from PIL import Image
from torchtune.data import Message
Expand Down Expand Up @@ -376,7 +374,7 @@ def chunked_completion(self, completion_request: CompletionRequest):
images.append(Image.open(BytesIO(base64_decoded)))
print("images:", len(images), flush=True)
if len(images) > 0:
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
torchtune_messages = self._openai_messages_to_torchtune(
completion_request.messages
)
Expand Down
Loading