# Punctuation NER

In [1]:
# Forgebox Imports
from forgebox.imports import *
from forgebox.category import Category
import pytorch_lightning as pl
from transformers import AutoTokenizer, BertForTokenClassification
from transformers import pipeline
from typing import List
import re

In [2]:
from gc_utils.env import sys_loc
DATA = sys_loc('DATA')/"nlp"/"zh"/"daizhigev20"

## Read Metadata

In [3]:
META = pd.read_csv(DATA/"meta.csv")

In [4]:
LABELS = META.query("charspan<15").sample(frac=1.).reset_index(drop=True)

In [6]:
punkt_regex = r'[^\w\s]'

def position_of_all_punctuation(x):
    return [m.start() for m in re.finditer(punkt_regex, x)]

# simplify the punctuation
eng_punkt_to_cn_dict = {
    ".": "。",
    ",": "，",
    ":": "：",
    ";": "；",
    "?": "？",
    "!": "！",
    "“": "\"",
    "”": "\"",
    "‘": "\'",
    "’": "\'",
    "「": "（",
    "」": "）",
    "『": "\"",
    "』": "\"",
    "（": "（",
    "）": "）",
    "《": "【",
    "》": "】",
    "［": "【",
    "］": "】",
    }

def translate_eng_punkt_to_cn(char):
    if char == "O":
        return char
    if char in eng_punkt_to_cn_dict.values():
        return char
    result = eng_punkt_to_cn_dict.get(char)
    if result is None:
        return "。"
    return result

def punct_ner_pair(sentence):
    positions = position_of_all_punctuation(sentence)
    x = re.sub(punkt_regex, '', sentence)
    y = list("O"*len(x))
    
    for i, p in enumerate(positions):
        y[p-i-1] = sentence[p]
    p_df = pd.DataFrame({"x":list(x), "y":y})
    p_df["y"] = p_df["y"].apply(translate_eng_punkt_to_cn)
    return p_df

In [7]:
ALL_LABELS = ["O",]+list(eng_punkt_to_cn_dict.values())

In [9]:
cates = Category(ALL_LABELS)

In [10]:
class PunctDataset(Dataset):
    def __init__(
        self,
        data_dir: Path,
        filelist: List[str],
        num_threads: int = 8,
        length: int = 1000,
        size: int = 540
    ):
        """
        Args:
            - filelist: list of file names
            - The dataset will open ```num_threads``` files, and hold
                in memory simoultaneously.
            - num_threads: number of threads to read files,
            - length: number of sentences per batch
            - size: number of characters per sentence
        """
        self.data_dir = Path(data_dir)
        self.filelist = filelist
        self.num_threads = num_threads
        self.length = length
        # open file strings, index is mod of num_threads
        self.current_files = dict(enumerate([""]*length))
        self.string_index = dict(enumerate([0]*length))
        self.to_open_idx = 0
        self.size = size
        self.get_counter = 0
        self.return_string = False

    def __len__(self):
        return self.length

    def __repr__(self):
        return f"PunctDataset: {len(self)}, on {len(self.filelist)} files"

    def new_file(self, idx_mod):
        filename = self.filelist[self.to_open_idx]
        with open(self.data_dir/filename, "r", encoding="utf-8") as f:
            self.current_files[idx_mod] = f.read()

        self.to_open_idx += 1

        # reset to open article file index
        if self.to_open_idx >= len(self.filelist):
            self.to_open_idx = 0

        # reset string_index within new article file
        self.string_index[idx_mod] = 0

        if self.to_open_idx % 500 == 0:
            print(f"went through files:\t{self.to_open_idx}")

    def __getitem__(self, idx):
        idx_mod = self.get_counter % self.num_threads

        if self.string_index[idx_mod] >= len(self.current_files[idx_mod]):
            self.new_file(idx_mod)
        string_idx = self.string_index[idx_mod]

        # slicing a sentence
        sentence = self.current_files[idx_mod][string_idx:string_idx+self.size]

        # move the string_index within current article file
        self.string_index[idx_mod] += self.size

        # move the get_counter
        self.get_counter += 1
        p_df = punct_ner_pair(sentence)
        return list(p_df.x), list(p_df.y)

    def align_offsets(
        self,
        inputs,
        text_labels: List[List[str]],
        words: List[List[str]]
    ):
        """
        inputs: output if tokenizer
        text_labels: labels in form of list of list of strings
        words: words in form of list of list of strings
        """
        labels = torch.zeros_like(inputs.input_ids).long()
        labels -= 100
        text_lables_array = np.empty(labels.shape, dtype=object)
        words_array = np.empty(labels.shape, dtype=object)
        max_len = inputs.input_ids.shape[1]

        for row_id, input_ids in enumerate(inputs.input_ids):
            word_pos = inputs.word_ids(row_id)
            for idx, pos in enumerate(word_pos):
                if pos is None:
                    continue
                if pos <= max_len:
                    labels[row_id, idx] = self.cates.c2i[text_labels[row_id][pos]]
                    if self.return_string:
                        text_lables_array[row_id,
                                          idx] = text_labels[row_id][pos]
                        words_array[row_id, idx] = words[row_id][pos]

        inputs['labels'] = labels
        if self.return_string:
            inputs['text_labels'] = text_lables_array.tolist()
            inputs['word'] = words_array.tolist()
        return inputs

    def collate_fn(self, data):
        """
        data: list of tuple
        """
        words, text_labels = zip(*data)

        inputs = self.tokenizer(
            list(words),
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=self.max_len,
            is_split_into_words=True,
            return_offsets_mapping=True,
            add_special_tokens=False,
        )
        return self.align_offsets(inputs, text_labels, words)

    def dataloaders(self, tokenizer, cates, max_len: int = 512, batch_size: int = 32):
        self.tokenizer = tokenizer
        self.cates = cates
        self.max_len = max_len
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=self.collate_fn
        )

    def split(self, ratio: float = 0.9):
        """
        Split the dataset into train and valid
        """
        np.random.shuffle(self.filelist)
        split_idx = int(len(self.filelist)*ratio)
        train_dataset = PunctDataset(
            self.data_dir,
            self.filelist[:split_idx],
            num_threads=self.num_threads,
            length=int(self.length*ratio),
            size=self.size,
        )
        valid_dataset = PunctDataset(
            self.data_dir,
            self.filelist[split_idx:],
            num_threads=self.num_threads,
            length=int(self.length*(1-ratio)),
            size=self.size,
        )
        return train_dataset, valid_dataset

