# GPT-J-6B Inference Demo

<a href="http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to run the [GPT-J-6B model](https://github.com/kingoflolz/mesh-transformer-jax/#GPT-J-6B). See the link for more details about the model, including evaluation metrics and credits.

## Install Dependencies

First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).

!!! **Make sure you are using a TPU runtime!** !!!

In [None]:
!apt install zstd

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd

!time tar -I zstd -xf step_383500_slim.tar.zstd

!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following NEW packages will be installed:
  zstd
0 upgraded, 1 newly installed, 0 to remove and 37 not upgraded.
Need to get 278 kB of archives.
After this operation, 1,141 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 zstd amd64 1.3.3+dfsg-2ubuntu1.2 [278 kB]
Fetched 278 kB in 1s (341 kB/s)
Selecting previously unselected package zstd.
(Reading database ... 155047 files and directories currently installed.)
Preparing to unpack .../zstd_1.3.3+dfsg-2ubuntu1.2_amd64.deb ...
Unpacking zstd (1.3.3+dfsg-2ubuntu1.2) ...
Setting up zstd (1.3.3+dfsg-2ubuntu1.2) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
--2021-10-22 03:42:27--  https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
Resolving the-eye.eu (the-eye.eu)... 162.213.130.244
Connecting to the-eye.eu (the-eye.eu)|162.213.130.244|:443... connected.
HT

Processing ./mesh-transformer-jax
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Collecting jax==0.2.12
  Downloading jax-0.2.12.tar.gz (590 kB)
