<div>
<img src="https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67" width="40%">  
</div>

# Distributed Bloom for Text Generation using Prompt Tuning

In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.

We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.

First, we have to prepare all dependencies.

In [20]:
import os
import sys
sys.path.insert(0, "../../../petals")
 
import torch
import transformers
import wandb
from datasets import load_dataset
from tqdm import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_scheduler
from torch import nn

# Import a Petals model
from src.client.remote_model import DistributedBloomForCausalLM

Let's set some hyperparameters for training:

In [2]:
MODEL_NAME = "bigscience/test-bloomd-6b3" # select model you like
INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"] # add your peers adresses here, like "/ip4/192.168.1.2/tcp/31000/p2p/Qma...."
NUM_PREFIX_TOKENS = 16
DEVICE = 'cpu'
BATCH_SIZE = 4
LR = 1e-2
WEIGHT_DECAY = 0.0
NUM_SAMPLES = 1000
SEED = 42
MODEL_MAX_LENGTH = 256
TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] 

Prepare tokenizer and distributed model, connect it to servers.

In [3]:
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
tokenizer.padding_side = 'right'
tokenizer.model_max_length = MODEL_MAX_LENGTH
model = DistributedBloomForCausalLM.from_pretrained(
    MODEL_NAME, 
    initial_peers=INITIAL_PEERS, 
    pre_seq_len=NUM_PREFIX_TOKENS, 
    tuning_mode=TUNING_MODE
).to(DEVICE)

Oct 28 12:56:03.927 [WARN] [/home/jagiljazev/personalized-chat-bot/notebooks/gilyazev/../../../petals/src/client/remote_sequential.py.__init__:34] RemoteSequential is in active development; expect adventures
Some weights of DistributedBloomForCausalLM were not initialized from the model checkpoint at bigscience/test-bloomd-6b3 and are newly initialized: ['lm_head.word_embeddings.weight', 'prompt_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization.

In [4]:
dataset = load_dataset("bavard/personachat_truecased")


def chunking(examples):
    inputs = [
        "\n-----\n".join(history) + "\n-----\n" + candidate
        for history, candidates in zip(examples["history"], examples["candidates"])
        for candidate in candidates
    ]
    return {"chunks": inputs}


def tokenize(examples):
    outputs = {
        "input_ids": tokenizer(examples["chunks"], padding='max_length', truncation=True)["input_ids"]
    }
    outputs["labels"] = outputs["input_ids"]
    return outputs


tokenized_datasets = (
    dataset
        .map(chunking, batched=True, remove_columns=dataset["train"].column_names)
        .map(tokenize, batched=True, remove_columns=["chunks"])
)


tokenized_datasets.set_format("torch")
train_dataset = tokenized_datasets["train"].shuffle(seed=SEED)
train_dataloader = DataLoader(
    train_dataset.select(list(range(NUM_SAMPLES))),
    shuffle=True,
    batch_size=BATCH_SIZE,
    drop_last=True,
)

Oct 28 12:57:16.929 [WARN] [datasets.builder._create_builder_config:427] No config specified, defaulting to: personachat_truecased/full
Oct 28 12:57:16.987 [WARN] [datasets.builder.download_and_prepare:739] Found cached dataset personachat_truecased (/home/jagiljazev/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638)


  0%|          | 0/2 [00:00<?, ?it/s]

Oct 28 12:57:17.053 [WARN] [datasets.arrow_dataset._map_single:2793] Loading cached processed dataset at /home/jagiljazev/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638/cache-5ecae882ebbd418d.arrow
Oct 28 12:57:21.147 [WARN] [datasets.arrow_dataset._map_single:2793] Loading cached processed dataset at /home/jagiljazev/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638/cache-7000c64a1a527e4d.arrow
Oct 28 12:57:21.478 [WARN] [datasets.arrow_dataset._map_single:2793] Loading cached processed dataset at /home/jagiljazev/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638/cache-76265556d7dc8064.arrow
Oct 28 12:57:38.196 [WARN] [datasets.arrow_dataset._map_single:2793] Loading cached processed dataset at /home/jagiljazev/.cache/huggingface/d

Before setting up optimizers, check the model parameters that will be trained.

In [5]:
for n, p in model.named_parameters():
    if p.requires_grad:
        print(n, p.requires_grad, p.device)

transformer.prompt_embeddings.weight True cpu


The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler.

In [6]:
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)
)

Let's initialize wandb for logging and start the training loop!

