# GPT-J-6B Inference Demo

<a href="http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to run the [GPT-J-6B model](https://github.com/kingoflolz/mesh-transformer-jax/#GPT-J-6B). See the link for more details about the model, including evaluation metrics and credits.

## Install Dependencies

First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).

!!! **Make sure you are using a TPU runtime!** !!!

In [None]:
!apt install zstd

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd

!time tar -I zstd -xf step_383500_slim.tar.zstd

!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following NEW packages will be installed:
  zstd
0 upgraded, 1 newly installed, 0 to remove and 2 not upgraded.
Need to get 603 kB of archives.
After this operation, 1,695 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 zstd amd64 1.4.8+dfsg-3build1 [603 kB]
Fetched 603 kB in 0s (2,217 kB/s)
Selecting previously unselected package zstd.
(Reading database ... 121954 files and directories currently installed.)
Preparing to unpack .../zstd_1.4.8+dfsg-3build1_amd64.deb ...
Unpacking zstd (1.4.8+dfsg-3build1) ...
Setting up zstd (1.4.8+dfsg-3build1) ...
Processing triggers for man-db (2.10.2-1) ...
--2025-06-16 18:12:05--  https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
Resolving the-eye.eu (the-eye.eu)... 162.213.130.250
Connecting to the-eye.eu (the-eye.eu)|162.213.130.250|:443... connected.
HTTP request sent, awaiting respo

# Yeni Bölüm

In [1]:
!pip install transformers torch accelerate bitsandbytes



Collecting bitsandbytes
  Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.me

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os

# Model adını belirtin
model_name = "EleutherAI/gpt-j-6B"

print(f"{model_name} modelini ve tokenizer'ı yüklüyor...")

# Tokenizer'ı yükle
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Modeli yüklemek için cihazı belirle
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Model {device} üzerinde yüklenecek.")

# Nicelleştirme konfigürasyonu (8-bit)
# load_in_8bit=True, modeli 8-bit formatında yüklemeye çalışır
quantization_config = BitsAndBytesConfig(load_in_8bit=True)

try:
    # Modeli 8-bit nicelleştirilmiş olarak yüklemeyi dene
    # Eğer bu başarısız olursa, genellikle bilgisayarınızın bu özelliği desteklemediği anlamına gelir
    # veya bazı bağımlılık sorunları vardır.
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=quantization_config,
        device_map="auto" # Otomatik olarak CPU ve/veya GPU'ya dağıt
    )
    print("Model 8-bit nicelleştirilmiş olarak yüklendi.")
except Exception as e:
    print(f"8-bit nicelleştirme yüklemesi başarısız oldu ({e}). Normal hassasiyette (float32) denenecek...")
    # 8-bit başarısız olursa, normal float32 hassasiyetinde yüklemeyi dene
    # Eğer bellek yetersizliği devam ederse bu da hata verebilir
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(device)
    print("Model float32 hassasiyetinde yüklendi.")

print(f"{model_name} modeli ve tokenizer başarıyla yüklendi.")

# --- Metin Üretimi Bölümü (Bu kısım aynı kalabilir) ---
print("\n--- Metin Üretimi Başlıyor ---")
prompt = "Yapay zeka gelecekte hayatımızı nasıl etkileyecek?"
print(f"Giriş: {prompt}")

inputs = tokenizer(prompt, return_tensors="pt").to(device)

try:
    generated_ids = model.generate(
        inputs["input_ids"],
        max_new_tokens=30, # Burayı yine 20-30 civarında tutalım
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95
    )

    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    print("\n--- Üretilen Metin ---")
    print(generated_text)
    print("----------------------")

except Exception as e:
    print(f"\nMetin üretimi sırasında bir hata oluştu: {e}")
    print("Bu genellikle bellek yetersizliğinden (RAM veya GPU belleği) kaynaklanabilir.")
    print("Çözümler:")
    print("1. 'max_new_tokens' değerini azaltın.")
    print("2. Eğer GPU kullanıyorsanız, sürücülerinizin güncel olduğundan emin olun.")
    print("3. Nicelleştirme (8-bit veya 4-bit) kullanmayı deneyin (mevcut kod bunu deniyor).")
    print(f"Mevcut cihaz: {device}. Modelin kullandığı bellek: {model.get_memory_footprint() / (1024**3):.2f} GB")
    if device == "cuda":
        print(f"CUDA belleği ayrıldı: {torch.cuda.memory_allocated() / (1024**3):.2f} GB")
        print(f"CUDA belleği önbellekte: {torch.cuda.memory_reserved() / (1024**3):.2f} GB")

## Setup Model


In [None]:
import os
import requests
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

Sometimes the next step errors for some reason, just run it again ¯\\\_(ツ)\_/¯

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes.

In [None]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

## Run Model

Finally, we are ready to infer with the model! The first sample takes around a minute due to compilation, but after that it should only take about 10 seconds per sample.

Feel free to mess with the different sampling parameters (top_p and temp), as well as the length of the generations (gen_len, causes a recompile when changed).

You can also change other things like per_replica_batch in the previous cells to change how many generations are done in parallel. A larger batch has higher latency but higher throughput when measured in tokens generated/s. This is useful for doing things like best-of-n cherry picking.

*Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.*

In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    print(f"completion done in {time.time() - start:06}s")
    return samples

print(infer("EleutherAI is")[0])

In [None]:
#@title  { form-width: "300px" }
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

context = """In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."""

print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])