In [2]:
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict, load_from_disk
from transformers import T5Tokenizer, T5ForConditionalGeneration
from pathlib import Path
import torch


In [3]:
tokenizer = T5Tokenizer.from_pretrained("./models/t5-small")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [4]:
def load_pt_data(file_path):
    return torch.load(file_path)


def preprocess_function(examples):
    input_texts = [[str(num) for num in ids] for ids in examples['input_ids']]
    label_texts = [[str(num) for num in ids] for ids in examples['labels']]
    
    inputs = tokenizer(input_texts, is_split_into_words=True, padding=False, truncation=False)
    targets = tokenizer(label_texts, is_split_into_words=True, padding=False, truncation=False)
    
    model_inputs = {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'labels': targets['input_ids']
    }
    return model_inputs

In [5]:
print("loading origin data ...")
Z_data = load_pt_data('./datasets/train/v3/Z_unified_dataset.pt')
print("loading simple data ...")
A_data = load_pt_data('./datasets/train/v3/A_unified_dataset.pt')
assert len(Z_data) == len(A_data)

A_list = A_data.tolist()
Z_list = Z_data.tolist()

# split dataset for train and test
train_Z, eval_Z, train_A, eval_A = train_test_split(
    Z_list, A_list, test_size=0.1, random_state=42
)
# Create dataset
train_dataset = Dataset.from_dict({
    'input_ids': train_Z,
    'labels': train_A
})
eval_dataset = Dataset.from_dict({
    'input_ids': eval_Z,
    'labels': eval_A
})
dataset = DatasetDict({
    'train': train_dataset,
    'eval': eval_dataset
})
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["input_ids", "labels"])

tokenized_dataset.save_to_disk("./datasets/preprocessed_data/v3")
print("preprocessed data saved to disk successfully! ")

loading origin data ...
loading simple data ...


  return torch.load(file_path)


Map:   0%|          | 0/900000 [00:00<?, ? examples/s]

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/900000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/100000 [00:00<?, ? examples/s]

preprocessed data saved to disk successfully! 


In [6]:
for i in range(10000):  
    print(f"Sample {i}: input length = {len(tokenized_dataset['train'][i+10000]['input_ids'])}")

Sample 0: input length = 11
Sample 1: input length = 10
Sample 2: input length = 12
Sample 3: input length = 11
Sample 4: input length = 10
Sample 5: input length = 11
Sample 6: input length = 9
Sample 7: input length = 12
Sample 8: input length = 11
Sample 9: input length = 11
Sample 10: input length = 10
Sample 11: input length = 11
Sample 12: input length = 11
Sample 13: input length = 12
Sample 14: input length = 10
Sample 15: input length = 12
Sample 16: input length = 12
Sample 17: input length = 11
Sample 18: input length = 9
Sample 19: input length = 11
Sample 20: input length = 12
Sample 21: input length = 10
Sample 22: input length = 11
Sample 23: input length = 10
Sample 24: input length = 12
Sample 25: input length = 11
Sample 26: input length = 10
Sample 27: input length = 9
Sample 28: input length = 12
Sample 29: input length = 10
Sample 30: input length = 11
Sample 31: input length = 9
Sample 32: input length = 11
Sample 33: input length = 10
Sample 34: input length = 9


In [7]:
for i in range(10):
    print(tokenized_dataset['train'][i]['input_ids'])

[3, 18, 2555, 2294, 4906, 910, 2313, 507, 2668, 4327, 943, 2313, 1]
[3, 4525, 3539, 4542, 910, 2313, 314, 14574, 943, 2313, 1]
[3, 4608, 4327, 335, 2313, 220, 2773, 2517, 910, 2313, 1]
[1283, 3390, 519, 910, 2313, 3, 4613, 4013, 943, 2313, 1]
[3, 18, 2688, 2128, 4271, 910, 2313, 8798, 23360, 944, 2313, 1]
[3, 18, 10402, 3288, 910, 2313, 3479, 4177, 460, 2313, 1]
[3, 5947, 3840, 536, 460, 2313, 1902, 26225, 910, 2313, 1]
[3, 6832, 4389, 4560, 910, 2313, 505, 27025, 944, 2313, 1]
[668, 2606, 2658, 910, 2313, 3, 4591, 943, 2313, 1]
[6374, 1828, 519, 910, 2313, 5400, 25312, 943, 2313, 1]
