In [1]:
import argparse
from contextlib import contextmanager
from itertools import chain, islice
import json
import math
from pathlib import Path
import random
import os
import sys
import zipfile

import accelerate
from datasets import load_dataset
import peft
import safetensors.torch as safetorch
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils import data
from tqdm import trange, tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from loguru import logger


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# autoreload import your package
%load_ext autoreload
%autoreload 2

from vae_llm_worldmodels.models.bigvae.bigvae import set_adapter, DecoderOnlyTransformerVAE, VAERouter


## Params

In [6]:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--batch_size", type=int, default=2, help="microbatch size")
parser.add_argument("--dropout", type=float, default=0.0, help="dropout rate")
parser.add_argument("--epochs", type=int, default=1, help="number of epochs")
parser.add_argument(
    "--gradient_accumulation_steps", type=int, default=1, help="gradient accumulation steps"
)
parser.add_argument(
    "--gradient_checkpointing",
    action="store_true",
    default=False,
    help="use gradient checkpointing",
)
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
parser.add_argument(
    "--model",
    type=str,
    # default="mistralai/Mistral-7B-v0.1",
    # default="yichunkuo/stablelm-3b-4e1t-gptq",
    default="stabilityai/stablelm-3b-4e1t",       
    # default="gpt2", 
    # default="mlabonne/gpt2-GPTQ-4bit",
    help="model name",
)
parser.add_argument("--context", type=int, default=2048, help="context window length")
parser.add_argument("--vae_context", type=int, default=64, help="vae embed context")
parser.add_argument("--output", type=Path, required=True, help="path to save adapter")
parser.add_argument("--rank", type=int, default=32, help="the lora rank")
parser.add_argument("--save_every", type=int, default=1000, help="save every n steps")
parser.add_argument("--start_from", type=str, help="start from existing lora")
parser.add_argument("--z_dim", type=int, default=768, help="the latent dimension")


argvs = """
--rank 16 
--context=96 
--vae_context=32 
--batch_size=1 
--output=./output/adapter 
"""
argvs = argvs.replace('\n', ' ').strip()
argv = [s.strip() for s in argvs.split(" ") if s and not s.startswith("#")]
print(argv)
args = parser.parse_args(argv)
args


['--rank', '16', '--context=96', '--vae_context=32', '--batch_size=1', '--output=./output/adapter']


Namespace(batch_size=1, dropout=0.0, epochs=1, gradient_accumulation_steps=1, gradient_checkpointing=False, lr=0.0001, model='stabilityai/stablelm-3b-4e1t', context=96, vae_context=32, output=PosixPath('output/adapter'), rank=16, save_every=1000, start_from=None, z_dim=768)

In [7]:
max_length = 32
tokenizer_args = dict(
    padding='max_length', max_length=max_length,
    truncation=True,
)


In [8]:
from transformers.utils.logging import _get_library_root_logger

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_VERBOSITY"] = "detail"

library_root_logger = _get_library_root_logger()
library_root_logger.propagate = True



## Load


In [9]:
accelerator = accelerate.Accelerator(
    mixed_precision="bf16", gradient_accumulation_steps=args.gradient_accumulation_steps
)
device = accelerator.device if accelerator.num_processes > 1 else "cuda:0"
is_main = accelerator.is_main_process

print = tqdm.external_write_mode()(logger.info)
print0 = accelerator.on_main_process(print)

if Path(args.model).exists():
    model_name = Path(args.model).resolve()
else:
    model_name = args.model

