In [None]:
import json
import glob
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn

In [None]:
model_path = Path("~/.cache/huggingface/hub/models--mlx-community--Meta-Llama-3-8B-Instruct-4bit/snapshots/c38b3b1f03cce0ce0ccd235e5c97b0d3d255e651").expanduser()

In [None]:
with open(model_path/"config.json", "r") as f:
  config = json.load(f)

print(config["quantization"])

In [None]:
weight_files = glob.glob(str(model_path/"model*.safetensors"))

In [None]:
weight_files

In [None]:
temp_weight = mx.load(weight_files[0])

In [None]:
temp_weight.keys()

In [None]:
temp_weight["model.layers.0.self_attn.q_proj.weight"]

In [None]:
temp_weight["model.layers.0.self_attn.q_proj.scales"]

In [None]:
temp_weight["model.layers.0.self_attn.q_proj.biases"]

In [None]:
from exo.inference.mlx.models.llama import LlamaModel, ModelArgs

args = ModelArgs.from_dict(config)
model = LlamaModel(args)


In [None]:
leaves = model.leaf_modules()

In [None]:
from mlx.utils import tree_map_with_path
from mlx.nn.layers.base import Module

In [None]:
def class_predicate(p, m):
    # print(m, hasattr(m, "to_quantized"), f"{p}.scales", f"{p}.scales" in temp_weight)
    if not hasattr(m, "to_quantized"):
        return False
    return f"model.{p}.scales" in temp_weight

def _maybe_quantize(path, m):
    if class_predicate(path, m):
        # print("hahahaha")
        if hasattr(m, "to_quantized"):
            # print("hahahaha")
            k = m.to_quantized(64, 4)
            print(k)
            return k
        else:
            raise ValueError(f"Unable to quantize model of type {type(m)}")
    else:
        return m

In [None]:
q_leaves = tree_map_with_path(_maybe_quantize, leaves, is_leaf=Module.is_module)

In [None]:
leaves["layers"][0]["self_attn"]["o_proj"]

In [None]:
q_leaves["layers"][0]["self_attn"]["o_proj"]

In [None]:
ll = nn.Linear(10, 10)

In [None]:
hasattr(ll, "to_quantized")

In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-2-7b-bnb-4bit")
model = AutoModelForCausalLM.from_pretrained("unsloth/llama-2-7b-bnb-4bit")

In [None]:
!pip install bitsandbytes

In [None]:
ans = set()
for t in temp_weight.keys():
    if "scales" in t:
        if ".".join(t.split(".")[3:-1]) == "":
            print(t)
        ans.add(".".join(t.split(".")[3:-1]))

ans

In [None]:
patch = nn.Conv2d(3, 1024, 16, 16)

In [None]:
from transformers import LlavaForConditionalGeneration, AutoProcessor

In [None]:
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

In [None]:
from PIL import Image    
import requests

In [None]:
IMG_URLS = [
    "https://picsum.photos/id/237/400/300",
    "https://picsum.photos/id/231/200/300",
    "https://picsum.photos/id/27/500/500",
    "https://picsum.photos/id/17/150/600",
]
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"

inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt")

In [None]:
len(inputs["pixel_values"][0])

In [None]:
ll = mx.ones((20,))

In [None]:
mx.multiply(ll, ll) == ll * ll

In [None]:
mx.multiply(ll, ll).shape

In [None]:
lp = ll * ll

In [None]:
ll * 2

In [None]:
freqs = 1.0 / (1000000000.0 ** (mx.arange(0, 1024, 2) / 1024))

In [None]:
h = mx.arange(1024//16)

In [None]:
freqs_h = mx.outer(h, freqs[::2])

In [None]:
freqs_h.repeat(1, 1024//16, 1)

In [None]:
mx.tile(freqs_h[: None, :], (1, 1024//16, 1)).shape

In [None]:
freqs_h.shape

In [None]:
import torch

In [None]:
ll = torch.ones((1, 1024))

In [None]:
ll.tobytes()

In [None]:
import torch
from transformers import LlavaForConditionalGeneration, AutoProcessor

model_id = "hf-internal-testing/pixtral-12b"
hf_model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16)

In [None]:
model.language_model.save_pretrained("./../mistral_weights/",  save_peft_format=False)

In [None]:
model.vision_tower.save_pretrained("./../pixtral_weights/",  save_peft_format=False)

In [None]:
sum(p.numel() for p in model.parameters())

In [None]:
sum(p.numel() for p in model.language_model.parameters())

In [None]:
processor = AutoProcessor.from_pretrained(model_id)

IMG_URLS = [
    "https://picsum.photos/id/237/400/300",
    "https://picsum.photos/id/231/200/300",
    "https://picsum.photos/id/27/500/500",
    "https://picsum.photos/id/17/150/600",
]
PROMPT = "<s>[INST]Describe the images in one sentence.\n[IMG][IMG][IMG][IMG][/INST]"

inputs = processor(images=IMG_URLS, text=PROMPT, return_tensors="pt")

In [None]:
generate_ids = model.language_model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=100)

In [None]:
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

In [None]:
ouptut

In [None]:
processor.tokenizer.save_pretrained("./../mistral_weights/")

In [None]:
model.language_model.lm_head.weight

In [None]:
inputs["input_ids"]

In [None]:
out = model.language_model.forward(inputs["input_ids"])

In [None]:
model.language_model.model.layers[10].mlp.gate_proj.weight

In [None]:
model.language_model.model.embed_tokens.weight

In [None]:
out.logits.shape

In [None]:
out.logits.dtype

In [None]:
out.logits

In [None]:
model.language_model.model.embed_tokens(inputs["input_ids"])

In [None]:
model.vision_tower.forward(inputs["pixel_values"][0])

In [None]:
import mlx.core as mx
# from exo.inference.mlx.models.pixtral import PixtralModel, PixtralVisionConfig
weights = {}
weights.update(mx.load("./../pixtral_weights/model.safetensors"))

# import json
# with open("./../pixtral_weights/config.json", "r") as f:
#     config = json.load(f)
# vision_config = PixtralVisionConfig.from_dict(config)
# model = PixtralModel(vision_config)
# sanitized_weights = model.sanitize(weights)
# model.load_weights(list(sanitized_weights.items()), strict=True)

In [None]:
_max = -10000000
_min = 10000000
for key in weights:
    _max = max(_max, weights[key].max())
    _min = min(_min, weights[key].min())

In [None]:
_max, _min

In [None]:
weights["transformer.layers.9.ffn_norm.weight"].astype(mx.float16).min()

In [None]:
weights["transformer.layers.9.ffn_norm.weight"].min()

In [None]:
hf_model.vision_tower.forward(inputs["pixel_values"][0])

In [None]:
pixel_vals = [mx.array(x) for x in inputs["pixel_values"][0]]

In [None]:
model(pixel_vals)

In [None]:
for key, value in weights.items():
    if "patch_conv" in key:
        print(weights[key].transpose(0, 2, 3, 1).shape)

In [None]:
hf_model.vision_tower.patch_conv

In [None]:
model.patch_conv

In [None]:

model.patch_conv.weight

In [None]:
hf_model.vision_tower.patch_conv.weight

In [None]:
patch_embeds_list = [hf_model.vision_tower.patch_conv(img.unsqueeze(0).to(torch.bfloat16)) for img in inputs["pixel_values"][0]]
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = hf_model.vision_tower.ln_pre(patch_embeds)

# hf_generate_block_attention_mask([p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds)
print(patch_embeds_list[0].shape)

In [None]:
import numpy as np

patch_embeds_list = [model.patch_conv(mx.expand_dims(mx.array(img), axis=0).transpose(0, 2, 3, 1)) for img in inputs["pixel_values"][0]]

patch_embeds = mx.concatenate([p.flatten(1, 2) for p in patch_embeds_list], axis=1)
patch_embeds = model.ln_pre(patch_embeds)

# print(patch_embeds_list[0].shape)

mlx_generate_block_attention_mask(
            [p.shape[1] * p.shape[2] for p in patch_embeds_list], patch_embeds
        )

In [None]:
def hf_generate_block_attention_mask(patch_embeds_list, tensor):
    dtype = tensor.dtype
    device = tensor.device
    seq_len = tensor.shape[1]
    d_min = torch.finfo(dtype).min
    causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device)

    block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1)
    block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1)
    for start, end in zip(block_start_idx, block_end_idx):
        print(start, end)
        causal_mask[start:end, start:end] = 0

    causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1)
    return causal_mask

