In [None]:
import torch
from transformers.models.llama.modeling_llama import LlamaForCausalLM

In [None]:


class BaselineLlamaForCausalLm(LlamaForCausalLM):

    @torch.no_grad()
    def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.Tensor:
        out = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=False
        )
        return out.logits

model_id: str = "meta-llama/Llama-3.1-8B-Instruct"
torch_model = BaselineLlamaForCausalLm.from_pretrained(model_id).eval()



In [None]:
import coremltools as ct
import numpy as np


batch_size, context_size = 1, 2048
input_shape = (batch_size, context_size)

example_inputs: tuple[torch.Tensor] = (
    torch.zeros(input_shape, dtype=torch.int32),
    torch.zeros(input_shape, dtype=torch.int32),
)

traced_model: torch.jit.ScriptModule = torch.jit.trace(torch_model, example_inputs)

In [None]:
inputs: list[ct.TensorType] = [
    ct.TensorType(shape=input_shape, dtype=np.int32, name="inputIds"),
    ct.TensorType(shape=input_shape, dtype=np.int32, name="attentionMask"),
]

outputs: list[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")]

mlmodel: ct.models.MLModel = ct.convert(traced_model, inputs=inputs, outputs=outputs, minimum_deployment_target=ct.target.macOS15, skip_model_load=True)