In [18]:
for batch in tqdm(train_dataloader):
    batch = {k: v.to(DEVICE) for k, v in batch.items()}
    print(batch)
    break

  0%|          | 0/250 [00:00<?, ?it/s]

{'input_ids': tensor([[  1270, 158204,  84148,  ...,      3,      3,      3],
        [ 47569,   1130,  25008,  ...,      3,      3,      3],
        [  1270, 158204,  84148,  ...,      3,      3,      3],
        [  1270, 158204,  84148,  ...,      3,      3,      3]]), 'labels': tensor([[  1270, 158204,  84148,  ...,      3,      3,      3],
        [ 47569,   1130,  25008,  ...,      3,      3,      3],
        [  1270, 158204,  84148,  ...,      3,      3,      3],
        [  1270, 158204,  84148,  ...,      3,      3,      3]])}





In [14]:
dataset.map(chunking, batched=True, remove_columns=dataset["train"].column_names)

Oct 28 14:52:25.194 [WARN] [datasets.arrow_dataset._map_single:2793] Loading cached processed dataset at /home/jagiljazev/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638/cache-5ecae882ebbd418d.arrow
Oct 28 14:52:25.244 [WARN] [datasets.arrow_dataset._map_single:2793] Loading cached processed dataset at /home/jagiljazev/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638/cache-7000c64a1a527e4d.arrow


DatasetDict({
    train: Dataset({
        features: ['chunks'],
        num_rows: 2628760
    })
    validation: Dataset({
        features: ['chunks'],
        num_rows: 156020
    })
})

In [16]:
dataset.map(chunking, batched=True, remove_columns=dataset["train"].column_names)['train']['chunks']

Oct 28 14:52:43.547 [WARN] [datasets.arrow_dataset._map_single:2793] Loading cached processed dataset at /home/jagiljazev/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638/cache-5ecae882ebbd418d.arrow
Oct 28 14:52:43.596 [WARN] [datasets.arrow_dataset._map_single:2793] Loading cached processed dataset at /home/jagiljazev/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638/cache-7000c64a1a527e4d.arrow


