In [18]:
import triton
import torch
import triton.language as tl
import math



@triton.jit
def __attn_fwd_inner(
    Q_data,O,L,M,K_ptr,V_ptr,k_block_ptr,v_block_ptr,q_block_idx,scale,k_stride_s,v_stride_s,BLOCK_SIZE_OQ:tl.constexpr,BLOCK_SIZE_KV:tl.constexpr,
    DIAGONAL: tl.constexpr, qo_seq_len_offsets,kv_seq_len_offsets,kv_seq_len:tl.constexpr,head_dim:tl.constexpr,group_size:tl.constexpr

):
    if DIAGONAL:
        lo = q_block_idx * BLOCK_SIZE_OQ
        hi = (q_block_idx+1) * BLOCK_SIZE_OQ
    else:
        lo = 0
        hi = q_block_idx * BLOCK_SIZE_OQ
    k_block_ptr+= lo * k_stride_s
    v_block_ptr += lo * v_stride_s
    kv_seq_len_offsets+= lo

    for start_kv in range(lo,hi,BLOCK_SIZE_KV):
        kv_seq_len_mask = kv_seq_len_offsets < kv_seq_len
        K_data_T = tl.load(K_ptr+k_block_ptr,mask=kv_seq_len_mask[None,:],other=0.0) #head_dim,BLOCK_SIZE_KV
        V_data = tl.load(V_ptr+v_block_ptr,mask= kv_seq_len_mask[:,None],other=0.0).to(tl.float32) #BLOCK_SIZE_KV,head_dim

        K_data_T_expanded = tl.broadcast_to(K_data_T[None, :, :], (group_size, head_dim, BLOCK_SIZE_KV))
        V_data_expanded = tl.broadcast_to(V_data[None, :, :], (group_size, BLOCK_SIZE_KV, head_dim)).to(tl.float32)


        S = tl.dot(Q_data,K_data_T_expanded) * scale #group_size,BLOCK_SIZE_OQ,BLOCK_SIZE_KV
        if DIAGONAL:
           causal_mask = qo_seq_len_offsets[:,None]>=kv_seq_len_offsets[None,:]
           S= tl.where(causal_mask,S,-float('inf')) #group_size,BLOCK_SIZE_OQ,BLOCK_SIZE_KV #apparently causal mask auto broadcasts so no need to add a dim for group_size
        new_max = tl.maximum(M,tl.max(S,axis=-1)) #group_size,BLOCK_SIZE_OQ
        S = S - new_max[:,:,None]
        P = tl.exp2(S)
        L_new = tl.sum(P,axis=-1) #group_size,BLOCK_SIZE_OQ

        correction_factor = tl.exp2(M - new_max) #group_size,BLOCK_SIZE_OQ
        L = L * correction_factor + L_new
        O = O * correction_factor[:,:,None]

        O = tl.dot(P,V_data_expanded,acc=O) #group_size,BLOCK_SIZE_OQ,BLOCK_SIZE_KV * BLOCK_SIZE_KV,HEAD_DIM

        M = new_max

        k_block_ptr+= BLOCK_SIZE_KV * k_stride_s
        v_block_ptr += BLOCK_SIZE_KV * v_stride_s
        kv_seq_len_offsets+= BLOCK_SIZE_KV

    return O,L,M





