In [2]:
from google.colab import drive
drive.mount('/content/drive')


In [2]:
# !apt install zstd
# # zstd is a fast lossless compression algorithm and data compression tool, with command line syntax similar to gzip (1) and xz (1). It is based on the LZ77 family, with further FSE & huff0 entropy stages
# # the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory

# # !time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
# # !time tar -I zstd -xf step_383500_slim.tar.zstd

# Download updated "Mesh Tranformer Jax" repo
!git clone https://github.com/raogsm7/mesh-transformer-jax.git

#install packages and librararies from requirements
!pip install -r mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13, tensorflow 2.5.0, colan has different version
!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0

Cloning into 'mesh-transformer-jax'...
remote: Enumerating objects: 812, done.[K
remote: Counting objects: 100% (443/443), done.[K
remote: Compressing objects: 100% (141/141), done.[K
remote: Total 812 (delta 369), reused 322 (delta 298), pack-reused 369[K
Receiving objects: 100% (812/812), 233.16 KiB | 3.19 MiB/s, done.
Resolving deltas: 100% (535/535), done.
Collecting git+https://github.com/deepmind/dm-haiku (from -r mesh-transformer-jax/requirements.txt (line 8))
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-etemj97x
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-etemj97x
Collecting git+https://github.com/EleutherAI/lm-evaluation-harness/ (from -r mesh-transformer-jax/requirements.txt (line 9))
  Cloning https://github.com/EleutherAI/lm-evaluation-harness/ to /tmp/pip-req-build-gj2wnpjx
  Running command git clone -q https://github.com/EleutherAI/lm-evaluation-harness/ /tmp/pip-req-build-gj2wnpjx
Collecting tqdm~=4.

Processing ./mesh-transformer-jax
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Collecting jax==0.2.12
  Downloading jax-0.2.12.tar.gz (590 kB)
[K     |████████████████████████████████| 590 kB 5.3 MB/s 
[?25hCollecting tensorflow==2.5.0
  Downloading tensorflow-2.5.0-cp37-cp37m-manylinux2010_x86_64.whl (454.3 MB)
[K     |████████████████████████████████| 454.3 MB 9.5 kB/s 
Building wheels for collected packages: mesh-transformer, jax
  Building wheel for mesh-transformer (setup.py) ... [?25l[?25hdone
  Created wheel for mesh-transformer: filename=mesh_transformer-0.0.0-py3-none-any.whl size=31968 sha256=36591f04326b617c77a6cf9be

In [1]:
import os
import requests 
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend   = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

print (colab_tpu_addr)
print("All devices: ", os.environ['COLAB_TPU_ADDR'])

10.94.189.50
All devices:  10.94.189.50:8470


In [2]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

In [3]:
params = {
  "layers": 3,
  "d_model": 512,
  "n_heads": 8,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 256,
  "cores_per_replica": 1,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]
print(per_replica_batch,cores_per_replica,seq)

params["sampler"] = nucleaus_sample
print(params)
# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)
print(params)
print("device_count", jax.device_count())
print("jax.device", jax.devices())
print(cores_per_replica)
mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
print("mesh_shape", mesh_shape)
devices = np.array(jax.devices()).reshape(mesh_shape)
print("devices", devices)
print("maps.Mesh",maps.Mesh(devices, ('dp', 'mp')))
maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))
print(maps.thread_resources.env)


1 1 256
{'layers': 3, 'd_model': 512, 'n_heads': 8, 'n_vocab': 50400, 'norm': 'layernorm', 'pe': 'rotary', 'pe_rotary_dims': 64, 'seq': 256, 'cores_per_replica': 1, 'per_replica_batch': 1, 'sampler': <function nucleaus_sample at 0x7f146ff7edd0>}
{'layers': 3, 'd_model': 512, 'n_heads': 8, 'n_vocab': 50400, 'norm': 'layernorm', 'pe': 'rotary', 'pe_rotary_dims': 64, 'seq': 256, 'cores_per_replica': 1, 'per_replica_batch': 1, 'sampler': <function nucleaus_sample at 0x7f146ff7edd0>, 'optimizer': GradientTransformation(init=<function scale.<locals>.init_fn at 0x7f146ff85200>, update=<function scale.<locals>.update_fn at 0x7f146fdc2050>)}
device_count 8
jax.device [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,0,0), core_on_chip=0), Tp

In [8]:
############ FINETUNING ##################

# run finetune program using config and tune model path, finetuned model path may contain or may not contain checkpoints

# with pre-existing checkpoints at tune-model-path
!python3 /content/mesh-transformer-jax/train_ft_6jb.py --config=/content/mesh-transformer-jax/configs/6B_roto_256_ft.json --tune-model-path=/content/drive/MyDrive/colab_data/ckpt_dir/

# # without pre-existing checkpoints
# !python3 /content/mesh-transformer-jax/train_ft_6jb.py --config=/content/mesh-transformer-jax/configs/6B_roto_256_ft.json

# # for testing new parameters 
# !python3 /content/mesh-transformer-jax/train_ft_6jb.py --config=/content/mesh-transformer-jax/configs/6B_roto_256_ft.json --tune-model-path=/content/mesh-transformer-jax/data/ckpt_dir/ --fresh-opt=True


bucket . model_dir /content/drive/MyDrive/colab_data/finetuned_ckpt_dir layers 3 d_model 512 n_heads 8 n_vocab 50400 seq 256 norm layernorm val_batches 1 val_every 10 ckpt_every 10 keep_every 10 total_steps 10 total_steps 10
jax devices: 8
jax runtime initialized in 8.93245s
`--tune_model_path` passed: we are beginning a fine-tuning run
path to load checkpoint from: /content/drive/MyDrive/colab_data/ckpt_dir/
setting up datasets
initializing network
  warn("xmap is an experimental feature and probably has bugs!")
key shape (1, 2)
in shape (8, 256)
dp 8
mp 1
Total parameters: 61109472
loading network
network loaded in 9.53674e-07s
compiling train fn
start 1631902351.7511172
2021-09-17 18:12:31.841829: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
network <mesh_transformer.transformer_shard.CausalTransformer object at 0x7faee4739750>
inputs {'obs': array([[[  464, 44799,  2839, ...,   339,  1139,    1

In [11]:
# ############## TRAINING #################
# device_name = os.environ['COLAB_TPU_ADDR']
# print(device_name)
# # run train program using config and empty model path, model path does not contain checkpoints
# !python3 /content/mesh-transformer-jax/train_nt.py --config=/content/mesh-transformer-jax/configs/6B_roto_256_t.json --tpu=device_name --tpu_region=us-central1-a --preemptible



In [4]:
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

In [10]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica
print("total_batch",total_batch)
network = CausalTransformer(params)
print("****network",network)
print("dev_shape", devices.shape[1])
network.state = read_ckpt(network.state, "/content/drive/MyDrive/colab_data/ckpt_dir/step_10/", devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

total_batch 8
key shape (1, 2)
in shape (8, 256)
dp 8
mp 1


  warn("xmap is an experimental feature and probably has bugs!")


Total parameters: 61109472
****network <mesh_transformer.transformer_shard.CausalTransformer object at 0x7f141f9569d0>
dev_shape 1


In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")
# 
    print(f"completion done in {time.time() - start:06}s")
    return samples
textInput = input("Enter your text: ")
# print(infer("Albert Einstein is")[0])
print(infer(textInput)[0])

In [None]:
#@title  { form-width: "300px" }
top_p = 1 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 0.7 #@param {type:"slider", min:0, max:1, step:0.1}
context = input("Enter your text: ")
# context = """In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."""
# context = "OneBOT CRM is chaning the world. It is coupled with AI and Deep Learning making a difference for future businesses"

print(infer(top_p=top_p, temp=temp, gen_len=64, context=context)[0])

In [12]:
drive.flush_and_unmount()
print('All changes made in this colab session should now be visible in Drive.')