In [None]:
from datasets import load_dataset
datasets = load_dataset("Anthropic/hh-rlhf", data_dir="harmless-base")
datasets

In [None]:
instance = datasets["train"][0]['chosen']
# Split by both 'Human' and "Assistant"
dialogue_list = instance.split('\n\n')
dialogue_list = [dialogue.strip() for dialogue in dialogue_list if dialogue.strip() != '']
res = []
for dialogue in dialogue_list:
    print(dialogue)
    if dialogue.startswith('Human:'): 
        res.append(dialogue.lstrip('Human:').strip())
    elif dialogue.startswith('Assistant'):
        res.append(dialogue.lstrip('Assistant:').strip())
    else:
        res[-1] += '\n\n' + dialogue
        
print(res)

In [1]:
from datasets import load_dataset

datasets = load_dataset('data/Anthropic')
datasets

  from .autonotebook import tqdm as notebook_tqdm


DatasetDict({
    train: Dataset({
        features: ['query', 'reference'],
        num_rows: 104054
    })
    test: Dataset({
        features: ['query', 'reference'],
        num_rows: 5756
    })
})

In [2]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

config = AutoConfig.from_pretrained("microsoft/DialoGPT-small")
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small", config=config)

In [3]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples['query'], 
        padding=False, 
        truncation=True,
    )
    labels = tokenizer(
        examples['reference'], 
        padding=False, 
        truncation=True,
        
    )
    tokenized_inputs['labels'] = labels['input_ids']
    # tokenized_inputs['labels_attention_mask'] = labels['attention_mask']
    return tokenized_inputs

train_dataset = datasets['train'].map(
    tokenize_and_align_labels,
    batched=True,
).remove_columns(datasets['train'].column_names)
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 104054
})

In [4]:
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq

label_pad_token_id = tokenizer.pad_token_id
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=None,
)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=5,
    collate_fn=data_collator,
    shuffle=True,
)
inputs = next(iter(train_dataloader))
# print(inputs)

2024-02-08 19:00:40.423430: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-02-08 19:00:40.574190: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2024-02-08 19:00:40.574222: I tensorflow/compiler/xla/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2024-02-08 19:00:41.248088: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-

In [None]:
# import torch

# def _prepare_decoding_inputs(
#     inputs: dict,
# ):
#     new_inputs = inputs.copy() # Don't modify the original dict
#     labels = new_inputs.pop("labels")
#     labels_attention_mask = torch.ones_like(labels)
#     new_inputs['input_ids'] = torch.cat((inputs['input_ids'], labels), dim=1)
#     new_inputs['attention_mask'] = torch.cat((inputs['attention_mask'], labels_attention_mask), dim=1)
#     new_labels = torch.cat(
#         (-100 * torch.ones_like(inputs['input_ids']), labels), dim=1
#     )
#     new_inputs['labels'] = new_labels
#     return new_inputs

# new_inputs = _prepare_decoding_inputs(inputs)
# print({k: v.shape for k, v in new_inputs.items()})
# # print(new_inputs)

# input_ids = new_inputs['input_ids']
# vocab_size = tokenizer.vocab_size  # Get the vocabulary size
# assert input_ids.max() < vocab_size, "Token index exceeds vocabulary size."

# max_input_length = model.config.max_position_embeddings
# assert input_ids.size(1) <= max_input_length, "Input length exceeds model's maximum input length."

# # Forward pass for CLM
# outputs = model(**new_inputs)

# # Extract logits
# logits = outputs.logits  # Shape: [batch_size, sequence_length, vocab_size]
# loss = outputs.loss

# print('decoding loss: ', loss)

In [5]:
import sys 
sys.dont_write_bytecode = True
from collections import defaultdict
from models import get_stages, _prepare_inputs,_prepare_decoding_inputs

timing_info = defaultdict(list)
stages = get_stages(
    config=config,
    token="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",
    model_name_or_path="microsoft/DialoGPT-small",
    num_stages=4,
    init_device=0,
    timing_info=timing_info,
)
inputs = _prepare_inputs(inputs, stages[0].device)
model_inputs = _prepare_decoding_inputs(inputs)
tuple_outputs = stages[0](**inputs)

Put stage GPTStartingStage (60647424 parameters) on device 0
Put stage GPTIntermediateStage (21263616 parameters) on device 1
Put stage GPTIntermediateStage (21263616 parameters) on device 2
Put stage GPTEndingStage (59862528 parameters) on device 3


In [13]:
from models import CustomizedGPT2Out

outputs = CustomizedGPT2Out(
    hidden_states=tuple_outputs[0].to(1),
    attention_mask=tuple_outputs[1].to(1),
    head_mask=tuple_outputs[2],
    encoder_hidden_states=tuple_outputs[3],
    encoder_attention_mask=tuple_outputs[4],
    all_hidden_states=tuple_outputs[5],
    all_self_attentions=tuple_outputs[6],
    all_cross_attentions=tuple_outputs[7],
    output_shape=tuple_outputs[8],
)
outputs

CustomizedGPT2Out(hidden_states=tensor([[[ 8.8920e+00, -5.6468e-01,  3.6520e+00,  ..., -7.9809e+00,
          -8.7312e+00, -7.5942e+00],
         [-4.0279e-01, -5.0313e-01,  7.6434e-01,  ...,  1.3460e+00,
          -2.2955e+00,  3.3652e-02],
         [ 4.4937e-02,  3.8293e-02,  5.0785e-01,  ..., -8.2714e-01,
           1.3216e+00,  1.0583e+00],
         ...,
         [-8.5965e-01,  5.4871e-01, -1.9659e-03,  ...,  2.0170e-01,
          -1.6414e+00,  6.6160e-01],
         [-8.6410e-01,  5.4152e-01, -3.4350e-03,  ...,  1.9571e-01,
          -1.6354e+00,  6.5914e-01],
         [-8.7025e-01,  5.3639e-01,  1.9031e-02,  ...,  2.1268e-01,
          -1.6416e+00,  6.3638e-01]],

        [[ 8.7684e+00, -1.2114e+00,  4.3377e+00,  ..., -7.5928e+00,
          -1.0622e+01, -8.8779e+00],
         [-6.4104e-01, -1.5081e+00, -3.8524e-01,  ...,  1.0774e+00,
          -2.8544e-01,  1.1921e+00],
         [-8.0381e-01, -4.2084e+00, -2.6535e+00,  ...,  2.1318e-01,
           1.7541e+00, -4.5577e-01],
       