@triton.autotune(
    configs=[

        triton.Config(
            kwargs={"BLOCK_SIZE_OQ": 16, "BLOCK_SIZE_KV": 16},
        ),
    ],
    key=["group_size", "head_dim"],
)
@triton.jit
def gqa_kernel(
    Q_ptr, K_ptr, V_ptr, O_ptr, LSE_ptr,
    q_stride_bs, q_stride_h, q_stride_s, q_stride_d,
    k_stride_bs, k_stride_h, k_stride_s, k_stride_d,
    v_stride_bs, v_stride_h, v_stride_s, v_stride_d,
    o_stride_bs, o_stride_h, o_stride_s, o_stride_d,
    lse_stride_bs, lse_stride_h, lse_stride_s,
    batch_size, q_seq_len, kv_seq_len,
    scale: tl.constexpr,
    group_size: tl.constexpr,
    num_heads: tl.constexpr,
    num_kv_heads: tl.constexpr,
    head_dim: tl.constexpr,
    BLOCK_SIZE_OQ: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
):
    rln2: tl.constexpr = 1.4426950408889634
    scale *= rln2
    q_block_idx = tl.program_id(0)
    bs_nkv_heads = tl.program_id(1)
    batch_index = bs_nkv_heads // num_kv_heads
    head_index  = bs_nkv_heads % num_kv_heads
    Q_ptr += batch_index * q_stride_bs
    K_ptr += batch_index * k_stride_bs + head_index * k_stride_h
    V_ptr += batch_index * v_stride_bs + head_index * v_stride_h
    O_ptr += batch_index * o_stride_bs
    #move every ptr to right batch and head
    qo_heads_offsets = group_size * head_index + tl.arange(0,group_size)
    qo_seq_len_offsets = q_block_idx * BLOCK_SIZE_OQ + tl.arange(0,BLOCK_SIZE_OQ)
    qo_seq_len_mask = qo_seq_len_offsets < q_seq_len #1,BLOCK_SIZE_OQ,1
    head_dim_offsets = tl.arange(0,head_dim)
    qo_block_ptr  = (qo_heads_offsets[:,None,None] * q_stride_h) + (qo_seq_len_offsets[None,:,None] * q_stride_s) + (head_dim_offsets[None,None,:]* q_stride_d)
    #bs,group_size,BLOCK_SIZE_OQ,head_dim
    q_data = tl.load(Q_ptr+qo_block_ptr,mask=qo_seq_len_mask[None,:,None],other=0.0)
    kv_seq_len_offsets = tl.arange(0,BLOCK_SIZE_KV)
    k_block_ptr = head_dim_offsets[:,None] * k_stride_d + kv_seq_len_offsets[None,:] * k_stride_s
    v_block_ptr  =  kv_seq_len_offsets[:,None] * v_stride_s + head_dim_offsets[None,:] * v_stride_d
    M = tl.full(shape=[group_size,BLOCK_SIZE_OQ],value = float("-inf"),dtype=tl.float32)
    L = tl.full(shape=[group_size,BLOCK_SIZE_OQ],value = float(0),dtype=tl.float32)
    O = tl.zeros([group_size,BLOCK_SIZE_OQ,head_dim],dtype=tl.float32)

    O,L,M = __attn_fwd_inner(
        q_data,O,L,M,K_ptr,V_ptr,k_block_ptr,v_block_ptr,q_block_idx,scale,k_stride_s,v_stride_s,BLOCK_SIZE_OQ,BLOCK_SIZE_KV,
        False,qo_seq_len_offsets,kv_seq_len_offsets,kv_seq_len,head_dim,group_size
    )
    O,L,M = __attn_fwd_inner(
        q_data,O,L,M,K_ptr,V_ptr,k_block_ptr,v_block_ptr,q_block_idx,scale,k_stride_s,v_stride_s,BLOCK_SIZE_OQ,BLOCK_SIZE_KV,
        True,qo_seq_len_offsets,kv_seq_len_offsets,kv_seq_len,head_dim,group_size
    )

    O = O/L[:,:,None]
    LSE = M + tl.math.log2(L) #GROUP_SIZE,BLOCK_SIZE_OQ

    LSE_offsets = batch_index*lse_stride_bs + qo_heads_offsets[:,None] * lse_stride_h + qo_seq_len_offsets
    tl.store(LSE_ptr+LSE_offsets,LSE,mask = qo_seq_len_mask[None,:])

    o_block_ptr  = (qo_heads_offsets[:,None,None] * o_stride_h) + (qo_seq_len_offsets[None,:,None] * o_stride_s) + (head_dim_offsets[None,None,:]* o_stride_d)
    tl.store(O_ptr+o_block_ptr,O.to(tl.bfloat16),mask=qo_seq_len_mask[None,:,None])




