<div>
<img src="https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67" width="40%">  
</div>

# Distributed Bloom for Text Generation using Prompt Tuning

In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.

We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.

First, we have to prepare all dependencies.

In [3]:
!pip install transformers==4.21.3

Collecting transformers==4.21.3
  Using cached transformers-4.21.3-py3-none-any.whl (4.7 MB)
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.23.1
    Uninstalling transformers-4.23.1:
      Successfully uninstalled transformers-4.23.1
Successfully installed transformers-4.21.3


In [1]:
import os
import sys
sys.path.insert(0, "../../../petals")
 
import torch
import transformers
import wandb
from datasets import load_dataset
from tqdm import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_scheduler

# Import a Petals model
from src.client.remote_model import DistributedBloomForCausalLM


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 6.1
CUDA SETUP: Detected CUDA version 110
CUDA SETUP: Loading binary /home/jagiljazev/personalized-chat-bot/env/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda110_nocublaslt.so...


NVIDIA GeForce RTX 3060 with CUDA capability sm_86 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70.
If you want to use the NVIDIA GeForce RTX 3060 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



Let's set some hyperparameters for training:

In [2]:
MODEL_NAME = "bigscience/test-bloomd-6b3" # select model you like
INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"] # add your peers adresses here, like "/ip4/192.168.1.2/tcp/31000/p2p/Qma...."
NUM_PREFIX_TOKENS = 16
DEVICE = 'cpu'
BATCH_SIZE = 4
LR = 1e-2
WEIGHT_DECAY = 0.0
NUM_SAMPLES = 1000
SEED = 42
MODEL_MAX_LENGTH = 256
TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] 

Prepare tokenizer and distributed model, connect it to servers.

In [3]:
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
tokenizer.padding_side = 'right'
tokenizer.model_max_length = MODEL_MAX_LENGTH
model = DistributedBloomForCausalLM.from_pretrained(
    MODEL_NAME, 
    initial_peers=INITIAL_PEERS, 
    pre_seq_len=NUM_PREFIX_TOKENS, 
    tuning_mode=TUNING_MODE
).to(DEVICE)

Downloading tokenizer_config.json:   0%|          | 0.00/267 [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/13.8M [00:00<?, ?B/s]

Downloading special_tokens_map.json:   0%|          | 0.00/96.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.91G [00:00<?, ?B/s]

Oct 22 14:15:13.116 [WARN] [/home/jagiljazev/personalized-chat-bot/notebooks/gilyazev/../../../petals/src/client/remote_sequential.py.__init__:34] RemoteSequential is in active development; expect adventures
Some weights of DistributedBloomForCausalLM were not initialized from the model checkpoint at bigscience/test-bloomd-6b3 and are newly initialized: ['lm_head.word_embeddings.weight', 'prompt_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization.

In [4]:
dataset = load_dataset("bavard/personachat_truecased")


def chunking(examples):
    inputs = [
        "\n-----\n".join(history) + "\n-----\n" + candidate
        for history, candidates in zip(examples["history"], examples["candidates"])
        for candidate in candidates
    ]
    return {"chunks": inputs}


def tokenize(examples):
    outputs = {
        "input_ids": tokenizer(examples["chunks"], padding='max_length', truncation=True)["input_ids"]
    }
    outputs["labels"] = outputs["input_ids"]
    return outputs


tokenized_datasets = (
    dataset
        .map(chunking, batched=True, remove_columns=dataset["train"].column_names)
        .map(tokenize, batched=True, remove_columns=["chunks"])
)


tokenized_datasets.set_format("torch")
train_dataset = tokenized_datasets["train"].shuffle(seed=SEED)
train_dataloader = DataLoader(
    train_dataset.select(list(range(NUM_SAMPLES))),
    shuffle=True,
    batch_size=BATCH_SIZE,
    drop_last=True,
)

Downloading builder script:   0%|          | 0.00/3.73k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/5.47k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/4.00k [00:00<?, ?B/s]

Oct 22 14:15:24.313 [WARN] [datasets.builder._create_builder_config:427] No config specified, defaulting to: personachat_truecased/full


Downloading and preparing dataset personachat_truecased/full (download: 195.70 MiB, generated: 210.99 MiB, post-processed: Unknown size, total: 406.69 MiB) to /home/jagiljazev/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/193M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/131438 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7801 [00:00<?, ? examples/s]

Dataset personachat_truecased downloaded and prepared to /home/jagiljazev/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/132 [00:00<?, ?ba/s]

  0%|          | 0/8 [00:00<?, ?ba/s]

  0%|          | 0/2629 [00:00<?, ?ba/s]

  0%|          | 0/157 [00:00<?, ?ba/s]

Before setting up optimizers, check the model parameters that will be trained.

In [5]:
for n, p in model.named_parameters():
    if p.requires_grad:
        print(n, p.requires_grad, p.device)

transformer.prompt_embeddings.weight True cpu


The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler.

In [6]:
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)
)

Let's initialize wandb for logging and start the training loop!

In [None]:
# wandb.init(
#     project="bloom-personachat",
#     config={
#         "num_samples": NUM_SAMPLES,
#         "batch_size": BATCH_SIZE,
#         "learning_rate": LR,
#         "weight_decay": WEIGHT_DECAY,
#         "num_prefix_tokens": NUM_PREFIX_TOKENS,
#         "model_name": MODEL_NAME,
#         "seed": SEED,
#     }
# )
loss_hist = []
print('wandb initialized\n')

for batch in tqdm(train_dataloader):
    batch = {k: v.to(DEVICE) for k, v in batch.items()}

    model.train()
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()

    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    print(f"Train Loss: {loss}")
    loss_hist.append(loss)

#     wandb.log({"Train Loss": loss})

wandb initialized



  0%|▍                                                                                                       | 1/250 [00:34<2:21:10, 34.02s/it]

Train Loss: 5.633817195892334


Try to talk with the trained model! Submit an empty input to stop the execution.


__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster "interactive" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica.

In [None]:
MAX_TOKENS = 16
TOP_K = 100
TEMPERATURE = 0.6
dialog = ""

while True:
    user_phrase = input()
    if len(user_phrase) == 0:
        break
    dialog += f"{user_phrase}\n-----\n"
    inputs = tokenizer([dialog], return_tensors='pt')['input_ids']
    outputs = model.generate(
        inputs,
        temperature=TEMPERATURE,
        do_sample=True,
        top_k=TOP_K,
        eos_token_id=tokenizer.eos_token_id,
        max_new_tokens=MAX_TOKENS,
    )
    bloom_answer = tokenizer.batch_decode(outputs)[0]
    bloom_answer = bloom_answer[len(dialog):].split("\n")[0]
    print(bloom_answer)
    dialog += f"{bloom_answer}\n-----\n"