In [1]:
#@title Load checkpoint
path_to_checkpoint = "/content/drive/MyDrive/pytorch_model.bin"  #@param {type:"string"}
checkpoint_type = "GPT-Neo-2.7B"  #@param ["GPT-Neo-1.3B", "GPT-Neo-2.7B"]
import os
from termcolor import colored
from IPython.display import clear_output

from google.colab import drive
drive.mount('/content/drive/')
import torch
print("Reading from checkpoint...")
torch_checkpoint = torch.load(path_to_checkpoint, map_location='cpu')
print("Done.")

if checkpoint_type == "GPT-Neo-1.3B":
    total_shards = 8
    d_model = 2048
    layers = 24
else:
    total_shards = 4
    d_model = 2560
    layers = 32
for i in range(total_shards):
    os.makedirs(f"jax_checkpoint/shard_{i}")
pieces = 16

Mounted at /content/drive/
Reading from checkpoint...
Done.


In [2]:
#@title Convert checkpoint to be JAX-compatible { display-mode: "form" }
from termcolor import colored
from IPython.display import clear_output
import torch
import numpy as np
import jax.numpy as jnp

def reshard_reverse(x, old_shape, is_shard_bias=False):
    if len(x.shape) == 1:
        assert False
        out = x[0:1]

    elif len(x.shape) == 2:
        #print(f"LN/bias")
        if old_shape[1] == x.shape[1]:
            #print("LN")
            if not is_shard_bias:
                out = np.tile(x[0:1], (total_shards, 1))
            else:
                #print("shard bias")
                out = np.tile(x[0:1], (total_shards, 1)) / total_shards
        else:
            #print("bias")
            out = x.reshape(old_shape)

    elif len(x.shape) == 3:
        if x.shape[0] * x.shape[2] == old_shape[2]:
            #print("case 1")
            out = x.reshape(old_shape)
        elif x.shape[0] * x.shape[1] == old_shape[1]:
            #print("case 2")
            out = jnp.transpose(x.reshape((old_shape[1], old_shape[0], old_shape[2])), (1, 0, 2))
        else:
            raise Exception(f"unimplemented, {x.shape}, {old_shape}")
    else:
        raise Exception(f"unimplemented, {x}")
    #flattened, structure = jax.tree_flatten(out)
    #return flattened
    return out