class GQA_Torch(torch.autograd.Function):
    """
    Custom kernel build efficient for GQA
    """

    @staticmethod
    def forward(ctx, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        batch_size, num_heads, q_seq_len, head_dim = Q.shape
        _, num_kv_heads, kv_seq_len, head_dim = K.shape
        assert Q.shape[-1] == K.shape[-1] == V.shape[-1]
        assert Q.shape[-1] <= 128
        assert Q.device == K.device == V.device
        assert K.shape == V.shape
        assert num_heads % num_kv_heads == 0
        O = torch.empty_like(Q)
        LSE = torch.empty(
            batch_size,
            num_heads,
            q_seq_len,
            device=Q.device,
            dtype=torch.bfloat16,
        )
        grid = lambda args: (
            triton.cdiv(q_seq_len, args["BLOCK_SIZE_OQ"]),
            batch_size * num_kv_heads,
        )
        scale = 1 / math.sqrt(head_dim)
        group_size = num_heads // num_kv_heads
        gqa_kernel[grid](
            Q, K, V, O, LSE,
            Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3),
            K.stride(0), K.stride(1), K.stride(2), K.stride(3),
            V.stride(0), V.stride(1), V.stride(2), V.stride(3),
            O.stride(0), O.stride(1), O.stride(2), O.stride(3),
            LSE.stride(0), LSE.stride(1), LSE.stride(2),
            batch_size,
            q_seq_len,
            kv_seq_len,
            scale,
            group_size,
            num_heads,
            num_kv_heads,
            head_dim,
        )
        return O

    @staticmethod
    def backward(ctx, dldO):
        pass


In [19]:


import torch
import torch.nn as nn

class FFN(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg['embed_dim'],cfg['hidden_dim'],dtype=cfg['dtype'],bias=False)
        self.fc2 = nn.Linear(cfg['embed_dim'],cfg['hidden_dim'],dtype=cfg['dtype'],bias=False)
        self.fc3 = nn.Linear(cfg['hidden_dim'],cfg['embed_dim'],dtype=cfg['dtype'],bias=False)
    def forward(self,x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x  = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

class RMSNorm(nn.Module):
    def __init__(self,emb_dim,eps=1e-6,bias=False,qwen_3_compaitable=True):
        super().__init__()
        self.eps = eps
        self.qwen_3_compaitable = qwen_3_compaitable
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None

    def forward(self,x):
        dtype = x.dtype
        if self.qwen_3_compaitable:
            x = x.to(torch.float32)
        variance = x.pow(2).mean(dim=-1,keepdim=True) #bs,seqlen,1
        norm_x  = x * torch.rsqrt(variance+self.eps)
        norm_x = self.scale*norm_x
        if self.shift:
            norm_x = norm_x + self.shift

        return norm_x.to(dtype)


def compute_rope_params(head_dim,theta_base=10000,context_length=4096,dtype=torch.float32):
    assert head_dim%2==0 ,'embedding dim is not even'
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype) / head_dim)) #shape headdim//2
    positions  = torch.arange(context_length,dtype=dtype) #shape context_length

    angles  = positions.unsqueeze(1) *  inv_freq.unsqueeze(0) # [[1],[2],[3],[4]] * [[1,2,3,4]] - shape - (context_length,head_dim//2)
    angles = torch.cat((angles,angles),dim=-1) #shape - context_len,head_dim

    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos,sin

