<a href="https://colab.research.google.com/github/zhangguanheng66/text/blob/arrow_dataset/examples/arraw_dataset/arraw_torchtext.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture

!rm -r /usr/local/lib/python3.6/dist-packages/torch*;
!pip install --pre torch torchtext -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html;
!pip install --upgrade --force-reinstall pyarrow;

In [None]:
import pyarrow as pa
import pandas as pd
from torchtext.utils import download_from_url, unicode_csv_reader
import io

url = 'https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv'
filepath = download_from_url(url)

def _create_data_from_csv(data_path):
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            yield (int(row[0]), ' '.join(row[1:]))

def build_pa_table_from_iterator(iterable, columns=None):
  df = pd.DataFrame(_create_data_from_csv(filepath), columns=columns)
  return pa.Table.from_pandas(df)
  
table = build_pa_table_from_iterator(_create_data_from_csv(filepath), 
                                     columns=['label', 'text'])
table.column_names

train.csv: 29.5MB [00:00, 82.8MB/s]


['label', 'text']

In [None]:
import torch
class Dataset(torch.utils.data.Dataset):
    """Defines torchtext dataset
    """

    def __init__(self, arrow_table: pa.Table):
        """Initiate dataset.
        """
        super(Dataset, self).__init__()
        self.arrow_table = arrow_table

    @property
    def num_rows(self):
        """The num of rows in the dataset."""
        return self.arrow_table.num_rows

    @property
    def column_names(self):
        """The column names of the dataset."""
        return self.arrow_table.column_names

    def __len__(self) -> int:
        """Number of rows in the dataset"""
        return self.arrow_table.num_rows

    def get_column(self, column_name):
      return self.arrow_table[column_name].to_pandas()

    def __getitem__(self, index):
      return self.arrow_table.slice(index, 1).to_pandas()

    def __iter__(self):
        for index in range(self.arrow_table.num_rows):
            yield self[index]

    def map(self, func, column):
      pd_df = self.arrow_table.to_pandas()
      new_pd_series = pd.Series([func(item) for item in pd_df[column]], name=column)
      pd_df.update(new_pd_series)
      self.arrow_table = pa.Table.from_pandas(pd_df)
    
    def to_batches(self, batch_size):
      return self.arrow_table.to_batches(batch_size)

In [None]:
# Process raw text data
from torchtext.experimental.vocab import build_vocab_from_iterator
from torchtext.experimental.transforms import basic_english_normalize
arrow_dataset = Dataset(table)
tokenizer = basic_english_normalize()
vocab = build_vocab_from_iterator(iter(tokenizer(line)
                                       for line in arrow_dataset.get_column('text')), 
                                  min_freq=1)

arrow_dataset.map(lambda x: int(x), 'label')
arrow_dataset.map(lambda x: vocab(tokenizer(x)), 'text')
arrow_dataset[120]



Unnamed: 0,label,text
0,4,"[196, 34518, 232, 4, 1318, 335, 4, 1088, 13, 3..."


In [None]:
# Generate batches to iterate dataset
batches = arrow_dataset.to_batches(8)
print(batches[12][0], batches[12][1]) # return labels/texts in batch 12

[
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4
] [
  [
    1453,
    634,
    39,
    4613,
    53200,
    13,
    27,
    14,
    27,
    15,
    ...
    499,
    53,
    8,
    303,
    25,
    1967,
    42,
    123,
    1453,
    1
  ],
  [
    183,
    16,
    9,
    7,
    5,
    951,
    80,
    582,
    3,
    2893,
    ...
    24,
    457,
    53203,
    95,
    8242,
    19,
    4864,
    1,
    477,
    80
  ],
  [
    386,
    452,
    1877,
    4,
    9439,
    761,
    7937,
    13,
    27,
    14,
    ...
    37,
    2,
    47,
    558,
    3280,
    21875,
    3611,
    7,
    358,
    1
  ],
  [
    370,
    7839,
    7422,
    570,
    1219,
    14731,
    3,
    11848,
    13,
    180,
    ...
    1582,
    899,
    35,
    1555,
    1705,
    24,
    5,
    180,
    2032,
    1
  ],
  [
    11849,
    3,
    18545,
    8,
    3786,
    436,
    5,
    5199,
    615,
    13,
    ...
    37,
    19,
    2142,
    19,
    23306,
    3,
    23,
    10263,
    4180,
    1
  ],


In [None]:
# Generate DataLoader to iterate dataset
def generate_batch(batch_data):
  return batch_data

data_iter = torch.utils.data.DataLoader(arrow_dataset, shuffle=True,
                                        batch_size=16, num_workers=2,
                                        collate_fn=generate_batch)

for idx, item in enumerate(data_iter):
  print(item[0])
  if idx == 20:
    break

   label                                               text
0      2  [2204, 653, 604, 465, 20, 50, 2140, 2205, 13, ...
   label                                               text
0      3  [217, 197, 36, 691, 148, 1042, 94, 36, 691, 18...
   label                                               text
0      2  [4955, 1200, 3460, 1691, 252, 5022, 4955, 12, ...
   label                                               text
0      1  [8609, 142, 7, 17664, 7207, 50740, 3, 494, 15,...
   label                                               text
0      4  [74193, 1743, 8795, 7, 205, 36224, 241, 41, 16...
   label                                               text
0      3  [945, 151, 193, 12, 9, 2735, 775, 10, 128, 12,...
   label                                               text
0      4  [386, 4, 3097, 2587, 7, 1320, 4012, 2515, 1113...
   label                                               text
0      4  [1163, 3783, 35, 3362, 2, 3258, 13, 4035, 1, 1...
   label                                