["Hi, how are you doing? I'm getting ready to do some cheetah chasing to stay in shape.\n-----\nMy mom was single with 3 boys, so we never left the projects.",
 "Hi, how are you doing? I'm getting ready to do some cheetah chasing to stay in shape.\n-----\nI try to wear all black every day. It makes me feel comfortable.",
 "Hi, how are you doing? I'm getting ready to do some cheetah chasing to stay in shape.\n-----\nWell nursing stresses you out so I wish luck with sister.",
 "Hi, how are you doing? I'm getting ready to do some cheetah chasing to stay in shape.\n-----\nYeah just want to pick up Nba nfl getting old.",
 "Hi, how are you doing? I'm getting ready to do some cheetah chasing to stay in shape.\n-----\nI really like Celine Dion. What about you?",
 "Hi, how are you doing? I'm getting ready to do some cheetah chasing to stay in shape.\n-----\nNo. I live near farms.",
 "Hi, how are you doing? I'm getting ready to do some cheetah chasing to stay in shape.\n-----\nI wish I had a dau

In [27]:
# Example of target with class probabilities
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
print(input)
target = torch.randn(3, 5).softmax(dim=1)
print(target)
output = loss(input, target)
output.backward()

tensor([[-0.0754,  1.3154,  1.8942, -1.5395, -1.9320],
        [ 0.6195,  0.4196,  1.4826, -1.3767, -1.4139],
        [-2.5223,  0.6830, -2.5059,  0.1596, -0.1500]], requires_grad=True)
tensor([[0.1011, 0.0790, 0.1870, 0.4164, 0.2165],
        [0.0636, 0.5943, 0.1131, 0.1221, 0.1070],
        [0.5366, 0.0769, 0.0543, 0.3117, 0.0204]])


In [8]:
# wandb.init(
#     project="bloom-personachat",
#     config={
#         "num_samples": NUM_SAMPLES,
#         "batch_size": BATCH_SIZE,
#         "learning_rate": LR,
#         "weight_decay": WEIGHT_DECAY,
#         "num_prefix_tokens": NUM_PREFIX_TOKENS,
#         "model_name": MODEL_NAME,
#         "seed": SEED,
#     }
# )
loss_hist = []
print('wandb initialized\n')

for batch in tqdm(train_dataloader):
    batch = {k: v.to(DEVICE) for k, v in batch.items()}

    model.train()
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()

    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    print(f"Train Loss: {loss}")
    loss_hist.append(loss)

    wandb.log({"Train Loss": loss})

wandb initialized



  0%|          | 1/250 [00:33<2:18:05, 33.28s/it]

Train Loss: 5.8812384605407715


  1%|          | 2/250 [01:05<2:14:39, 32.58s/it]

Train Loss: 6.158038139343262


  1%|          | 3/250 [01:37<2:13:18, 32.38s/it]

Train Loss: 6.3185272216796875


  2%|▏         | 4/250 [02:09<2:11:27, 32.06s/it]

Train Loss: 5.083332538604736


  2%|▏         | 5/250 [02:41<2:10:55, 32.06s/it]

Train Loss: 5.77915620803833


  2%|▏         | 6/250 [03:12<2:09:15, 31.79s/it]

Train Loss: 5.025237560272217


  3%|▎         | 7/250 [03:42<2:07:07, 31.39s/it]

Train Loss: 6.119544506072998


  3%|▎         | 8/250 [04:13<2:05:23, 31.09s/it]

Train Loss: 5.077054023742676


  4%|▎         | 9/250 [04:44<2:04:24, 30.97s/it]

Train Loss: 4.961986064910889


  4%|▍         | 10/250 [05:14<2:03:27, 30.87s/it]

Train Loss: 5.047936916351318


  4%|▍         | 11/250 [05:45<2:03:11, 30.93s/it]

Train Loss: 4.147397518157959


  5%|▍         | 12/250 [06:19<2:05:49, 31.72s/it]

Train Loss: 5.502799987792969


  5%|▌         | 13/250 [06:52<2:06:44, 32.09s/it]

Train Loss: 5.048236846923828


  6%|▌         | 14/250 [07:23<2:05:33, 31.92s/it]

Train Loss: 4.592619895935059


  6%|▌         | 15/250 [07:55<2:04:50, 31.88s/it]

Train Loss: 4.619165897369385


  6%|▋         | 16/250 [08:26<2:03:24, 31.64s/it]

Train Loss: 4.7517499923706055


  7%|▋         | 17/250 [08:57<2:02:26, 31.53s/it]

Train Loss: 4.201253414154053


  7%|▋         | 18/250 [09:28<2:00:59, 31.29s/it]

Train Loss: 4.865632057189941


  8%|▊         | 19/250 [10:00<2:00:36, 31.33s/it]

Train Loss: 4.848703861236572


  8%|▊         | 20/250 [10:31<1:59:47, 31.25s/it]

Train Loss: 4.93528413772583


  8%|▊         | 21/250 [11:02<1:59:52, 31.41s/it]

Train Loss: 4.177621841430664


  9%|▉         | 22/250 [11:34<2:00:00, 31.58s/it]

Train Loss: 4.013919353485107


  9%|▉         | 23/250 [12:05<1:58:32, 31.33s/it]

Train Loss: 3.7343552112579346


 10%|▉         | 24/250 [12:37<1:58:46, 31.53s/it]

Train Loss: 4.407346248626709


 10%|█         | 25/250 [13:08<1:57:58, 31.46s/it]

Train Loss: 3.802149772644043


 10%|█         | 26/250 [13:39<1:56:23, 31.18s/it]

Train Loss: 3.6604557037353516


 11%|█         | 27/250 [14:10<1:55:14, 31.01s/it]

Train Loss: 4.012041091918945


 11%|█         | 28/250 [14:42<1:56:32, 31.50s/it]

Train Loss: 4.028496265411377


 12%|█▏        | 29/250 [15:16<1:57:59, 32.03s/it]

Train Loss: 3.9141862392425537


 12%|█▏        | 30/250 [15:47<1:57:05, 31.93s/it]

Train Loss: 3.6024794578552246


 12%|█▏        | 31/250 [16:17<1:54:41, 31.42s/it]

Train Loss: 3.8550846576690674


 13%|█▎        | 32/250 [16:48<1:53:08, 31.14s/it]

Train Loss: 3.7658965587615967


 13%|█▎        | 33/250 [17:18<1:51:57, 30.95s/it]

Train Loss: 3.674410581588745


 14%|█▎        | 34/250 [17:49<1:51:02, 30.84s/it]

Train Loss: 3.5821011066436768


 14%|█▍        | 35/250 [18:20<1:50:16, 30.77s/it]

Train Loss: 3.7951884269714355


 14%|█▍        | 36/250 [18:50<1:49:47, 30.78s/it]

Train Loss: 3.3672428131103516


 15%|█▍        | 37/250 [19:21<1:49:03, 30.72s/it]

Train Loss: 3.8997981548309326


 15%|█▌        | 38/250 [19:52<1:48:22, 30.67s/it]

Train Loss: 3.5416626930236816


 16%|█▌        | 39/250 [20:22<1:47:59, 30.71s/it]

Train Loss: 3.277491807937622


 16%|█▌        | 40/250 [20:52<1:46:34, 30.45s/it]

Train Loss: 3.5675172805786133


 16%|█▋        | 41/250 [21:23<1:45:58, 30.42s/it]

Train Loss: 3.488633632659912


 17%|█▋        | 42/250 [21:53<1:45:18, 30.38s/it]

Train Loss: 3.5238099098205566


 17%|█▋        | 43/250 [22:23<1:44:43, 30.35s/it]

Train Loss: 3.3181333541870117


 18%|█▊        | 44/250 [22:54<1:44:32, 30.45s/it]

Train Loss: 3.3862814903259277


 18%|█▊        | 45/250 [23:24<1:44:06, 30.47s/it]

Train Loss: 3.283531427383423


 18%|█▊        | 46/250 [23:55<1:44:04, 30.61s/it]

Train Loss: 3.437098503112793


 19%|█▉        | 47/250 [24:26<1:43:52, 30.70s/it]

Train Loss: 3.110487937927246


 19%|█▉        | 48/250 [24:57<1:43:56, 30.87s/it]

Train Loss: 3.411040782928467


 20%|█▉        | 49/250 [25:29<1:44:04, 31.07s/it]

Train Loss: 3.231167793273926


 20%|██        | 50/250 [26:01<1:44:19, 31.30s/it]

Train Loss: 3.13914155960083


 20%|██        | 51/250 [26:34<1:45:18, 31.75s/it]

Train Loss: 3.315136671066284


 21%|██        | 52/250 [27:05<1:44:41, 31.72s/it]

Train Loss: 3.284080743789673


 21%|██        | 53/250 [27:37<1:44:05, 31.70s/it]

Train Loss: 3.1710739135742188


 22%|██▏       | 54/250 [28:09<1:44:08, 31.88s/it]

Train Loss: 3.1664555072784424


 22%|██▏       | 55/250 [28:41<1:43:46, 31.93s/it]

Train Loss: 3.0723533630371094


 22%|██▏       | 56/250 [29:14<1:43:54, 32.14s/it]

Train Loss: 3.113518476486206


 23%|██▎       | 57/250 [29:46<1:43:33, 32.19s/it]

Train Loss: 2.8343820571899414


 23%|██▎       | 58/250 [30:19<1:43:53, 32.47s/it]

Train Loss: 3.1861753463745117


 24%|██▎       | 59/250 [30:53<1:44:12, 32.73s/it]

Train Loss: 2.976379871368408


 24%|██▍       | 60/250 [31:23<1:41:46, 32.14s/it]

Train Loss: 3.065255880355835


 24%|██▍       | 61/250 [31:56<1:41:35, 32.25s/it]

Train Loss: 3.0720670223236084


 25%|██▍       | 62/250 [32:28<1:40:50, 32.18s/it]

Train Loss: 2.7995290756225586


 25%|██▌       | 63/250 [33:00<1:40:30, 32.25s/it]

Train Loss: 2.917301893234253


 26%|██▌       | 64/250 [33:33<1:40:37, 32.46s/it]

Train Loss: 2.8905181884765625


 26%|██▌       | 65/250 [34:06<1:40:07, 32.47s/it]

Train Loss: 2.8617348670959473


 26%|██▋       | 66/250 [34:38<1:39:18, 32.38s/it]

Train Loss: 2.9991891384124756


 27%|██▋       | 67/250 [35:10<1:38:22, 32.25s/it]

Train Loss: 2.7494521141052246


 27%|██▋       | 68/250 [35:42<1:37:36, 32.18s/it]

Train Loss: 2.90533709526062


 28%|██▊       | 69/250 [36:14<1:36:39, 32.04s/it]

Train Loss: 2.8967907428741455


 28%|██▊       | 70/250 [36:45<1:35:39, 31.88s/it]

Train Loss: 2.866926670074463


 28%|██▊       | 71/250 [37:17<1:35:02, 31.86s/it]

Train Loss: 2.836390972137451


 29%|██▉       | 72/250 [37:49<1:34:59, 32.02s/it]

Train Loss: 2.85862398147583


 29%|██▉       | 73/250 [38:22<1:34:46, 32.13s/it]

Train Loss: 2.5076403617858887


 30%|██▉       | 74/250 [38:54<1:34:14, 32.13s/it]

Train Loss: 2.931793689727783


 30%|███       | 75/250 [39:26<1:33:59, 32.23s/it]

Train Loss: 2.7575552463531494


 30%|███       | 76/250 [39:59<1:33:44, 32.32s/it]

Train Loss: 2.8047196865081787


 31%|███       | 77/250 [40:31<1:33:18, 32.36s/it]

Train Loss: 2.7385048866271973


 31%|███       | 78/250 [41:04<1:32:38, 32.32s/it]

Train Loss: 2.5961692333221436


 32%|███▏      | 79/250 [41:35<1:31:43, 32.19s/it]

Train Loss: 2.6907615661621094


 32%|███▏      | 80/250 [42:07<1:30:59, 32.11s/it]

Train Loss: 2.7227864265441895


 32%|███▏      | 81/250 [42:39<1:29:56, 31.93s/it]

Train Loss: 2.743349313735962


 33%|███▎      | 82/250 [43:11<1:29:42, 32.04s/it]

Train Loss: 2.73982834815979


 33%|███▎      | 83/250 [43:42<1:28:19, 31.73s/it]

Train Loss: 2.6048011779785156


 34%|███▎      | 84/250 [44:14<1:28:10, 31.87s/it]

Train Loss: 2.677286148071289


 34%|███▍      | 85/250 [44:46<1:27:14, 31.72s/it]

Train Loss: 2.6447112560272217


 34%|███▍      | 86/250 [45:18<1:27:01, 31.84s/it]

Train Loss: 2.5557456016540527


 35%|███▍      | 87/250 [45:49<1:26:04, 31.68s/it]

Train Loss: 2.641582489013672


 35%|███▌      | 88/250 [46:20<1:25:00, 31.48s/it]

Train Loss: 2.6647748947143555


 36%|███▌      | 89/250 [46:52<1:24:25, 31.46s/it]

Train Loss: 2.713804244995117


 36%|███▌      | 90/250 [47:24<1:24:23, 31.65s/it]

Train Loss: 2.7313854694366455


 36%|███▋      | 91/250 [47:56<1:24:13, 31.78s/it]

Train Loss: 2.5104103088378906


 37%|███▋      | 92/250 [48:28<1:23:40, 31.77s/it]

Train Loss: 2.6287195682525635


 37%|███▋      | 93/250 [49:00<1:23:16, 31.82s/it]

Train Loss: 2.589596748352051


 38%|███▊      | 94/250 [49:31<1:22:25, 31.70s/it]

Train Loss: 2.6537177562713623


 38%|███▊      | 95/250 [50:03<1:21:58, 31.73s/it]

Train Loss: 2.460665702819824


 38%|███▊      | 96/250 [50:35<1:21:27, 31.74s/it]

Train Loss: 2.6272408962249756


 39%|███▉      | 97/250 [51:06<1:20:59, 31.76s/it]

Train Loss: 2.611649751663208


 39%|███▉      | 98/250 [51:37<1:19:41, 31.45s/it]

Train Loss: 2.44063663482666


 40%|███▉      | 99/250 [52:08<1:19:05, 31.43s/it]

Train Loss: 2.4935708045959473


 40%|████      | 100/250 [52:39<1:18:02, 31.21s/it]

Train Loss: 2.409193515777588


 40%|████      | 101/250 [53:10<1:17:08, 31.06s/it]

Train Loss: 2.5296080112457275


 41%|████      | 102/250 [53:40<1:16:11, 30.89s/it]

Train Loss: 2.4958877563476562


 41%|████      | 103/250 [54:12<1:15:54, 30.98s/it]

Train Loss: 2.590679168701172


 42%|████▏     | 104/250 [54:43<1:16:02, 31.25s/it]

Train Loss: 2.6269984245300293


 42%|████▏     | 105/250 [55:15<1:15:31, 31.25s/it]

Train Loss: 2.5811498165130615


 42%|████▏     | 106/250 [55:46<1:14:57, 31.23s/it]

Train Loss: 2.6402740478515625


 43%|████▎     | 107/250 [56:18<1:14:46, 31.37s/it]

Train Loss: 2.4228885173797607


 43%|████▎     | 108/250 [56:48<1:13:50, 31.20s/it]

Train Loss: 2.5843892097473145


 44%|████▎     | 109/250 [57:20<1:13:48, 31.41s/it]

Train Loss: 2.456522226333618


 44%|████▍     | 110/250 [57:52<1:13:13, 31.39s/it]

Train Loss: 2.5367212295532227


 44%|████▍     | 111/250 [58:23<1:13:01, 31.52s/it]

Train Loss: 2.433985471725464


 45%|████▍     | 112/250 [58:54<1:11:43, 31.19s/it]

Train Loss: 2.5998218059539795


 45%|████▌     | 113/250 [59:25<1:11:03, 31.12s/it]

Train Loss: 2.5738017559051514


 46%|████▌     | 114/250 [59:55<1:10:00, 30.89s/it]

Train Loss: 2.423733711242676


 46%|████▌     | 115/250 [1:00:26<1:09:27, 30.87s/it]

Train Loss: 2.607238531112671


 46%|████▋     | 116/250 [1:00:57<1:08:47, 30.80s/it]

Train Loss: 2.523813486099243


 47%|████▋     | 117/250 [1:01:27<1:07:55, 30.65s/it]

Train Loss: 2.329479455947876


 47%|████▋     | 118/250 [1:01:57<1:07:13, 30.56s/it]

Train Loss: 2.5054357051849365


 48%|████▊     | 119/250 [1:02:28<1:07:05, 30.73s/it]

Train Loss: 2.51655650138855


 48%|████▊     | 120/250 [1:02:59<1:06:13, 30.57s/it]

Train Loss: 2.380293846130371


 48%|████▊     | 121/250 [1:03:30<1:05:57, 30.68s/it]

Train Loss: 2.4686641693115234


 49%|████▉     | 122/250 [1:04:00<1:05:17, 30.60s/it]

Train Loss: 2.3209757804870605


 49%|████▉     | 123/250 [1:04:31<1:04:52, 30.65s/it]

Train Loss: 2.400566577911377


 50%|████▉     | 124/250 [1:05:02<1:04:31, 30.73s/it]

Train Loss: 2.312504291534424


 50%|█████     | 125/250 [1:05:32<1:03:54, 30.68s/it]

Train Loss: 2.4093875885009766


 50%|█████     | 126/250 [1:06:03<1:03:32, 30.75s/it]

Train Loss: 2.3602957725524902


 51%|█████     | 127/250 [1:06:34<1:03:24, 30.93s/it]

Train Loss: 2.359422206878662


 51%|█████     | 128/250 [1:07:05<1:02:58, 30.97s/it]

Train Loss: 2.5276870727539062


 52%|█████▏    | 129/250 [1:07:36<1:02:13, 30.85s/it]

Train Loss: 2.5789918899536133


 52%|█████▏    | 130/250 [1:08:07<1:01:38, 30.82s/it]

Train Loss: 2.35882568359375


 52%|█████▏    | 131/250 [1:08:37<1:01:01, 30.76s/it]

Train Loss: 2.3288216590881348


 53%|█████▎    | 132/250 [1:09:08<1:00:36, 30.82s/it]

Train Loss: 2.2998480796813965


 53%|█████▎    | 133/250 [1:09:39<59:50, 30.69s/it]  

Train Loss: 2.510432004928589


 54%|█████▎    | 134/250 [1:10:09<58:59, 30.51s/it]

Train Loss: 2.402573823928833


 54%|█████▍    | 135/250 [1:10:39<58:17, 30.42s/it]

Train Loss: 2.456371545791626


 54%|█████▍    | 136/250 [1:11:09<57:31, 30.28s/it]

Train Loss: 2.3318004608154297


 55%|█████▍    | 137/250 [1:11:40<57:09, 30.35s/it]

Train Loss: 2.559959888458252


 55%|█████▌    | 138/250 [1:12:10<56:45, 30.40s/it]

Train Loss: 2.565134048461914


 56%|█████▌    | 139/250 [1:12:41<56:46, 30.68s/it]

Train Loss: 2.296376943588257


 56%|█████▌    | 140/250 [1:13:13<56:38, 30.90s/it]

Train Loss: 2.386653423309326


 56%|█████▋    | 141/250 [1:13:44<56:15, 30.97s/it]

Train Loss: 2.3849446773529053


 57%|█████▋    | 142/250 [1:14:15<55:32, 30.86s/it]

Train Loss: 2.2451601028442383


 57%|█████▋    | 143/250 [1:14:45<54:58, 30.82s/it]

Train Loss: 2.289623498916626


 58%|█████▊    | 144/250 [1:15:16<54:12, 30.69s/it]

Train Loss: 2.3907856941223145


 58%|█████▊    | 145/250 [1:15:46<53:41, 30.68s/it]

Train Loss: 2.3045554161071777


 58%|█████▊    | 146/250 [1:16:18<53:30, 30.87s/it]

Train Loss: 2.3177802562713623


 59%|█████▉    | 147/250 [1:16:49<53:16, 31.03s/it]

Train Loss: 2.443358898162842


 59%|█████▉    | 148/250 [1:17:19<52:16, 30.75s/it]

Train Loss: 2.441258668899536


 60%|█████▉    | 149/250 [1:17:50<51:37, 30.67s/it]

Train Loss: 2.512977123260498


 60%|██████    | 150/250 [1:18:21<51:18, 30.78s/it]

Train Loss: 2.459224224090576


 60%|██████    | 151/250 [1:18:51<50:26, 30.57s/it]

Train Loss: 2.1266398429870605


 61%|██████    | 152/250 [1:19:21<49:46, 30.47s/it]

Train Loss: 2.375915765762329


 61%|██████    | 153/250 [1:19:51<49:12, 30.43s/it]

Train Loss: 2.252338409423828


 62%|██████▏   | 154/250 [1:20:21<48:28, 30.29s/it]

Train Loss: 2.1460061073303223


 62%|██████▏   | 155/250 [1:20:52<48:12, 30.45s/it]

Train Loss: 2.251965045928955


 62%|██████▏   | 156/250 [1:21:22<47:39, 30.42s/it]

Train Loss: 2.22326922416687


 63%|██████▎   | 157/250 [1:21:52<46:55, 30.28s/it]

Train Loss: 2.284468412399292


 63%|██████▎   | 158/250 [1:22:23<46:42, 30.47s/it]

Train Loss: 2.23396635055542


 64%|██████▎   | 159/250 [1:22:53<45:58, 30.32s/it]

Train Loss: 2.2709875106811523


 64%|██████▍   | 160/250 [1:23:23<45:23, 30.26s/it]

Train Loss: 2.2408230304718018


 64%|██████▍   | 161/250 [1:23:54<44:54, 30.27s/it]

Train Loss: 2.5437631607055664


 65%|██████▍   | 162/250 [1:24:24<44:22, 30.26s/it]

Train Loss: 2.2765543460845947


 65%|██████▌   | 163/250 [1:24:54<43:55, 30.30s/it]

Train Loss: 2.1473639011383057


 66%|██████▌   | 164/250 [1:25:25<43:42, 30.49s/it]

Train Loss: 2.2796924114227295


 66%|██████▌   | 165/250 [1:25:56<43:18, 30.56s/it]

Train Loss: 2.464749336242676


 66%|██████▋   | 166/250 [1:26:27<42:54, 30.65s/it]

Train Loss: 2.226282835006714


 67%|██████▋   | 167/250 [1:26:58<42:37, 30.82s/it]

Train Loss: 2.433509588241577


 67%|██████▋   | 168/250 [1:27:28<41:51, 30.63s/it]

Train Loss: 2.271265983581543


 68%|██████▊   | 169/250 [1:27:59<41:23, 30.67s/it]

Train Loss: 2.311041831970215


 68%|██████▊   | 170/250 [1:28:29<40:37, 30.46s/it]

Train Loss: 2.4015541076660156


 68%|██████▊   | 171/250 [1:29:00<40:17, 30.60s/it]

Train Loss: 2.2536227703094482


 69%|██████▉   | 172/250 [1:29:31<39:51, 30.67s/it]

Train Loss: 2.180079936981201


 69%|██████▉   | 173/250 [1:30:01<39:20, 30.66s/it]

Train Loss: 2.14896821975708


 70%|██████▉   | 174/250 [1:30:32<38:45, 30.60s/it]

Train Loss: 2.179696559906006


 70%|███████   | 175/250 [1:31:03<38:25, 30.74s/it]

Train Loss: 2.284982442855835


 70%|███████   | 176/250 [1:31:34<38:01, 30.83s/it]

Train Loss: 2.1029622554779053


 71%|███████   | 177/250 [1:32:06<37:51, 31.12s/it]

Train Loss: 2.0492265224456787


 71%|███████   | 178/250 [1:32:36<37:07, 30.94s/it]

Train Loss: 2.1655073165893555


 72%|███████▏  | 179/250 [1:33:08<36:51, 31.14s/it]

Train Loss: 2.145443916320801


 72%|███████▏  | 180/250 [1:33:39<36:29, 31.28s/it]

Train Loss: 2.6361117362976074


 72%|███████▏  | 181/250 [1:34:10<35:53, 31.20s/it]

Train Loss: 2.139732837677002


 73%|███████▎  | 182/250 [1:34:41<35:05, 30.97s/it]

Train Loss: 2.198423147201538


 73%|███████▎  | 183/250 [1:35:11<34:23, 30.80s/it]

Train Loss: 2.24401593208313


 74%|███████▎  | 184/250 [1:35:42<33:44, 30.68s/it]

Train Loss: 2.069396734237671


 74%|███████▍  | 185/250 [1:36:12<33:12, 30.66s/it]

Train Loss: 2.0441231727600098


 74%|███████▍  | 186/250 [1:36:43<32:38, 30.60s/it]

Train Loss: 2.036651849746704


 75%|███████▍  | 187/250 [1:37:13<32:06, 30.59s/it]

Train Loss: 2.105814218521118


 75%|███████▌  | 188/250 [1:37:44<31:43, 30.70s/it]

Train Loss: 2.1078715324401855


 76%|███████▌  | 189/250 [1:38:15<31:08, 30.63s/it]

Train Loss: 2.1771240234375


 76%|███████▌  | 190/250 [1:38:46<30:42, 30.71s/it]

Train Loss: 2.1337177753448486


 76%|███████▋  | 191/250 [1:39:16<30:04, 30.58s/it]

Train Loss: 1.9625576734542847


 77%|███████▋  | 192/250 [1:39:46<29:30, 30.52s/it]

Train Loss: 2.1983182430267334


 77%|███████▋  | 193/250 [1:40:16<28:51, 30.37s/it]

Train Loss: 2.4785044193267822


 78%|███████▊  | 194/250 [1:40:46<28:14, 30.25s/it]

Train Loss: 2.002462148666382


 78%|███████▊  | 195/250 [1:41:16<27:40, 30.18s/it]

Train Loss: 2.2556262016296387


 78%|███████▊  | 196/250 [1:41:47<27:10, 30.20s/it]

Train Loss: 2.34643292427063


 79%|███████▉  | 197/250 [1:42:17<26:47, 30.34s/it]

Train Loss: 2.522667407989502


 79%|███████▉  | 198/250 [1:42:49<26:32, 30.63s/it]

Train Loss: 2.169942617416382


 80%|███████▉  | 199/250 [1:43:19<26:00, 30.59s/it]

Train Loss: 2.2558305263519287


 80%|███████▉  | 199/250 [1:43:21<26:29, 31.16s/it]


KeyboardInterrupt: 

In [16]:
%%time
outputs = model.generate(
    tokenizer(['something'], return_tensors='pt')['input_ids'],
    temperature=1.0,
    do_sample=True,
    top_k=10,
    eos_token_id=tokenizer.eos_token_id,
    max_new_tokens=250,
)

CPU times: user 28min 16s, sys: 7.11 s, total: 28min 23s
Wall time: 1min 57s


In [20]:
print(tokenizer.decode(outputs[0]))

something else, it shouldnt be able to recognize the word "engineer", as it does not exist in its lexicon.
Is it possible to do this?
Thanks in advance,
Alex.

A:

The problem is that the "engineer" word isn't in your lexicon, therefore, when you run the model and the model sees "engineer", it doesn't understand it. You can try the following code to make it understand it (just replace "engineer" with the word you want):
from __future__ import print_function

from gensim import models
from sklearn.externals.joblib import Parallel, delayed
from sklearn.preprocessing import StandardScaler
from gensim.models import Word2Vec, SkipGramModel
from gensim.models.word2vec.utils import pad_sequences

import numpy as np
from scipy.sparse import csr_matrix 
from sklearn import preprocessing
from sklearn.feature_selection import f_classif
from sklearn.cross_validation import KfoldCV as sklearnCV
from gensim.models import Word2Vec, SkipGramModel
from gensim.models.word2vec.utils import pad_sequences


Try to talk with the trained model! Submit an empty input to stop the execution.


__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster "interactive" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica.

In [None]:
MAX_TOKENS = 16
TOP_K = 100
TEMPERATURE = 0.6
dialog = ""

while True:
    user_phrase = input()
    if len(user_phrase) == 0:
        break
    dialog += f"{user_phrase}\n-----\n"
    inputs = tokenizer([dialog], return_tensors='pt')['input_ids']
    outputs = model.generate(
        inputs,
        temperature=TEMPERATURE,
        do_sample=True,
        top_k=TOP_K,
        eos_token_id=tokenizer.eos_token_id,
        max_new_tokens=MAX_TOKENS,
    )
    bloom_answer = tokenizer.batch_decode(outputs)[0]
    bloom_answer = bloom_answer[len(dialog):].split("\n")[0]
    print(bloom_answer)
    dialog += f"{bloom_answer}\n-----\n"