In [1]:
!pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
!pip install transformers
!pip install datasets
!pip install deepspeed
!pip install mpi4py
!pip install accelerate

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu113
Collecting torch
  Downloading https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp37-cp37m-linux_x86_64.whl (1837.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 GB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m00:01[0mm00:04[0m
Installing collected packages: torch
Successfully installed torch-1.12.1+cu113
Collecting transformers
  Using cached transformers-4.23.1-py3-none-any.whl (5.3 MB)
Collecting regex!=2019.12.17
  Downloading regex-2022.9.13-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (757 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m757.0/757.0 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [

Collecting pytz>=2017.3
  Using cached pytz-2022.4-py2.py3-none-any.whl (500 kB)
Installing collected packages: pytz, xxhash, pyarrow, multidict, fsspec, frozenlist, dill, asynctest, async-timeout, yarl, responses, pandas, multiprocess, aiosignal, aiohttp, datasets
Successfully installed aiohttp-3.8.3 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 datasets-2.6.1 dill-0.3.5.1 frozenlist-1.3.1 fsspec-2022.8.2 multidict-6.0.2 multiprocess-0.70.13 pandas-1.3.5 pyarrow-9.0.0 pytz-2022.4 responses-0.18.0 xxhash-3.0.0 yarl-1.8.1
Collecting deepspeed
  Using cached deepspeed-0.7.3-py3-none-any.whl
Collecting py-cpuinfo
  Using cached py_cpuinfo-8.0.0-py3-none-any.whl
Collecting pydantic
  Downloading pydantic-1.10.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.8/11.8 MB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting hjson
  Using cached hjson-3.1.0-py3-none-any.whl (5

In [1]:
import os, re, math, random
import torch
import numpy as np

from datasets import load_dataset
from transformers import BloomTokenizerFast, BloomForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_path = os.getcwd()
model_root_path = os.path.join(root_path, "models")
data_root_path = os.path.join(root_path, "data")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [3]:
models = [
    ("bloom-560m", "bloom-560m", 4),
    ("finetuned-bloom-560m-002-e2", "bloom-560m", 0),
    ("finetuned-bloom-560m-004-e2-bloom-data", "bloom-560m", 0),
    ("bloom-1b7", "bloom-1b7", 4),
    ("finetuned-bloom-1b7-003-e2", "bloom-1b7", 0),
    ("finetuned-bloom-1b7-004-e2-bloom-data", "bloom-1b7", 0),
    ("bloom-3b", "bloom-3b", 4),
]

In [4]:
data_file = "handwritten_data.jsonl"
dataset = load_dataset("json", data_files=data_file, data_dir=data_root_path)["train"]
dataset

Using custom data configuration default-eae2243bd3a8d665


Downloading and preparing dataset json/default to /home/gordon/.cache/huggingface/datasets/json/default-eae2243bd3a8d665/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab...


Downloading data files: 100%|████████████████████████████████████████████████████| 1/1 [00:00<00:00, 928.97it/s]
Extracting data files: 100%|█████████████████████████████████████████████████████| 1/1 [00:00<00:00, 456.50it/s]
                            

Dataset json downloaded and prepared to /home/gordon/.cache/huggingface/datasets/json/default-eae2243bd3a8d665/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab. Subsequent calls will reuse this data.


100%|████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 234.59it/s]


Dataset({
    features: ['email', 'summary'],
    num_rows: 167
})

In [5]:
prompt_end_token = "==="
completion_end_token = "END"
batch_size = 1

In [6]:
for m in models:
    model_name, base_model, shot_num = m
    few_shot = shot_num > 0
    model_path = os.path.join(model_root_path, model_name)
    model = BloomForCausalLM.from_pretrained(model_path).to(device)
    tokenizer = BloomTokenizerFast.from_pretrained(f"bigscience/{base_model}")
    prompt_end_token_id = tokenizer.encode(prompt_end_token)[0]
    completion_end_token_id = tokenizer.encode(completion_end_token)[0]
    print(model_name, shot_num)
    
    relevant_texts = zip(dataset["email"], dataset["summary"])
    text_data = [t[0] + "\n\n===\n\n" + t[1] + "\nEND" for t in relevant_texts]
    if few_shot:
        few_shot_data = []
        for i in range(len(text_data)):
            if i < len(text_data) / shot_num:
                primers = random.sample(text_data[i+1:], 2)
            else:
                primers = random.sample(text_data[:i], 2)
            text = primers[0] + "\n\n" + primers[1] + "\n\n" + text_data[i]
            few_shot_data.append(text)
        text_data = few_shot_data
        
    loss = 0.0
    for text in text_data:
        input_ids = tokenizer.encode(text, return_tensors="pt")

        label_ids = input_ids.clone()
        start_indices = (input_ids == prompt_end_token_id).nonzero()
        end_indices = (input_ids == completion_end_token_id).nonzero()
        start = torch.max(start_indices)
        end = torch.max(end_indices)
        label_ids[0,:start+1] = -100
        label_ids[0,end:] = -100

        input_ids = input_ids.to(device)
        label_ids = label_ids.to(device)
        res = model(input_ids, labels=label_ids)
        loss += res.loss.item()
        
    loss = loss / len(text_data)
    print("loss:", loss)

bloom-560m 4
loss: 2.8537353801870062
finetuned-bloom-560m-002-e2 0
loss: 1.124883704329597
finetuned-bloom-560m-004-e2-bloom-data 0
loss: 1.1768863393040165
bloom-1b7 4
loss: 2.700445728186897
finetuned-bloom-1b7-003-e2 0
loss: 1.1033410306920264
finetuned-bloom-1b7-004-e2-bloom-data 0
loss: 1.1443349683490334
bloom-3b 4
loss: 2.1599639508494004