print0(f"Loading model: {model_name}")
with accelerator.main_process_first():
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map={"": device},
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16, 
        trust_remote_code=True
    )
    peft_config = peft.LoraConfig(
        peft.TaskType.CAUSAL_LM,
        inference_mode=False,
        r=args.rank,
        lora_alpha=8,
        lora_dropout=args.dropout,
        target_modules=[
            "self_attn.q_proj",
            "self_attn.k_proj",
            "self_attn.v_proj",
            "self_attn.o_proj",
            "mlp.gate_proj",
            "mlp.up_proj",
            "mlp.down_proj",
        ],
    )
    base_model_peft = peft.get_peft_model(base_model, peft_config)
    vae_model = DecoderOnlyTransformerVAE(
        base_model_peft, peft_config, device=device, z_dim=args.z_dim,
    )
    if args.start_from:
        vae_model.load_pretrained(args.start_from)
    base_model_peft.requires_grad_(False)
    vae_model.vae.requires_grad_(False)
    vae_model.vae.w_d.requires_grad_()
    router = VAERouter(base_model_peft, vae_model, device)
    if args.start_from:
        router.load_pretrained(args.start_from, is_trainable=True)
accelerator.wait_for_everyone()

router.train()
if args.gradient_checkpointing:
    router.model.gradient_checkpointing_enable()
    router.model.enable_input_require_grads()

if is_main:
    router.model.print_trainable_parameters()

router.model.set_adapter("router")
opt = optim.Adam(router.model.parameters(),
                     lr=args.lr,
                     betas=(0.9, 0.99))
accelerator.wait_for_everyone()


