This notebook is mostly from

https://mlexplained.com/2019/01/30/an-in-depth-tutorial-to-allennlp-from-basics-to-elmo-and-bert/#more-853

with slight modification and annotation by YL

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
from typing import *
import torch
import torch.optim as optim
import numpy as np
import pandas as pd
from functools import partial
from overrides import overrides

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.nn import util as nn_util

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)
        
config = Config(
    testing=False,
    seed=1,
    batch_size=64,
    lr=5e-5,
    epochs=5,
    hidden_sz=64,
    max_seq_len=100, # necessary to limit memory usage
    max_vocab_size=100000,
)

bert_flavour = "bert-base-multilingual-cased"
# bert_flavour = "bert-base-chinese"

#if true, read whole dataset and load pretrained weight then run fine tune training
#if false, read small dataset and load finetunned weight and skip fine tune training
retrain = False

In [4]:
from allennlp.common.checks import ConfigurationError

In [5]:
USE_GPU = torch.cuda.is_available()
USE_GPU

False

In [6]:
DATA_ROOT = Path("../")

Set random seed manually to replicate results

In [7]:
torch.manual_seed(config.seed)

<torch._C.Generator at 0x7fd79233cf90>

# Load Data

In [8]:
import pickle
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import DatasetReader

### Prepare dataset

In [9]:
cat_ids = pickle.load(open(DATA_ROOT / "data/cat_id_lut.pkl","rb"))
cat_ids

{1: '吹水台',
 4: '手機台',
 5: '時事台',
 6: '體育台',
 7: '娛樂台',
 8: '動漫台',
 9: 'Apps台',
 10: '遊戲台',
 11: '影視台',
 12: '講故台',
 13: '潮流台',
 14: '上班台',
 15: '財經台',
 16: '飲食台',
 17: '旅遊台',
 18: '學術台',
 19: '校園台',
 20: '汽車台',
 21: '音樂台',
 22: '硬件台',
 23: '攝影台',
 24: '玩具台',
 25: '寵物台',
 26: '軟件台',
 27: '活動台',
 28: '站務台',
 29: '成人台',
 30: '感情台',
 31: '創意台',
 32: '黑\u3000洞',
 33: '政事台',
 34: '直播台',
 35: '電訊台',
 36: '健康台'}

In [10]:
ncats = max(cat_ids.keys()) - min(cat_ids.keys())+1
ncats

36

In [11]:
from allennlp.data.fields import TextField, MetadataField, ArrayField

class LihkgDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer] = None,
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len

    @overrides
    def text_to_instance(self, tokens: List[Token],
                         label: int=None) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"tokens": sentence_field}
        
        
        labels = np.zeros(ncats)
        if label is not None:
            labels[label-1]=1 #one hot encode
            
        label_field = ArrayField(array=labels)
        fields["label"] = label_field

        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(1000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                [Token(x) for x in self.tokenizer(row["title"])],
                row["cat_id"],
            )

### Prepare token handlers

In [12]:
from allennlp.data.token_indexers import PretrainedBertIndexer

token_indexer = PretrainedBertIndexer(
    pretrained_model=str(DATA_ROOT / ("pretrain/%s-vocab.txt"%bert_flavour)),
    max_pieces=config.max_seq_len,
    do_lowercase=True,
 )
# apparently we need to truncate the sequence here, which is a stupid design decision
def tokenizer(s: str):
    return token_indexer.wordpiece_tokenizer(s)[:config.max_seq_len - 2]

In [13]:
reader = LihkgDatasetReader(
    tokenizer=tokenizer,
    token_indexers={"tokensIdxers1": token_indexer}
)

In [14]:
if retrain:
    train_ds, val_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["data/lihkg_posts_20190227_train.csv", "data/lihkg_posts_20190227_val.csv", "data/lihkg_posts_20190227_test.csv"])
else:
    train_ds, val_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["data/lihkg_posts_train.csv", "data/lihkg_posts_val.csv", "data/lihkg_posts_test.csv"])


14609it [00:05, 2767.26it/s]
3661it [00:01, 2846.54it/s]
2035it [00:00, 3016.07it/s]


train_ds is a list of "instance"

In [16]:
len(train_ds)

14609

In [17]:
train_ds[:10]