Create dataset object

* Length is the length of the epoch
* Size: is the sequence length
* num_threads: num of files that is opening at the same time

In [11]:
ds = PunctDataset(DATA, list(LABELS.filepath), num_threads=8, length=10000, size=512)
train_ds, valid_ds = ds.split(0.9)

### lightning data module

In [12]:
class PunctDataModule(pl.LightningDataModule):
    def __init__(self, train_ds, valid_ds, tokenizer, cates, 
    max_len=512, batch_size=32):
        super().__init__()
        self.train_ds, self.valid_ds = train_ds, valid_ds
        self.tokenizer = tokenizer
        self.cates = cates
        self.max_len = max_len
        self.batch_size = batch_size

    def split_data(self):
        
        return train_ds, valid_ds
    
    def train_dataloader(self):
        return self.train_ds.dataloaders(
            self.tokenizer,
            self.cates,
            self.max_len,
            self.batch_size,
        )
    
    def val_dataloader(self):
        return self.valid_ds.dataloaders(
            self.tokenizer,
            self.cates,
            self.max_len,
            self.batch_size*4)

## Load Pretrained

In [14]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

Load pretrained model with proper num of categories

In [15]:
model = BertForTokenClassification.from_pretrained("bert-base-chinese", num_labels=len(cates),)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-c

In [16]:
data_module = PunctDataModule(train_ds, valid_ds, tokenizer, cates,
                              batch_size=32,)

### Run data pipeline

In [17]:
inputs = next(iter(data_module.val_dataloader()))

In [18]:
inputs.input_ids.shape

torch.Size([128, 464])

In [19]:
inputs.labels.shape

torch.Size([128, 464])

In [20]:
# @interact
# def view_label(idx=range(0,31)):
#     for x,y in zip(inputs['word'][idx], inputs['text_labels'][idx]):
#         print(f"{x}-{y}", end="\t")

## NER tranining module

In [21]:
from forgebox.thunder.callbacks import DataFrameMetricsCallback
from forgebox.hf.train import NERModule

In [22]:
module = NERModule(model)

In [23]:
save_callback = pl.callbacks.ModelCheckpoint(
    dirpath="/GCI/transformers/weights/punkt_ner/",
    save_top_k=2,
    verbose=True,
    monitor='val_loss',
    mode='min',
)
df_show = DataFrameMetricsCallback()

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Reset the configure_optimizers function

In [24]:
def configure_optimizers(self):
        # discriminative learning rate
    param_groups = [
            {'params': self.model.bert.parameters(), 'lr': 5e-6},
            {'params': self.model.classifier.parameters(), 'lr': 1e-3},
        ]
    optimizer = torch.optim.Adam(param_groups, lr=1e-3)
    return optimizer

NERModule.configure_optimizers = configure_optimizers

Trainer