def mlx_generate_block_attention_mask(patch_embeds_list, input_array):
    dtype = np.array(input_array).dtype
    seq_len = input_array.shape[1]
    dmin = np.finfo(dtype).min
    causal_mask = mx.full((seq_len, seq_len), dmin, dtype=input_array.dtype)

    block_end_idx = mx.cumsum(mx.array(patch_embeds_list), axis=-1).tolist()
    block_start_idx = mx.cumsum(mx.array([0] + patch_embeds_list[:-1]), axis=-1).tolist()
    
    for start, end in zip(block_start_idx, block_end_idx):
        print(start, end)
        causal_mask[start:end, start:end] = 0

    causal_mask = mx.broadcast_to(causal_mask[None, None, :, :], (input_array.shape[0], 1, seq_len, seq_len))
    return causal_mask

In [None]:
mlx_conv = model.patch_conv(mx.expand_dims(mx.array(inputs["pixel_values"][0][0]), axis=0).transpose(0, 2, 3, 1))

In [None]:
torch_conv
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in [torch_conv]], dim=1)

In [None]:
patch_embeds.shape

In [None]:
torch_conv.shape

In [None]:
mx.array(torch_conv.flatten(2).permute(0, 2, 1).to(torch.float32).detach().numpy()) == mlx_conv.flatten(1, 2)

In [None]:
mlx_conv.flatten(1, 2).shape

In [None]:
torch_result = mx.array(torch_conv.flatten(2).permute(0, 2, 1).detach().numpy())
mlx_result = mlx_conv.flatten(1, 2)

are_close = mx.allclose(torch_result, mlx_result, atol=1e-4, rtol=1e-4)


In [None]:
mlx_result

In [None]:
torch_result

In [None]:
diff = mx.abs(torch_result - mlx_result)

In [None]:
torch_result.shape

In [None]:
def torch_position_ids_in_meshgrid(patch_embeds_list, max_width):
    positions = []
    for patch in patch_embeds_list:
        height, width = patch.shape[-2:]
        mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
        # return torch.stack(mesh, dim=-1).reshape(-1, 2)
        h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
        ids = h_grid * max_width + v_grid
        # return ids
        positions.append(ids[:, 0])
    return torch.cat(positions)

In [None]:
torch_position_ids_in_meshgrid([torch_conv], 1024//16).shape

In [None]:
torch_conv.shape

In [None]:
mlx_conv.shape

In [None]:
def mlx_position_ids_in_meshgrid(patch_embeds_list, max_width):
    positions = []
    for patch in patch_embeds_list:
        height, width = patch.shape[1:3]
        mesh = mx.meshgrid(mx.arange(height), mx.arange(width), indexing="ij")
        h_grid, v_grid = mesh[0].reshape(-1), mesh[1].reshape(-1)
        ids = h_grid * max_width + v_grid
        positions.append(ids)
    return mx.concatenate(positions)

In [None]:
mlx_position_ids = mlx_position_ids_in_meshgrid([mlx_conv], 1024//16)
mlx_position_ids.shape

In [None]:
kp.shape

In [None]:
l.shape

In [None]:
kp.shape

In [None]:
torch_position_ids = torch_position_ids_in_meshgrid([torch_conv], 1024//16)
torch_position_ids.shape

In [None]:
jp.shape

In [None]:
import numpy as np

In [None]:
(np.array(kp) == jp.detach().numpy()).all()

In [None]:
mx.expand_dims(mx.array(inputs["pixel_values"][0][0]), axis=0).shape

In [None]:
inputs["pixel_values"][0][0].unsqueeze(0).shape

In [None]:
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in [torch_conv]], dim=1)

In [None]:
torch_result

In [None]:
mlx_result

In [None]:
hf_position_embedding = hf_model.vision_tower.patch_positional_embedding(patch_embeds, torch_position_ids)

In [None]:
hf_position_embedding[0].shape

In [None]:
hf_position_embedding[1].shape

In [None]:
mlx_position_embeddings = model.patch_positional_embedding("1", mlx_position_ids)

In [None]:
mlx_position_embeddings[0].shape

In [None]:
hf_position_embedding[0].shape

In [None]:
mlx_position_embeddings[0] == hf_position_embedding[0].detach().numpy()

In [None]:
mlx_position_embeddings

In [None]:
def hf_generate_block_attention_mask(patch_embeds_list, tensor):
    dtype = tensor.dtype
    device = tensor.device
    seq_len = tensor.shape[1]
    d_min = torch.finfo(dtype).min
    causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device)

    block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1)
    block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1)
    for start, end in zip(block_start_idx, block_end_idx):
        causal_mask[start:end, start:end] = 0

    causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1)
    return causal_mask

