In [1]:
from transformers import AutoModelForCausalLM
# import Prompt Tuning Config class 
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType
import torch
from datasets import load_dataset
import os
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import default_data_collator, get_linear_schedule_with_warmup
from tqdm import tqdm
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm





## Prompt Tuning 參數設定

In [11]:
device = "cuda"

model_name_or_path = "bigscience/bloomz-560m"
tokenizer_name_or_path = "bigscience/bloomz-560m"

# 建立微調方法
peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM, # 指定任務類型， e.g. 條件生成任務（SEQ_2_SEQ_LM），因果語言建模（CAUSAL_LM）等...
    prompt_tuning_init=PromptTuningInit.TEXT, # 提示嵌入的初始方法，PEFT支持文本（TEXT）和随機（RANDOM）初始化
    num_virtual_tokens=8, # 指定虛擬 Token 數量
    prompt_tuning_init_text="Classify if the tweet is a complaint or not:", # 用於初始化提示嵌入的文本，在使用文本（TEXT）初始化方法時使用
    tokenizer_name_or_path=model_name_or_path,
)

dataset_name = "twitter_complaints"

text_column = "Tweet text"
label_column = "text_label"
max_length = 64
lr = 3e-2
num_epochs = 10
batch_size = 8

In [3]:
peft_config

PromptTuningConfig(peft_type=<PeftType.PROMPT_TUNING: 'PROMPT_TUNING'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, num_virtual_tokens=8, token_dim=None, num_transformer_submodules=None, num_attention_heads=None, num_layers=None, prompt_tuning_init=<PromptTuningInit.TEXT: 'TEXT'>, prompt_tuning_init_text='Classify if the tweet is a complaint or not:', tokenizer_name_or_path='/data/nfs/llm/model/bloomz-560m', tokenizer_kwargs=None)

## 讀取資料
- 使用 PyTorch datasets

In [4]:
from datasets import load_dataset

dataset = load_dataset("ought/raft", dataset_name)
# dataset = load_dataset("/home/guodong.li/data/peft/raft/raft.py", dataset_name, cache_dir="/home/guodong.li/data/peft/data")

classes = [k.replace("_", " ") for k in dataset["train"].features["Label"].names]
print(classes)

dataset = dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["Label"]]},
    batched=True,
    num_proc=1,
)
print(dataset)

dataset["train"][0]

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


['Unlabeled', 'complaint', 'no complaint']


Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 4331.88 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3399/3399 [00:00<00:00, 208937.60 examples/s]

DatasetDict({
    train: Dataset({
        features: ['Tweet text', 'ID', 'Label', 'text_label'],
        num_rows: 50
    })
    test: Dataset({
        features: ['Tweet text', 'ID', 'Label', 'text_label'],
        num_rows: 3399
    })
})





{'Tweet text': '@HMRCcustomers No this is my first job',
 'ID': 0,
 'Label': 2,
 'text_label': 'no complaint'}

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['Tweet text', 'ID', 'Label', 'text_label'],
        num_rows: 50
    })
    test: Dataset({
        features: ['Tweet text', 'ID', 'Label', 'text_label'],
        num_rows: 3399
    })
})

In [6]:
dataset["train"]

Dataset({
    features: ['Tweet text', 'ID', 'Label', 'text_label'],
    num_rows: 50
})

In [7]:
dataset["train"].features

{'Tweet text': Value(dtype='string', id=None),
 'ID': Value(dtype='int32', id=None),
 'Label': ClassLabel(names=['Unlabeled', 'complaint', 'no complaint'], id=None),
 'text_label': Value(dtype='string', id=None)}

In [8]:
dataset["train"].features["Label"]

ClassLabel(names=['Unlabeled', 'complaint', 'no complaint'], id=None)

In [9]:
dataset["train"].features["Label"].names

['Unlabeled', 'complaint', 'no complaint']

In [22]:
text_column

'Tweet text'

In [24]:
dir(dataset["train"].features[text_column])

['__annotations__',
 '__call__',
 '__class__',
 '__dataclass_fields__',
 '__dataclass_params__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__post_init__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slotnames__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_type',
 'dtype',
 'encode_example',
 'id',
 'pa_type']

In [29]:
dataset["train"].features[text_column]

Value(dtype='string', id=None)

In [28]:
dataset["train"].features[text_column].pa_type

DataType(string)

In [31]:
dataset["train"][0][text_column]

'@HMRCcustomers No this is my first job'

In [32]:
dataset["train"][0][label_column]

'no complaint'

## 資料前處理

In [39]:
""" Data format
DatasetDict({
    train: Dataset({
        features: ['Tweet text', 'ID', 'Label', 'text_label'],
        num_rows: 50
    })
    test: Dataset({
        features: ['Tweet text', 'ID', 'Label', 'text_label'],
        num_rows: 3399
    })
})
"""

text_column = "Tweet text"
label_column = "text_label"
max_length = 64
lr = 3e-2
num_epochs = 10
batch_size = 8

# data preprocessing
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

target_max_length = max([len(tokenizer(class_label)["input_ids"]) for class_label in classes])
# 在 class label: ['Unlabeled', 'complaint', 'no complaint'] 當中，經過轉換後所需的 token 長度
print("target_max_length:", target_max_length)


# 預處理模型的 input
def preprocess_function(examples):
    batch_size = len(examples[text_column]) # text_column: 'Tweet text'
    print("batch_size:", batch_size)
    
    # 處理 input 資料格式
    ## [""" 'Tweet text': '@HMRCcustomers No this is my first job' Label : """]
    inputs = [f"{text_column} : {x} Label : " for x in examples[text_column]]
    # label_column: 'text_label', x 是 ['Unlabeled', 'complaint', 'no complaint'] 其中一個
    targets = [str(x) for x in examples[label_column]]
    
    model_inputs = tokenizer(inputs) # 段詞
    labels = tokenizer(targets) # 轉成 token
    
    # 建立 輸入資料以及 fine-tuning 模型所需要的資訊
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i] + [tokenizer.pad_token_id] # tokenizer.pad_token_id = 3
        if i == 0:
            print(i, sample_input_ids, label_input_ids)
        model_inputs["input_ids"][i] = sample_input_ids + label_input_ids
        labels["input_ids"][i] = [-100] * len(sample_input_ids) + label_input_ids
        model_inputs["attention_mask"][i] = [1] * len(model_inputs["input_ids"][i])
    print(model_inputs)
    
    # 進一步處理成符合模型需求的格式
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i]
        
        model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * (max_length - len(sample_input_ids)) + sample_input_ids # Pre-padding: 後面內容較為重要， e.g. tensor([ 3, 3, 915, 210, 1936, 106863, 3])
        model_inputs["attention_mask"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs["attention_mask"][i] # 告知模型 attention mask, e.g. tensor([0, 0, 0, 1, 1, 1])
        labels["input_ids"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids
        
        # max_length: 64
        model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_length])
        model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_length])
        labels["input_ids"][i] = torch.tensor(labels["input_ids"][i][:max_length])
        if i == 0:
            print("model_inputs input_ids:", model_inputs["input_ids"][i])
            print("model_inputs attention_mask:", model_inputs["attention_mask"][i])
            print("labels input_ids:", labels["input_ids"][i])
        
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