def apply_rope(x,cos,sin,offset=0):
    bs,n_heads,seq_len,head_dim = x.shape
    assert head_dim % 2 == 0
    x1 = x[...,:head_dim//2]
    x2 = x[...,head_dim//2:]

    cos = cos[offset:offset+seq_len,:].unsqueeze(0).unsqueeze(0)
    sin = sin[offset:offset+seq_len,:].unsqueeze(0).unsqueeze(0)

    rotated = torch.cat((-x2,x1),dim=-1)
    x_rotated = x*cos+(rotated*sin)

    return x_rotated.to(dtype=x.dtype)

class GQA(nn.Module):
    def __init__(self,d_in,num_heads,num_kv_groups,head_dim=None,qk_norm=False,dtype=None):
        super().__init__()
        assert num_heads % num_kv_groups == 0
        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups
        if head_dim is  None:
            assert d_in % num_heads == 0
            head_dim = d_in // num_heads
        self.head_dim = head_dim
        self.d_out = num_heads * head_dim
        self.W_query = nn.Linear(d_in,self.d_out,bias=False,dtype=dtype)
        self.W_key = nn.Linear(d_in,num_kv_groups*head_dim,bias=False,dtype=dtype)
        self.W_value = nn.Linear(d_in,num_kv_groups*head_dim,bias=False,dtype=dtype)

        self.out_proj = nn.Linear(self.d_out,d_in,bias=False,dtype=dtype)

        if qk_norm:
            self.q_norm = RMSNorm(head_dim,eps=1e-6)
            self.k_norm = RMSNorm(head_dim,eps=1e-6)
        else:
            self.q_norm = None
            self.k_norm = None
    def forward(self,x,mask,cos,sin,start_pos=0,cache=None):
        batch_size,seq_len,n_dim = x.shape
        queries  = self.W_query(x)  #-> bs,seqlen,numheads*head_dim
        keys  = self.W_key(x)   #- > bs,seqlen,num_kv_groups*head_dim
        values  = self.W_value(x) #-> bs,seqlen,num_kv_groups*head_dim


        queries  = queries.view(batch_size,seq_len,self.num_heads,self.head_dim).transpose(1,2)
        keys_new = keys.view(batch_size,seq_len,self.num_kv_groups,self.head_dim).transpose(1,2)
        values_new = values.view(batch_size,seq_len,self.num_kv_groups,self.head_dim).transpose(1,2)

        if self.q_norm:
            queries = self.q_norm(queries)
        if self.k_norm:
            keys_new = self.k_norm(keys_new)

        queries = apply_rope(queries,cos,sin,offset=start_pos)
        keys_new = apply_rope(keys_new,cos,sin,offset=start_pos)

        if cache is not None:
            prev_k,prev_v = cache
            keys = torch.cat((prev_k,keys_new),dim=2)
            values = torch.cat((prev_v,values_new),dim=2)
            next_cache = (keys,values)
        else:
            start_pos = 0
            keys,values  = keys_new,values_new
            next_cache = (keys,values)


        context = GQA_Torch.apply(queries,keys,values).transpose(1,2).reshape(batch_size,seq_len,self.d_out)

        return self.out_proj(context),next_cache



class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = GQA(
            d_in = cfg['embed_dim'],
            num_heads = cfg['num_heads'],
            head_dim = cfg['head_dim'],
            num_kv_groups = cfg['num_kv_groups'],
            qk_norm = cfg['qk_norm'],
            dtype = cfg['dtype']
        )
        self.ffn = FFN(cfg)
        self.norm1 = RMSNorm(cfg['embed_dim'],eps=1e-6)
        self.norm2 = RMSNorm(cfg['embed_dim'],eps=1e-6)
    def forward(self,x,mask,cos,sin,start_pos=0,cache=None):
        skip = x
        x = self.norm1(x)
        x,next_cache = self.att(x,mask,cos,sin,start_pos=start_pos,cache=cache)
        x = x+ skip

        skip = x

        x = self.norm2(x)
        x = self.ffn(x)
        x = x+skip

        return x , next_cache

class KVCache:
    def __init__(self, n_layers):
        self.cache = [None] * n_layers

    def get(self, layer_idx):
        return self.cache[layer_idx]

    def update(self, layer_idx, value):
        self.cache[layer_idx] = value

    def get_all(self):
        return self.cache

    def reset(self):
        for i in range(len(self.cache)):
            self.cache[i] = None


class Qwen3Model(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg['vocab_size'],cfg['embed_dim'],dtype=cfg['dtype'])
        self.blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg['n_layers'])]
        )
        self.final_norm = RMSNorm(cfg['embed_dim'])
        self.out_head = nn.Linear(cfg['embed_dim'],cfg['vocab_size'],bias=False,dtype=cfg['dtype'])


        cos,sin = compute_rope_params(head_dim=cfg['head_dim'],theta_base=cfg['rope_base'],context_length=cfg['context_length'])
        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)
        self.cfg = cfg
        self.current_pos =  0

    def forward(self,in_idx,cache=None):
        tok_embeds = self.tok_emb(in_idx)
        x = tok_embeds

        seq_len = x.shape[1]

        if cache is not None:
            pos_start = self.current_pos
            pos_end = seq_len + pos_start
            self.current_pos = pos_end

            mask = torch.triu(
                torch.ones(pos_end,pos_end,device=x.device,dtype=torch.bool),diagonal=1
            ) [pos_start:pos_end,:pos_end]   #the shape is seq_len,seq_len
        else:
            pos_start = 0
            mask = torch.triu(
                torch.ones(seq_len,seq_len,device=x.device,dtype=torch.bool),diagonal=1
            )
        mask = mask[None,None,:,:]

        for i , block in enumerate(self.blocks):
            blk_cache = cache.get(i) if cache else None
            x,new_blk_cache = block(x,mask,self.cos,self.sin,start_pos=pos_start,cache=blk_cache)
            if cache is not None:
                cache.update(i,new_blk_cache)
        x = self.final_norm(x)
        logits  = self.out_head(x.to(self.cfg['dtype']))
        return logits

    def reset_kv_cache(self):
        self.current_pos = 0