def mlx_generate_block_attention_mask(patch_embeds_list, input_array):
    dtype = np.array(input_array).dtype
    seq_len = input_array.shape[1]
    dmin = np.finfo(dtype).min
    causal_mask = mx.full((seq_len, seq_len), dmin, dtype=input_array.dtype)

    block_end_idx = mx.cumsum(mx.array(patch_embeds_list), axis=-1).tolist()
    block_start_idx = mx.cumsum(mx.array([0] + patch_embeds_list[:-1]), axis=-1).tolist()
    
    for start, end in zip(block_start_idx, block_end_idx):
        causal_mask[start:end, start:end] = 0

    causal_mask = mx.broadcast_to(causal_mask[None, None, :, :], (input_array.shape[0], 1, seq_len, seq_len))
    return causal_mask

In [None]:
def mlx_generate_block_attention_mask(patch_embeds_list, input_array):
    dtype = np.array(input_array).dtype
    seq_len = input_array.shape[1]
    dmin = np.finfo(dtype).min
    causal_mask = mx.full((seq_len, seq_len), dmin, dtype=input_array.dtype)

    block_end_idx = mx.cumsum(mx.array(patch_embeds_list), axis=-1).tolist()
    block_start_idx = mx.cumsum(mx.array([0] + patch_embeds_list[:-1]), axis=-1).tolist()
    
    for start, end in zip(block_start_idx, block_end_idx):
        causal_mask[start:end, start:end] = 0

    causal_mask = mx.broadcast_to(causal_mask[None, None, :, :], (input_array.shape[0], 1, seq_len, seq_len))
    return causal_mask

In [None]:
jp = hf_generate_block_attention_mask(
            [p.shape[-2] * p.shape[-1] for p in [torch_conv]], patch_embeds
        )

In [None]:
mlx_patch_embeds = mx.concatenate([p.flatten(1, 2) for p in [mlx_conv]], axis=1)
kp = mlx_generate_block_attention_mask(
            [p.shape[-2] * p.shape[-1] for p in [mlx_conv]], mlx_patch_embeds
        )

In [None]:
kp.shape

In [None]:
jp.shape

In [None]:
jp.detach().numpy() == kp

In [None]:
(jp.to(torch.float32).detach().numpy() == kp).all()

In [None]:
import torch
from transformers import LlavaForConditionalGeneration, AutoProcessor

model_id = "hf-internal-testing/pixtral-12b"
hf_model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16)

In [None]:
processor = AutoProcessor.from_pretrained(model_id)

IMG_URLS = [
    "https://picsum.photos/id/237/400/300",
    "https://picsum.photos/id/231/200/300",
    "https://picsum.photos/id/27/500/500",
    "https://picsum.photos/id/17/150/600",
]
PROMPT = "<s>[INST]Describe the images in one sentence.\n[IMG][/INST]"

inputs = processor(PROMPT, IMG_URLS[:1], return_tensors="pt")

In [None]:
inputs["input_ids"].shape

In [None]:
inputs["input_ids"]

In [None]:
kp = hf_model.forward(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"][0], attention_mask=inputs["attention_mask"])

In [None]:
kp = hf_model.vision_tower.forward(inputs["pixel_values"][0])

In [None]:
import mlx.core as mx
from exo.inference.mlx.models.pixtral import PixtralModel, PixtralVisionConfig
weights = {}
weights.update(mx.load("./../pixtral_weights/model.safetensors"))

import json
with open("./../pixtral_weights/config.json", "r") as f:
    config = json.load(f)
vision_config = PixtralVisionConfig.from_dict(config)
model = PixtralModel(vision_config)
sanitized_weights = model.sanitize(weights)
model.load_weights(list(sanitized_weights.items()), strict=True)

In [None]:
pixel_vals = [mx.array(x) for x in inputs["pixel_values"][0]]

In [None]:
kp = hf_model.vision_tower.forward(inputs["pixel_values"][0])

In [None]:
jp = model(pixel_vals)

In [None]:
kp[0]

In [None]:
jp

In [None]:
torch.finfo(torch.bfloat16).min

In [None]:
float('-inf') + 10000000

In [None]:
import numpy as np

# Initialize the arrays
self = np.zeros((1, 5), dtype=int)
mask = np.array([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=bool)
source = np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])

# Use boolean indexing to perform the masked scatter
self[mask] = source[mask]

# indices = mx.argwhere(special_image_mask)
#     inputs_embeds = mx.scatter(inputs_embeds, indices, image_features.reshape(-1, image_features.shape[-1]))

print(self)

In [None]:
self = torch.tensor([0, 0, 0, 0, 0])
mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
self.masked_scatter(mask, source)

In [None]:
import mlx.core as mx

In [None]:
import mlx.core as mx

# Initialize the arrays
self = mx.zeros((2, 5), dtype=mx.int32)
mask = mx.array([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=mx.bool_)
source = mx.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])

# Create the result array
result = mx.zeros_like(self)

# Flatten the source array
flat_source = source.reshape(-1)

