# GPT2 LoRA finetuning

This notebook demonstrates how to LoRA finetune a pretrained GPT2(124M) model to follow user instructions. We are going to load the pretrained GPT2 model weights from Hugging Face and then instruct finetune the model on TPU with LoRA. If you are not familiar with LoRA, please go ahead and read the original [paper](https://arxiv.org/abs/2106.09685) first.

## Determine platform

In [1]:
import os
if os.path.exists('/content/'):
  platform = "Colab"
elif os.path.exists('/kaggle/'):
  platform = "Kaggle"
else:
  # Assume using Cloud TPU otherwise
  platform = "GCP"

## Setup

Install JAX and Flax first.

In [2]:
!pip install -q jax-ai-stack[grain]
if platform == "Colab": # temp workaround on Colab (https://github.com/jax-ml/jax-ai-stack/issues/149)
  !pip install -Uq "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -Uq tiktoken matplotlib kaggle wandb tpu-info datasets


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/456.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m450.6/456.0 kB[0m [31m14.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m456.0/456.0 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m473.3/473.3 kB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m67.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m319.2/319.2 kB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m406.3/406.3 kB[0m [31m22.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.2/86.2 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

Confirm we have TPUs set up.

In [3]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Take care of the imports.

In [4]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax, orbax
from collections import Counter
from dataclasses import dataclass
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import numpy as np
import tiktoken, time, wandb
from huggingface_hub import snapshot_download
from safetensors import safe_open
from pathlib import Path
from flax.nnx.nn.lora import LoRAParam

## Build the model

Define the device mesh.


In [5]:
### Alternative data and model parallel
# mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))

mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))

We are going to use the GPT-2 tokenizer via OpenAI's [Tiktoken](https://github.com/openai/tiktoken) library.

In [6]:
tokenizer = tiktoken.get_encoding("gpt2")

Set some hyperparameters. We set LoRA rank to 8.

In [7]:
vocab_size = tokenizer.n_vocab
GPT2_variant = "GPT2" # Only supports GPT2
num_transformer_blocks = 12
seqlen = 1024
embed_dim = 768
num_heads = 12
feed_forward_dim = 4 * embed_dim
if platform == "Colab":
    batch_size = 24 # TPU v2
else:
    batch_size = 72 # TPU v3

dropout_rate = 0.1
lora_rank = 8

max_steps = 600000*12//batch_size
init_learning_rate = 5e-4
weight_decay = 1e-1
top_k = 10
sampling_temp = 2
dtype = jnp.bfloat16
param_dtype = jnp.float32

Let's define a custom multi-head attention class first. Since we are doing LoRA finetuning this time, we are going to replace all Linear layers with [nnx.LoRALinear layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/lora.html#flax.nnx.LoRALinear). Notice how the query, key, value and out projection layers are defined.

In [8]:
def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))


class CustomMHA(nnx.Module):
    def __init__(self, embed_dim, num_heads, dropout_rate, layer_idx, weights, rngs):
        self.num_heads = num_heads
        self.head_dim = embed_dim // self.num_heads
        self.embed_dim = embed_dim

        kernel_init = nnx.with_partitioning(
            nnx.initializers.xavier_uniform(), (P(None, "model"),)
        )

        self.query = nnx.LoRALinear(
            embed_dim,
            embed_dim,
            rngs=rngs,
            use_bias=False,
            kernel_init=kernel_init,
            lora_rank=lora_rank,
        )
        self.key = nnx.LoRALinear(
            embed_dim,
            embed_dim,
            rngs=rngs,
            use_bias=False,
            kernel_init=kernel_init,
            lora_rank=lora_rank,
        )
        self.value = nnx.LoRALinear(
            embed_dim,
            embed_dim,
            rngs=rngs,
            use_bias=False,
            kernel_init=kernel_init,
            lora_rank=lora_rank,
        )
        self.out = nnx.LoRALinear(
            embed_dim,
            embed_dim,
            rngs=rngs,
            use_bias=False,
            kernel_init=kernel_init,
            lora_rank=lora_rank,
        )

        qkv_kernel = weights[f"h.{layer_idx}.attn.c_attn.weight"]
        q_kernel, k_kernel, v_kernel = jnp.split(qkv_kernel, 3, axis=-1)
        self.query.kernel.value = q_kernel
        self.key.kernel.value = k_kernel
        self.value.kernel.value = v_kernel

        qkv_bias = weights[f"h.{layer_idx}.attn.c_attn.bias"]
        q_b, k_b, v_b = jnp.split(qkv_bias, 3, axis=-1)

        self.q_bias = nnx.Param(q_b, sharding=P("model"))
        self.k_bias = nnx.Param(k_b, sharding=P("model"))
        self.v_bias = nnx.Param(v_b, sharding=P("model"))

        self.out.kernel.value = weights[f"h.{layer_idx}.attn.c_proj.weight"]
        self.out_bias = nnx.Param(
            weights[f"h.{layer_idx}.attn.c_proj.bias"], sharding=P("model")
        )

        self.dropout = nnx.Dropout(dropout_rate)

    def __call__(self, x, mask, padding_mask=None, training: bool = False):
        batch_size, seq_len, _ = x.shape

        q = self.query(x) + self.q_bias
        k = self.key(x) + self.k_bias
        v = self.value(x) + self.v_bias

        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(
            (0, 2, 1, 3)
        )
        k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(
            (0, 2, 1, 3)
        )
        v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(
            (0, 2, 1, 3)
        )

        attn_weights = jnp.matmul(q, k.transpose((0, 1, 3, 2))) / jnp.sqrt(
            self.head_dim
        )

        combined_mask = mask
        if padding_mask is not None:
            combined_mask = jnp.logical_and(mask, padding_mask)

        if combined_mask is not None:
            attn_weights = jnp.where(combined_mask, attn_weights, -jnp.inf)

        attn_weights = nnx.softmax(attn_weights, axis=-1)
        attn_weights = self.dropout(attn_weights, deterministic=not training)

        attn_output = jnp.matmul(attn_weights, v)
        attn_output = attn_output.transpose((0, 2, 1, 3)).reshape(
            (batch_size, seq_len, self.embed_dim)
        )

        output = self.out(attn_output) + self.out_bias
        return output

Now define the model architecture. Similarly, notice how the up and down projection layers are defined with the NNX LoRALinear layer.

In [9]:

class TransformerBlock(nnx.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        ff_dim: int,
        dropout_rate: float,
        rngs: nnx.Rngs,
        layer_idx: int,
        weights: dict,
    ):
        self.layer_norm1 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(
                nnx.initializers.ones_init(), NamedSharding(mesh, P("model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.layer_norm1.scale.value = weights[f"h.{layer_idx}.ln_1.weight"]
        self.layer_norm1.bias.value = weights[f"h.{layer_idx}.ln_1.bias"]
        self.mha = CustomMHA(
            embed_dim, num_heads, dropout_rate, layer_idx, weights, rngs
        )
        self.dropout1 = nnx.Dropout(rate=dropout_rate)
        self.layer_norm2 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(
                nnx.initializers.ones_init(), NamedSharding(mesh, P("model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.layer_norm2.scale.value = weights[f"h.{layer_idx}.ln_2.weight"]
        self.layer_norm2.bias.value = weights[f"h.{layer_idx}.ln_2.bias"]
        self.linear1 = nnx.LoRALinear(
            in_features=embed_dim,
            out_features=ff_dim,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, "model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
            lora_rank=lora_rank,
        )
        self.linear1.kernel.value = weights[f"h.{layer_idx}.mlp.c_fc.weight"]
        self.linear1.bias.value = weights[f"h.{layer_idx}.mlp.c_fc.bias"]
        self.linear2 = nnx.LoRALinear(
            in_features=ff_dim,
            out_features=embed_dim,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, "model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
            lora_rank=lora_rank,
        )
        self.linear2.kernel.value = weights[f"h.{layer_idx}.mlp.c_proj.weight"]
        self.linear2.bias.value = weights[f"h.{layer_idx}.mlp.c_proj.bias"]
        self.dropout2 = nnx.Dropout(rate=dropout_rate)

    def __call__(self, inputs, padding_mask=None, training: bool = False):
        input_shape = inputs.shape
        bs, seq_len, emb_sz = input_shape

        attention_output = self.mha(
            self.layer_norm1(inputs),
            mask=causal_attention_mask(seq_len),
            padding_mask=padding_mask,
            training=training,
        )
        x = inputs + self.dropout1(attention_output, deterministic=not training)

        # MLP
        mlp_output = self.linear1(self.layer_norm2(x))
        mlp_output = nnx.gelu(mlp_output)
        mlp_output = self.linear2(mlp_output)
        mlp_output = self.dropout2(mlp_output, deterministic=not training)

        return x + mlp_output


class TokenAndPositionEmbedding(nnx.Module):
    def __init__(
        self,
        seqlen: int,
        vocab_size: int,
        embed_dim: int,
        rngs: nnx.Rngs,
        weights: dict,
    ):
        self.token_emb = nnx.Embed(
            num_embeddings=vocab_size,
            features=embed_dim,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.pos_emb = nnx.Embed(
            num_embeddings=seqlen,
            features=embed_dim,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.token_emb.embedding.value = weights["wte.weight"]
        self.pos_emb.embedding.value = weights["wpe.weight"]

    def __call__(self, x):
        positions = jnp.arange(0, x.shape[1])[None, :]
        position_embedding = self.pos_emb(positions)
        token_embedding = self.token_emb(x)
        return self.token_emb, token_embedding + position_embedding


class GPT2(nnx.Module):
    def __init__(
        self,
        seqlen: int,
        vocab_size: int,
        embed_dim: int,
        num_heads: int,
        rate: float,
        feed_forward_dim: int,
        num_transformer_blocks: int,
        rngs: nnx.Rngs,
        weights: dict,
    ):
        self.embedding_layer = TokenAndPositionEmbedding(
            seqlen, vocab_size, embed_dim, rngs=rngs, weights=weights
        )
        self.dropout = nnx.Dropout(rate=rate)

        self.transformer_blocks = [
            TransformerBlock(
                embed_dim,
                num_heads,
                feed_forward_dim,
                dropout_rate,
                rngs=rngs,
                layer_idx=i,
                weights=weights,
            )
            for i in range(num_transformer_blocks)
        ]

        self.layer_norm = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(
                nnx.initializers.ones_init(), NamedSharding(mesh, P("model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.layer_norm.scale.value = weights["ln_f.weight"]
        self.layer_norm.bias.value = weights["ln_f.bias"]

    def __call__(self, inputs, padding_mask=None, training: bool = False):
        token_embedding, x = self.embedding_layer(inputs)
        x = self.dropout(x, deterministic=not training)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, padding_mask=padding_mask, training=training)
        x = self.layer_norm(x)
        # Weights tying
        outputs = token_embedding.attend(x)
        return outputs

    @nnx.jit
    def sample_from(self, logits, key):
        logits, indices = jax.lax.top_k(logits, k=top_k)
        logits = nnx.softmax(logits / sampling_temp)
        return jax.random.choice(key, indices, p=logits)

    @nnx.jit
    def generate_step(self, params, static_def, padded_tokens, length, key):
        padding_mask = jnp.arange(seqlen) < length
        padding_mask = padding_mask.reshape(1, 1, 1, seqlen)

        model = nnx.merge(params, static_def)
        logits = model(padded_tokens, padding_mask=padding_mask, training=False)
        last_token_logits = logits[:, length - 1, :]

        key, subkey = jax.random.split(key)
        next_token = self.sample_from(
            jnp.squeeze(last_token_logits), subkey
        )  # Pass subkey here
        return next_token

    def generate_text(self, max_tokens, start_tokens):
        key = jax.random.PRNGKey(int(time.time()))

        params, static_def = nnx.split(self)

        tokens = jnp.array(start_tokens, dtype=jnp.int32)[None, :]
        end_token = tokenizer.encode(
            "<|endoftext|>", allowed_special={"<|endoftext|>"}
        )[0]

        current_len = tokens.shape[1]
        padded_tokens = jnp.pad(tokens, ((0, 0), (0, seqlen - current_len)))

        print(tokenizer.decode(tokens[0]), end="", flush=True)

        for i in range(max_tokens):
            key, subkey = jax.random.split(key)

            next_token = self.generate_step(
                params, static_def, padded_tokens, current_len, subkey
            )

            if next_token.item() == end_token:
                break

            print(tokenizer.decode([next_token.item()]), end="", flush=True)

            padded_tokens = padded_tokens.at[:, current_len].set(next_token.item())
            current_len += 1

        final_tokens = padded_tokens[0, :current_len]
        return tokenizer.decode(final_tokens.tolist())


def create_model(rngs, weights):
    return GPT2(
        seqlen,
        vocab_size,
        embed_dim,
        num_heads,
        dropout_rate,
        feed_forward_dim,
        num_transformer_blocks,
        rngs=rngs,
        weights=weights,
    )

Now we are going to load the GPT2 model weights from Hugging Face.

In [10]:
model_id = "openai-community/gpt2"
if os.path.exists("/content"):
    # Colab
    weights_base_dir = "/content"
else:
    # Local machine
    weights_base_dir = "."

path_to_model_weights = os.path.join(weights_base_dir, model_id)

snapshot_download(
    repo_id=model_id, local_dir=path_to_model_weights, allow_patterns="*.safetensors"
)


def load_safetensors():
    weights = {}
    safetensors_files = Path(path_to_model_weights).glob("*.safetensors")

    for file in safetensors_files:
        with safe_open(file, framework="jax", device="cpu") as f:
            for key in f.keys():
                weights[key] = f.get_tensor(key)
    return weights


weights = load_safetensors()
model = create_model(rngs=nnx.Rngs(0), weights=weights)

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Use Weights and Biases to track training progress.

In [11]:
if platform == "Colab":
  from google.colab import userdata
  os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
  os.environ['KAGGLE_USERNAME'] = userdata.get('KAGGLE_USERNAME')
  os.environ['KAGGLE_KEY'] = userdata.get('KAGGLE_KEY')
elif platform == "Kaggle":
  from kaggle_secrets import UserSecretsClient
  user_secrets = UserSecretsClient()
  os.environ['WANDB_API_KEY'] = user_secrets.get_secret('WANDB_API_KEY')
else:
  print("Please set the WANDB_API_KEY env variable manually") #input()

wandb.login()

import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project='GPT2-LoRA',

    # track hyperparameters and run metadata
    config={
      'architecture': GPT2_variant,
      'dataset': 'OpenWebText',
      'platform': platform,
      'max_steps': max_steps,
      'batch_size': batch_size,
      'dtype': dtype,
      'param_dtype': param_dtype,
      'init_learning_rate': init_learning_rate,
      'num_transformer_blocks': num_transformer_blocks,
      'seqlen': seqlen,
      'embed_dim': embed_dim,
      'num_heads': num_heads,
      'feed_forward_dim': feed_forward_dim,
      'max_steps': max_steps,
      'batch_size': batch_size,
      'weight_decay': weight_decay
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mwindmaple[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Instruct tune

We are going to use the [Alpaca dataset](https://huggingface.co/datasets/tatsu-lab/alpaca) from Stanford.

In [12]:
import grain.python as pygrain
import pandas as pd
from datasets import load_dataset

@dataclass
class TextDataset:
    data_df: list
    seqlen: int

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx: int):
        # Use Tiktoken for tokenization
        encoding = tokenizer.encode(
            self.data_df.iloc[idx], allowed_special={"<|endoftext|>"}
        )[:self.seqlen        ]
        return encoding + [50256] * (self.seqlen - len(encoding))


def load_and_preprocess_data(alpaca_data, batch_size, seqlen):
    alpaca_data_df = pd.DataFrame(alpaca_data)
    dataset = TextDataset(alpaca_data_df["text"], seqlen)
    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=True,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=1,
    )
    dl = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )
    return dl


alpaca_data = load_dataset("tatsu-lab/alpaca", split="train")
text_dl = load_and_preprocess_data(alpaca_data, batch_size, seqlen)

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001-a09b74b3ef9c3b(…):   0%|          | 0.00/24.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/52002 [00:00<?, ? examples/s]

Define the loss and training step function. In the train_step() function, when we compute the gradients, we only compute the ones for the LoRA parameters and leave the other parameters alone.

In [13]:
@nnx.jit
def loss_fn(model, batch):
    logits = model(batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch[1]
    ).mean()
    return loss, logits


@nnx.jit
def train_step(
    lora_params, opt_state, graphdef, static_params, metrics: nnx.MultiMetric, batch
):
    grad_fn = nnx.value_and_grad(
        lambda lp: loss_fn(nnx.merge(graphdef, lp, static_params), batch), has_aux=True
    )
    (loss, logits), grads = grad_fn(lora_params)
    metrics.update(loss=loss, logits=logits, lables=batch[1])
    updates, opt_state = tx.update(grads, opt_state, lora_params)
    lora_params = optax.apply_updates(lora_params, updates)
    return lora_params, opt_state

Save a copy of the inital model parameters so that we can later verify what is changed and what is not after LoRA finetuning.

In [14]:
graphdef, lora_params, static_params = nnx.split(model, LoRAParam, nnx.Param)

# Save a copy of the initial parameters for later verification.
initial_lora_params = lora_params
initial_static_params = static_params

Define optimizer and metrics.

In [15]:

schedule = optax.cosine_decay_schedule(
    init_value=init_learning_rate, decay_steps=max_steps
)
tx = optax.chain(optax.adamw(learning_rate=schedule, weight_decay=weight_decay))
opt_state = tx.init(lora_params)


metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
)

metrics_history = {
    "train_loss": [],
}

Do a test run on our pretrained model to see how it reponds to user instruction. Note how we use a template to format the prompt, which needs to be consistent with the training data.

In [16]:
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}"

start_prompt = template.format(
    instruction="What is the future for human?",
    input="",
    output="",
)
start_tokens = tokenizer.encode(start_prompt)[:seqlen]
print(f"***Initial generated text:")
generated_text = model.generate_text(seqlen//5, start_tokens)

***Initial generated text:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is the future for human?

### Input:


### Response:

What do the tasks in this program do, besides making your computer do something you're not sure is necessary. You will learn how to do something that isn't important. You learn to not let your body, which isn't so different, decide what will work best for what purposes.

### Method of action:


What will it cost, how much time is left in it to learn what the next steps might be like and how do they relate to each task?

If you can't see that, then you're just a lazy ass! You should learn what will help your life go faster and easier, instead of trying to learn all these different things that you're supposed to get your brain to do each day to improve their productivity and productivity.

As you can see, the pretrained model generates a bunch of random text; clearly it does not know how to follow the instruction to generate an appropriate answer, which is not surprising given that we have not trained it to do so.

Now let's do the instruction tuning with LoRA.

In [17]:
prep_target_batch = jax.vmap(
    lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0])))
)

step = 0
start_time = time.time()
for batch in text_dl:
    if len(batch) % len(jax.devices()) != 0:
        continue  # skip the remaining elements
    input_batch = jnp.array(batch).T
    target_batch = prep_target_batch(input_batch)
    lora_params, opt_state = train_step(
        lora_params,
        opt_state,
        graphdef,
        static_params,
        metrics,
        jax.device_put(
            (input_batch, target_batch), NamedSharding(mesh, P("batch", None))
        ),
    )

    if (step + 1) % 20 == 0:
        for metric, value in metrics.compute().items():
            metrics_history[f"train_{metric}"].append(value)
        metrics.reset()

        elapsed_time = time.time() - start_time
        print(
            f"\n\nStep {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds"
        )
        wandb.log(data={'train_loss': metrics_history['train_loss'][-1]}, step=step)
        start_time = time.time()
        print(f"\n***Intermediate generated text:")
        intermediate_model = nnx.merge(graphdef, lora_params, static_params)
        generated_text = intermediate_model.generate_text(seqlen // 5, start_tokens)
    step += 1

# Final text generation
model = nnx.merge(graphdef, lora_params, static_params)
print(f"\n***Final generated text:")
generated_text = model.generate_text(seqlen // 5, start_tokens)



Step 20, Loss: 2.0909180641174316, Elapsed Time: 35.95 seconds

***Intermediate generated text:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is the future for human?

### Input:


### Response:

### Output:

Step 40, Loss: 0.3573242127895355, Elapsed Time: 38.03 seconds

***Intermediate generated text:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is the future for human?

### Input:


### Response:
What are we to learn from this? Is that you?

Step 60, Loss: 0.32646486163139343, Elapsed Time: 5.31 seconds

***Intermediate generated text:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is the future for human?

### Input:


### Response:
What do you mean, what do you expect in the future that's going to be better for us, for humanity?




#

As you can see, at the end of the finetuning, the model is able to follow human instruction and generate a somewhat sensible answer. And the answer is actually better than what previously got from our own pretrained model.

And to prove that only LoRA parameters were trained, we compare the final parameters with the inital copies we made before finetuning.

In [18]:
# Check that the main kernel weights (static params) have NOT changed.
static_leaves_before = jax.tree_util.tree_leaves(initial_static_params)
static_leaves_after = jax.tree_util.tree_leaves(static_params)

static_params_are_unchanged = all(
    jnp.allclose(b, a) for b, a in zip(static_leaves_before, static_leaves_after)
)
print(f"Original model parameters (static) remained frozen: {static_params_are_unchanged}")

# Check that the LoRA parameters HAVE changed.
lora_leaves_before = jax.tree_util.tree_leaves(initial_lora_params)
lora_leaves_after = jax.tree_util.tree_leaves(lora_params)

lora_params_have_changed = not all(
    jnp.allclose(b, a) for b, a in zip(lora_leaves_before, lora_leaves_after)
)
print(f"LoRA parameters were updated: {lora_params_have_changed}")

Original model parameters (static) remained frozen: True
LoRA parameters were updated: True


One other thing you could do is to compare the HBM usage of this LoRA run with the previous full parameter finetune run (for example, by eyeballing the tpu-info output). There will be a pretty significant reduction in terms of memory usage.