Usage: from tfrecord_pytorch import TFRecordPytorch file_name = "train.tfrecord" col_mapping={ "input_ids":tf.io.VarLenFeature(tf.int64), "label_ids":tf.io.VarLenFeature(tf.int64) } *Note pytorch Iterable Dataset doesnt allow shuffle in Data Loader dataset = TFRecordPytorch(file_name,col_mapping,shuffle=True,buffer_size=10000) loader = torch.utils.DataLoader(dataset,batch_size=4,collate_fn=pad_and_sort) iterator = iter(loader) print(next(iterator))
-
Notifications
You must be signed in to change notification settings - Fork 1
santhoshkolloju/TfRecordPytorch
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
About
tfrecord
Topics
Resources
Stars
Watchers
Forks
Releases
No releases published
Packages 0
No packages published