# Manually update the result array
scatter_index = 0
for i in range(mask.shape[0]):
    for j in range(mask.shape[1]):
        if mask[i, j]:
            result = result.at[i, j].add(flat_source[scatter_index])
            scatter_index += 1

print(result)

In [None]:
import mlx.core as mx

def masked_scatter(self1, mask, source):
    """
    Scatter values from source into self at positions where mask is True.
    
    Args:
    self (mx.array): The array to be modified.
    mask (mx.array): A boolean mask of the same shape as self.
    source (mx.array): The array containing values to be scattered.
    
    Returns:
    mx.array: A new array with values from source scattered into self where mask is True.
    
    Note:
    This function assumes broadcasting rules similar to PyTorch.
    """
    mask = mx.array(mask, dtype=mx.bool_)
    
    # Ensure shapes are compatible
    if self1.shape != mask.shape:
        raise ValueError("Shapes of self and mask must be the same")
    
    # Flatten arrays
    self_flat = self1.reshape(-1)
    mask_flat = mask.reshape(-1)
    source_flat = source.reshape(-1)
    
    # Ensure source has enough elements
    num_true = int(mx.sum(mask_flat))
    if source_flat.size < num_true:
        raise ValueError("Source array does not have enough elements to scatter.")
    
    # Create output array
    output_flat = mx.where(mask_flat, source_flat[:num_true], self_flat)
    
    # Reshape output to original shape
    return output_flat.reshape(self.shape)

# Example usage
self1 = mx.zeros(5)
mask = mx.array([[False, False, False, True, True],
                    [True, True, False, True, True]], dtype=mx.bool_)
source = mx.array([[0, 1, 2, 3, 4],
                    [5, 6, 7, 8, 9]])

result = masked_scatter(self1, mask, source)
print(result)

In [None]:
self = torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
mask = torch.tensor([[1, 1, 1, 1, 1], [1, 0, 1, 1, 1]], dtype=torch.bool)
source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
self.masked_scatter(mask, source)

In [None]:
def masked_scatter(inputs_embeds, special_image_mask, image_features):
    # Flatten all arrays
    flat_result = inputs_embeds.reshape(-1)
    flat_mask = special_image_mask.reshape(-1)
    flat_source = image_features.reshape(-1)
    
    source_idx = 0
    for i in range(flat_result.size()[0]):
        if flat_mask[i]:
            flat_result[i] = flat_source[source_idx]
            source_idx += 1
                
    # Reshape the result back to the original shape
    return flat_result.reshape(inputs_embeds.shape)

In [None]:
def masked_scatter(inputs_embeds, special_image_mask, image_features):
    # Flatten all arrays
    flat_result = inputs_embeds.reshape(-1)
    flat_mask = special_image_mask.reshape(-1)
    flat_source = image_features.reshape(-1)
    
    source_idx = 0
    for i in range(flat_result.size):
        if flat_mask[i]:
            flat_result[i] = flat_source[source_idx]
            source_idx += 1
            
            # Wrap around if we've used all source elements
            if source_idx >= flat_source.size:
                raise Exception("Number of elements of source < number of ones in mask")
    # Reshape the result back to the original shape
    return flat_result.reshape(inputs_embeds.shape)

In [None]:
def masked_scatter(inputs_embeds, special_image_mask, image_features):
    # Flatten the arrays
    flat_result = np.array(inputs_embeds).ravel()
    flat_mask = np.array(special_image_mask).ravel()
    flat_source = np.array(image_features).ravel()
    
    # Get indices where mask is True
    indices = np.flatnonzero(flat_mask)
    print(type(indices))
    
    # Check if there are enough elements in flat_source
    num_masked = indices.size
    if flat_source.size < num_masked:
        raise Exception("Number of elements of source < number of ones in mask")
    
    # Perform the assignment using vectorized operations
    flat_result[indices] = flat_source[:num_masked]
    
    # Reshape the result back to the original shape
    return flat_result.reshape(inputs_embeds.shape)

In [None]:
# self = torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])
# mask = torch.tensor([[1, 1, 1, 1, 1], [1, 0, 1, 1, 1]], dtype=torch.bool)
# source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])

import torch
import mlx.core as mx

self = torch.randn(1, 505, 5120)
source = torch.randn(1, 475, 5120)
mask = torch.randint(0, 2, (1, 505, 5120), dtype=torch.bool)

mx_self = mx.array(self.detach().numpy())
mx_source = mx.array(source.detach().numpy())
mx_mask = mx.array(mask.detach().numpy())


In [None]:
ll = mx.array([1,2,3], dtype=mx.bfloat16)

In [None]:
ll[[1,2,3]]

In [None]:
kp = self.masked_scatter(mask, source) 

In [None]:
import numpy as np
jp = masked_scatter(mx_self, mx_mask, mx_source)

In [None]:
(jp == kp.detach().numpy()).all()

In [None]:
jp

In [None]:
kp.detach().numpy()

In [None]:
ll = mx.array([1, 2, 3], dtype=mx.bfloat16)
lk = mx.array([0, 1, 0], dtype=mx.bool_)

In [None]:
ko = mx.array(np.flatnonzero(lk))

In [None]:
ko

In [None]:
ll[ko] = ll[:1]

In [None]:
ll

In [None]:
ll

In [None]:
ll.reshape(-1)

In [None]:
ll = mx.array([1, 2, 3])

In [None]:
mx.arange(ll.size)[ll]

In [None]:
ll[mx.array([1,2,0])]

In [None]:
import mlx

array = mlx.core.array([True, False, True, True, False])
indices = mlx.core.where(array)[0]

print(indices)  # Output: [0 2 3]

In [None]:
[4 096, 224]	
F16

model.layers.0.mlp.down_proj.scales	[4 096, 224]	
F16

model.layers.0.mlp.down_proj.weight	[4 096, 1 792]	
U32

In [None]:
import os
os.environ["CLANG"] = "1"
import numpy as np
import mlx.core as mx
from tinygrad import Tensor

In [None]:
w = np.random.randint(0, 9, size=(1024, 512), dtype=np.uint32)
s = np.random.rand(1024, 64).astype(np.float16)
b = np.random.rand(1024, 64).astype(np.float16)
x = np.random.rand(120, 4096).astype(np.float16)