QWEN3_CONFIG = {
        "vocab_size": 151_936,           # Vocabulary size
        "context_length": 40_960,        # Context length that was used to train the model
        "embed_dim": 1024,                 # Embedding dimension
        "num_heads": 16,                   # Number of attention heads
        "n_layers": 28,                  # Number of layers
        "hidden_dim": 3072,              # Size of the intermediate dimension in FeedForward
        "head_dim": 128,                 # Size of the heads in GQA
        "qk_norm": True,                 # Whether to normalize queries and keys in GQA
        "num_kv_groups": 8,                # Key-Value groups for grouped-query attention
        "rope_base": 1_000_000.0,        # The base in RoPE's "theta"
        "dtype": torch.bfloat16,         # Lower-precision dtype to reduce memory usage
    }


def load_weights_into_qwen(model, param_config, params):
    def assign(left, right, tensor_name="unknown"):
        if left.shape != right.shape:
            raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")

        with torch.no_grad():
            if isinstance(right, torch.Tensor):
                left.copy_(right)
            else:
                left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))

        return left

    model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")

    for l in range(param_config["n_layers"]):
        block = model.blocks[l]
        att = block.att

        # Q, K, V projections
        att.W_query.weight = assign(
            att.W_query.weight,
            params[f"model.layers.{l}.self_attn.q_proj.weight"],
            f"model.layers.{l}.self_attn.q_proj.weight"
        )
        att.W_key.weight = assign(
            att.W_key.weight,
            params[f"model.layers.{l}.self_attn.k_proj.weight"],
            f"model.layers.{l}.self_attn.k_proj.weight"
        )
        att.W_value.weight = assign(
            att.W_value.weight,
            params[f"model.layers.{l}.self_attn.v_proj.weight"],
            f"model.layers.{l}.self_attn.v_proj.weight"
        )

        # Output projection
        att.out_proj.weight = assign(
            att.out_proj.weight,
            params[f"model.layers.{l}.self_attn.o_proj.weight"],
            f"model.layers.{l}.self_attn.o_proj.weight"
        )

        # QK norms
        if hasattr(att, "q_norm") and att.q_norm is not None:
            att.q_norm.scale = assign(
                att.q_norm.scale,
                params[f"model.layers.{l}.self_attn.q_norm.weight"],
                f"model.layers.{l}.self_attn.q_norm.weight"
            )
        if hasattr(att, "k_norm") and att.k_norm is not None:
            att.k_norm.scale = assign(
                att.k_norm.scale,
                params[f"model.layers.{l}.self_attn.k_norm.weight"],
                f"model.layers.{l}.self_attn.k_norm.weight"
            )

        # Attention layernorm
        block.norm1.scale = assign(
            block.norm1.scale,
            params[f"model.layers.{l}.input_layernorm.weight"],
            f"model.layers.{l}.input_layernorm.weight"
        )

        # Feedforward weights
        block.ffn.fc1.weight = assign(
            block.ffn.fc1.weight,
            params[f"model.layers.{l}.mlp.gate_proj.weight"],
            f"model.layers.{l}.mlp.gate_proj.weight"
        )
        block.ffn.fc2.weight = assign(
            block.ffn.fc2.weight,
            params[f"model.layers.{l}.mlp.up_proj.weight"],
            f"model.layers.{l}.mlp.up_proj.weight"
        )
        block.ffn.fc3.weight = assign(
            block.ffn.fc3.weight,
            params[f"model.layers.{l}.mlp.down_proj.weight"],
            f"model.layers.{l}.mlp.down_proj.weight"
        )
        block.norm2.scale = assign(
            block.norm2.scale,
            params[f"model.layers.{l}.post_attention_layernorm.weight"],
            f"model.layers.{l}.post_attention_layernorm.weight"
        )

    # Final normalization and output head
    model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")

    if "lm_head.weight" in params:
        model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
    else:
        model.out_head.weight = model.tok_emb.weight
        print("Model uses weight tying.")

