<a href="https://colab.research.google.com/github/zhangguanheng66/text/blob/arrow_dataset/examples/arraw_dataset/AI_hackathon_text_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture

!rm -r /usr/local/lib/python3.6/dist-packages/torch*;
!pip install --pre torch torchtext -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html;
!pip install --upgrade --force-reinstall pyarrow datasets;
!pip install pytorch-lightning

## Prepare PyTorch-Lightning Module

In [2]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

class TextClassificationModel(LightningModule):

  def __init__(self, vocab_size, embed_dim, num_class, learning_rate):
    super().__init__()
    self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
    self.fc = nn.Linear(embed_dim, num_class)
    self.lr = learning_rate
    self.init_weights()

  def init_weights(self):
    initrange = 0.5
    self.embedding.weight.data.uniform_(-initrange, initrange)
    self.fc.weight.data.uniform_(-initrange, initrange)
    self.fc.bias.data.zero_()

  def forward(self, text, offsets):
    embedded = self.embedding(text, offsets)
    return self.fc(embedded)

  def configure_optimizers(self):
    optimizer = torch.optim.SGD(self.parameters(), lr=self.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.5)
    return [optimizer], [scheduler]

    # return torch.optim.SGD(self.parameters(), lr=self.lr)

  def training_step(self, batch, batch_idx):
    labels, texts, offsets = batch
    predited_label = self(texts, offsets)
    loss = torch.nn.functional.cross_entropy(predited_label, labels)
    return loss

  def _eval_step(self, batch, batch_idx):
    labels, texts, offsets = batch
    predited_labels = self(texts, offsets)
    return (predited_labels, labels)

  def _eval_epoch_end(self, outputs):
    total_acc, total_count = 0, 0
    for i, (predited_labels, target_labels) in enumerate(outputs):
        total_acc += (predited_labels.argmax(1) == target_labels).sum().item()
        total_count += predited_labels.size(0)
    return total_acc, total_count
    self.log('val_acc', total_acc/total_count, prog_bar=True)

  def validation_step(self, batch, batch_idx):
    return self._eval_step(batch, batch_idx)

  def validation_epoch_end(self, valid_outputs):
    total_acc, total_count = self._eval_epoch_end(valid_outputs)
    self.log('val_acc', total_acc/total_count, prog_bar=True)

  def test_step(self, batch, batch_idx):
    return self._eval_step(batch, batch_idx)

  def test_epoch_end(self, test_outputs):
    total_acc, total_count = self._eval_epoch_end(test_outputs)
    self.log('test_acc', total_acc/total_count, prog_bar=True)

## Preapre PyTorch-Lightning DataModule

In [3]:
import io
import torch
import datasets as ds
from torchtext.experimental.vocab import build_vocab_from_iterator
from torchtext.experimental.transforms import basic_english_normalize
from torchtext.utils import download_from_url, unicode_csv_reader
from pytorch_lightning import LightningDataModule

def create_data_from_csv(data_path):
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            yield (int(row[0]), ' '.join(row[1:]))

def convert_to_arrow(file_path, raw_data):
    """ Write labels and texts into HF dataset"""
    labels, texts = zip(*raw_data)    
    return ds.Dataset.from_dict(
        {
            "labels": labels,
            "texts": texts
        }).save_to_disk(file_path)

def process_raw_data(arrow_ds, tokenizer, vocab):
    processed_arrow_ds = arrow_ds.map(function=lambda x: {'labels': int(x) - 1}, input_columns='labels')
    processed_arrow_ds = processed_arrow_ds.map(function=lambda x: {'texts': vocab(tokenizer(x))}, input_columns='texts')
    return processed_arrow_ds

def generate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for item in batch:
         label_list.append(item['labels'])
         processed_text = torch.tensor(item['texts'], dtype=torch.int64)
         text_list.append(processed_text)
         offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list, text_list, offsets    

class TextClassificationDataModule(LightningDataModule):
    def __init__(self, train_valid_split=0.9):
        super().__init__()
        self.train_valid_split = train_valid_split
        self.base_url = 'https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/'
        self.train_filepath = download_from_url(self.base_url + 'train.csv')
        self.test_filepath = download_from_url(self.base_url + 'test.csv')
        raw_train_data = list(create_data_from_csv(self.train_filepath))
        raw_test_data = list(create_data_from_csv(self.test_filepath))
        train_ds = convert_to_arrow('train_arrow', raw_train_data)
        test_ds = convert_to_arrow('test_arrow', raw_test_data)
        self.tokenizer = basic_english_normalize().to_ivalue()
        train_ds = ds.Dataset.load_from_disk('train_arrow')
        self.vocab = build_vocab_from_iterator(iter(self.tokenizer(line)
                                       for line in train_ds['texts'])).to_ivalue()

    def setup(self, stage):
        # Load and split the raw train dataset into train and valid set
        train_dataset = ds.Dataset.load_from_disk('train_arrow')
        dict_train_valid = train_dataset.train_test_split(test_size=1-self.train_valid_split,
                                                          train_size=self.train_valid_split,
                                                          shuffle=True)
        self.train = dict_train_valid['train'] # raw dataset
        self.valid = dict_train_valid['test']  # raw dataset
        self.test = ds.Dataset.load_from_disk('test_arrow')  # raw dataset

    def train_dataloader(self):
        # Process the raw dataset
        self.train = process_raw_data(self.train, self.tokenizer, self.vocab)
        return torch.utils.data.DataLoader(self.train, shuffle=True,
                                           batch_size=16, num_workers=1,
                                           collate_fn=generate_batch)
        
    def val_dataloader(self):
        # Process the raw dataset
        self.valid = process_raw_data(self.valid, self.tokenizer, self.vocab)
        return torch.utils.data.DataLoader(self.valid, batch_size=16, num_workers=1,
                                           collate_fn=generate_batch)
        
    def test_dataloader(self):
        # Process the raw dataset
        self.test = process_raw_data(self.test, self.tokenizer, self.vocab)
        return torch.utils.data.DataLoader(self.test, batch_size=16, num_workers=1,
                                           collate_fn=generate_batch)

## Initiate and execute PyTorch-Lightning Trainer

In [4]:
from pytorch_lightning import Trainer
LR = 5  # learning rate
NUM_CLASS = 4  # number of classes
EMBED = 256  # embedding
EPOCH = 3 # max epoch number

data_module = TextClassificationDataModule()
model = TextClassificationModel(len(data_module.vocab), EMBED, NUM_CLASS, LR)
trainer = Trainer(gpus=1, max_epochs=EPOCH, progress_bar_refresh_rate=40)
trainer.fit(model, data_module)

# run test set
result = trainer.test()
print("The accuracy for test set is: {:4.3f}".format(result[0]['test_acc']))

train.csv: 29.5MB [00:00, 98.3MB/s]
test.csv: 1.86MB [00:00, 67.1MB/s]                  
GPU available: True, used: True
INFO:lightning:GPU available: True, used: True
TPU available: False, using: 0 TPU cores
INFO:lightning:TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type         | Params
-------------------------------------------
0 | embedding | EmbeddingBag | 24.5 M
1 | fc        | Linear       | 1.0 K 
INFO:lightning:
  | Name      | Type         | Params
-------------------------------------------
0 | embedding | EmbeddingBag | 24.5 M
1 | fc        | Linear       | 1.0 K 


HBox(children=(FloatProgress(value=0.0, max=12000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=12000.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=0.0, max=108000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=108000.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=0.0, max=7600.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=7600.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.9090789473684211}
--------------------------------------------------------------------------------

The accuracy for test set is: 0.909