[<allennlp.data.instance.Instance at 0x7fd710b27b70>,
 <allennlp.data.instance.Instance at 0x7fd710b27fd0>,
 <allennlp.data.instance.Instance at 0x7fd71109f1d0>,
 <allennlp.data.instance.Instance at 0x7fd71109f2b0>,
 <allennlp.data.instance.Instance at 0x7fd71109f5c0>,
 <allennlp.data.instance.Instance at 0x7fd71109fc18>,
 <allennlp.data.instance.Instance at 0x7fd7110a33c8>,
 <allennlp.data.instance.Instance at 0x7fd7110a3780>,
 <allennlp.data.instance.Instance at 0x7fd7110a3ba8>,
 <allennlp.data.instance.Instance at 0x7fd7110a8278>]

Let's see what is inside an instance

In [18]:
train_ds[0].fields

{'tokens': <allennlp.data.fields.text_field.TextField at 0x7fd7a5212208>,
 'label': <allennlp.data.fields.array_field.ArrayField at 0x7fd7a3880ef0>}

In [19]:
vars(train_ds[0].fields["tokens"])

{'tokens': [好, ##想, ##玩, ##三, ##國, ##志, ##11],
 '_token_indexers': {'tokensIdxers1': <allennlp.data.token_indexers.wordpiece_indexer.PretrainedBertIndexer at 0x7fd7a0640278>},
 '_indexed_tokens': None,
 '_indexer_name_to_indexed_token': None}

In [20]:
vars(train_ds[0].fields['label'])

{'array': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]), 'padding_value': 0}

### Prepare vocabulary

We don't need to build the vocab: all that is handled by the token indexer

In [21]:
vocab = Vocabulary()

### Prepare iterator

The iterator is responsible for batching the data and preparing it for input into the model. We'll use the BucketIterator that batches text sequences of smilar lengths together.

In [22]:
from allennlp.data.iterators import BucketIterator

In [23]:
iterator = BucketIterator(batch_size=config.batch_size, 
                          #sorting_keys=[("tokens", "num_tokens")],
                          sorting_keys=[("tokens", "tokensIdxers1_length")], #same as num_token since only 1 tokenIndexer
                          max_instances_in_memory = 1000,
                         )

We need to tell the iterator how to numericalize the text data. We do this by passing the vocabulary to the iterator. This step is easy to forget so be careful! 

In [24]:
iterator.index_with(vocab)

### Read sample

In [25]:
batch = next(iter(iterator(train_ds)))

In [26]:
batch.keys()

dict_keys(['tokens', 'label'])

In [27]:
batch["tokens"].keys()

dict_keys(['tokensIdxers1', 'tokensIdxers1-offsets', 'tokensIdxers1-type-ids', 'mask'])

In [28]:
batch["tokens"]["tokensIdxers1"]

tensor([[   101,  22257,   4163,  ...,      0,      0,      0],
        [   101,   8595, 117791,  ...,      0,      0,      0],
        [   101,   2796, 112987,  ...,      0,      0,      0],
        ...,
        [   101,   2104, 112440,  ...,      0,      0,      0],
        [   101,    164, 117483,  ...,      0,      0,      0],
        [   101,    113, 117673,  ...,    102,      0,      0]])

In [29]:
batch["tokens"]["tokensIdxers1"].shape

torch.Size([64, 32])

In [30]:
batch['label'].shape

torch.Size([64, 36])

# Prepare Model

In [31]:
import torch
import torch.nn as nn
import torch.optim as optim

In [32]:
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.nn.util import get_text_field_mask
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder

class BaselineModel(Model):
    def __init__(self, word_embeddings: TextFieldEmbedder,
                 encoder: Seq2VecEncoder,
                 out_sz: int=ncats):
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        self.loss = nn.BCEWithLogitsLoss()
        
    def forward(self, tokens: Dict[str, torch.Tensor],
                label: torch.Tensor) -> torch.Tensor:
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        state = self.encoder(embeddings, mask)
        class_logits = self.projection(state)
        
        output = {"class_logits": class_logits}
        output["loss"] = self.loss(class_logits, label)

        return output

### Prepare embeddings

In [33]:
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.bert_token_embedder import PretrainedBertEmbedder

