# GPT2 instruction tuning

This notebook demonstrates how to finetune a pretrained GPT2(124M) model to follow user instructions. We are going to load the pretrained GPT2 model weights from Hugging Face and then instruct finetune the model on TPU.

## Determine platform

In [1]:
import os
if os.path.exists('/content/'):
  platform = "Colab"
elif os.path.exists('/kaggle/'):
  platform = "Kaggle"
else:
  # Assume using Cloud TPU otherwise
  platform = "GCP"

## Setup

Install JAX and Flax first.

In [2]:
!pip install -q jax-ai-stack[grain]
if platform == "Colab": # temp workaround on Colab (https://github.com/jax-ml/jax-ai-stack/issues/149)
  !pip install -Uq "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -Uq tiktoken matplotlib kaggle wandb tpu-info datasets


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/456.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m450.6/456.0 kB[0m [31m16.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m456.0/456.0 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/473.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m473.3/473.3 kB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.4/2.4 MB[0m [31m152.1 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.4/2.4 MB[0m [31m152.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━

Confirm we have TPUs set up.

In [2]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Take care of the imports.

In [3]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax, orbax
from collections import Counter
from dataclasses import dataclass
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import numpy as np
import tiktoken, time, wandb
from huggingface_hub import snapshot_download
from safetensors import safe_open
from pathlib import Path
from flax.nnx.nn.lora import LoRAParam

## Build the model

Define the device mesh.


In [4]:
### Alternative data and model parallel
# mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))

mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))

We are going to use the GPT-2 tokenizer via OpenAI's [Tiktoken](https://github.com/openai/tiktoken) library.

In [5]:
tokenizer = tiktoken.get_encoding("gpt2")

Set some hyperparameters.

In [6]:
vocab_size = tokenizer.n_vocab
GPT2_variant = "GPT2" # Only supports GPT2
num_transformer_blocks = 12
seqlen = 1024
embed_dim = 768
num_heads = 12
feed_forward_dim = 4 * embed_dim
if platform == "Colab":
    batch_size = 24 # TPU v2
else:
    batch_size = 72 # TPU v3

dropout_rate = 0.1
lora_rank = 8

max_steps = 600000*12//batch_size
# Kaggle TPU limit per session is 9 hours, which is ~95K steps for GPT2
if platform == "Kaggle":
  max_steps = 90000
init_learning_rate = 5e-4
weight_decay = 1e-1
top_k = 10
sampling_temp = 2
dtype = jnp.bfloat16
param_dtype = jnp.float32

We are going to load the weights from Hugging Face, which has a different tensor layout from what we used to use

In [7]:
def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))


class CustomMHA(nnx.Module):
    def __init__(self, embed_dim, num_heads, dropout_rate, layer_idx, weights, rngs):
        self.num_heads = num_heads
        self.head_dim = embed_dim // self.num_heads
        self.embed_dim = embed_dim

        kernel_init = nnx.with_partitioning(
            nnx.initializers.xavier_uniform(), (P(None, "model"),)
        )

        self.query = nnx.LoRALinear(
            embed_dim,
            embed_dim,
            rngs=rngs,
            use_bias=False,
            kernel_init=kernel_init,
            lora_rank=lora_rank,
        )
        self.key = nnx.LoRALinear(
            embed_dim,
            embed_dim,
            rngs=rngs,
            use_bias=False,
            kernel_init=kernel_init,
            lora_rank=lora_rank,
        )
        self.value = nnx.LoRALinear(
            embed_dim,
            embed_dim,
            rngs=rngs,
            use_bias=False,
            kernel_init=kernel_init,
            lora_rank=lora_rank,
        )
        self.out = nnx.LoRALinear(
            embed_dim,
            embed_dim,
            rngs=rngs,
            use_bias=False,
            kernel_init=kernel_init,
            lora_rank=lora_rank,
        )

        qkv_kernel = weights[f"h.{layer_idx}.attn.c_attn.weight"]
        q_kernel, k_kernel, v_kernel = jnp.split(qkv_kernel, 3, axis=-1)
        self.query.kernel.value = q_kernel
        self.key.kernel.value = k_kernel
        self.value.kernel.value = v_kernel

        qkv_bias = weights[f"h.{layer_idx}.attn.c_attn.bias"]
        q_b, k_b, v_b = jnp.split(qkv_bias, 3, axis=-1)

        self.q_bias = nnx.Param(q_b, sharding=P("model"))
        self.k_bias = nnx.Param(k_b, sharding=P("model"))
        self.v_bias = nnx.Param(v_b, sharding=P("model"))

        self.out.kernel.value = weights[f"h.{layer_idx}.attn.c_proj.weight"]
        self.out_bias = nnx.Param(
            weights[f"h.{layer_idx}.attn.c_proj.bias"], sharding=P("model")
        )

        self.dropout = nnx.Dropout(dropout_rate)

    def __call__(self, x, mask, padding_mask=None, training: bool = False):
        batch_size, seq_len, _ = x.shape

        q = self.query(x) + self.q_bias
        k = self.key(x) + self.k_bias
        v = self.value(x) + self.v_bias

        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(
            (0, 2, 1, 3)
        )
        k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(
            (0, 2, 1, 3)
        )
        v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(
            (0, 2, 1, 3)
        )

        attn_weights = jnp.matmul(q, k.transpose((0, 1, 3, 2))) / jnp.sqrt(
            self.head_dim
        )

        combined_mask = mask
        if padding_mask is not None:
            combined_mask = jnp.logical_and(mask, padding_mask)

        if combined_mask is not None:
            attn_weights = jnp.where(combined_mask, attn_weights, -jnp.inf)

        attn_weights = nnx.softmax(attn_weights, axis=-1)
        attn_weights = self.dropout(attn_weights, deterministic=not training)

        attn_output = jnp.matmul(attn_weights, v)
        attn_output = attn_output.transpose((0, 2, 1, 3)).reshape(
            (batch_size, seq_len, self.embed_dim)
        )

        output = self.out(attn_output) + self.out_bias
        return output

Now define the model architecture, which is the same as in our previous pretraining notebook.

In [8]:

class TransformerBlock(nnx.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        ff_dim: int,
        dropout_rate: float,
        rngs: nnx.Rngs,
        layer_idx: int,
        weights: dict,
    ):
        self.layer_norm1 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(
                nnx.initializers.ones_init(), NamedSharding(mesh, P("model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.layer_norm1.scale.value = weights[f"h.{layer_idx}.ln_1.weight"]
        self.layer_norm1.bias.value = weights[f"h.{layer_idx}.ln_1.bias"]
        self.mha = CustomMHA(
            embed_dim, num_heads, dropout_rate, layer_idx, weights, rngs
        )
        self.dropout1 = nnx.Dropout(rate=dropout_rate)
        self.layer_norm2 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(
                nnx.initializers.ones_init(), NamedSharding(mesh, P("model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.layer_norm2.scale.value = weights[f"h.{layer_idx}.ln_2.weight"]
        self.layer_norm2.bias.value = weights[f"h.{layer_idx}.ln_2.bias"]
        self.linear1 = nnx.LoRALinear(
            in_features=embed_dim,
            out_features=ff_dim,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, "model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
            lora_rank=lora_rank,
        )
        self.linear1.kernel.value = weights[f"h.{layer_idx}.mlp.c_fc.weight"]
        self.linear1.bias.value = weights[f"h.{layer_idx}.mlp.c_fc.bias"]
        self.linear2 = nnx.LoRALinear(
            in_features=ff_dim,
            out_features=embed_dim,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, "model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
            lora_rank=lora_rank,
        )
        self.linear2.kernel.value = weights[f"h.{layer_idx}.mlp.c_proj.weight"]
        self.linear2.bias.value = weights[f"h.{layer_idx}.mlp.c_proj.bias"]
        self.dropout2 = nnx.Dropout(rate=dropout_rate)

    def __call__(self, inputs, padding_mask=None, training: bool = False):
        input_shape = inputs.shape
        bs, seq_len, emb_sz = input_shape

        attention_output = self.mha(
            self.layer_norm1(inputs),
            mask=causal_attention_mask(seq_len),
            padding_mask=padding_mask,
            training=training,
        )
        x = inputs + self.dropout1(attention_output, deterministic=not training)

        # MLP
        mlp_output = self.linear1(self.layer_norm2(x))
        mlp_output = nnx.gelu(mlp_output)
        mlp_output = self.linear2(mlp_output)
        mlp_output = self.dropout2(mlp_output, deterministic=not training)

        return x + mlp_output


class TokenAndPositionEmbedding(nnx.Module):
    def __init__(
        self,
        seqlen: int,
        vocab_size: int,
        embed_dim: int,
        rngs: nnx.Rngs,
        weights: dict,
    ):
        self.token_emb = nnx.Embed(
            num_embeddings=vocab_size,
            features=embed_dim,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.pos_emb = nnx.Embed(
            num_embeddings=seqlen,
            features=embed_dim,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.token_emb.embedding.value = weights["wte.weight"]
        self.pos_emb.embedding.value = weights["wpe.weight"]

    def __call__(self, x):
        positions = jnp.arange(0, x.shape[1])[None, :]
        position_embedding = self.pos_emb(positions)
        token_embedding = self.token_emb(x)
        return self.token_emb, token_embedding + position_embedding


class GPT2(nnx.Module):
    def __init__(
        self,
        seqlen: int,
        vocab_size: int,
        embed_dim: int,
        num_heads: int,
        rate: float,
        feed_forward_dim: int,
        num_transformer_blocks: int,
        rngs: nnx.Rngs,
        weights: dict,
    ):
        self.embedding_layer = TokenAndPositionEmbedding(
            seqlen, vocab_size, embed_dim, rngs=rngs, weights=weights
        )
        self.dropout = nnx.Dropout(rate=rate)

        self.transformer_blocks = [
            TransformerBlock(
                embed_dim,
                num_heads,
                feed_forward_dim,
                dropout_rate,
                rngs=rngs,
                layer_idx=i,
                weights=weights,
            )
            for i in range(num_transformer_blocks)
        ]

        self.layer_norm = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(
                nnx.initializers.ones_init(), NamedSharding(mesh, P("model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.layer_norm.scale.value = weights["ln_f.weight"]
        self.layer_norm.bias.value = weights["ln_f.bias"]

    def __call__(self, inputs, padding_mask=None, training: bool = False):
        token_embedding, x = self.embedding_layer(inputs)
        x = self.dropout(x, deterministic=not training)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, padding_mask=padding_mask, training=training)
        x = self.layer_norm(x)
        # Weights tying
        outputs = token_embedding.attend(x)
        return outputs

    @nnx.jit
    def sample_from(self, logits, key):
        logits, indices = jax.lax.top_k(logits, k=top_k)
        logits = nnx.softmax(logits / sampling_temp)
        return jax.random.choice(key, indices, p=logits)

    @nnx.jit
    def generate_step(self, params, static_def, padded_tokens, length, key):
        padding_mask = jnp.arange(seqlen) < length
        padding_mask = padding_mask.reshape(1, 1, 1, seqlen)

        model = nnx.merge(params, static_def)
        logits = model(padded_tokens, padding_mask=padding_mask, training=False)
        last_token_logits = logits[:, length - 1, :]

        key, subkey = jax.random.split(key)
        next_token = self.sample_from(
            jnp.squeeze(last_token_logits), subkey
        )  # Pass subkey here
        return next_token

    def generate_text(self, max_tokens, start_tokens):
        key = jax.random.PRNGKey(int(time.time()))

        params, static_def = nnx.split(self)

        tokens = jnp.array(start_tokens, dtype=jnp.int32)[None, :]
        end_token = tokenizer.encode(
            "<|endoftext|>", allowed_special={"<|endoftext|>"}
        )[0]

        current_len = tokens.shape[1]
        padded_tokens = jnp.pad(tokens, ((0, 0), (0, seqlen - current_len)))

        print(tokenizer.decode(tokens[0]), end="", flush=True)

        for i in range(max_tokens):
            key, subkey = jax.random.split(key)

            next_token = self.generate_step(
                params, static_def, padded_tokens, current_len, subkey
            )

            if next_token.item() == end_token:
                break

            print(tokenizer.decode([next_token.item()]), end="", flush=True)

            padded_tokens = padded_tokens.at[:, current_len].set(next_token.item())
            current_len += 1

        final_tokens = padded_tokens[0, :current_len]
        return tokenizer.decode(final_tokens.tolist())


def create_model(rngs, weights):
    return GPT2(
        seqlen,
        vocab_size,
        embed_dim,
        num_heads,
        dropout_rate,
        feed_forward_dim,
        num_transformer_blocks,
        rngs=rngs,
        weights=weights,
    )

Although we previously pretrained a pretty good GPT2 model from scratch, it is still less capable than the OpenAI official model (this is probably because the OpenWebText dataset is less comprehensive). So we are going to load the official weights from Hugging Face now.

In [9]:
model_id = "openai-community/gpt2"
if os.path.exists("/kaggle"):
    weights_base_dir = "/kaggle/tmp"
elif os.path.exists("/content"):
    # Colab
    weights_base_dir = "/content"
else:
    # Local machine
    weights_base_dir = "."

path_to_model_weights = os.path.join(weights_base_dir, model_id)

snapshot_download(
    repo_id=model_id, local_dir=path_to_model_weights, allow_patterns="*.safetensors"
)


def load_safetensors():
    weights = {}
    safetensors_files = Path(path_to_model_weights).glob("*.safetensors")

    for file in safetensors_files:
        with safe_open(file, framework="jax", device="cpu") as f:
            for key in f.keys():
                weights[key] = f.get_tensor(key)
    return weights


weights = load_safetensors()
model = create_model(rngs=nnx.Rngs(0), weights=weights)

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

In [10]:
start_prompt = "Once uppon a time"
start_tokens = tokenizer.encode(start_prompt)[:seqlen]
print(f"***Initial generated text:")
generated_text = model.generate_text(seqlen // 5, start_tokens)

***Initial generated text:
Once uppon a time (say 10), i.i.upport is the first step and upports a time to a time. So i.o.vaport is the time. The last part of our code, if i.p. is an empty array, and is not called a second step (for example in C++11 we would have called upport, which returns upport to uap. But if we use upp, it's an empty array and we need to re-arrand the array in order to make it work), then the next part will call upp to return it and so we would have uptime. This method will be called to call our own callback method, uppon. uptime() , but if uppon is called to return a pointer, we would need a way to call the method that was passed in, i.i.upport. The callback is called to return upp. upport

For a list and for loops

Use Weights and Biases to track training progress.

In [11]:
if platform == "Colab":
  from google.colab import userdata
  os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
  os.environ['KAGGLE_USERNAME'] = userdata.get('KAGGLE_USERNAME')
  os.environ['KAGGLE_KEY'] = userdata.get('KAGGLE_KEY')
elif platform == "Kaggle":
  from kaggle_secrets import UserSecretsClient
  user_secrets = UserSecretsClient()
  os.environ['WANDB_API_KEY'] = user_secrets.get_secret('WANDB_API_KEY')
else:
  print("Please set the WANDB_API_KEY env variable manually") #input()

wandb.login()

import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project='GPT2-LoRA',

    # track hyperparameters and run metadata
    config={
      'architecture': GPT2_variant,
      'dataset': 'OpenWebText',
      'platform': platform,
      'max_steps': max_steps,
      'batch_size': batch_size,
      'dtype': dtype,
      'param_dtype': param_dtype,
      'init_learning_rate': init_learning_rate,
      'num_transformer_blocks': num_transformer_blocks,
      'seqlen': seqlen,
      'embed_dim': embed_dim,
      'num_heads': num_heads,
      'feed_forward_dim': feed_forward_dim,
      'max_steps': max_steps,
      'batch_size': batch_size,
      'weight_decay': weight_decay
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mwindmaple[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Instruct tune

We are going to use the [Alpaca dataset](https://huggingface.co/datasets/tatsu-lab/alpaca) from Stanford.

In [12]:
import grain.python as pygrain
import pandas as pd
from datasets import load_dataset

@dataclass
class TextDataset:
    data_df: list
    seqlen: int

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx: int):
        # Use Tiktoken for tokenization
        encoding = tokenizer.encode(
            self.data_df.iloc[idx], allowed_special={"<|endoftext|>"}
        )[:self.seqlen        ]
        return encoding + [50256] * (self.seqlen - len(encoding))


def load_and_preprocess_data(alpaca_data, batch_size, seqlen):
    alpaca_data_df = pd.DataFrame(alpaca_data)
    dataset = TextDataset(alpaca_data_df["text"], seqlen)
    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=True,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=1,
    )
    dl = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )
    return dl


alpaca_data = load_dataset("tatsu-lab/alpaca", split="train")
text_dl = load_and_preprocess_data(alpaca_data, batch_size, seqlen)

Define the loss and training step function.

In [13]:
@nnx.jit
def loss_fn(model, batch):
    logits = model(batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch[1]
    ).mean()
    return loss, logits


@nnx.jit
def train_step(
    lora_params, opt_state, graphdef, static_params, metrics: nnx.MultiMetric, batch
):
    grad_fn = nnx.value_and_grad(
        lambda lp: loss_fn(nnx.merge(graphdef, lp, static_params), batch), has_aux=True
    )
    (loss, logits), grads = grad_fn(lora_params)
    metrics.update(loss=loss, logits=logits, lables=batch[1])
    updates, opt_state = tx.update(grads, opt_state, lora_params)
    lora_params = optax.apply_updates(lora_params, updates)
    return lora_params, opt_state

Define optimizer and metrics.

In [14]:
graphdef, lora_params, static_params = nnx.split(model, LoRAParam, nnx.Param)

schedule = optax.cosine_decay_schedule(
    init_value=init_learning_rate, decay_steps=max_steps
)
tx = optax.chain(optax.adamw(learning_rate=schedule, weight_decay=weight_decay))
opt_state = tx.init(lora_params)


metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
)

metrics_history = {
    "train_loss": [],
}

Do a test run on our pretrained model to see how it reponds to instruction. Note how we use a template to format the prompt, which needs to be consistent with the training data.

In [15]:
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}"

start_prompt = template.format(
    instruction="What is the future for human?",
    input="",
    output="",
)
start_tokens = tokenizer.encode(start_prompt)[:seqlen]
print(f"***Initial generated text:")
generated_text = model.generate_text(seqlen//5, start_tokens)

***Initial generated text:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is the future for human?

### Input:


### Response:

How does the future for man look like:


How does the future look like:

 and why would I ever like it if man doesn't get a human?

I want a human in this world, so this is not my choice.

I am going to do what he wants, if I'm not careful.


What does this mean about human?


I want to do what he's going to do, and I have to do what's right in this case?


What does human think about this, and if it makes them happy, then it has something to do with my human. And I am the guy that does the things to get human to do these things for them?


And I don't care. That's why he will do anything to help man. I don't care if a human lives and I'm happy.


What do I do? That's my decision now, to make sure this human has something in his life that

As you can see, the pretrained model generates a bunch of garbage; clearly it does not know how to follow the instruction to generate an appropriate answer, which is not surprising given that we have not trained it to do so.

Now let's do the instruction tuning.

In [16]:
prep_target_batch = jax.vmap(
    lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0])))
)

step = 0
start_time = time.time()
for batch in text_dl:
    if len(batch) % len(jax.devices()) != 0:
        continue  # skip the remaining elements
    input_batch = jnp.array(batch).T
    target_batch = prep_target_batch(input_batch)
    lora_params, opt_state = train_step(
        lora_params,
        opt_state,
        graphdef,
        static_params,
        metrics,
        jax.device_put(
            (input_batch, target_batch), NamedSharding(mesh, P("batch", None))
        ),
    )

    if (step + 1) % 20 == 0:
        for metric, value in metrics.compute().items():
            metrics_history[f"train_{metric}"].append(value)
        metrics.reset()

        elapsed_time = time.time() - start_time
        print(
            f"\n\nStep {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds"
        )
        # wandb.log(data={'train_loss': metrics_history['train_loss'][-1]}, step=step)
        start_time = time.time()
        # print(f"\n***Intermediate generated text:")
        intermediate_model = nnx.merge(graphdef, lora_params, static_params)
        generated_text = intermediate_model.generate_text(seqlen // 5, start_tokens)
    step += 1

# Final text generation
model = nnx.merge(graphdef, lora_params, static_params)
print(f"\n***Final generated text:")
generated_text = model.generate_text(seqlen // 5, start_tokens)



Step 20, Loss: 2.0909180641174316, Elapsed Time: 38.90 seconds


Step 40, Loss: 0.3573242127895355, Elapsed Time: 18.07 seconds


Step 60, Loss: 0.32646486163139343, Elapsed Time: 3.31 seconds


Step 80, Loss: 0.30937501788139343, Elapsed Time: 3.31 seconds


Step 100, Loss: 0.294921875, Elapsed Time: 3.54 seconds


Step 120, Loss: 0.26103517413139343, Elapsed Time: 3.31 seconds


Step 140, Loss: 0.22988282144069672, Elapsed Time: 3.44 seconds


Step 160, Loss: 0.21030274033546448, Elapsed Time: 3.37 seconds


Step 180, Loss: 0.20200195908546448, Elapsed Time: 3.35 seconds


Step 200, Loss: 0.19423829019069672, Elapsed Time: 3.46 seconds


Step 220, Loss: 0.19350586831569672, Elapsed Time: 3.32 seconds


Step 240, Loss: 0.1982421875, Elapsed Time: 3.40 seconds


Step 260, Loss: 0.18901367485523224, Elapsed Time: 3.35 seconds


Step 280, Loss: 0.19008789956569672, Elapsed Time: 3.55 seconds


Step 300, Loss: 0.18852539360523224, Elapsed Time: 3.35 seconds


Step 320, Loss: 0.186865240

As you can see, at the end of the finetuning, the model is able to follow human instruction and generate a somewhat sensible answer. And the answer is actually better than what previously got from our own pretrained model.