In [None]:
import numpy as np

def quantized_matmul(x, w_packed, scales, biases, width=4, groups=64):
    """
    Perform quantized matrix multiplication between input x and quantized weights w_packed.
    
    Parameters:
    - x: np.ndarray of shape (M, K), input activations
    - w_packed: np.ndarray of shape (N, K_packed), packed quantized weights
    - scales: np.ndarray of shape (N, K // groups), scales for dequantization
    - biases: np.ndarray of shape (N, K // groups), biases for dequantization
    - width: int, number of bits per quantized value (default is 4 bits)
    - groups: int, number of quantization groups (default is 64)
    
    Returns:
    - output: np.ndarray of shape (M, N), result of the quantized matrix multiplication
    """
    M, K = x.shape
    N, K_packed = w_packed.shape
    num_values_per_uint32 = 32 // width
    K_unpacked = K_packed * num_values_per_uint32
    num_groups = K // groups

    assert K == K_unpacked, f"Mismatch in K dimensions: {K} vs {K_unpacked}"
    assert scales.shape == biases.shape == (N, num_groups), "Scales and biases must have shape (N, K // groups)"
    assert K % groups == 0, "K must be divisible by the number of groups"

    # Prepare bitmask and shifts for unpacking
    bitmask = (1 << width) - 1
    shifts = np.arange(num_values_per_uint32) * width

    # Reshape x for group-wise processing
    x_grouped = x.reshape(M, num_groups, groups)

    # Initialize the output matrix
    output = np.zeros((M, N), dtype=np.float32)

    # Process each group
    for g in range(num_groups):
        # Extract scales and biases for the current group
        scale_g = scales[:, g].astype(np.float32)  # Shape: (N,)
        bias_g = biases[:, g].astype(np.float32)   # Shape: (N,)

        # Calculate the start and end indices for the packed weights of the current group
        packs_per_group = groups // num_values_per_uint32  # Number of uint32 packs per group
        pack_start = g * packs_per_group
        pack_end = pack_start + packs_per_group

        # Extract the packed weights for the current group
        w_packed_group = w_packed[:, pack_start:pack_end]  # Shape: (N, packs_per_group)

        # Unpack the quantized weights on-the-fly
        w_quantized_group = np.zeros((N, groups), dtype=np.uint8)  # Shape: (N, groups)
        for i, shift in enumerate(shifts):
            w_values = (w_packed_group >> shift) & bitmask  # Shape: (N, packs_per_group)
            indices = np.arange(i, groups, num_values_per_uint32)
            w_quantized_group[:, indices] = w_values

        # Dequantize the unpacked weights for the current group
        w_group = w_quantized_group.astype(np.float32)
        w_group = w_group * scale_g[:, np.newaxis] + bias_g[:, np.newaxis]  # Shape: (N, groups)

        # Extract the corresponding input activations for the current group
        x_group = x_grouped[:, g, :]  # Shape: (M, groups)

        # Perform the partial matrix multiplication and accumulate the results
        output += np.dot(x_group, w_group.T)  # Shape: (M, N)

    return output

In [None]:
# %timeit quantized_matmul(x, w, s, b)
quantized_matmul(x, w, s, b)

In [None]:
lm = mx.quantized_matmul(mx.array(x), mx.array(w), scales=mx.array(s), biases=mx.array(b), transpose=True)

In [None]:
from tinygrad import Tensor
from tinygrad.dtype import dtypes

def quantized_matmul_tg(x, w_packed, scales, biases, width=4, groups=64):
    """
    Perform quantized matrix multiplication between input x and quantized weights w_packed using tinygrad Tensors.
    
    Parameters:
    - x: Tensor of shape (M, K), input activations
    - w_packed: Tensor of shape (N, K_packed), packed quantized weights (dtype=Tensor.int32)
    - scales: Tensor of shape (N, K // groups), scales for dequantization (dtype=Tensor.float32)
    - biases: Tensor of shape (N, K // groups), biases for dequantization (dtype=Tensor.float32)
    - width: int, number of bits per quantized value (default is 4 bits)
    - groups: int, number of quantization groups (default is 64)
    
    Returns:
    - output: Tensor of shape (M, N), result of the quantized matrix multiplication
    """
    M, K = x.shape
    N, K_packed = w_packed.shape
    num_values_per_uint32 = 32 // width
    K_unpacked = K_packed * num_values_per_uint32
    num_groups = K // groups

    assert K == K_unpacked, f"Mismatch in K dimensions: {K} vs {K_unpacked}"
    assert scales.shape == (N, num_groups), f"Scales must have shape (N, {num_groups}), but is {scales.shape}"
    assert biases.shape == (N, num_groups), f"Biases must have shape (N, {num_groups}), but is {biases.shape}"
    assert K % groups == 0, "K must be divisible by the number of groups"

    # Prepare bitmask and shifts for unpacking
    bitmask = (1 << width) - 1  # e.g., for width=4, bitmask=0b1111
    shifts = Tensor.arange(num_values_per_uint32, dtype=dtypes.uint32) * width  # Tensor of shifts

    packs_per_group = groups // num_values_per_uint32  # Number of uint32 packs per group

    # Reshape x for group-wise processing
    x_grouped = x.reshape(M, num_groups, groups)

    # Initialize the output matrix
    output = Tensor.zeros((M, N), dtype=dtypes.float32)

    # Process each group
    for g in range(num_groups):
        # Extract scales and biases for the current group
        scale_g = scales[:, g]  # Shape: (N,)
        bias_g = biases[:, g]   # Shape: (N,)

        # Calculate the start and end indices for the packed weights of the current group
        pack_start = g * packs_per_group
        pack_end = pack_start + packs_per_group

        # Extract the packed weights for the current group
        w_packed_group = w_packed[:, pack_start:pack_end]  # Shape: (N, packs_per_group)
        
        # Unpack the quantized weights on-the-fly
        w_quantized_group = Tensor.zeros((N, groups), dtype=dtypes.uint8)  # Shape: (N, groups)
        print(w_quantized_group.shape)
        for i, shift in enumerate(shifts):
            w_values = (w_packed_group >> shift.item()) & bitmask  # Shape: (N, packs_per_group)
            # w_values.squee
            print(w_values.numpy().shape)
            indices = Tensor.arange(i, groups, num_values_per_uint32)
            print(indices.shape)
            w_quantized_group[:, indices] = w_values

        # Dequantize the unpacked weights for the current group
        w_group = w_quantized_group.astype(np.float32)
        w_group = w_group * scale_g[:, np.newaxis] + bias_g[:, np.newaxis]  # Shape: (N, groups)

        # Extract the corresponding input activations for the current group
        x_group = x_grouped[:, g, :]  # Shape: (M, groups)

        # Perform the partial matrix multiplication and accumulate the results
        output += x_group.dot(w_group.transpose())  # Shape: (M, N)

    return output