bert_embedder = PretrainedBertEmbedder(
        pretrained_model=str(DATA_ROOT / ("pretrain/%s.tar.gz"% bert_flavour)),
        requires_grad=True, #Finetune BERT weight or not
        top_layer_only=True, # if False, embedding is weighted average of all layers in BERT
)
word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder({"tokensIdxers1": bert_embedder},
                                                            # we'll be ignoring masks so we'll need to set this to True
                                                           allow_unmatched_keys = True)

# this is because the bert_indexer generate more then 1 output
# ordinary indexer just gen 1 output with key the same as the indexer i.e. "tokensIdxess1"
# bert indexer gen output with key 'tokensIdxers1', 'tokensIdxers1-offsets', 'tokensIdxers1-type-ids', 'mask'

In [34]:
BERT_DIM = word_embeddings.get_output_dim()

class BertSentencePooler(Seq2VecEncoder):
    def forward(self, embs: torch.tensor, 
                mask: torch.tensor=None) -> torch.tensor:
        # extract first token tensor
        return embs[:, 0]
    
    @overrides
    def get_output_dim(self) -> int:
        return BERT_DIM
    
encoder = BertSentencePooler(vocab)

Notice how simple and modular the code for initializing the model is. All the complexity is delegated to each component.

In [35]:
model = BaselineModel(
    word_embeddings, 
    encoder, 
)

In [36]:
if USE_GPU: model.cuda()
else: model

# Basic sanity checks

In [39]:
batch = nn_util.move_to_device(batch, 0 if USE_GPU else -1)

In [40]:
tokens = batch["tokens"]

In [41]:
tokens

{'tokensIdxers1': tensor([[   101,  22257,   4163,  ...,      0,      0,      0],
         [   101,   8595, 117791,  ...,      0,      0,      0],
         [   101,   2796, 112987,  ...,      0,      0,      0],
         ...,
         [   101,   2104, 112440,  ...,      0,      0,      0],
         [   101,    164, 117483,  ...,      0,      0,      0],
         [   101,    113, 117673,  ...,    102,      0,      0]]),
 'tokensIdxers1-offsets': tensor([[ 1,  2,  3,  ...,  0,  0,  0],
         [ 1,  2,  3,  ...,  0,  0,  0],
         [ 1,  2,  3,  ...,  0,  0,  0],
         ...,
         [ 1,  2,  3,  ...,  0,  0,  0],
         [ 1,  2,  3,  ...,  0,  0,  0],
         [ 1,  2,  3,  ..., 27, 28,  0]]),
 'tokensIdxers1-type-ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'mask': tensor([[1, 1, 1,  ..., 0, 0,

In [42]:
mask = get_text_field_mask(tokens)
mask

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 0]])

In [43]:
embeddings = model.word_embeddings(tokens)
state = model.encoder(embeddings, mask)
class_logits = model.projection(state)
class_logits

tensor([[ 0.0933, -0.0759, -0.0174,  ..., -0.5220,  0.4285,  0.3026],
        [-0.0986,  0.0405,  0.0439,  ..., -0.4200,  0.3379,  0.0592],
        [-0.1062,  0.0241,  0.1827,  ..., -0.6862,  0.1846,  0.2976],
        ...,
        [-0.1739, -0.0776,  0.4444,  ..., -0.4775,  0.3417,  0.1587],
        [-0.0259, -0.0951,  0.0562,  ..., -0.6593,  0.2696,  0.0848],
        [-0.1038, -0.2220,  0.0832,  ..., -0.3604,  0.2358,  0.1642]],
       grad_fn=<AddmmBackward>)

In [44]:
model(**batch)

{'class_logits': tensor([[ 0.1668,  0.0560,  0.0265,  ..., -0.2267,  0.1827,  0.1101],
         [-0.3499, -0.0192,  0.2075,  ..., -0.4669,  0.3533,  0.2253],
         [-0.1326, -0.0518,  0.3157,  ..., -0.5325,  0.2581,  0.0155],
         ...,
         [-0.2375, -0.0828,  0.1574,  ..., -0.5315,  0.0620, -0.0714],
         [-0.2675,  0.1637,  0.2599,  ..., -0.5501,  0.1942,  0.2000],
         [-0.0214, -0.0584,  0.2488,  ..., -0.6487,  0.3309,  0.1936]],
        grad_fn=<AddmmBackward>),
 'loss': tensor(0.7198, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)}

