Skip to content

Commit 187636c

Browse files
committed
Adds read_data function unsing HF datasets
1 parent fe8ea63 commit 187636c

File tree

1 file changed

+74
-47
lines changed

1 file changed

+74
-47
lines changed

src/utils/data_utils_sentencepiece.py

Lines changed: 74 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import torch
33
import pandas as pd
4+
from datasets import load_dataset
45
from torch.utils.data import DataLoader, Dataset
56
import torch
67
from functools import partial
@@ -11,8 +12,6 @@
1112
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
1213

1314

14-
15-
1615
def get_dataloader(tokenizer, data_path, batch_size, max_seq_len):
1716
dataset = TextDataset(tokenizer=tokenizer, data_path=data_path)
1817

@@ -27,70 +26,98 @@ def get_dataloader(tokenizer, data_path, batch_size, max_seq_len):
2726

2827
while True:
2928
for batch in dataloader:
30-
#batch = tokenizer(batch)
31-
print("BATCH")
32-
print(batch)
33-
print(type(batch))
34-
print(len(batch))
3529
yield batch
3630

3731

3832
class TextDataset(Dataset):
3933
def __init__(
40-
self,
41-
tokenizer,
42-
data_path: str,
43-
has_labels: bool = False
44-
) -> None:
34+
self,
35+
tokenizer,
36+
data_path: str,
37+
has_labels: bool = False
38+
) -> None:
4539
super().__init__()
4640
self.data_path = data_path
4741
self.tokenizer = tokenizer
42+
self.input_ids = None
4843
self.read_data()
4944
if has_labels:
5045
self.read_labels()
5146

5247
def read_data(self):
5348
logging.info("Reading data from {}".format(self.data_path))
54-
data = pd.read_csv(self.data_path, sep="\t", header=None) # read text file
55-
logging.info(f"Tokenizing {len(data)} sentences")
56-
print("Start converting to lists")
57-
self.text = data[0].apply(lambda x: x.strip()).tolist()
58-
# encoded_input = self.tokenizer(self.questions, self.paragraphs)
59-
print("End converting to lists")
60-
# check if tokenizer has a method 'encode_batch'
61-
print("Start tokenizing")
62-
if hasattr(self.tokenizer, 'encode_batch'):
49+
50+
dataset = load_dataset("text", data_files=self.data_path)
51+
52+
logging.info("Creating helper function")
53+
54+
def tokenization(example):
55+
return self.tokenizer(example["text"],
56+
max_lenght=512,
57+
padding=True,
58+
truncation=True)
59+
60+
if hasattr(self.tokenizer, 'encode_batch'): # relict of old days
6361
print("encode_batch")
62+
self.text = dataset["text"].apply(lambda x: x.strip()).tolist()
6463
encoded_input = self.tokenizer.encode_batch(self.text)
6564
self.input_ids = [x.ids for x in encoded_input]
6665

67-
# elif hasattr(self.tokenizer, 'batch_encode_plus'):
68-
# print("batch_encode_plus")
69-
# encoded_input = self.tokenizer.batch_encode_plus(self.text)
70-
# self.input_ids = [x.ids for x in encoded_input]
71-
7266
else:
73-
print("not encode_batch")
74-
encoded_input = self.tokenizer(self.text)
75-
self.input_ids = encoded_input["input_ids"]
76-
77-
from pympler.asizeof import asizeof
78-
79-
def get_disc_size_gb(obj):
80-
return asizeof(obj) / 8 / 1_000 / 1_000
81-
82-
print(f"Type enc input: {type(encoded_input)}\n"
83-
f"Len input ids: {len(encoded_input['input_ids'])}\n"
84-
f"Getsizeof input ids: {get_disc_size_gb(encoded_input['input_ids'])}\n"
85-
f"Getsizeof token_type_ids: {get_disc_size_gb(encoded_input['token_type_ids'])}\n"
86-
f"Getsizeof attention_mask: {get_disc_size_gb(encoded_input['attention_mask'])}\n"
87-
f"Getsizeof encoded: {get_disc_size_gb(encoded_input)}\n"
88-
f"Getsizeof data(frame): {get_disc_size_gb(data)}\n"
89-
f"Getsizeof text: {get_disc_size_gb(self.text)}\n"
90-
f"Type input ids: {type(self.input_ids)}")
91-
# sys.exit(0)
92-
93-
print("End tokenizing")
67+
logging.info("Start Batched Tokenization")
68+
69+
data = dataset.map(tokenization,
70+
batched=True,
71+
remove_columns=[],
72+
load_from_cache_file=True,
73+
desc=f"Running tokenizer on {self.data_path}")
74+
75+
self.input_ids = data['train']["input_ids"]
76+
77+
78+
# def read_data(self):
79+
# logging.info("Reading data from {}".format(self.data_path))
80+
# data = pd.read_csv(self.data_path, sep="\t", header=None) # read text file
81+
# logging.info(f"Tokenizing {len(data)} sentences")
82+
# print("Start converting to lists")
83+
# self.text = data[0].apply(lambda x: x.strip()).tolist()
84+
# # encoded_input = self.tokenizer(self.questions, self.paragraphs)
85+
# print("End converting to lists")
86+
# # check if tokenizer has a method 'encode_batch'
87+
# print("Start tokenizing")
88+
# if hasattr(self.tokenizer, 'encode_batch'):
89+
# print("encode_batch")
90+
# encoded_input = self.tokenizer.encode_batch(self.text)
91+
# self.input_ids = [x.ids for x in encoded_input]
92+
#
93+
# # elif hasattr(self.tokenizer, 'batch_encode_plus'):
94+
# # print("batch_encode_plus")
95+
# # encoded_input = self.tokenizer.batch_encode_plus(self.text)
96+
# # self.input_ids = [x.ids for x in encoded_input]
97+
#
98+
# else:
99+
# print("not encode_batch")
100+
# encoded_input = self.tokenizer(self.text)
101+
# self.input_ids = encoded_input["input_ids"]
102+
#
103+
# from pympler.asizeof import asizeof
104+
#
105+
# def get_disc_size_gb(obj):
106+
# return asizeof(obj) / 8 / 1_000 / 1_000
107+
#
108+
# print(f"Type enc input: {type(encoded_input)}\n"
109+
# f"Len input ids: {len(encoded_input['input_ids'])}\n"
110+
# f"Getsizeof input ids: {get_disc_size_gb(encoded_input['input_ids'])}\n"
111+
# f"Getsizeof token_type_ids: {get_disc_size_gb(encoded_input['token_type_ids'])}\n"
112+
# f"Getsizeof attention_mask: {get_disc_size_gb(encoded_input['attention_mask'])}\n"
113+
# f"Getsizeof encoded: {get_disc_size_gb(encoded_input)}\n"
114+
# f"Getsizeof data(frame): {get_disc_size_gb(data)}\n"
115+
# f"Getsizeof text: {get_disc_size_gb(self.text)}\n"
116+
# f"Type input ids: {type(self.input_ids)}")
117+
# # sys.exit(0)
118+
#
119+
# print("End tokenizing")
120+
94121

95122
def read_labels(self):
96123
self.labels = pd.read_csv(self.data_path, sep="\t", header=None)[1].tolist()

0 commit comments

Comments
 (0)