<a href="https://colab.research.google.com/github/wshuyi/demo_chinese_text_classification_bert_fastai/blob/master/demo_refactored_dianping_classification_with_BERT_fastai.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from fastai.text import *

In [0]:
!wget https://github.com/wshuyi/public_datasets/raw/master/dianping.csv

In [0]:
df = pd.read_csv("dianping.csv")

In [0]:
from sklearn.model_selection import train_test_split

In [0]:
train, test = train_test_split(df, test_size=.2, random_state=2)

In [0]:
train, valid = train_test_split(train, test_size=.2, random_state=2)

In [0]:
len(train)

In [0]:
len(valid)

In [0]:
len(test)

In [0]:
train.head()

In [0]:
!pip install pytorch-transformers

In [0]:
from pytorch_transformers import BertTokenizer, BertForSequenceClassification

In [0]:
bert_model = "bert-base-chinese"
max_seq_len = 128
batch_size = 32

In [0]:
bert_tokenizer = BertTokenizer.from_pretrained(bert_model)

In [0]:
list(bert_tokenizer.vocab.items())[2000:2005]

In [0]:
bert_vocab = Vocab(list(bert_tokenizer.vocab.keys()))

In [0]:
class BertFastaiTokenizer(BaseTokenizer):
    def __init__(self, tokenizer, max_seq_len=128, **kwargs):
        self.pretrained_tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __call__(self, *args, **kwargs):
        return self

    def tokenizer(self, t):
        return ["[CLS]"] + self.pretrained_tokenizer.tokenize(t)[:self.max_seq_len - 2] + ["[SEP]"]

In [0]:
tok_func = BertFastaiTokenizer(bert_tokenizer, max_seq_len=max_seq_len)

In [0]:
bert_fastai_tokenizer = Tokenizer(
    tok_func=tok_func,
    pre_rules = [],
    post_rules = []
)

In [0]:
path = Path(".")

In [0]:
databunch = TextClasDataBunch.from_df(path, train, valid, test,
                  tokenizer=bert_fastai_tokenizer,
                  vocab=bert_vocab,
                  include_bos=False,
                  include_eos=False,
                  text_cols="comment",
                  label_cols='sentiment',
                  bs=batch_size,
                  collate_fn=partial(pad_collate, pad_first=False, pad_idx=0),
             )

In [0]:
databunch.show_batch()

In [0]:
class MyNoTupleModel(BertForSequenceClassification):
  def forward(self, *args, **kwargs):
    return super().forward(*args, **kwargs)[0]

In [0]:
bert_pretrained_model = MyNoTupleModel.from_pretrained(bert_model, num_labels=2)

In [0]:
loss_func = nn.CrossEntropyLoss()

In [0]:
learn = Learner(databunch, 
                bert_pretrained_model,
                loss_func=loss_func,
                metrics=accuracy)

In [0]:
learn.lr_find()

In [0]:
learn.recorder.plot()

In [0]:
learn.fit_one_cycle(2, 2e-5)

In [0]:
def dumb_series_prediction(n):
  preds = []
  for loc in range(n):
    preds.append(int(learn.predict(test.iloc[loc]['comment'])[1]))
  return preds

In [0]:
preds = dumb_series_prediction(len(test))

In [0]:
preds[:10]

In [0]:
from sklearn.metrics import classification_report, confusion_matrix

In [0]:
print(classification_report(test.sentiment, preds))

In [0]:
print(confusion_matrix(test.sentiment, preds))