In [None]:
quantized_matmul_tg(Tensor(x), Tensor(w), Tensor(s), Tensor(b))

In [None]:
ll = Tensor.arange(32, dtype=dtypes.uint32)

In [None]:
ll.numpy()

In [None]:
lm = ll*10

In [None]:
lm.numpy()

In [None]:
shifts = np.arange(8) * 4

In [None]:
for shift in shifts:
    print(type(shift))

In [None]:
ll = Tensor.arange(32, dtype=dtypes.uint32)

In [None]:
ll = ll >> 1

In [None]:
ll.numpy()

In [None]:
ll.numpy()

In [None]:
np.arange(32) >> 1

In [None]:
ll = Tensor.randint((3, 2)).numpy()
ll

In [None]:
ll.reshape(-1)

In [None]:
from tinygrad.dtype import dtypes

def quantized_matmul_tg(x, w_packed, scales, biases, width=4, groups=64):
    """
    Perform quantized matrix multiplication using tinygrad Tensors with shift operators.

    Parameters:
    - x: Tensor of shape (M, K), input activations.
    - w_packed: Tensor of shape (N, K_packed), packed quantized weights (dtype=dtypes.int32).
    - scales: Tensor of shape (N, K // groups), scales for dequantization (dtype=dtypes.float32).
    - biases: Tensor of shape (N, K // groups), biases for dequantization (dtype=dtypes.float32).
    - width: int, number of bits per quantized value (default is 4 bits).
    - groups: int, number of quantization groups (default is 64).

    Returns:
    - output: Tensor of shape (M, N), result of the quantized matrix multiplication.
    """
    M, K = x.shape
    N, K_packed = w_packed.shape

    num_values_per_uint32 = 32 // width  # E.g., for width=4, this is 8
    K_unpacked = K_packed * num_values_per_uint32
    num_groups = K // groups
    packs_per_group = groups // num_values_per_uint32  # Number of uint32 packs per group

    assert K == K_unpacked, f"Mismatch in K dimensions: {K} vs {K_unpacked}"
    assert scales.shape == (N, num_groups), f"Scales must have shape (N, {num_groups}), got {scales.shape}"
    assert biases.shape == (N, num_groups), f"Biases must have shape (N, {num_groups}), got {biases.shape}"
    assert K % groups == 0, "K must be divisible by the number of groups"

    # Prepare bitmask
    bitmask = (1 << width) - 1  # E.g., for width=4, bitmask=15

    # Reshape x for group-wise processing
    x_grouped = x.reshape(M, num_groups, groups)  # Shape: (M, num_groups, groups)

    # Initialize the output matrix
    output = Tensor.zeros((M, N), dtype=dtypes.float16)

    # Prepare shift amounts
    shift_list = [i * width for i in range(num_values_per_uint32)]

    # Process each group
    for g in range(num_groups):
        # Extract scales and biases for the current group
        scale_g = scales[:, g].reshape(N, 1)  # Shape: (N, 1)
        bias_g = biases[:, g].reshape(N, 1)   # Shape: (N, 1)

        # Extract the packed weights for the current group
        pack_start = g * packs_per_group
        pack_end = pack_start + packs_per_group
        w_packed_group = w_packed[:, pack_start:pack_end]  # Shape: (N, packs_per_group)

        # Initialize a list to collect unpacked values
        unpacked_values = []

        # Unpack the quantized weights
        for shift_amount in shift_list:
            # Perform the shift and mask operations
            shifted = w_packed_group >> shift_amount  # Broadcasting scalar shift_amount
            masked = (shifted & bitmask).cast(dtypes.float16)
            masked = masked.reshape(N, -1)  # Flatten over packs_per_group

            unpacked_values.append(masked)

        # Stack the unpacked values and transpose to get correct order
        # After stacking: Shape becomes (num_values_per_uint32, N, total_packed_values)
        w_unpacked_stack = Tensor.stack(*unpacked_values, dim=0)
        w_unpacked_group = w_unpacked_stack.permute(1, 2, 0).reshape(N, groups)  # Shape: (N, groups)

        # Dequantize the unpacked weights
        w_group = w_unpacked_group * scale_g + bias_g  # Shape: (N, groups)

        # Extract the input activations for the current group
        x_group = x_grouped[:, g, :]  # Shape: (M, groups)

        # Perform matrix multiplication and accumulate the result
        partial_output = x_group @ w_group.transpose()  # Shape: (M, N)
        output += partial_output

    return output


In [None]:
kp = quantized_matmul_tg(Tensor(x).realize(), Tensor(w).realize(), Tensor(s).realize(), Tensor(b).realize()).realize()

In [None]:
kp.numpy()

In [None]:
mx.quantized_matmul(mx.array(x), mx.array(w), mx.array(s), mx.array(b))

In [None]:
Tensor.zeros((120, 1024)).cast(dtypes.int32).numpy()

In [None]:
Tensor.zeros((120, 1024)).numpy()

In [None]:
kp[0][0].item()

In [None]:
i=0
j=100
lm[i][j], kp[i][j].item()

In [None]:
jp

In [None]:
ll = Tensor.randint((2,3))

In [None]:
ll

In [None]:
ll.cat(Tensor.empty())

In [None]:
tensors = [Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])]