def get_old_shape(t, dim=2):
    if len(t.shape) == 2:
        shard_shape = t.shape
        if dim == 1:
            assert shard_shape[0] % total_shards == 0
            return (shard_shape[0] // total_shards, shard_shape[1])
        elif dim == 2:
            assert shard_shape[1] % total_shards == 0
            return (shard_shape[0], shard_shape[1] // total_shards)
        else:
            raise ValueError(f"unsupported dim {dim}")
    if len(t.shape) == 1:
        assert t.shape[0] % total_shards == 0
        return (t.shape[0] // total_shards,)
    else:
        raise ValueError(f"unsupported shape {t.shape}")


def split(a, n):
    k, m = divmod(len(a), n)
    return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))

def save(cpu_flattened):
    for i in range(total_shards):
        cpu_flattened_chunked = split(cpu_flattened, pieces)
        for j, chunk in enumerate(cpu_flattened_chunked):
            with open(f"jax_checkpoint/shard_{i}/{j}.npz", "wb") as f:
                np.savez(f, *map(lambda c: c[i], chunk))


transforms = [
    ("transformer.wpe.weight", False, 2),
    ("transformer.wte.weight", False, 1)
]

checkpoint = []

layer_names = sorted(map(str, range(layers)))
for layer in layer_names:
    transforms.extend([
        (f"transformer.h.{layer}.attn.attention.q_proj.weight", False, 2),
        (f"transformer.h.{layer}.attn.attention.v_proj.weight", False, 2),
        (f"transformer.h.{layer}.attn.attention.k_proj.weight", False, 2),
        (f"transformer.h.{layer}.attn.attention.out_proj.bias", True, None),
        (f"transformer.h.{layer}.attn.attention.out_proj.weight", False, 1),
        (f"transformer.h.{layer}.mlp.c_fc.bias", False, 1),
        (f"transformer.h.{layer}.mlp.c_fc.weight", False, 2),
        (f"transformer.h.{layer}.mlp.c_proj.bias", True, None),
        (f"transformer.h.{layer}.mlp.c_proj.weight", False, 1),
        (f"transformer.h.{layer}.ln_1.bias", False, None),
        (f"transformer.h.{layer}.ln_1.weight", False, None),
        (f"transformer.h.{layer}.ln_2.bias", False, None),
        (f"transformer.h.{layer}.ln_2.weight", False, None),
    ])
transforms.extend([
    ("transformer.ln_f.bias", False, None),
    ("transformer.ln_f.weight", False, None),
])

for i in range(len(transforms)):
    transform = transforms.pop(0)

    params = torch_checkpoint[transform[0]]

    # Pad input and output embeddings with 0 at the bottom to have 50400 rows
    # instead of 50257 rows (the padding value doesn't have to be 0, it doesn't
    # even have to be a constant value; the only thing the padding affects is
    # it adds junk logits to the end of the logits array the transformer returns
    # without affecting the other logits)
    if transform[0] in ("transformer.wte.weight", "lm_head.weight"):
        params = torch.cat((params, torch.zeros(143, params.shape[1])))
    
    # torch.nn.Linear uses a transposed version of the equivalent tensor that
    # haiku.Linear uses, so we have to un-transpose the tensor first
    if not any(s in transform[0] for s in ("wte", "wpe")):
        params = params.T

    if transform[2] is not None:
        old_shape = (total_shards,) + get_old_shape(params, transform[2])
    else:
        old_shape = (total_shards, params.shape[0],)
    print(f"< [{transform[0]}] {params.shape} to {old_shape}")

    params = np.asarray(params[None], dtype=jnp.bfloat16)
    params = reshard_reverse(params, old_shape, is_shard_bias=transform[1])

    if np.isnan(params).any() or np.isinf(params).any():
        raise ValueError(f"bfloat16 overflow/underflow")

    print(f"> [{transform[0]}] {params.shape}")
    assert params.shape == old_shape
    checkpoint.append(params)

# Append the checkpoint step number (can be set to an arbitrary value, in this
# case 0, as long as we're only using inference and not training the model)
checkpoint.append(np.zeros(total_shards, dtype=np.int32))

print("saving")
save(checkpoint)
del checkpoint
print(colored("DONE! The JAX checkpoint is now stored at /content/jax_checkpoint", "green"))

< [transformer.wpe.weight] torch.Size([2048, 2560]) to (4, 2048, 640)




> [transformer.wpe.weight] (4, 2048, 640)
< [transformer.wte.weight] torch.Size([50400, 2560]) to (4, 12600, 2560)
> [transformer.wte.weight] (4, 12600, 2560)
< [transformer.h.0.attn.attention.q_proj.weight] torch.Size([2560, 2560]) to (4, 2560, 640)
> [transformer.h.0.attn.attention.q_proj.weight] (4, 2560, 640)
< [transformer.h.0.attn.attention.v_proj.weight] torch.Size([2560, 2560]) to (4, 2560, 640)
> [transformer.h.0.attn.attention.v_proj.weight] (4, 2560, 640)
< [transformer.h.0.attn.attention.k_proj.weight] torch.Size([2560, 2560]) to (4, 2560, 640)
> [transformer.h.0.attn.attention.k_proj.weight] (4, 2560, 640)
< [transformer.h.0.attn.attention.out_proj.bias] torch.Size([2560]) to (4, 2560)
> [transformer.h.0.attn.attention.out_proj.bias] (4, 2560)
< [transformer.h.0.attn.attention.out_proj.weight] torch.Size([2560, 2560]) to (4, 640, 2560)
> [transformer.h.0.attn.attention.out_proj.weight] (4, 640, 2560)
< [transformer.h.0.mlp.c_fc.bias] torch.Size([10240]) to (4, 2560)
> [tra