[32m2023-11-12 09:44:45.920[0m | [1mINFO    [0m | [36mcontextlib[0m:[36minner[0m:[36m81[0m - [1mLoading model: stabilityai/stablelm-3b-4e1t[0m
Using pad_token, but it is not set yet.


TypeError: VAERouter.__init__() takes from 3 to 4 positional arguments but 5 were given

https://github.com/JD-P/minihf/blob/adavae-moe/train_vae_router.py#L277


In [None]:
# prepare dataset
input_ids_all, attention_mask_all = [], []
for shard_name in os.listdir(args.preprocessed):
    data_path = os.path.join(args.preprocessed, shard_name)
    data_file = safetorch.load_file(data_path)
    input_ids = torch.split(data_file["input_ids"], args.context, dim=1)
    attention_mask = torch.split(data_file["attention_mask"], args.context, dim=1)
    if input_ids[-1].shape[1] != args.context:
        input_ids = input_ids[:-1]
        attention_mask = attention_mask[:-1]
    input_ids_all.extend(input_ids)
    attention_mask_all.extend(attention_mask)
del data_file, input_ids, attention_mask
input_ids_all = torch.cat(input_ids_all)
attention_mask_all = torch.cat(attention_mask_all)
valid_indices = attention_mask_all.sum(dim=1) == args.context
input_ids_all = input_ids_all[valid_indices]
attention_mask_all = attention_mask_all[valid_indices]
del valid_indices

preprocessed = data.TensorDataset(input_ids_all, attention_mask_all)

dataloader = data.DataLoader(
    preprocessed,
    batch_size=args.batch_size,
    shuffle=True,
    drop_last=True,
)

router, opt, dataloader = accelerator.prepare(router, opt, dataloader)


In [None]:
from vae_llm_worldmodels.utils import cosine_warmup


@torch.no_grad()
@torch.cuda.amp.autocast(dtype=torch.bfloat16)
def demo(model, input_ids, attention_mask, n_tokens):
    """inference."""
    bs = min(input_ids.shape[0], 2)
    n_outputs = 2
    tau = 0.8

    index = random.randrange(args.context - (args.vae_context * 2))
    context_ids = input_ids[:,:index]
    context_mask = attention_mask[:,:index]
    embed_ids = input_ids[:,index:index + args.vae_context]
    embed_mask = attention_mask[:,index:index + args.vae_context]
    target_ids = input_ids[:,index:index + args.vae_context * 2]
    target_mask = input_ids[:,index:index + args.vae_context * 2]

    in_texts = [tokenizer.decode(toks, skip_special_tokens=True)
                for toks in torch.cat([context_ids, embed_ids], dim=1)]
    mean = model.encode(embed_ids[:bs], embed_mask[:bs])
    z = model.vae.vae.sample(mean.repeat_interleave(n_outputs, 0), tau=tau)
    context_ids = context_ids[:bs].repeat_interleave(n_outputs, 0)
    context_mask = context_mask[:bs].repeat_interleave(n_outputs, 0)
    # empty = z.new_zeros([z.shape[0], 0], dtype=torch.long)
    output_ids = model.generate(z, context_ids, context_mask, n_tokens, tau=tau)
    out_texts = [tokenizer.decode(toks, skip_special_tokens=True) for toks in output_ids]
    out_texts = list(batched(out_texts, n_outputs))
    print("======")
    for in_text, out_batch in zip(in_texts, out_texts):
        print("=== Input ===")
        print(in_text)
        print("=== Outputs ===")
        for i, out_text in enumerate(out_batch):
            print(out_text)
            if i < len(out_batch) - 1:
                print("===")
    print("======")

def save():
    print0(f"### Saving model to {args.output}", file=sys.stderr)
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        unwrapped_model = accelerator.unwrap_model(router)
        unwrapped_model.save_pretrained(args.output)
        state_obj = {"step": i, "last_kl_weight": kl_sched(i)}
        with open(args.output / "state.json", "w") as f:
            json.dump(state_obj, f)


In [None]:
# train
i = 0
kl_sched = cosine_warmup(5000, 0.01)


accelerator.wait_for_everyone()
for epoch in trange(args.epochs, disable=not is_main):
    for input_ids, attention_mask in tqdm(dataloader, disable=not is_main):
        input_ids = input_ids.long()
        if is_main and i % 100 == 0:
            demo(accelerator.unwrap_model(router), input_ids, attention_mask, args.vae_context)
            pass
        with accelerator.accumulate(router):
            index = random.randrange(args.context - (args.vae_context * 2))
            context_ids = input_ids[:,:index]
            context_mask = attention_mask[:,:index]
            embed_ids = input_ids[:,index:index + args.vae_context]
            embed_mask = attention_mask[:,index:index + args.vae_context]
            target_ids = input_ids[:,index:index + args.vae_context * 2]
            target_mask = attention_mask[:,index:index + args.vae_context * 2]

            drop_mask = torch.rand([context_ids.shape[0], 1], device=device) < 0.5
            context_ids = torch.where(drop_mask, torch.zeros_like(context_ids), context_ids)
            context_mask = torch.where(drop_mask, torch.zeros_like(context_mask), context_mask)
            outputs = router(embed_ids, embed_mask,
                                target_ids[:,:-1], target_mask[:,:-1],
                                context_ids, context_mask)
            rec_losses = F.cross_entropy(
                outputs.logits[:, -args.vae_context * 2:].transpose(-1, -2),
                target_ids,
                reduction="none",
            )
            n_toks = target_mask.sum()
            rec_loss = torch.sum(rec_losses * target_mask, dtype=torch.float32) / n_toks
            # kl_loss = torch.sum(mean**2 / 2, dtype=torch.float32) * kl_sched(i) / n_toks
            loss = rec_loss # + kl_loss

            # accelerator.backward(loss, inputs=list(p for p in accelerator.unwrap_model(router).model.parameters() if p.requires_grad))
            accelerator.backward(loss)
            # for n, p in router.named_parameters():
            #     if p.grad is not None:
            #         grad_norm = torch.norm(p.grad, dtype=torch.float32)
            #         if grad_norm != 0:
            #             print(f"{n}: {grad_norm:g}", file=sys.stderr)
            opt.step()
            opt.zero_grad()

            loss_global, rec_global = accelerator.reduce(
                (loss, rec_loss), "mean"
            )
            print0(
                f"epoch: {epoch}, step: {i}, loss: {loss_global.item():g}, rec: {rec_global.item():g}",
                file=sys.stderr,
            )
            i += 1

            if i % args.save_every == 0:
                save()

    save()