import json
import os
from pathlib import Path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download, snapshot_download

repo_id = f"Qwen/Qwen3-{0.6}B"
local_dir = Path(repo_id).parts[-1]


weights_file = hf_hub_download(
        repo_id=repo_id,
        filename="model.safetensors",
        local_dir=local_dir,
    )
weights_dict = load_file(weights_file)


model = Qwen3Model(QWEN3_CONFIG)



load_weights_into_qwen(model, QWEN3_CONFIG, weights_dict)
model = model.to(0)
del weights_dict

In [20]:
from transformers import AutoTokenizer

tokenizer  = AutoTokenizer.from_pretrained(repo_id)

In [21]:


messages = [
    {'role':"user","content":"Give me a short introduction to large language models with gandhiji as word used"}
]
ids = tokenizer.apply_chat_template(messages,add_generation_prompt=True)
ids = torch.tensor(ids).unsqueeze(0)
ids = ids.to(0)



In [26]:
def stream_text_from_model(
    model,
    token_ids,
    max_new_tokens=512,
    eos_token_id=None,
    use_cache: bool = False,
):
    """
    Streams tokens from a causal LM.

    use_cache=False:
        - No KV cache
        - Recomputes full sequence each step (slow, but simple)

    use_cache=True:
        - Uses KV cache
        - First pass full prompt, then token-by-token
    """

    model.eval()
    with torch.no_grad():

        # ---- setup ----
        if use_cache:
            cache = KVCache(n_layers=len(model.blocks))
            model.reset_kv_cache()
        else:
            cache = None

        # ---- first forward (full prompt) ----
        logits = model(token_ids, cache=cache)

        generated = token_ids

        for _ in range(max_new_tokens):

            # pick next token
            next_token = torch.argmax(
                logits[:, -1, :], dim=-1, keepdim=True
            )  # (bs, 1)

            # EOS check
            if eos_token_id is not None and torch.all(next_token == eos_token_id):
                break

            yield next_token

            if use_cache:
                # ---- cached path: feed ONLY new token ----
                logits = model(next_token, cache=cache)
            else:
                # ---- no-cache path: re-feed entire sequence ----
                generated = torch.cat([generated, next_token], dim=1)
                logits = model(generated, cache=None)


In [None]:
for token in stream_text_from_model(model,ids,512):
    print(tokenizer.decode(token[0]),end='')



<think>
Okay, the user wants a short introduction to large language models, but they mentioned using "Gandhiji" as a word. First, I need to check if "Gandhiji" is a real person. Gandhiji is a prominent Indian philosopher and social activist, known for his ideas on non-violence and social justice. So, the user might be mixing up the name or using it incorrectly.

I should clarify that "Gandhiji" is a real person, not a word used in the context of large language models. Maybe they meant "Gandhi" or "Gandhiji" as part of the model's name? But since the user specified "Gandhiji," I need to make sure to correct that. 

Next, I should provide a general introduction to large language models, explaining their purpose, capabilities, and how they work. Then, mention Gandh