In [25]:
trainer = pl.Trainer(
    gpus=[0],
    max_epochs=100,
    callbacks=[df_show, save_callback],
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(module, datamodule=data_module)

## Load the best model

In [29]:
module = module.load_from_checkpoint(save_callback.best_model_path, model=model)

In [28]:
module.model.config.id2label = dict(enumerate(cates.i2c))
module.model.config.label2id = cates.c2i.dict

In [35]:
from transformers import pipeline

In [40]:
module.model = module.model.eval()
module.model = module.model.cpu()

## Push to model hub

In [32]:
TAG = "raynardj/classical-chinese-punctuation-guwen-biaodian"

In [33]:
module.model.push_to_hub(TAG)

Upload file pytorch_model.bin:   0%|          | 32.0k/388M [00:00<?, ?B/s]

To https://user:eOwfuFZJHbcMgbzVtVPDaSGtpbpjumsgTzZtfKlrMbSECzypnCYHZGDhHVsHRsYZzvdrkcxbnnSXRROfqdNRYfMvVfaVSOTxORkEUcMnAPEWXhkWpVEDrgfUZJdmleTx@huggingface.co/raynardj/classical-chinese-punctuation-guwen-biaodian
   da1b1fa..163772b  main -> main



'https://huggingface.co/raynardj/classical-chinese-punctuation-guwen-biaodian/commit/163772b14564fa2930b1460f48be30fa7c9f8438'

In [34]:
tokenizer.push_to_hub(TAG)

To https://user:eOwfuFZJHbcMgbzVtVPDaSGtpbpjumsgTzZtfKlrMbSECzypnCYHZGDhHVsHRsYZzvdrkcxbnnSXRROfqdNRYfMvVfaVSOTxORkEUcMnAPEWXhkWpVEDrgfUZJdmleTx@huggingface.co/raynardj/classical-chinese-punctuation-guwen-biaodian
   163772b..c83256b  main -> main



'https://huggingface.co/raynardj/classical-chinese-punctuation-guwen-biaodian/commit/c83256b9ba08883a91c78512cce496b3cebe27a5'

In [36]:
ner = pipeline("ner",module.model,tokenizer=tokenizer)

In [37]:
def mark_sentence(x: str):
    outputs = ner(x)
    x_list = list(x)
    for i, output in enumerate(outputs):
        x_list.insert(output['end']+i, output['entity'])
    return "".join(x_list)

In [42]:
mark_sentence("""是书虽称文粹实与地志相表里东南文献多借是以有征与范成大呉郡志相辅而行亦如骖有靳矣乾隆四十二年三月恭校上""")

'是书虽称文粹，实与地志相表里。东南文献多借。是以有征与范成大呉郡志相辅而行，亦如骖有靳矣。乾隆四十二年三月，恭校上。'

In [47]:
mark_sentence("""郡邑置夫子庙于学以嵗时释奠盖自唐贞观以来未之或改我宋有天下因其制而损益之姑苏当浙右要区规模尤大更建炎戎马荡然无遗虽修学宫于荆榛瓦砾之余独殿宇未遑议也每春秋展礼于斋庐已则置不问殆为阙典今寳文阁直学士括苍梁公来牧之明年实绍兴十有一禩也二月上丁修祀既毕乃愓然自咎揖诸生而告之曰天子不以汝嘉为不肖俾再守兹土顾治民事神皆守之职惟是夫子之祀教化所基尤宜严且谨而拜跪荐祭之地卑陋乃尔其何以掲防妥灵汝嘉不敢避其责曩常去此弥年若有所负尚安得以罢輭自恕复累后人乎他日或克就绪愿与诸君落之于是谋之僚吏搜故府得遗材千枚取赢资以给其费鸠工庀役各举其任嵗月讫工民不与知像设礼器百用具修至于堂室廊序门牖垣墙皆一新之""")

'郡邑，置夫子庙于学，以嵗时释奠。盖自唐贞观以来，未之或改。我宋有天下因其制而损益之。姑苏当浙右要区，规模尤大，更建炎戎马，荡然无遗。虽修学宫于荆榛瓦砾之余，独殿宇未遑议也。每春秋展礼于斋庐，已则置不问，殆为阙典。今寳文阁直学士括苍梁公来牧之。明年，实绍兴十有一禩也。二月，上丁修祀既毕，乃愓然自咎，揖诸生而告之曰"天子不以汝嘉为不肖，俾再守兹土，顾治民事，神皆守之职。惟是夫子之祀，教化所基，尤宜严且谨。而拜跪荐祭之地，卑陋乃尔。其何以掲防妥灵？汝嘉不敢避其责。曩常去此弥年，若有所负，尚安得以罢輭自恕，复累后人乎！他日或克就绪，愿与诸君落之。于是谋之，僚吏搜故府，得遗材千枚，取赢资以给其费。鸠工庀役，各举其任。嵗月讫，工民不与知像，设礼器，百用具修。至于堂室。廊序。门牖。垣墙，皆一新之。'