print("column_names:", dataset["train"].column_names)

# 將原始資料集和測試資料集同時預處理，然後作為訓練和評估數據集
processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["train"]

# 訓練與評估使用同一份數據，但是訓練數據打亂
train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)
print(len(train_dataloader))
print(len(eval_dataloader))

target_max_length: 3
column_names: ['Tweet text', 'ID', 'Label', 'text_label']


Running tokenizer on dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 1282.21 examples/s]

batch_size: 50
0 [227985, 5484, 915, 2566, 169403, 15296, 36272, 525, 3928, 1119, 632, 2670, 3968, 15270, 77658, 915, 210] [1936, 106863, 3]
{'input_ids': [[227985, 5484, 915, 2566, 169403, 15296, 36272, 525, 3928, 1119, 632, 2670, 3968, 15270, 77658, 915, 210, 1936, 106863, 3], [227985, 5484, 915, 2566, 88653, 2321, 144017, 138861, 59283, 1152, 613, 2632, 12120, 4, 5673, 1152, 32153, 427, 36992, 15, 1152, 1400, 5065, 114438, 66455, 919, 404, 146304, 14078, 87856, 7061, 2906, 17, 77658, 915, 210, 1936, 106863, 3], [227985, 5484, 915, 5673, 473, 11229, 2213, 2670, 35307, 28629, 461, 2566, 2765, 1531, 3470, 47134, 10144, 2765, 1531, 427, 2909, 17918, 6782, 27268, 4390, 1517, 17, 3904, 632, 267, 6497, 483, 361, 2670, 101848, 17, 32465, 9585, 2566, 37, 2481, 2566, 37, 2481, 12384, 77658, 915, 210, 16449, 5952, 3], [227985, 5484, 915, 2566, 15157, 4867, 14731, 165189, 2021, 769, 11528, 7220, 35025, 530, 27937, 149533, 1965, 43435, 163255, 1141, 3611, 17, 30655, 632, 1119, 17, 77658, 915, 21


Running tokenizer on dataset:  29%|█████████████████████████████████████▋                                                                                          | 1000/3399 [00:00<00:00, 8728.99 examples/s]

batch_size: 1000
0 [227985, 5484, 915, 2566, 74757, 64626, 12384, 44639, 613, 52282, 2670, 79920, 3344, 1002, 368, 17646, 14472, 8348, 664, 718, 4, 19036, 17, 31849, 17, 6312, 76, 44, 62470, 56, 91, 50, 14839, 21, 77658, 915, 210] [3074, 4762, 60943, 3]
{'input_ids': [[227985, 5484, 915, 2566, 74757, 64626, 12384, 44639, 613, 52282, 2670, 79920, 3344, 1002, 368, 17646, 14472, 8348, 664, 718, 4, 19036, 17, 31849, 17, 6312, 76, 44, 62470, 56, 91, 50, 14839, 21, 77658, 915, 210, 3074, 4762, 60943, 3], [227985, 5484, 915, 405, 187059, 2256, 664, 2550, 18833, 18607, 162467, 4, 1387, 6199, 3291, 23405, 613, 4657, 17082, 566, 3432, 368, 78851, 1185, 61273, 23181, 1553, 15596, 212, 116057, 77658, 915, 210, 3074, 4762, 60943, 3], [227985, 5484, 915, 39762, 2566, 22253, 6201, 75701, 15, 632, 718, 5840, 10006, 6201, 18881, 427, 3804, 19528, 267, 158974, 1320, 368, 10029, 632, 49666, 92, 34, 77658, 915, 210, 3074, 4762, 60943, 3], [227985, 5484, 915, 2566, 104565, 8695, 2089, 6140, 109676, 99579, 

Running tokenizer on dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3399/3399 [00:00<00:00, 8490.51 examples/s]

batch_size: 1000
0 [227985, 5484, 915, 2566, 99198, 53312, 2566, 99198, 53312, 7064, 1074, 1800, 138435, 17, 77658, 915, 210] [3074, 4762, 60943, 3]
{'input_ids': [[227985, 5484, 915, 2566, 99198, 53312, 2566, 99198, 53312, 7064, 1074, 1800, 138435, 17, 77658, 915, 210, 3074, 4762, 60943, 3], [227985, 5484, 915, 2566, 64228, 2309, 18584, 3595, 361, 368, 8876, 17, 47411, 6281, 361, 158974, 64787, 361, 8431, 61970, 17, 53003, 19168, 4, 4020, 2782, 267, 13473, 613, 660, 6281, 3269, 1119, 34, 77658, 915, 210, 3074, 4762, 60943, 3], [227985, 5484, 915, 2566, 96186, 29756, 351, 5568, 42696, 4472, 2782, 632, 267, 2550, 47490, 1199, 361, 2550, 47, 120856, 2550, 102289, 72697, 4020, 718, 99433, 3390, 46898, 661, 6355, 34, 77658, 915, 210, 3074, 4762, 60943, 3], [227985, 5484, 915, 2550, 1339, 19336, 11257, 1306, 1152, 1130, 664, 368, 5025, 5746, 661, 368, 57880, 14519, 525, 53078, 2525, 42696, 47258, 1074, 201363, 4, 2566, 1339, 19336, 5231, 262, 77658, 915, 210, 3074, 4762, 60943, 3], [227985,




In [37]:
processed_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 50
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 3399
    })
})

In [38]:
processed_datasets["train"]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 50
})

In [36]:
dataset["train"].column_names

['Tweet text', 'ID', 'Label', 'text_label']

In [35]:
print("column_names:", dataset["train"].column_names)

column_names: ['Tweet text', 'ID', 'Label', 'text_label']


In [13]:
tokenizer

BloomTokenizerFast(name_or_path='bigscience/bloomz-560m', vocab_size=250680, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [14]:
tokenizer.pad_token_id

3

In [15]:
tokenizer.eos_token_id

2

In [None]:
target_max_length = max([len(tokenizer(class_label)["input_ids"]) for class_label in classes])

In [17]:
[ class_label for class_label in classes]

['Unlabeled', 'complaint', 'no complaint']

In [18]:
[ tokenizer(class_label) for class_label in classes]

[{'input_ids': [3074, 4762, 60943], 'attention_mask': [1, 1, 1]},
 {'input_ids': [16449, 5952], 'attention_mask': [1, 1]},
 {'input_ids': [1936, 106863], 'attention_mask': [1, 1]}]

In [19]:
[ tokenizer(class_label)["input_ids"] for class_label in classes]

[[3074, 4762, 60943], [16449, 5952], [1936, 106863]]

In [20]:
max([len(tokenizer(class_label)["input_ids"]) for class_label in classes])

3

## 測試集前處理

In [40]:
def test_preprocess_function(examples):
    batch_size = len(examples[text_column])
    inputs = [f"{text_column} : {x} Label : " for x in examples[text_column]]
    model_inputs = tokenizer(inputs)
    print(model_inputs)
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        
        model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * ( max_length - len(sample_input_ids)) + sample_input_ids
        model_inputs["attention_mask"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs["attention_mask"][i]
        
        model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_length])
        model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_length])
    return model_inputs

# 將原始的數據用於測試
test_dataset = dataset["test"].map(
    test_preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)
next(iter(test_dataloader))

Running tokenizer on dataset:  59%|██████████████████████████████████████████████████████████████████████████▋                                                    | 2000/3399 [00:00<00:00, 13757.77 examples/s]

{'input_ids': [[227985, 5484, 915, 2566, 74757, 64626, 12384, 44639, 613, 52282, 2670, 79920, 3344, 1002, 368, 17646, 14472, 8348, 664, 718, 4, 19036, 17, 31849, 17, 6312, 76, 44, 62470, 56, 91, 50, 14839, 21, 77658, 915, 210], [227985, 5484, 915, 405, 187059, 2256, 664, 2550, 18833, 18607, 162467, 4, 1387, 6199, 3291, 23405, 613, 4657, 17082, 566, 3432, 368, 78851, 1185, 61273, 23181, 1553, 15596, 212, 116057, 77658, 915, 210], [227985, 5484, 915, 39762, 2566, 22253, 6201, 75701, 15, 632, 718, 5840, 10006, 6201, 18881, 427, 3804, 19528, 267, 158974, 1320, 368, 10029, 632, 49666, 92, 34, 77658, 915, 210], [227985, 5484, 915, 2566, 104565, 8695, 2089, 6140, 109676, 99579, 1369, 512, 368, 4570, 54, 632, 368, 1503, 241485, 132226, 15, 982, 727, 1152, 18100, 861, 32596, 77597, 168154, 1306, 132226, 4346, 87843, 17, 130462, 364, 32923, 89, 53, 8309, 20, 75, 77658, 915, 210], [227985, 5484, 915, 2566, 14173, 2960, 29906, 387, 20706, 49337, 1369, 77658, 915, 210], [227985, 5484, 915, 2566, 21

Running tokenizer on dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3399/3399 [00:00<00:00, 13493.75 examples/s]


{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
               3,      3,      3,      3,      3,      3,      3,      3,      3,
               3,      3,      3,      3,      3,      3,      3,      3,      3,
          227985,   5484,    915,   2566,  74757,  64626,  12384,  44639,    613,
           52282,   2670,  79920,   3344,   1002,    368,  17646,  14472,   8348,
             664,    718,      4,  19036,     17,  31849,     17,   6312,     76,
              44,  62470,     56,     91,     50,  14839,     21,  77658,    915,
             210],
         [     3,      3,      3,      3,      3,      3,      3,      3,      3,
               3,      3,      3,      3,      3,      3,      3,      3,      3,
               3,      3,      3,      3,      3,      3,      3,      3,      3,
               3,      3,      3,      3, 227985,   5484,    915,    405, 187059,
            2256,    664,   2550,  18833,  18607, 162467,      4, 

## 預訓模型

### 預訓練模型

In [41]:
model_name_or_path

'bigscience/bloomz-560m'

In [48]:
# creating model
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model

BloomForCausalLM(
  (transformer): BloomModel(
    (word_embeddings): Embedding(250880, 1024)
    (word_embeddings_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (h): ModuleList(
      (0-23): 24 x BloomBlock(
        (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (self_attention): BloomAttention(
          (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
          (dense): Linear(in_features=1024, out_features=1024, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): BloomMLP(
          (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
          (gelu_impl): BloomGelu()
          (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (

### 加入Prompt Tuning

In [49]:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 8,192 || all params: 559,222,784 || trainable%: 0.0014648902430985358


In [51]:
# 只在輸入層加入 prompt virtual tokens
model

PeftModelForCausalLM(
  (base_model): BloomForCausalLM(
    (transformer): BloomModel(
      (word_embeddings): Embedding(250880, 1024)
      (word_embeddings_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (h): ModuleList(
        (0-23): 24 x BloomBlock(
          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)
          )
        )
      

In [None]:
"""
PromptEmbedding source code: https://github.com/huggingface/peft/blob/v0.8.2/src/peft/tuners/prompt_tuning/model.py#L22

class PromptEmbedding(torch.nn.Module):
    def __init__(self, config, word_embeddings):
        super().__init__()

        total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
        # 初始化 embedding 層
        self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
        
        # 如果使用文本進行初始化，執行如下邏輯，PromptTuningConfig 配置 class 需要傳入初始化文本
        if config.prompt_tuning_init == PromptTuningInit.TEXT:
            from transformers import AutoTokenizer

            tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
            init_text = config.prompt_tuning_init_text
            init_token_ids = tokenizer(init_text)["input_ids"]
            # Trim or iterate until num_text_tokens matches total_virtual_tokens
            num_text_tokens = len(init_token_ids)
            if num_text_tokens > total_virtual_tokens:
                init_token_ids = init_token_ids[:total_virtual_tokens]
            elif num_text_tokens < total_virtual_tokens:
                num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
                init_token_ids = init_token_ids * num_reps
            init_token_ids = init_token_ids[:total_virtual_tokens]

            word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()
            word_embedding_weights = word_embedding_weights.to(torch.float32)
            # 初始化embedding層的權重
            self.embedding.weight = torch.nn.Parameter(word_embedding_weights)

    def forward(self, indices):
        # Just get embeddings
        prompt_embeddings = self.embedding(indices)
        return prompt_embeddings
"""

### model config
- optimizer
- lr scheduler

In [53]:
lr

0.03

In [54]:
# model
# optimizer and lr scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

### 訓練模型

In [55]:
device

'cuda'

In [56]:
# training and evaluation
from tqdm import tqdm
model = model.to(device)

for epoch in tqdm(range(num_epochs)):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        print(batch)
        print(batch["input_ids"].shape)
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
    print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

  0%|                                                                                                                                                                                     | 0/7 [00:00<?, ?it/s]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566, 182441,     55,  35040,    435,  16796,  79920,
            427,    661,   6355,     17,   5568,   2213,   3172,    960, 126355,
           6216,   5559,  61273,  23181,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  69408,  73736,   1400,    473,   2213,    267, 131388,
          17817,   9781, 158974,   3262,    718,  35752,   2496,   1336,  209

 14%|████████████████████████▋                                                                                                                                                    | 1/7 [00:21<02:11, 21.97s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915, 191971,    261,  74182,      4, 226928,    427,  25608,
           1002,  39839,    473,   4472,   2550,     41,  86461,   4352,  19821,
           2550, 238683,  13663,  87843,     17,   1594,  14663,     36,     52,
             44,   6794,     81,  87781,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          15157,   4867,  14731, 165189,   2021,    769,  11528,   7220,  350

 29%|█████████████████████████████████████████████████▍                                                                                                                           | 2/7 [00:47<02:00, 24.06s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,  14173,   2960,  29906,    387,
          73303,    473,   9283,  11257,    368,  64129,    361,  11571,    461,
            490,   4283,  40067,   1620,   1130,   1186,  14881,   1002,     75,
           1728,    368,  63049,     17,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   1387,  18688,    632,  31139,   3478,     

 43%|██████████████████████████████████████████████████████████████████████████▏                                                                                                  | 3/7 [01:14<01:41, 25.34s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          81048,  44166,  55675,    473,  19134,   1152,   1965,   3262,  52282,
           1074,  52787,  14685,  20425,   5926,   2971,  32564,   3509,   2550,
         242086,    290,   3143,   1317,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3, 227985,   5484,    9

 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 4/7 [01:39<01:16, 25.45s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   2566,  21714,  14571,  10215,   2566,   1339,
          19336,    575,   5110,   1074,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   25

 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 5/7 [02:04<00:50, 25.26s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          96186,  29756,    351,    473,   1542,    654,   9322,    530,    368,
          21851,    632,   6644,    530,  48132,     17,   6728,   1152,    727,
           7747,   3638,   1119,     34,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 6/7 [02:29<00:25, 25.11s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   6782,    297,  12245,  11246,   1002,   2550,
          32375,  27516,  67121,   9512,  87843,     17,   1594,  55189,  58465,
         220018,  58251,    667,     40,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   2566, 154507,     66, 104833,   2566, 132893,
          26569,   3678,   3597,   2566,     57,   4492, 208259,    770,  50224,
            267,  10512,  18453,   1776,  12480,  56006,   2221,  32544,    5

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:45<00:00, 23.58s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:04<00:00, 17.79s/it]


epoch=0: train_ppl=tensor(7.6856e+15, device='cuda:0') train_epoch_loss=tensor(36.5781, device='cuda:0') eval_ppl=tensor(9512.2666, device='cuda:0') eval_epoch_loss=tensor(9.1603, device='cuda:0')


  0%|                                                                                                                                                                                     | 0/7 [00:00<?, ?it/s]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,     61,  31311,   6640,   1935,  15527,   3784,  46823,
            664,    267,  57502,    427,   2670, 148307,    530,    524,  23099,
            613,  15226,   5840,     34,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  69408,  73736,   1400,    473,   2213,    267, 131388,
          17817,   9781, 158974,   3262,    718,  35752,   2496,   1336,  209

 14%|████████████████████████▋                                                                                                                                                    | 1/7 [00:23<02:19, 23.18s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
         112229,   9107,  53312,   3262,   1306,   1152, 157816,   2084,  44326,
          40006,    613,  27019,  16680,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3, 227985,   5484,    915,  43504,  51648,   5456,
         176088,     34,  57180,  28627,    269,  13041,   2566, 114242,    672,
           2338,  56114,    427,   4054,    530, 182640,   7963,    427,  19134,
            718,      4,   2550,  55061,  17209,   2550, 114242,    672,   97

 29%|█████████████████████████████████████████████████▍                                                                                                                           | 2/7 [00:48<02:01, 24.32s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566, 182441,     55,  35040,    435,  16796,  79920,
            427,    661,   6355,     17,   5568,   2213,   3172,    960, 126355,
           6216,   5559,  61273,  23181,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   2566, 172168,   6220,   5068,    473,   9016,
           4085,    380,  25004, 105384,    361,    272,   3049,     17,    473,
           4026,    427,   8265,  43624,   1445,   3264,  95495,     17,  699

 43%|██████████████████████████████████████████████████████████████████████████▏                                                                                                  | 3/7 [01:15<01:41, 25.47s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3, 227985,   5484,    915,   5161,  13500,    386,
              4, 122906,    415,  66027,  42431,    613,  70016,    361,   2550,
             51,  21351,    322,  17896,     15,   2550,     49,     45,   2550,
          21450,    388,    655,   2550,     75,  14263,    916,     86,   1881,
          36534,  53902,      4,   4346,  87843,     17, 130462,   8188,     23,
           5949, 187347,     78, 130589,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   1387,  18688,    632,  31139,   3478,     

 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 4/7 [01:38<01:14, 24.75s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   2566, 154507,     66, 104833,   2566, 132893,
          26569,   3678,   3597,   2566,     57,   4492, 208259,    770,  50224,
            267,  10512,  18453,   1776,  12480,  56006,   2221,  32544,    530,
         122327,   1531,  75164,    473,   9318,  25755,   6610,   1427,  20500,
              4,  61273, 164516,  23181,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566,     49,    656,    266,  53312,   2566,     49,
            656,    266, 100800,  10966,    368,  62798,    632,    267, 1907

 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 5/7 [02:04<00:49, 24.96s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3, 227985,   5484,    915,  69168,  70823,    919,   2566,
         137298,   8106,   6700,   3958,   5734,     17,  39660,   3509,    473,
           1955,    361,   2782,     15,   3808,  67667,   1306,   5007, 117731,
             18,   1191,  36547,  16549,     17,  26402,   7083,   2670,  39347,
            427,   2566,  19593,  15450,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566,  18247,  11847,  53312,   1728,    461,    2

 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 6/7 [02:28<00:24, 24.81s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  88653,   2321, 144017, 138861,  59283,   1152,    613,
           2632,  12120,      4,   5673,   1152,  32153,    427,  36992,     15,
           1152,   1400,   5065, 114438,  66455,    919,    404, 146304,  14078,
          87856,   7061,   2906,     17,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3, 227985,   5484,    915,   5673,    473,
          11229,   2213,   2670,  35307,  28629,    461,   2566,   2765,   1531,
           3470,  47134,  10144,   2765,   1531,    427,   2909,  17918,   6782,
          27268,   4390,   1517,     17,   3904,    632,    267,   6497,    4

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:43<00:00, 23.29s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:10<00:00, 18.60s/it]


epoch=1: train_ppl=tensor(2794.6868, device='cuda:0') train_epoch_loss=tensor(7.9355, device='cuda:0') eval_ppl=tensor(498.3484, device='cuda:0') eval_epoch_loss=tensor(6.2113, device='cuda:0')


  0%|                                                                                                                                                                                     | 0/7 [00:00<?, ?it/s]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3, 227985,   5484,    915,   2566, 125474, 168916,
           2566,  14971,  14167,   2632,  34822,    632,  43435,  68932,    632,
          14005,    530,   5616,   6216,  29180, 173064,     17,  12018,    718,
          15564,   6648,    971,   4212,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

 14%|████████████████████████▋                                                                                                                                                    | 1/7 [00:21<02:07, 21.18s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,    296,   3143,   5990,   1475,
           1026,      4,   3162,   3403,   6440,   1152,    267,  57733,   1002,
           2632,  31335,   9313,   4040,    530,   3595,   3509,   3291, 137057,
            722,  33766,   2256,      4,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3, 227985,   5484,    915,   5161,  13500,    386,
              4, 122906,    415,  66027,  42431,    613,  70016,    361,   2550,
             51,  21351,    322,  17896,     15,   2550,     49,     45,   2550,
          21450,    388,    655,   2550,     75,  14263,    916,     86,   18

 29%|█████████████████████████████████████████████████▍                                                                                                                           | 2/7 [00:48<02:03, 24.73s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          15157,   4867,  14731, 165189,   2021,    769,  11528,   7220,  35025,
            530,  27937, 149533,   1965,  43435, 163255,   1141,   3611,     17,
          30655,    632,   1119,     17,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   25

 43%|██████████████████████████████████████████████████████████████████████████▏                                                                                                  | 3/7 [01:22<01:56, 29.05s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          96186,  29756,    351,    473,   1542,    654,   9322,    530,    368,
          21851,    632,   6644,    530,  48132,     17,   6728,   1152,    727,
           7747,   3638,   1119,     34,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,    473,   3370,  29408,    973,  44805,    427, 162074,     

 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 4/7 [01:49<01:24, 28.20s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3, 227985,   5484,    915,   2566,     58,  24673,     21,
          34274,   1244,    613,   2910,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 5/7 [02:16<00:55, 27.74s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,     43,   9043,  19624,   2670,
         113385,   2152,   1130,  10916,   1074,    427,  32003,   9671,  12917,
            718,   3804,  92078, 156616,  21734,   2550, 101367,   4973,  11168,
           2550,   3767,   8855,  80772,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,  14173,   2960,  29906,    387,
          73303,    473,   9283,  11257,    368,  64129,    361,  11571,    4

 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 6/7 [02:42<00:27, 27.03s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   2566, 172168,   6220,   5068,    473,   9016,
           4085,    380,  25004, 105384,    361,    272,   3049,     17,    473,
           4026,    427,   8265,  43624,   1445,   3264,  95495,     17,  69949,
          31335,   3403,    473,   2971,     17,   6728,    473,   2971,    368,
          43624,   1392,   2592,     17,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566, 182441,     55,  35040,    435,  16796,  799

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:58<00:00, 25.53s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:15<00:00, 19.36s/it]


epoch=2: train_ppl=tensor(332.0180, device='cuda:0') train_epoch_loss=tensor(5.8052, device='cuda:0') eval_ppl=tensor(274.5290, device='cuda:0') eval_epoch_loss=tensor(5.6151, device='cuda:0')


  0%|                                                                                                                                                                                     | 0/7 [00:00<?, ?it/s]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,     43,   9043,  19624,   2670,
         113385,   2152,   1130,  10916,   1074,    427,  32003,   9671,  12917,
            718,   3804,  92078, 156616,  21734,   2550, 101367,   4973,  11168,
           2550,   3767,   8855,  80772,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   25

 14%|████████████████████████▋                                                                                                                                                    | 1/7 [00:21<02:08, 21.34s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3, 227985,   5484,    915,
           2566,     38,  16261,  12462,   2566,     39,  61302,   2566,   2338,
         188609,   3395,     38,   7708,   9293,  31335,  11919,   6738,   7396,
           1809,   3784, 168950,    530,  48430,     15,   1965,   3595,   3638,
            368,  98045,  11919,     34,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   1387,  18688,    632,  31139,   3478,     

 29%|█████████████████████████████████████████████████▍                                                                                                                           | 2/7 [00:45<01:56, 23.28s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  69408,  73736,   1400,    473,   2213,    267, 131388,
          17817,   9781, 158974,   3262,    718,  35752,   2496,   1336,  20941,
            530,   1701,  44920, 133198,     34,   2550,     44,    328,  61066,
           1258,   8049,   7171,   5448,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

 43%|██████████████████████████████████████████████████████████████████████████▏                                                                                                  | 3/7 [01:12<01:38, 24.59s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  88653,   2321, 144017, 138861,  59283,   1152,    613,
           2632,  12120,      4,   5673,   1152,  32153,    427,  36992,     15,
           1152,   1400,   5065, 114438,  66455,    919,    404, 146304,  14078,
          87856,   7061,   2906,     17,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 4/7 [01:35<01:12, 24.15s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          17785,  45614,  19985,   1400,  14831,  45614,    973,     55,    727,
          11571,  37643,  65221,     34,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   2566, 154507,     66, 104833,   2566, 132893,
          26569,   3678,   3597,   2566,     57,   4492, 208259,    770,  50224,
            267,  10512,  18453,   1776,  12480,  56006,   2221,  32544,    5

 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 5/7 [02:00<00:48, 24.31s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3, 227985,   5484,    915,   2566,     60,  80772,   1400,
           1701,   2213,    368,  12171,  67777,    613,    267,  18210,  76252,
            375,    916,   6635,   1320,   3776,    934,  44805,   1965,  13002,
            934,     17,     21,     12,    791,    727,   1701,   2971,    267,
          35307,  20845,  10172,     34,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          81048,  44166,  55675,    473,  19134,   1152,   1965,   3262,  522

 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 6/7 [02:24<00:24, 24.20s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          31934,   5227,   6640,  16261,  87843,     17, 130462,   9600, 169668,
             28,  13604, 112581,     40,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:38<00:00, 22.69s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:02<00:00, 17.56s/it]


epoch=3: train_ppl=tensor(215.2742, device='cuda:0') train_epoch_loss=tensor(5.3719, device='cuda:0') eval_ppl=tensor(178.5299, device='cuda:0') eval_epoch_loss=tensor(5.1848, device='cuda:0')


  0%|                                                                                                                                                                                     | 0/7 [00:00<?, ?it/s]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3, 227985,   5484,    915,  43504,  51648,   5456,
         176088,     34,  57180,  28627,    269,  13041,   2566, 114242,    672,
           2338,  56114,    427,   4054,    530, 182640,   7963,    427,  19134,
            718,      4,   2550,  55061,  17209,   2550, 114242,    672,   9702,
             90,    647,   4346,  87843,     17, 130462,  40081,  10881,     73,
             84,     38,    624,     43,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3, 227985,   5484,    915,   5673,    473,
          11229,   2213,   2670,  35307,  28629,    461,   2566,   2765,   1531,
           3470,  47134,  10144,   2765,   1531,    427,   2909,  17918,   6782,
          27268,   4390,   1517,     17,   3904,    632,    267,   6497,    4

 14%|████████████████████████▋                                                                                                                                                    | 1/7 [00:21<02:06, 21.07s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          31934,   5227,   6640,  16261,  87843,     17, 130462,   9600, 169668,
             28,  13604, 112581,     40,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,     61,  31311,   6640,   1935,  15527,   3784,  468

 29%|█████████████████████████████████████████████████▍                                                                                                                           | 2/7 [00:46<01:57, 23.58s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          57647,    327,  38804,     86,  35631,    368,   7733,   4676,    427,
          10665,  57903,    664,    267,   6917,  18706,    427,    368,  16698,
          35633,   3383,  27409,     34,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3, 227985,   5484,    915,   2566, 125474, 168916,
           2566,  14971,  14167,   2632,  34822,    632,  43435,  68932,    6

 43%|██████████████████████████████████████████████████████████████████████████▏                                                                                                  | 3/7 [01:14<01:41, 25.44s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,     44,    256,  67875,  21033,  86274,  79707,   2632,
           9999,    427,   2150,  54036,  98091,     34, 112164,  15971,  16154,
           5382,    861,   7220,     17,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566, 182441,     55,  35040,    435,  16796,  799

 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 4/7 [01:38<01:15, 25.19s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   6782,    297,  12245,  11246,   1002,   2550,
          32375,  27516,  67121,   9512,  87843,     17,   1594,  55189,  58465,
         220018,  58251,    667,     40,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          96186,  29756,    351,    473,   1542,    654,   9322,    530,    3

 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 5/7 [02:04<00:50, 25.29s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,     43,   9043,  19624,   2670,
         113385,   2152,   1130,  10916,   1074,    427,  32003,   9671,  12917,
            718,   3804,  92078, 156616,  21734,   2550, 101367,   4973,  11168,
           2550,   3767,   8855,  80772,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          15157,   4867,  14731, 165189,   2021,    769,  11528,   7220,  350

 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 6/7 [02:28<00:24, 24.83s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   2566,  21714,  14571,  10215,   2566,   1339,
          19336,    575,   5110,   1074,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:43<00:00, 23.37s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:06<00:00, 18.14s/it]


epoch=4: train_ppl=tensor(163.2530, device='cuda:0') train_epoch_loss=tensor(5.0953, device='cuda:0') eval_ppl=tensor(137.8402, device='cuda:0') eval_epoch_loss=tensor(4.9261, device='cuda:0')


  0%|                                                                                                                                                                                     | 0/7 [00:00<?, ?it/s]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,    473,   3370,  29408,    973,  44805,    427, 162074,     72,
            919,   2566,    951,  23323, 228277,    351,    613,    368,  10087,
           9440,    473,  15017,      4,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  69408,  73736,   1400,    473,   2213,    267, 131388,
          17817,   9781, 158974,   3262,    718,  35752,   2496,   1336,  209

 14%|████████████████████████▋                                                                                                                                                    | 1/7 [00:22<02:16, 22.70s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3, 227985,   5484,    915,
           2566,     38,  16261,  12462,   2566,     39,  61302,   2566,   2338,
         188609,   3395,     38,   7708,   9293,  31335,  11919,   6738,   7396,
           1809,   3784, 168950,    530,  48430,     15,   1965,   3595,   3638,
            368,  98045,  11919,     34,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3, 227985,   5484,    915,   2566,  47959,   6745,
          19624,  13929,   2152,    722,  11045,    635,   3869,    290,  321

 29%|█████████████████████████████████████████████████▍                                                                                                                           | 2/7 [00:48<02:01, 24.38s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566, 182441,     55,  35040,    435,  16796,  79920,
            427,    661,   6355,     17,   5568,   2213,   3172,    960, 126355,
           6216,   5559,  61273,  23181,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3, 227985,   5484,    915,
           2566,     80,  11010,    905, 200058,   3904,   9746,   3370,    722,
           1074,     15,   1965,    600,  50713, 191765,   4973,     34, 200008,
         123467,   1306,   1427,  16198,   3262,  11700,  35237,  12602,   12

 43%|██████████████████████████████████████████████████████████████████████████▏                                                                                                  | 3/7 [01:14<01:40, 25.12s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          31934,   5227,   6640,  16261,  87843,     17, 130462,   9600, 169668,
             28,  13604, 112581,     40,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   25

 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 4/7 [01:38<01:14, 24.83s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          15157,   4867,  14731, 165189,   2021,    769,  11528,   7220,  35025,
            530,  27937, 149533,   1965,  43435, 163255,   1141,   3611,     17,
          30655,    632,   1119,     17,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3, 227985,   5484,    9

 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 5/7 [02:04<00:50, 25.27s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  88653,   2321, 144017, 138861,  59283,   1152,    613,
           2632,  12120,      4,   5673,   1152,  32153,    427,  36992,     15,
           1152,   1400,   5065, 114438,  66455,    919,    404, 146304,  14078,
          87856,   7061,   2906,     17,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3, 227985,   5484,    915,  51591,  23746,    727,  16916,
           3638,   9322,    578,  17444,    361,  51950,   2084,   2307,     

 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 6/7 [02:29<00:25, 25.03s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3, 227985,   5484,    915,  17585,  74460,
            267,  31355,  87843,     17,   1594,  48082,   8027,     41,     90,
             38,     77,  48012,     88,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3, 227985,   5484,    915,  35673,   85

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:44<00:00, 23.48s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:07<00:00, 18.26s/it]


epoch=5: train_ppl=tensor(121.9483, device='cuda:0') train_epoch_loss=tensor(4.8036, device='cuda:0') eval_ppl=tensor(109.8494, device='cuda:0') eval_epoch_loss=tensor(4.6991, device='cuda:0')


  0%|                                                                                                                                                                                     | 0/7 [00:00<?, ?it/s]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  88653,   2321, 144017, 138861,  59283,   1152,    613,
           2632,  12120,      4,   5673,   1152,  32153,    427,  36992,     15,
           1152,   1400,   5065, 114438,  66455,    919,    404, 146304,  14078,
          87856,   7061,   2906,     17,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  69408,  73736,   1400,    473,   2213,    267, 131388,
          17817,   9781, 158974,   3262,    718,  35752,   2496,   1336,  209

 14%|████████████████████████▋                                                                                                                                                    | 1/7 [00:21<02:07, 21.26s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3, 227985,   5484,    915,
           2566,     80,  11010,    905, 200058,   3904,   9746,   3370,    722,
           1074,     15,   1965,    600,  50713, 191765,   4973,     34, 200008,
         123467,   1306,   1427,  16198,   3262,  11700,  35237,  12602,   1293,
           8398,    530,   1999,   4346,  87843,     17,   1594,  35367, 241792,
         130376,    894,   3143,     42,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3, 227985,   5484,    915,   2566,     60,  80772,   1400,
           1701,   2213,    368,  12171,  67777,    613,    267,  18210,  76252,
            375,    916,   6635,   1320,   3776,    934,  44805,   1965,  130

 29%|█████████████████████████████████████████████████▍                                                                                                                           | 2/7 [00:45<01:56, 23.27s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566, 137538,  78869,  12122,   2963,
           3226,  15756,   1965,   3276,  14967,   6610,    664,   3509,    427,
         112046,   1800,  21859,   3250,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

 43%|██████████████████████████████████████████████████████████████████████████▏                                                                                                  | 3/7 [01:13<01:40, 25.16s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566, 182441,     55,  35040,    435,  16796,  79920,
            427,    661,   6355,     17,   5568,   2213,   3172,    960, 126355,
           6216,   5559,  61273,  23181,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3, 227985,   5484,    915,
           2566,     38,  16261,  12462,   2566,     39,  61302,   2566,   2338,
         188609,   3395,     38,   7708,   9293,  31335,  11919,   6738,   73

 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 4/7 [01:37<01:14, 24.79s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566,     49,    656,    266,  53312,   2566,     49,
            656,    266, 100800,  10966,    368,  62798,    632,    267, 190795,
         230331,  15226,    427,   2213,  20889,   5011,  20073,  13538,   5840,
         221781,  41051,  49337,  42696,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   25

 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 5/7 [02:03<00:50, 25.10s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915, 191971,    261,  74182,      4, 226928,    427,  25608,
           1002,  39839,    473,   4472,   2550,     41,  86461,   4352,  19821,
           2550, 238683,  13663,  87843,     17,   1594,  14663,     36,     52,
             44,   6794,     81,  87781,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,  14173,   2960,  29906,    387,
          73303,    473,   9283,  11257,    368,  64129,    361,  11571,    4

 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 6/7 [02:27<00:24, 24.95s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3, 227985,   5484,    915,  35673,   8562,
          29826, 102530,     15,   1427, 207595,     17,    915,     12,   2550,
          81623,  14282,   5715,  37623,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3, 227985,   5484,    915,   2566,     46,  30579,     

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:43<00:00, 23.33s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:16<00:00, 19.44s/it]


epoch=6: train_ppl=tensor(105.1474, device='cuda:0') train_epoch_loss=tensor(4.6554, device='cuda:0') eval_ppl=tensor(92.9772, device='cuda:0') eval_epoch_loss=tensor(4.5324, device='cuda:0')


  0%|                                                                                                                                                                                     | 0/7 [00:00<?, ?it/s]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566,  18247,  11847,  53312,   1728,    461,    267,
          53531,    473,  11229,  14456,    427,   2670,  25357,  82707,  14218,
           1965,  60115,   2592,  11859,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   6782,    297,  12245,  11246,   1002,   25

 14%|████████████████████████▋                                                                                                                                                    | 1/7 [00:21<02:11, 21.97s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3, 227985,   5484,    915,  69168,  70823,    919,   2566,
         137298,   8106,   6700,   3958,   5734,     17,  39660,   3509,    473,
           1955,    361,   2782,     15,   3808,  67667,   1306,   5007, 117731,
             18,   1191,  36547,  16549,     17,  26402,   7083,   2670,  39347,
            427,   2566,  19593,  15450,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3, 227985,   5484,    915,   2566,   2774,
           4114,     53,   1711, 193163,   7708,  39762,  49337,    613,  347

 29%|█████████████████████████████████████████████████▍                                                                                                                           | 2/7 [00:47<02:01, 24.35s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          17785,  45614,  19985,   1400,  14831,  45614,    973,     55,    727,
          11571,  37643,  65221,     34,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          15157,   4867,  14731, 165189,   2021,    769,  11528,   7220,  350

 43%|██████████████████████████████████████████████████████████████████████████▏                                                                                                  | 3/7 [01:15<01:43, 25.78s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  69408,  73736,   1400,    473,   2213,    267, 131388,
          17817,   9781, 158974,   3262,    718,  35752,   2496,   1336,  20941,
            530,   1701,  44920, 133198,     34,   2550,     44,    328,  61066,
           1258,   8049,   7171,   5448,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  88653,   2321, 144017, 138861,  59283,   1152,    613,
           2632,  12120,      4,   5673,   1152,  32153,    427,  36992,     

 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 4/7 [01:40<01:16, 25.45s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,    296,   3143,   5990,   1475,
           1026,      4,   3162,   3403,   6440,   1152,    267,  57733,   1002,
           2632,  31335,   9313,   4040,    530,   3595,   3509,   3291, 137057,
            722,  33766,   2256,      4,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 5/7 [02:06<00:51, 25.71s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,  14173,   2960,  29906,    387,
          73303,    473,   9283,  11257,    368,  64129,    361,  11571,    461,
            490,   4283,  40067,   1620,   1130,   1186,  14881,   1002,     75,
           1728,    368,  63049,     17,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3, 227985,   5484,    915,   2566,  47959,   6745,
          19624,  13929,   2152,    722,  11045,    635,   3869,    290,  321

 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 6/7 [02:32<00:25, 25.77s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
         112229,   9107,  53312,   3262,   1306,   1152, 157816,   2084,  44326,
          40006,    613,  27019,  16680,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   2566,     80,   2335,    488,    905,  527

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:48<00:00, 24.03s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:20<00:00, 20.08s/it]


epoch=7: train_ppl=tensor(87.1513, device='cuda:0') train_epoch_loss=tensor(4.4676, device='cuda:0') eval_ppl=tensor(82.4511, device='cuda:0') eval_epoch_loss=tensor(4.4122, device='cuda:0')


  0%|                                                                                                                                                                                     | 0/7 [00:00<?, ?it/s]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          96186,  29756,    351,    473,   1542,    654,   9322,    530,    368,
          21851,    632,   6644,    530,  48132,     17,   6728,   1152,    727,
           7747,   3638,   1119,     34,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566,     49,    656,    266,  53312,   2566,     49,
            656,    266, 100800,  10966,    368,  62798,    632,    267, 1907

 14%|████████████████████████▋                                                                                                                                                    | 1/7 [00:21<02:09, 21.59s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3, 227985,   5484,    915,  43504,  51648,   5456,
         176088,     34,  57180,  28627,    269,  13041,   2566, 114242,    672,
           2338,  56114,    427,   4054,    530, 182640,   7963,    427,  19134,
            718,      4,   2550,  55061,  17209,   2550, 114242,    672,   9702,
             90,    647,   4346,  87843,     17, 130462,  40081,  10881,     73,
             84,     38,    624,     43,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3, 227985,
           5484,    915,   2566,  18247,  11847,  53312,   1728,    461,    2

 29%|█████████████████████████████████████████████████▍                                                                                                                           | 2/7 [00:47<02:00, 24.12s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566, 137538,  78869,  12122,   2963,
           3226,  15756,   1965,   3276,  14967,   6610,    664,   3509,    427,
         112046,   1800,  21859,   3250,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,    296,   3143,   5990,   1475,
           1026,      4,   3162,   3403,   6440,   1152,    267,  57733,   10

 43%|██████████████████████████████████████████████████████████████████████████▏                                                                                                  | 3/7 [01:14<01:41, 25.48s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3, 227985,   5484,    915,   5673,    473,
          11229,   2213,   2670,  35307,  28629,    461,   2566,   2765,   1531,
           3470,  47134,  10144,   2765,   1531,    427,   2909,  17918,   6782,
          27268,   4390,   1517,     17,   3904,    632,    267,   6497,    483,
            361,   2670, 101848,     17,  32465,   9585,   2566,     37,   2481,
           2566,     37,   2481,  12384,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,     43,   9043,  19624,   2670,
         113385,   2152,   1130,  10916,   1074,    427,  32003,   9671,  129

 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 4/7 [01:39<01:15, 25.13s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          17785,  45614,  19985,   1400,  14831,  45614,    973,     55,    727,
          11571,  37643,  65221,     34,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 5/7 [02:05<00:51, 25.65s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          81048,  44166,  55675,    473,  19134,   1152,   1965,   3262,  52282,
           1074,  52787,  14685,  20425,   5926,   2971,  32564,   3509,   2550,
         242086,    290,   3143,   1317,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3, 227985,   5484,    915,
           2566,     38,  16261,  12462,   2566,     39,  61302,   2566,   2338,
         188609,   3395,     38,   7708,   9293,  31335,  11919,   6738,   73

 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 6/7 [02:32<00:25, 25.96s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3, 227985,   5484,    915,   2566,  47959,   6745,
          19624,  13929,   2152,    722,  11045,    635,   3869,    290,  32107,
             75,   2481,  56557,   1002, 208814,  16924,   1231,     17,     19,
             34,  59283,   1152,      4,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          57647,    327,  38804,     86,  35631,    368,   7733,   4676,    4

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:48<00:00, 24.13s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:20<00:00, 20.01s/it]


epoch=8: train_ppl=tensor(77.6554, device='cuda:0') train_epoch_loss=tensor(4.3523, device='cuda:0') eval_ppl=tensor(79.1418, device='cuda:0') eval_epoch_loss=tensor(4.3712, device='cuda:0')


  0%|                                                                                                                                                                                     | 0/7 [00:00<?, ?it/s]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   2566,
          57647,    327,  38804,     86,  35631,    368,   7733,   4676,    427,
          10665,  57903,    664,    267,   6917,  18706,    427,    368,  16698,
          35633,   3383,  27409,     34,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3, 227985,   5484,    915,   25

 14%|████████████████████████▋                                                                                                                                                    | 1/7 [00:21<02:09, 21.62s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3, 227985,   5484,    915,   2566,     60,  80772,   1400,
           1701,   2213,    368,  12171,  67777,    613,    267,  18210,  76252,
            375,    916,   6635,   1320,   3776,    934,  44805,   1965,  13002,
            934,     17,     21,     12,    791,    727,   1701,   2971,    267,
          35307,  20845,  10172,     34,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,  88653,   2321, 144017, 138861,  59283,   1152,    613,
           2632,  12120,      4,   5673,   1152,  32153,    427,  36992,     

 29%|█████████████████████████████████████████████████▍                                                                                                                           | 2/7 [00:49<02:06, 25.37s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566,  14173,   2960,  29906,    387,
          73303,    473,   9283,  11257,    368,  64129,    361,  11571,    461,
            490,   4283,  40067,   1620,   1130,   1186,  14881,   1002,     75,
           1728,    368,  63049,     17,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

 43%|██████████████████████████████████████████████████████████████████████████▏                                                                                                  | 3/7 [01:17<01:45, 26.34s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3, 227985,   5484,    915,
           2566,     38,  16261,  12462,   2566,     39,  61302,   2566,   2338,
         188609,   3395,     38,   7708,   9293,  31335,  11919,   6738,   7396,
           1809,   3784, 168950,    530,  48430,     15,   1965,   3595,   3638,
            368,  98045,  11919,     34,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3, 227985,   5484,    915,  69168,  70823,    919,   2566,
         137298,   8106,   6700,   3958,   5734,     17,  39660,   3509,    473,
           1955,    361,   2782,     15,   3808,  67667,   1306,   5007, 1177

 57%|██████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                          | 4/7 [01:42<01:17, 25.79s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3, 227985,   5484,
            915,   2566,     44,    256,  67875,  21033,  86274,  79707,   2632,
           9999,    427,   2150,  54036,  98091,     34, 112164,  15971,  16154,
           5382,    861,   7220,     17,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,     

 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 5/7 [02:08<00:51, 26.00s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
         227985,   5484,    915,   2566,     80,   2335,    488,    905,  52775,
            368,  23984,     34,  12899,   9313,   4143,    473,    727,    718,
            664,  31885,    445,   9313,  77658,    915,    210,   1936, 106863,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3, 227985,   5484,    915,   2566, 137538,  78869,  12122,   29

 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 6/7 [02:32<00:25, 25.28s/it]

{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3, 227985,   5484,    915,  51591,  23746,    727,  16916,
           3638,   9322,    578,  17444,    361,  51950,   2084,   2307,     17,
           2137,   3025,   1790,      4,   2566, 216744,     38,   1316,     54,
          42705,   2566, 110647, 216744,  77658,    915,    210,  16449,   5952,
              3],
        [     3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3, 227985,   5484,    915,  17585,  744

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:47<00:00, 23.87s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:13<00:00, 19.10s/it]

epoch=9: train_ppl=tensor(75.3501, device='cuda:0') train_epoch_loss=tensor(4.3221, device='cuda:0') eval_ppl=tensor(75.2927, device='cuda:0') eval_epoch_loss=tensor(4.3214, device='cuda:0')





### 模型評估

In [57]:
model.eval()

i = 33
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")
print(dataset["test"][i]["Tweet text"])
print(inputs)

with torch.no_grad():
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = model.generate(
        input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=10, eos_token_id=3
    )
    print(outputs)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

@TommyHilfiger Dramatic shopping exp. ordered 6 jeans same size (30/32) 2 fits / 2 too large / 2 too slim : same brand &gt; different sizing
{'input_ids': tensor([[227985,   5484,    915,   2566, 226154, 126015,   5385,    259, 239364,
           3396,  70823,   5853,     17,  57247,   1231, 191040,   5025,   7869,
            375,   2324, 149349,     12,    415, 122321,    897,    415,  10136,
          10021,    897,    415,  10136,   6497,    381,    915,   5025,  51950,
          66869,   5955,    272,  20311,  77658,    915,    210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
tensor([[227985,   5484,    915,   2566, 226154, 126015,   5385,    259, 239364,
           3396,  70823,   5853,     17,  57247,   1231, 191040,   5025,   7869,
            375,   2324, 149349,     12,    415, 122321,    897,    415,  10136,
          10021,    897,    415,  10136,  

### 儲存模型

In [59]:
# saving model
peft_model_id = f"{model_name_or_path}-{peft_config.peft_type}-{peft_config.task_type}"
peft_model_id

'bigscience/bloomz-560m-PROMPT_TUNING-CAUSAL_LM'

In [60]:
model.save_pretrained(peft_model_id)

In [64]:
!powershell -Command "ls bigscience/bloomz-560m-PROMPT_TUNING-CAUSAL_LM/adapter_model.safetensors | Select-Object Name, @{Name='Size';Expression={\$_length/1KB -as [int]}}, LastWriteTime"


Name                      Size LastWriteTime          
----                      ---- -------------          
adapter_model.safetensors      2024/2/25 下午 01:06:04




In [66]:
!tree /F /A bigscience/

列出磁碟區 DATA 的資料夾 PATH
磁碟區序號為 D0F1-8E10
D:\NLP\PEFT\BIGSCIENCE
\---bloomz-560m-PROMPT_TUNING-CAUSAL_LM
        adapter_config.json
        adapter_model.safetensors
        README.md
        


### 載入 prompt-tuned pretrained model

In [69]:
from peft import PeftModel, PeftConfig

peft_model_id = f"{model_name_or_path}-{peft_config.peft_type}-{peft_config.task_type}"

config = PeftConfig.from_pretrained(peft_model_id)
print("model path:", config.base_model_name_or_path)
# 載入基礎模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
# 加载 prompt tuning 模型
model = PeftModel.from_pretrained(model, peft_model_id)

model path: bigscience/bloomz-560m


In [70]:
from peft import PeftModel, PeftConfig

peft_model_id = f"{model_name_or_path}-{peft_config.peft_type}-{peft_config.task_type}"

# 載入 PEFT 配置
config = PeftConfig.from_pretrained(peft_model_id)

# 載入基礎模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)

# 載入 prompt tuning 模型
model = PeftModel.from_pretrained(model, peft_model_id)

# Tokenizer 編碼
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")

# 模型推理
outputs = model.generate(
        input_ids=inputs["input_ids"], 
        attention_mask=inputs["attention_mask"], 
        max_new_tokens=10, 
        eos_token_id=3
    )

# Tokenizer 解碼
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

['Tweet text : @TommyHilfiger Dramatic shopping exp. ordered 6 jeans same size (30/32) 2 fits / 2 too large / 2 too slim : same brand &gt; different sizing Label : no complaintNo complaint<b data-parso']