# Concatenate along dimension 0
result = Tensor.cat(*tensors[1:], dim=1)
print(result.numpy())

In [None]:
ll = Tensor.zeros((120, 1024))
ll.realize().numpy()

In [None]:
ll = Tensor.ones(2, 2, dtype=dtypes.uint32).realize()
lm = Tensor.ones(2, 2, dtype=dtypes.float32).realize()

kp = ll@lm.realize()

In [None]:
ll[0].numpy()

In [None]:
4096/64

In [None]:
import os
os.environ["CLANG"] = "1"

from tinygrad import Tensor
from tinygrad.dtype import dtypes

class MLXLinear:
  def __init__(self, in_features, out_features, bits=4, group_size=64, bias=False):
    assert in_features % group_size == 0
    assert 32 % bits == 0
    assert (in_features * bits) % 32 == 0
    self.weight = Tensor.ones(out_features, (in_features * bits) // 32, dtype=dtypes.uint32)
    self.scales = Tensor.ones(out_features, in_features // group_size, dtype=dtypes.half)
    if bias:
      self.biases = Tensor.ones(out_features, in_features // group_size, dtype=dtypes.half)
    self.bits = bits
    self.group_size = group_size

  def __call__(self, x):
    M, K = x.shape
    N, K_packed = self.weight.shape

    num_values_per_uint32 = 32 // self.bits
    K_unpacked = K_packed * num_values_per_uint32
    num_groups = K // self.group_size
    packs_per_group = self.group_size // num_values_per_uint32

    assert K == K_unpacked, f"Mismatch in K dimensions: {K} vs {K_unpacked}"
    assert self.scales.shape == self.biases.shape == (N, num_groups), f"Scales must have shape (N, {num_groups}), got {self.scales.shape}"
    assert K % self.group_size == 0, "K must be divisible by the number of groups"

    bitmask = (1 << self.bits) - 1

    x_grouped = x.reshape(M, num_groups, self.group_size)

    output = Tensor.zeros((M, N), dtype=dtypes.float16)

    shift_list = [i * self.bits for i in range(num_values_per_uint32)]

    for g in range(num_groups):
        scale_g = self.scales[:, g].reshape(N, 1)
        bias_g = self.biases[:, g].reshape(N, 1)

        pack_start = g * packs_per_group
        pack_end = pack_start + packs_per_group
        w_packed_group = self.weight[:, pack_start:pack_end]

        unpacked_values = []

        for shift_amount in shift_list:
            shifted = w_packed_group >> shift_amount
            masked = (shifted & bitmask).cast(dtypes.float16)
            masked = masked.reshape(N, -1)

            unpacked_values.append(masked)

        w_unpacked_stack = Tensor.stack(*unpacked_values, dim=0)
        w_unpacked_group = w_unpacked_stack.permute(1, 2, 0).reshape(N, self.group_size)
        w_group = w_unpacked_group * scale_g + bias_g

        x_group = x_grouped[:, g, :]

        partial_output = x_group @ w_group.T
        output += partial_output

    return output

In [None]:
import numpy as np
import mlx.nn as nn
import mlx.core as mx

In [None]:
w = weights["model.layers.0.self_attn.q_proj.weight"]
s = weights["model.layers.0.self_attn.q_proj.scales"]
b = weights["model.layers.0.self_attn.q_proj.biases"]
x = np.random.rand(120, 4096).astype(np.float16)

In [None]:
import mlx.core as mx
weights = {}
weights.update(mx.load("./../model.safetensors"))

In [None]:
for l in weights.keys():
    if "embed" in l:
        print(l)

In [None]:
weights["model.embed_tokens.scales"].shape

In [None]:
weights["model.embed_tokens.weight"].shape

In [None]:
jp = nn.QuantizedLinear(4096, 4096, bias=False)
jp.weight = w
jp.scales = s
jp.biases = b
jp(mx.array(x))

In [None]:
from tinygrad import nn as nn1

nn1.Linear(4096, 4096)(Tensor(x).realize()).realize().numpy()

In [None]:
kp = MLXLinear(4096, 4096)

In [None]:
kp.weight = Tensor(np.array(w)).realize()
kp.scales = Tensor(np.array(s)).realize()
kp.biases = Tensor(np.array(b)).realize()

In [None]:
lm = kp(Tensor(x).realize()).realize()

In [None]:
lm.numpy().shape

In [None]:
lm.numpy()

In [None]:
quantized_matmul(x, np.array(w), np.array(s), np.array(b))

In [8]:
from transformers import AutoProcessor, LlavaForConditionalGeneration
# model_id = "mistral-community/pixtral-12b"
# model_id = "llava-hf/llava-1.5-7b-hf"
model_id = "varb15/hf-internal-testing-pixtral-12b"
# processor = AutoProcessor.from_pretrained(model_id)

url_dog = "https://picsum.photos/id/237/200/300"
url_mountain = "https://picsum.photos/seed/picsum/200/300"

# chat = [
#     {
#       "role": "system", "content": "haha"
#     },
#     {
#       "role": "user", "content": [
#         {"type": "text", "text": "Can this animal"}, 
#         {"type": "image", "image": "haha"}, 
#         {"type": "text", "text": "live here?"}, 
#         {"type": "image"}
#       ]
#     },
#     {
#       "role": "user", "content": [
#         {"type": "text", "text": "Can this animal"}, 
#         {"type": "image", "image": "haha"}, 
#         {"type": "text", "text": "live here?"}, 
#         {"type": "image"}
#       ]
#     },
#     {
#       "role": "assistant", "content": [{"type": "text", "text": "Can this animal"}]
#     },
#     {
#       "role": "user", "content": [
#         {"type": "text", "text": "Can this animal"}, 
#         {"type": "image", "image": "haha"}, 
#         {"type": "text", "text": "live here?"}, 
#         {"type": "image"}
#       ]
#     }
# ]
# chat = [{"role": "user", "content": [{"type": "image", "image": "image"}, {"type": "text", "text": "explain"}]}, {"role": "assistant", "content": [{"type": "text", "text": " The image displays a computer screen with multiple graphs and charts. These graphs and charts are likely used for monitoring and analyzing various data points. The graphs show different types of data, such as CPU usage, memory usage, and network traffic. The charts are organized in a way that allows for easy comparison and understanding of the data. The screen is filled with a variety of graphs and charts, indicating a complex system being monitored and analyzed."}]}, {"role": "user", "content": [{"type": "text", "text": "but like explain properly, be more clear"}]}]
# chat = [{"role": "user", "content": [{"type": "image", "image": "image"}, {"type": "text", "text": "explain"}]}, {"role": "assistant", "content": [{"type": "text", "text": " The image displays a computer screen with multiple graphs and charts. These graphs and charts are likely used for monitoring and analyzing various data points. The graphs show different types of data, such as CPU usage, memory usage, and network traffic. The charts are organized in a way that allows for easy comparison and understanding of the data. The screen is filled with a variety of graphs and charts, indicating a complex system being monitored and analyzed."}]}, {"role": "user", "content": [{"type": "text", "text": "but like explain properly, be more clear"}]}]
# chat = [{"role": "user", "content": [{"type": "image", "image": "image"}, {"type": "text", "text": "explain"}]}, {"role": "user", "content": [{"type": "text", "text": "explain"}]}]
# chat = [{"role": "user", "content": "hi"}]
chat = [
  {
    "role": "user",
    "content": [
      {
        "type": "text",
        "text": "hi"
      }
    ]
  },
  {
    "role": "assistant",
    "content": [
      {
        "type": "text",
        "text": "Hello! How can I assist you today? Let's chat about anything you'd like. ���"
      }
    ]
  },
  {
    "role": "user",
    "content": [
      {
        "type": "text",
        "text": "ok how can you help?"
      }
    ]
  },
  {
    "role": "assistant",
    "content": [
      {
        "type": "text",
        "text": "I can help in a variety of ways! Here are some things I can do:\n\n1. **Answer Questions**: Provide information on a wide range of topics, from general knowledge to specific queries.\n2. **Explain Concepts**: Break down complex ideas into simpler parts to help you understand them better.\n3. **Offer Suggestions**: Provide recommendations for books, movies, recipes, and more.\n4. **Help with Language**: Assist with grammar, vocabulary, or translations in multiple languages.\n5. **Provide Study Help**: Offer tips, summaries, and explanations for various subjects.\n6. **Engage in Conversation**: Chat on various topics to help you practice a language or just have a friendly chat.\n7. **Offer Tips and Advice**: Provide advice on productivity, wellness, and other lifestyle topics.\n\nWhat do you need help with today?"
      }
    ]
  },
  {
    "role": "user",
    "content": [
      {
        "type": "image",
        "image": "image"
      },
      {
        "type": "text",
        "text": "explain this image please?"
      }
    ]
  }
]
prompt = processor.apply_chat_template(chat)
# inputs = processor(text=prompt, images=[url_dog, url_mountain], return_tensors="pt")

In [9]:
prompt

"<s>[INST]hi[/INST]Hello! How can I assist you today? Let's chat about anything you'd like. ���</s>[INST]ok how can you help?[/INST]I can help in a variety of ways! Here are some things I can do:\n\n1. **Answer Questions**: Provide information on a wide range of topics, from general knowledge to specific queries.\n2. **Explain Concepts**: Break down complex ideas into simpler parts to help you understand them better.\n3. **Offer Suggestions**: Provide recommendations for books, movies, recipes, and more.\n4. **Help with Language**: Assist with grammar, vocabulary, or translations in multiple languages.\n5. **Provide Study Help**: Offer tips, summaries, and explanations for various subjects.\n6. **Engage in Conversation**: Chat on various topics to help you practice a language or just have a friendly chat.\n7. **Offer Tips and Advice**: Provide advice on productivity, wellness, and other lifestyle topics.\n\nWhat do you need help with today?</s>[INST][IMG]explain this image please?[/INS

In [None]:
prompt

In [10]:
processor.save_pretrained("/Users/varb/pixtral_temp/")

['/Users/varb/pixtral_temp/processor_config.json']

In [6]:
ct = """{%- if messages[0][\"role\"] == \"system\" %}\n  {%- set system_message = messages[0][\"content\"] %}\n  {%- set loop_messages = messages[1:] %}\n{%- else %}\n  {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n  {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n    {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n  {%- endif %}\n  {%- if message[\"role\"] == \"user\" %}\n    {%- if loop.last and system_message is defined %}\n      {{- \"[INST]\" + system_message + \"\\n\\n\" }}\n    {%- else %}\n      {{- \"[INST]\" }}\n    {%- endif %}\n    {%- if message[\"content\"] is not string %}\n      {%- for chunk in message[\"content\"] %}\n        {%- if chunk[\"type\"] == \"text\" %}\n          {{- chunk[\"text\"] }}\n        {%- elif chunk[\"type\"] == \"image\" %}\n          {{- \"[IMG]\" }}\n        {%- else %}\n          {{- raise_exception(\"Unrecognized content type!\") }}\n        {%- endif %}\n      {%- endfor %}\n    {%- else %}\n      {{- message[\"content\"] }}\n    {%- endif %}\n    {{- \"[/INST]\" }}\n  {%- elif message[\"role\"] == \"assistant\" %}\n    {%- if message[\"content\"] is not string %}\n      {%- for chunk in message[\"content\"] %}\n        {%- if chunk[\"type\"] == \"text\" %}\n          {{- chunk[\"text\"] }}\n        {%- elif chunk[\"type\"] == \"image\" %}\n          {{- \"[IMG]\" }}\n        {%- else %}\n          {{- raise_exception(\"Unrecognized content type!\") }}\n        {%- endif %}\n      {%- endfor %}\n    {%- else %}\n      {{- message[\"content\"] }}\n    {%- endif %}\n    {{- eos_token}}\n  {%- else %}\n    {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n  {%- endif %}\n{%- endfor %}"""

In [7]:
processor.chat_template = ct

In [None]:
processor.chat_template

In [None]:
processor.chat_template

In [None]:
nn.QuantizedEmbedding