[K     |████████████████████████████████| 590 kB 5.4 MB/s 
[?25hCollecting tensorflow==2.5.0
  Downloading tensorflow-2.5.0-cp37-cp37m-manylinux2010_x86_64.whl (454.3 MB)
[K     |████████████████████████████████| 454.3 MB 6.0 kB/s 
Building wheels for collected packages: mesh-transformer, jax
  Building wheel for mesh-transformer (setup.py) ... [?25l[?25hdone
  Created wheel for mesh-transformer: filename=mesh_transformer-0.0.0-py3-none-any.whl size=26330 sha256=f395e1419ba67dd195411ad9e

## Setup Model


In [None]:
import os
import requests 
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

Sometimes the next step errors for some reason, just run it again ¯\\\_(ツ)\_/¯

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=665.0, style=ProgressStyle(description_…




Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes.

In [None]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

  warn("xmap is an experimental feature and probably has bugs!")


key shape (8, 2)
in shape (1, 2048)
dp 1
mp 8
Total parameters: 6053381344
read from disk/gcs in 329.802s


## Run Model

Finally, we are ready to infer with the model! The first sample takes around a minute due to compilation, but after that it should only take about 10 seconds per sample.

Feel free to mess with the different sampling parameters (top_p and temp), as well as the length of the generations (gen_len, causes a recompile when changed).

You can also change other things like per_replica_batch in the previous cells to change how many generations are done in parallel. A larger batch has higher latency but higher throughput when measured in tokens generated/s. This is useful for doing things like best-of-n cherry picking.

*Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.*

In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=300):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    print(f"completion done in {time.time() - start:06}s")
    return samples



In [None]:
prompt = 'Please summarize the article below. \
\n""" \
Microsoft in talks to buy TikTok \
Negotiations for ByteDance-owned social media group come as Trump threatens action \
Microsoft has held talks to acquire TikTok, whose Chinese owner ByteDance faces mounting pressure from the US government to sell the video sharing app or risk being blacklisted in the country, said people briefed on the matter.\
The approach to buy TikTok was at an early stage, and it was unclear whether Microsoft would succeed in taking it over, as the transaction faced multiple hurdles, said people familiar with the negotiation.\
The people also said multiple parties were interested in acquiring TikTok and the process remained in its preliminary stages. Microsoft and TikTok declined to comment.\
Takeover interest comes amid rising scrutiny of TikTok and ByteDance’s data-sharing practices, which US President Donald Trump’s re-election campaign ads have implied amount to Chinese spying on American users.\
Mr Trump said on Friday that he was considering “a couple of options” to address TikTok, including a ban. The Financial Times reported this month that the White House had considered placing ByteDance on the “entity list” which would effectively bar it from doing business with US companies.\
Separately, the Committee on Foreign Investment in the United States has been reviewing ByteDance’s 2017 purchase of Musical.ly, which paved the way for TikTok’s growth. Steven Mnuchin, Treasury secretary, said on Wednesday he would make a recommendation to Mr Trump by the end of the week.\
“The fears about TikTok are more likely to be answered if the company is fully acquired by a non-China-based entity than if ByteDance retains any ownership,” said Erik Gordon, a professor at the University of Michigan’s Ross School of Business. “If ByteDance is completely out, neither Treasury nor Cfius has anything to worry about.”\
It was not immediately clear how much TikTok would be worth in a sale, though its value has been estimated in the tens of billions of dollars. ByteDance has been valued as high as $140bn in private share transactions, according to one person familiar with the trades.\
Despite its main focus on the business technology markets, Microsoft had built a highly successful consumer operation around its Xbox gaming platform and could use TikTok to push deeper into a younger demographic, according to analysts.\
“They want to grow up with a younger audience,” said Brent Thill, an analyst at Jefferies. “They could introduce new services along the way, as they age up.” He said that the high engagement TikTok had achieved with many of its younger users had eaten into their time playing games, making it a natural complementary service for Microsoft to explore.\
It was a potential greenfield opportunity for Microsoft as it looks to expand into new markets, added Youssef Squali, an analyst at SunTrust.\
A purchase of Tik Tok would be likely to bring an abrupt end to Microsoft’s loose alliance with Facebook, which dates back to when it saw the social media company as an ally in its battle with Google in the search market.\
US investors led by General Atlantic and Sequoia Capital had also been discussing a buyout of TikTok in which ByteDance would retain a minority stake. The investors have held discussions with the Treasury about whether the buyout would satisfy US concerns about the app, the FT reported last week.\
General Atlantic and Sequoia declined to comment on Friday.\
Josh Hawley, a Republican senator from Missouri and frequent critic of TikTok, told the FT this week that a full divestment of the app would represent a “major step forward” but not go far enough to satisfy concerns about personal data collection.\
Earlier this week, Kevin Mayer, TikTok’s new chief executive, defended the company, saying it had a “commitment to accountability”. In his first public comments since joining from Disney in June, he said that without TikTok, US advertisers “would again be left with few choices”, adding: “TikTok has become the latest target, but we are not the enemy.” \
"""\
\nCould you please summarize the article above in three sentences?'

In [None]:
print(prompt)

Please summarize the article below. 
""" Microsoft in talks to buy TikTok Negotiations for ByteDance-owned social media group come as Trump threatens action Microsoft has held talks to acquire TikTok, whose Chinese owner ByteDance faces mounting pressure from the US government to sell the video sharing app or risk being blacklisted in the country, said people briefed on the matter.The approach to buy TikTok was at an early stage, and it was unclear whether Microsoft would succeed in taking it over, as the transaction faced multiple hurdles, said people familiar with the negotiation.The people also said multiple parties were interested in acquiring TikTok and the process remained in its preliminary stages. Microsoft and TikTok declined to comment.Takeover interest comes amid rising scrutiny of TikTok and ByteDance’s data-sharing practices, which US President Donald Trump’s re-election campaign ads have implied amount to Chinese spying on American users.Mr Trump said on Friday that he wa

In [None]:
#@title  { form-width: "300px" }
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}

context = " "

print(infer(top_p=top_p, temp=temp, gen_len=300, context=prompt)[0])

completion done in 47.6471312046051s
[1mPlease summarize the article below. 
""" Microsoft in talks to buy TikTok Negotiations for ByteDance-owned social media group come as Trump threatens action Microsoft has held talks to acquire TikTok, whose Chinese owner ByteDance faces mounting pressure from the US government to sell the video sharing app or risk being blacklisted in the country, said people briefed on the matter.The approach to buy TikTok was at an early stage, and it was unclear whether Microsoft would succeed in taking it over, as the transaction faced multiple hurdles, said people familiar with the negotiation.The people also said multiple parties were interested in acquiring TikTok and the process remained in its preliminary stages. Microsoft and TikTok declined to comment.Takeover interest comes amid rising scrutiny of TikTok and ByteDance’s data-sharing practices, which US President Donald Trump’s re-election campaign ads have implied amount to Chinese spying on American