In [45]:
loss = model(**batch)["loss"]

In [46]:
loss

tensor(0.7201, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [47]:
loss.backward()

Our "encoder" just takes the 1st token's embedding, thus has no free parameters

In [48]:
[x.grad for x in list(model.encoder.parameters())]

[]

# Train

In [49]:
optimizer = optim.Adam(model.parameters(), lr=config.lr)

In [52]:
from allennlp.training.trainer import Trainer

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    cuda_device=0 if USE_GPU else -1,
    num_epochs=config.epochs,
    patience=10,
)

In [53]:
weight_filename = "chkpoints/%s-finetune.pth" % bert_flavour

In [54]:
if not retrain:
    with open(DATA_ROOT / weight_filename, 'rb') as f:
        if USE_GPU:
            model.load_state_dict(torch.load(f))
        else:
            model.load_state_dict(torch.load(f, map_location='cpu'))


In [55]:
if retrain:
    metrics = trainer.train()
    with open(DATA_ROOT / weight_filename, 'wb') as f:
        torch.save(model.state_dict(), f)
    print (metrics)


# Generating Predictions in bulk

In [56]:
from allennlp.data.iterators import DataIterator
from tqdm import tqdm
from scipy.special import expit # the sigmoid function

def tonp(tsr): return tsr.detach().cpu().numpy()

class Predictor:
    def __init__(self, model: Model, iterator: DataIterator,
                 cuda_device: int=-1) -> None:
        self.model = model
        self.iterator = iterator
        self.cuda_device = cuda_device
        
    def _extract_data(self, batch) -> np.ndarray:
        out_dict = self.model(**batch)
        return expit(tonp(out_dict["class_logits"]))
    
    def predict(self, ds: Iterable[Instance]) -> np.ndarray:
        pred_generator = self.iterator(ds, num_epochs=1, shuffle=False)
        self.model.eval()
        pred_generator_tqdm = tqdm(pred_generator,
                                   total=self.iterator.get_num_batches(ds))
        preds = []
        with torch.no_grad():
            for batch in pred_generator_tqdm:
                batch = nn_util.move_to_device(batch, self.cuda_device)
                preds.append(self._extract_data(batch))
        return np.concatenate(preds, axis=0)

In [57]:
from allennlp.data.iterators import BasicIterator
# iterate over the dataset without changing its order
seq_iterator = BasicIterator(batch_size=64)
seq_iterator.index_with(vocab)

