<a href="https://colab.research.google.com/github/whoislimshady/Ai_movie/blob/main/Genrating_script_with_GPT_J_6B_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!apt install zstd

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd

!time tar -I zstd -xf step_383500_slim.tar.zstd

!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
The following NEW packages will be installed:
  zstd
0 upgraded, 1 newly installed, 0 to remove and 42 not upgraded.
Need to get 278 kB of archives.
After this operation, 1,141 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 zstd amd64 1.3.3+dfsg-2ubuntu1.2 [278 kB]
Fetched 278 kB in 1s (332 kB/s)
Selecting previously unselected package zstd.
(Reading database ... 155629 files and directories currently installed.)
Preparing to unpack .../zstd_1.3.3+dfsg-2ubuntu1.2_amd64.deb ...
Unpacking zstd (1.3.3+dfsg-2ubuntu1.2) ...
Setting up zstd (1.3.3+dfsg-2ubuntu1.2) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
--2022-05-31 06:27:40--  https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.ta

## Setup Model


In [None]:
import os
import requests 
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

Sometimes the next step errors for some reason, just run it again ¯\\\_(ツ)\_/¯

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes.

In [None]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    print(f"completion done in {time.time() - start:06}s")
    return samples

print(infer("""A BAD MIDDLE SCHOOL BAND PLAYS THE DISNEY LOGO THEME.
ONCE IT ENDS...
JOE (O.S.)
Alright! Let’s try something else.
Uh...from the top. Ready? One, two,
three...
INT. MIDDLE SCHOOL BAND ROOM.
JOE GARDNER, a passionate, well-dressed middle-aged man,
conducts an off-key middle school band. It’s painfully bad.
JOE
One, two, three, four! Stay on the
beat! Two, three four--that’s a C
Sharp, horns!
A TROMBONIST loses his trombone end, which lands on the floor
with a CLANK.
A TRUMPETER uses his horn to vacuum up M&Ms from the floor.
CALEB, a saxophonist, pretends to play while actually on his
iPHONE.
JOE
Two, three, I see you, Caleb!
Startled, Caleb tosses the phone into a neighboring student’s
sax.
JOE
(to another student)
Rachel, now you!
But Rachel lies across a few chairs.
RACHEL
Forgot my sax, Mr. G.
JOE
Okay, she forgot her sax! Aaand now-
- aaaaall you, Connie. Go for it!
Joe then motions to CONNIE, a Chinese American girl holding a
trombone. She’s his last hope.
Connie plays her solo, strong and passionate. Joe smiles.
But some of the other kids start giggling, and Connie’s
confidence (and playing) suddenly wilts.""")[0])

In [None]:


context = """A BAD MIDDLE SCHOOL BAND PLAYS THE DISNEY LOGO THEME.
ONCE IT ENDS...
JOE (O.S.)
Alright! Let’s try something else.
Uh...from the top. Ready? One, two,
three...
INT. MIDDLE SCHOOL BAND ROOM.
JOE GARDNER, a passionate, well-dressed middle-aged man,
conducts an off-key middle school band. It’s painfully bad.
JOE
One, two, three, four! Stay on the
beat! Two, three four--that’s a C
Sharp, horns!
A TROMBONIST loses his trombone end, which lands on the floor
with a CLANK.
A TRUMPETER uses his horn to vacuum up M&Ms from the floor.
CALEB, a saxophonist, pretends to play while actually on his
iPHONE.
JOE
Two, three, I see you, Caleb!
Startled, Caleb tosses the phone into a neighboring student’s
sax.
JOE
(to another student)
Rachel, now you!
But Rachel lies across a few chairs.
RACHEL
Forgot my sax, Mr. G.
JOE
Okay, she forgot her sax! Aaand now-
- aaaaall you, Connie. Go for it!
Joe then motions to CONNIE, a Chinese American girl holding a
trombone. She’s his last hope.
Connie plays her solo, strong and passionate. Joe smiles.
But some of the other kids start giggling, and Connie’s
confidence (and playing) suddenly wilts."""

print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])