In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from retentionengine import RetentionEngine
from thelethe.titans import AtlasConfig

In [2]:
# Load Gemma3 4b model and tokenizer
basemodel_name = "google/gemma-3-4b-it"
tokenizer = AutoTokenizer.from_pretrained(basemodel_name)
basemodel = AutoModelForCausalLM.from_pretrained(basemodel_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
# Load adapter config
adapter_name = "retentionlabs/atlas-g-4b"
config = AtlasConfig.from_pretrained(adapter_name)

In [4]:
# Convert the transformer to a retentive model
model = RetentionEngine(basemodel, config)
model.to("cuda")

RetentionEngine(
  (module): PreTrainedTitansModel(
    (model): Gemma3ForConditionalGeneration(
      (vision_tower): SiglipVisionModel(
        (vision_model): SiglipVisionTransformer(
          (embeddings): SiglipVisionEmbeddings(
            (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
            (position_embedding): Embedding(4096, 1152)
          )
          (encoder): SiglipEncoder(
            (layers): ModuleList(
              (0-26): 27 x SiglipEncoderLayer(
                (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
                (self_attn): SiglipAttention(
                  (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
             

In [5]:
# Adapt the memory adapter to the model
model.adapt()

In [6]:
# Save the full model
import os
os.mkdir("pretrained")
model.save_pretrained("pretrained/gemma-3-4b-retentive")