In [59]:
predictor = Predictor(model, seq_iterator, cuda_device=0 if USE_GPU else -1)
# train_preds = predictor.predict(train_ds) 
test_preds = predictor.predict(test_ds) 


  0%|          | 0/32 [00:00<?, ?it/s][A
  3%|▎         | 1/32 [00:04<02:21,  4.55s/it][A
  6%|▋         | 2/32 [00:07<02:02,  4.08s/it][A
  9%|▉         | 3/32 [00:10<01:47,  3.71s/it][A
 12%|█▎        | 4/32 [00:14<01:44,  3.72s/it][A
 16%|█▌        | 5/32 [00:18<01:42,  3.81s/it][A
 19%|█▉        | 6/32 [00:21<01:37,  3.74s/it][A
 22%|██▏       | 7/32 [00:24<01:28,  3.52s/it][A
 25%|██▌       | 8/32 [00:28<01:27,  3.63s/it][A
 28%|██▊       | 9/32 [00:31<01:15,  3.30s/it][A
 31%|███▏      | 10/32 [00:33<01:07,  3.09s/it][A
 34%|███▍      | 11/32 [00:36<01:03,  3.04s/it][A
 38%|███▊      | 12/32 [00:39<01:01,  3.08s/it][A
 41%|████      | 13/32 [00:42<00:59,  3.11s/it][A
 44%|████▍     | 14/32 [00:46<00:57,  3.22s/it][A
 47%|████▋     | 15/32 [00:49<00:53,  3.16s/it][A
 50%|█████     | 16/32 [00:52<00:51,  3.22s/it][A
 53%|█████▎    | 17/32 [00:56<00:50,  3.39s/it][A
 56%|█████▋    | 18/32 [01:00<00:48,  3.47s/it][A
 59%|█████▉    | 19/32 [01:03<00:44,  3.46s/it]

In [60]:
test_preds[0]

array([3.0256435e-03, 5.2302953e-09, 6.6384516e-09, 2.7938961e-04,
       5.1119551e-04, 9.0638740e-04, 1.3369879e-04, 2.3170123e-04,
       2.8342550e-04, 9.9209142e-01, 1.9863427e-04, 3.5612855e-05,
       4.8792692e-05, 1.2237292e-04, 2.4284588e-04, 1.0908637e-04,
       8.8957328e-05, 2.2884137e-04, 2.1211820e-04, 1.9855321e-04,
       2.1967693e-04, 1.9203312e-03, 9.6240437e-05, 3.8534845e-04,
       2.2629087e-05, 2.3867024e-04, 1.2670888e-04, 2.8933393e-05,
       1.3840353e-04, 1.5202230e-04, 2.9313882e-04, 8.0075715e-06,
       7.5074902e-05, 7.0124120e-04, 4.0976520e-05, 2.5300418e-05],
      dtype=float32)

# Generating Predictions per instance

In [61]:
from allennlp.predictors.sentence_tagger import SentenceTaggerPredictor

model.eval()
tagger = SentenceTaggerPredictor(model, reader)

def predict_cat(s):
    logits = tagger.predict(s)['class_logits']
    probs = expit(logits)
    cat_ranked = probs.argsort()[::-1]
    for cat in cat_ranked[:5]:
        print ("%2d %s %0.2f" % (cat+1, cat_ids.get(cat+1,'Nil'), probs[cat]))
    return probs

In [62]:
predict_cat('有無巴絲想創業賣暖手杯？')

31 創意台 0.37
14 上班台 0.19
30 感情台 0.16
 1 吹水台 0.13
15 財經台 0.03


array([1.30229738e-01, 9.68756547e-08, 7.82302534e-08, 4.29208223e-03,
       4.03621973e-04, 1.62166415e-03, 3.20505841e-03, 1.61382171e-03,
       2.56355792e-03, 1.33242549e-03, 5.55419318e-04, 2.57377612e-03,
       2.09480984e-02, 1.89542223e-01, 3.44811291e-02, 2.37401844e-03,
       6.74083538e-04, 1.44032373e-02, 9.29753838e-04, 8.42620630e-04,
       1.44957371e-03, 1.38655105e-03, 3.23537888e-03, 1.04428094e-02,
       8.03332830e-03, 6.41325793e-03, 8.64076251e-03, 1.17262867e-03,
       9.41072244e-03, 1.59753568e-01, 3.65698823e-01, 1.67630522e-04,
       3.44324442e-05, 3.72896553e-05, 5.06894585e-05, 1.68918455e-03])

In [63]:
predict_cat('中國四大偉人：毛澤東、鄧小平、習近平、貧僧')

18 學術台 0.62
 5 時事台 0.20
31 創意台 0.07
33 政事台 0.06
 1 吹水台 0.03


array([2.82041618e-02, 5.82127678e-09, 7.75350130e-09, 2.84836857e-04,
       1.99179655e-01, 1.25203246e-02, 1.38632244e-02, 1.19707607e-03,
       6.08758834e-05, 1.91235374e-04, 8.16013017e-03, 2.23782212e-03,
       6.44997313e-04, 1.68288568e-03, 1.22917348e-02, 3.67496999e-03,
       5.65681267e-04, 6.18406327e-01, 6.12195146e-03, 1.66929984e-05,
       3.92323312e-04, 1.94203670e-05, 7.96323730e-05, 7.55593419e-05,
       7.97885011e-05, 9.86057605e-05, 3.30366512e-04, 2.76948116e-04,
       2.25995666e-04, 1.55878164e-03, 6.50506960e-02, 5.58609565e-06,
       5.92488836e-02, 1.21655289e-04, 1.49689708e-05, 1.06933874e-03])

In [64]:
predict_cat("[ALL IN US] ITZY討論區(1) IT'z Different大發~~~~")

 7 娛樂台 0.57
21 音樂台 0.52
 1 吹水台 0.01
30 感情台 0.00
11 影視台 0.00


array([5.28733937e-03, 5.97462056e-09, 5.50169581e-09, 4.81850084e-05,
       1.27002943e-04, 5.00252889e-04, 5.71169596e-01, 2.75188324e-04,
       6.48161862e-05, 6.40406414e-05, 1.48426226e-03, 5.59883542e-05,
       1.14716496e-03, 1.26386932e-04, 1.83471336e-04, 1.08561192e-04,
       1.69060101e-04, 2.19638183e-04, 1.97755341e-04, 4.15588344e-05,
       5.18831785e-01, 2.37996032e-04, 5.14486406e-05, 2.55534965e-05,
       1.82427066e-05, 3.59101682e-05, 2.60833461e-04, 3.24450924e-05,
       6.08080002e-05, 1.67675952e-03, 7.90514695e-04, 2.78521552e-06,
       3.17497633e-05, 3.09394175e-05, 5.84617594e-06, 6.36811578e-05])

In [65]:
predict_cat("有無巴打食過狗糧？ 咩味？")

25 寵物台 0.46
16 飲食台 0.46
31 創意台 0.06
 1 吹水台 0.04
18 學術台 0.02


array([4.45906485e-02, 1.05464807e-08, 9.91994384e-09, 1.41109741e-03,
       1.11162316e-03, 2.50632383e-03, 7.80373077e-04, 7.57171726e-04,
       5.90223057e-04, 9.63383967e-04, 3.18183019e-04, 2.91407163e-03,
       3.38568138e-04, 1.61592621e-02, 5.02451719e-03, 4.56044820e-01,
       5.27513357e-04, 1.93788835e-02, 1.53796477e-03, 7.29677135e-04,
       9.06565308e-04, 1.73104754e-03, 5.61649490e-04, 3.32808541e-04,
       4.64066695e-01, 8.43469731e-04, 1.29416916e-04, 4.77583143e-04,
       9.43065140e-04, 4.89875325e-03, 6.28404662e-02, 9.74792941e-05,
       3.28291810e-04, 1.20680593e-04, 7.02659955e-05, 1.95548360e-03])

In [66]:
predict_cat("西野カナ fans揮手區(5) 活動休止中")

 7 娛樂台 0.86
21 音樂台 0.09
 8 動漫台 0.08
 1 吹水台 0.00
11 影視台 0.00


array([4.16704358e-03, 1.18593103e-08, 6.94446923e-09, 1.25296766e-04,
       2.43168394e-04, 2.12584806e-03, 8.62017581e-01, 8.47825680e-02,
       5.36677543e-04, 1.20059663e-03, 2.29705821e-03, 3.15451946e-04,
       4.26230405e-04, 1.79752870e-04, 1.38878525e-04, 1.66694575e-04,
       1.60906888e-04, 1.39018703e-04, 2.81078969e-04, 7.40532153e-05,
       9.47564088e-02, 4.38377990e-05, 1.01587204e-04, 1.46183797e-04,
       3.73129194e-05, 4.33523894e-05, 2.61034476e-04, 1.10513337e-04,
       2.07680987e-03, 1.64214797e-03, 2.04445787e-03, 5.43810165e-06,
       5.33113101e-05, 1.73729902e-04, 5.03013863e-06, 4.98549227e-05])

In [67]:
predict_cat("Annual Dinner 抽中左Dyson 風筒，HR打黎話要收返")

14 上班台 0.70
31 創意台 0.15
 1 吹水台 0.12
30 感情台 0.04
 5 時事台 0.02


array([1.20766512e-01, 7.80079108e-09, 7.09452324e-09, 1.55773708e-03,
       2.28576223e-02, 7.01079612e-03, 3.17047582e-03, 7.62227088e-04,
       9.17044176e-05, 1.35244559e-03, 1.27855468e-03, 1.45804302e-03,
       1.51424715e-02, 6.98473609e-01, 6.85537399e-03, 8.99559108e-03,
       1.13343704e-04, 1.27359278e-02, 4.46021036e-03, 2.48043903e-04,
       2.02212525e-04, 5.48392768e-03, 4.14365124e-05, 2.67368337e-04,
       1.04350359e-04, 4.57699812e-03, 7.52210175e-05, 5.02096609e-05,
       8.49452806e-04, 4.42502162e-02, 1.45827749e-01, 8.49899383e-06,
       2.82615207e-04, 8.42950162e-05, 1.39365514e-04, 3.89645649e-03])