In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="cuda")

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

<bos>Write me a poem about Machine Learning.

I’m not sure what you mean by “write me a poem about Machine Learning.”

I’m not sure what you mean by “write me a poem about Machine Learning.”

I’m not sure what you mean by “write


In [7]:
model

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaR

In [2]:
import torch

torch.manual_seed(42)
tokens = torch.randint(0, 256000, (1, 10))
tokens = tokens.to("cuda")
output = model(tokens)
print(output.logits)
hf_output = output.logits
loss = output.logits.sum()
loss.backward()
print(model.lm_head.weight.grad)
print(model.model.embed_tokens.weight.grad)
hf_grad = model.lm_head.weight.grad

tensor([[[ -6.6963,   2.4422, -18.7727,  ...,  -7.9401,  -2.7512,  -6.6050],
         [ -7.3636,  14.5272, -23.6110,  ...,  -9.1783,  -3.4405,  -7.2611],
         [ -7.6963,  10.9584, -20.6943,  ...,  -9.9076,  -3.3036,  -7.5720],
         ...,
         [-14.3678,   3.1154, -12.6028,  ..., -11.5705, -11.9707, -14.4218],
         [-22.8693,  -5.5922, -24.4688,  ..., -22.0343, -18.6676, -22.9308],
         [-17.8728,   1.2903, -39.8689,  ..., -19.7219, -12.3213, -17.9414]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)
tensor([[ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282],
        [ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282],
        [ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282],
        ...,
        [ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282],
        [ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282],
        [ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282]],
       device='cuda:0')
tensor([[

In [3]:
from torchtune.models.gemma import gemma_2b
from torchtune.utils import FullModelHFCheckpointer

checkpointer = FullModelHFCheckpointer(
  checkpoint_dir="/tmp/gemma/",
  checkpoint_files=[
    "model-00001-of-00002.safetensors",
    "model-00002-of-00002.safetensors",
  ],
  recipe_checkpoint=None,
  output_dir="/tmp/gemma",
  model_type="GEMMA",
)
sd = checkpointer.load_checkpoint()

model = gemma_2b()
model.load_state_dict(sd['model'])
model = model.to("cuda")
output = model(tokens)
print(output)
tt_output = output
loss = output.sum()
loss.backward()
print(model.tok_embeddings.weight.grad)
print(model.output.weight.grad)
tt_grad = model.output.weight.grad

tensor([[[ -6.6963,   2.4422, -18.7727,  ...,  -7.9401,  -2.7512,  -6.6050],
         [ -7.3636,  14.5272, -23.6109,  ...,  -9.1783,  -3.4404,  -7.2610],
         [ -7.6963,  10.9585, -20.6943,  ...,  -9.9076,  -3.3036,  -7.5720],
         ...,
         [-14.3678,   3.1154, -12.6028,  ..., -11.5705, -11.9707, -14.4217],
         [-22.8693,  -5.5922, -24.4688,  ..., -22.0343, -18.6676, -22.9308],
         [-17.8728,   1.2904, -39.8689,  ..., -19.7219, -12.3213, -17.9414]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)
tensor([[ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282],
        [ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282],
        [ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282],
        ...,
        [ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282],
        [ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282],
        [ 1.4173,  1.9047, -3.1243,  ..., -0.5948, -6.0515, -5.6282]],
       device='cuda:0')


AttributeError: 'GemmaTransformerDecoder' object has no attribute 'output'

In [4]:
torch.testing.assert_close(tt_output, hf_output)
torch.testing.assert_close(tt_grad, hf_grad)

AssertionError: Tensor-likes are not close!

Mismatched elements: 896814 / 2560000 (35.0%)
Greatest absolute difference: 0.00025177001953125 at index (0, 7, 232983) (up to 1e-05 allowed)
Greatest relative difference: 0.29729729890823364 at index (0, 1, 91390) (up to 1.3e-06 allowed)

In [9]:
tt_output[0,1,91390]

tensor(9.9182e-05, device='cuda:0', grad_fn=<SelectBackward0>)

In [10]:
hf_output[0,1,91390]

tensor(0.0001, device='cuda:0', grad_fn=<SelectBackward0>)

In [13]:
from torchtune.models.gemma import gemma_2b
from torchtune.utils import FullModelHFCheckpointer

checkpointer = FullModelHFCheckpointer(
  checkpoint_dir="/tmp/gemma/",
  checkpoint_files=[
    "model-00001-of-00002.safetensors",
    "model-00002-of-00002.safetensors",
  ],
  recipe_checkpoint=None,
  output_dir="/tmp/gemma",
  model_type="GEMMA",
)
sd = checkpointer.load_checkpoint()

model = gemma_2b()
model.load_state_dict(sd['model'])
model = model.to("cuda")
output = model(tokens)
print(output)
loss = output.sum()
loss.backward()
print(model.tok_embeddings.weight.grad)

tensor([[[ -6.6963,   2.4422, -18.7727,  ...,  -7.9401,  -2.7512,  -6.6050],
         [ -7.3636,  14.5272, -23.6109,  ...,  -9.1783,  -3.4404,  -7.2610],
         [ -7.6963,  10.9585, -20.6943,  ...,  -9.9076,  -3.3036,  -7.5720],
         ...,
         [-14.3678,   3.1154, -12.6028,  ..., -11.5705, -11.9707, -14.4217],
         [-22.8693,  -5.5922, -24.4688,  ..., -22.0343, -18.6676, -22.9308],
         [-17.8728,   1.2904, -39.8689,  ..., -19.7219, -12.3213, -17.9414]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')


In [12]:
import sys
sys.path.append("../gemma_pytorch")
from gemma.model import GemmaForCausalLM
from gemma.config import get_config_for_2b

model = GemmaForCausalLM(get_config_for_2b())
model

AssertionError: tokenizer/tokenizer.model