In [5]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Character Language Model of CIM

This notebook implements the character language model using the Hugging Face 
transformers’ implementations of BERT and BART. It includes:
  - A custom top_k_top_p_filtering function (Option B)
  - A custom Damerau–Levenshtein distance implementation (to remove dependency on the external package)
  
"""

import os
import json
import copy
import argparse
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import multiprocessing

# Set a constant for maximum character positions
DEFAULT_MAX_CHARACTER_POSITIONS = 64

# Use a raw string for the character tokens to avoid escape sequence warnings.
char_tokens = list(r"0123456789abcdefghijklmnopqrstuvwxyz+-*/^.,;:=!?'()[]{}&")
special_tokens_fairseq = ['<s>', '<pad>', '</s>', '<unk>']
special_tokens_bert = ['[CLS]', '[PAD]', '[SEP]', '[UNK]']  # Note: CLS and SEP do not exactly match.

# --- Custom Filtering Function (Option B) ---
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
    """
    Filters logits using top-k and nucleus (top-p) filtering.
    This simple version sets logits outside the top k (and beyond cumulative top-p)
    to filter_value.
    """
    # Make a copy to avoid in-place modifications.
    logits = logits.clone()
    if top_k > 0:
        topk_values, _ = torch.topk(logits, top_k)
        kth_value = topk_values[..., -1, None]
        logits = torch.where(logits < kth_value, torch.full_like(logits, filter_value), logits)
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits

# --- Custom Damerau–Levenshtein Implementation ---
def damerauLevenshtein(seq1, seq2, similarity=False):
    """
    Compute the Damerau–Levenshtein distance between two sequences.
    This implementation supports insertion, deletion, substitution, and adjacent transpositions.
    The 'similarity' parameter is ignored in this version.
    """
    len1 = len(seq1)
    len2 = len(seq2)
    d = [[0] * (len2 + 1) for _ in range(len1 + 1)]
    for i in range(len1 + 1):
        d[i][0] = i
    for j in range(len2 + 1):
        d[0][j] = j
    for i in range(1, len1 + 1):
        for j in range(1, len2 + 1):
            cost = 0 if seq1[i - 1] == seq2[j - 1] else 1
            d[i][j] = min(
                d[i - 1][j] + 1,         # deletion
                d[i][j - 1] + 1,         # insertion
                d[i - 1][j - 1] + cost   # substitution
            )
            if i > 1 and j > 1 and seq1[i - 1] == seq2[j - 2] and seq1[i - 2] == seq2[j - 1]:
                d[i][j] = min(d[i][j], d[i - 2][j - 2] + cost)  # transposition
    return d[len1][len2]

print("Custom filtering and Damerau-Levenshtein functions loaded.")


Custom filtering and Damerau-Levenshtein functions loaded.


In [17]:
#################################
# Classes and Model Definitions #
#################################

class BeamHypotheses(object):
    def __init__(self, num_beams, max_length, length_penalty, early_stopping):
        """
        Initialize n-best list of hypotheses.
        """
        self.max_length = max_length - 1  # ignoring BOS token
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
        self.num_beams = num_beams
        self.beams = []
        self.worst_score = 1e9

    def __len__(self):
        return len(self.beams)

    def add(self, hyp, sum_logprobs, sum_logprobs2=0.0):
        score = sum_logprobs / len(hyp) ** self.length_penalty
        score2 = sum_logprobs2 / len(hyp) ** self.length_penalty
        if len(self) < self.num_beams or score > self.worst_score:
            self.beams.append((score, hyp, score2))
            if len(self) > self.num_beams:
                sorted_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs, cur_len):
        if len(self) < self.num_beams:
            return False
        elif self.early_stopping:
            return True
        else:
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
            return self.worst_score >= cur_score

# Note: We now subclass BartPreTrainedModel (as recommended) instead of the deprecated PretrainedBartModel.
class CharacterLanguageModel(transformers.BartPreTrainedModel):
    def __init__(self, args, config, bert_model, char_decoder):
        super().__init__(config)
        self.args = args
        self.config = config
        self.final_logits_bias = torch.nn.parameter.Parameter(torch.zeros(1, self.config.vocab_size))
        self.encoder = bert_model.bert
        self.decoder = char_decoder
        self.train_bert = args.train_bert
        if not self.train_bert:
            for param in self.encoder.parameters():
                param.requires_grad = False
        self.bos, self.pad, self.eos, self.unk = special_tokens_fairseq
        self.ed_pool_master = multiprocessing.pool.ThreadPool(processes=1)
        self.ed_pool_worker = multiprocessing.Pool(processes=4)

    @classmethod
    def get_bert_config(cls, bert_config_file):
        with open(bert_config_file, 'r') as fd:
            config_orig = json.load(fd)
        bert_config = transformers.BertConfig(
            vocab_size=config_orig['vocab_size'],
            hidden_size=config_orig['hidden_size'],
            num_hidden_layers=config_orig['num_hidden_layers'],
            num_attention_heads=config_orig['num_attention_heads'],
            intermediate_size=config_orig['intermediate_size'],
            hidden_act=config_orig['hidden_act'],
            hidden_dropout_prob=config_orig['hidden_dropout_prob'],
            attention_probs_dropout_prob=config_orig['attention_probs_dropout_prob'],
            max_position_embeddings=config_orig['max_position_embeddings'],
            type_vocab_size=config_orig['type_vocab_size'],
            initializer_range=config_orig['initializer_range'],
            layer_norm_eps=1e-12
        )
        return bert_config

    @classmethod
    def build_bert_model(cls, bert_config):
        bert_model = transformers.BertForPreTraining(bert_config)
        return bert_model

    @classmethod
    def get_bert_tokenizer(cls, bert_vocab_file):
        bert_tokenizer = transformers.BertTokenizer(
            bert_vocab_file,
            do_lower_case=True,
            do_basic_tokenize=True
        )
        return bert_tokenizer

    @classmethod
    def get_char_embeddings_from_bert(cls, bert_embeddings, bert_tokenizer):
        char_word_ids = []
        char_word_ids.append(bert_tokenizer.cls_token_id)  # CLS -> BOS
        char_word_ids.append(bert_tokenizer.pad_token_id)
        char_word_ids.append(bert_tokenizer.sep_token_id)  # SEP -> EOS
        char_word_ids.append(bert_tokenizer.unk_token_id)
        char_word_ids.extend(bert_tokenizer.convert_tokens_to_ids(char_tokens))
        if isinstance(bert_embeddings, torch.nn.modules.sparse.Embedding):
            embedding_matrix = bert_embeddings(torch.tensor(char_word_ids)).detach()
        elif isinstance(bert_embeddings, torch.Tensor):
            embedding_matrix = bert_embeddings[char_word_ids, :].detach()
        char_embeddings = torch.nn.Embedding.from_pretrained(embedding_matrix,
                                freeze=False, padding_idx=1)
        return char_embeddings

    @classmethod
    def get_char_decoder_config(cls, bert_config_file, args):
        with open(bert_config_file, 'r') as fd:
            config_orig = json.load(fd)
        bart_config = transformers.BartConfig(
            is_decoder=True,
            vocab_size=len(char_tokens) + len(special_tokens_fairseq),
            d_model=config_orig['hidden_size'],
            decoder_layers=args.decoder_layers,
            max_position_embeddings=DEFAULT_MAX_CHARACTER_POSITIONS,
            init_str=config_orig['initializer_range']
        )
        return bart_config

    @classmethod
    def build_char_decoder(cls, bart_config, char_embedding):
        """Create BART decoder for character LM."""
        from transformers.models.bart.modeling_bart import BartDecoder  # Updated module path
        bart_decoder = BartDecoder(bart_config, char_embedding)
        return bart_decoder

    def forward(self,
                input_ids_context,
                attention_mask_context=None,
                input_ids_correct=None,
                attention_mask_correct=None,
                target_correct=None,
                encoder_outputs=None,
                encoder_embeds=None,
                past_key_values=None,
                use_cache=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                **kwargs):
        if encoder_embeds is not None:
            encoder_outputs = (encoder_embeds,)
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids_context,
                attention_mask=attention_mask_context,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        # Remove EOS (token id 2) from correct input by converting them to PAD (token id 1)
        decoder_input_ids = input_ids_correct - input_ids_correct.eq(2).int()
        if attention_mask_correct is not None:
            decoder_padding_mask = attention_mask_correct.eq(0) | input_ids_correct.eq(2)
        else:
            decoder_padding_mask = decoder_input_ids.eq(1)
        _, _, causal_mask = transformers.modeling_bart._prepare_bart_decoder_inputs(
            self.config,
            input_ids=input_ids_correct,
            decoder_input_ids=None,
            decoder_padding_mask=attention_mask_correct,
            causal_mask_dtype=self.decoder.embed_tokens.weight.dtype,
        )
        decoder_outputs = self.decoder(
            decoder_input_ids,
            encoder_outputs[0],
            attention_mask_context,
            decoder_padding_mask,
            decoder_causal_mask=causal_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        lm_output_logits = F.linear(decoder_outputs[0], self.decoder.embed_tokens.weight,
                                    bias=self.final_logits_bias)
        output_logits = transformers.modeling_outputs.Seq2SeqLMOutput(
            logits=lm_output_logits,
        ).logits

        if not kwargs.get('return_encoder_embeds'):
            return output_logits
        else:
            return encoder_outputs[0], output_logits

    def get_encoder(self):
        return self.encoder

    def get_output_embeddings(self):
        from transformers.models.bart.modeling_bart import _make_linear_from_emb
        return _make_linear_from_emb(self.decoder.embed_tokens)

    def prepare_inputs_for_generation(self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs):
        return {
            "input_ids_context": None,
            "encoder_outputs": encoder_outputs,
            "past_key_values": past,
            "input_ids_correct": decoder_input_ids,
            "attention_mask_context": attention_mask,
            "use_cache": use_cache,
        }

    def adjust_logits_during_generation(self, logits, cur_len, max_length):
        if cur_len == 1 and self.config.force_bos_token_to_be_generated:
            self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
        elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
            self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
        return logits

    @staticmethod
    def _force_token_id_to_be_generated(scores, token_id) -> None:
        scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")

    def parameters(self):
        if len(list(super().parameters())):
            ret = super().parameters()
        elif hasattr(self, "_former_parameters"):
            ret = (p for m in self.modules() for p in m._former_parameters.values())
        elif hasattr(self, "_parameters"):
            ret = (p for m in self.modules() for p in m._parameters.values())
        return ret

    def _compute_edit_distances(self, input_ids_list, typo_token_ids_strip, vocab_size, edit_distance_extra_len):
        batch_size = len(typo_token_ids_strip)
        num_beams = len(input_ids_list) // batch_size
        cur_len = len(input_ids_list[0])
        dataset = [
            (input_ids_list[i * num_beams : (i + 1) * num_beams],
             typo_token_ids_strip[i],
             vocab_size,
             edit_distance_extra_len)
            for i in range(batch_size)
        ]
        num_process = min(batch_size, 4)
        edit_distance_mat = list(self.ed_pool_worker.imap(_edit_distance_pool_job, dataset, chunksize=len(dataset) // num_process))
        return edit_distance_mat

    def get_param_group(self, finetune_bert=False, finetune_factor=0.1):
        if not finetune_bert:
            param_group = [{'params': list(self.parameters()), 'lr_factor': 1.0}]
        else:
            bert_params = list(self.encoder.parameters())
            bert_params_set = set(bert_params)
            other_params = [p for p in list(self.parameters()) if p not in bert_params_set]
            param_group = [{'params': other_params, 'lr_factor': 1.0},
                           {'params': bert_params, 'lr_factor': finetune_factor}]
        return param_group

    @staticmethod
    def get_set_lr():
        def set_lr(self, lr):
            for param_group in self.optimizer.param_groups:
                if 'lr_factor' in param_group:
                    param_group['lr'] = lr * param_group['lr_factor']
                else:
                    param_group['lr'] = lr
        return set_lr

def _edit_distance_pool_job(data):
    input_ids_ex, typo_token_ids_strip_ex, vocab_size, edit_distance_extra_len = data
    num_beams, cur_len = len(input_ids_ex), len(input_ids_ex[0])
    ed_submat = [[0.0 for _ in range(vocab_size)] for _ in range(num_beams)]
    # Use our custom damerauLevenshtein function defined above.
    typo_token_ids_temp = typo_token_ids_strip_ex[: cur_len + 1 + edit_distance_extra_len]
    for idx1 in range(num_beams):
        temp = input_ids_ex[idx1] + [0]
        for idx2 in range(vocab_size):
            temp[-1] = idx2
            ed_submat[idx1][idx2] = damerauLevenshtein(temp, typo_token_ids_temp, similarity=False)
    return ed_submat

def _trie_score_pool_job(input_ids_ex, trie, vocab_size):
    num_examples, cur_len = len(input_ids_ex), len(input_ids_ex[0])
    score = [[float('-inf') for _ in range(vocab_size)] for _ in range(num_examples)]
    for i, prefix_ids in enumerate(input_ids_ex):
        cand_ids = trie.get_candidate_chars(prefix_ids[1:])  # ignore BOS token
        for j in cand_ids:
            score[i][j] = 0.0
    return score


In [19]:
class CharTokenizer(object):
    def __init__(self, max_length=DEFAULT_MAX_CHARACTER_POSITIONS):
        self.max_length = max_length
        self.bos, self.pad, self.eos, self.unk = special_tokens_fairseq
        self.bos_index, self.pad_index, self.eos_index, self.unk_index = 0, 1, 2, 3
        self.char_to_id = {}
        self.id_to_char = {}
        for i, c in enumerate(special_tokens_fairseq + char_tokens):
            self.char_to_id[c] = i
            self.id_to_char[i] = c

    def tokenize(self, text, eos_bos=True, padding_end=False, max_length=None, output_token_ids=False):
        assert isinstance(text, str)
        max_seq_len = self.max_length - 2 if eos_bos else self.max_length
        tokens = []
        attention_mask = []
        for c in text[:max_seq_len]:
            if c in self.char_to_id:
                tokens.append(c)
            else:
                tokens.append(self.unk)
            attention_mask.append(1)
        if eos_bos:
            tokens.insert(0, '<s>')
            tokens.append(self.eos)
            attention_mask.insert(0, 1)
            attention_mask.append(1)
        if padding_end:
            max_length = max_length if max_length is not None else self.max_length
            while len(tokens) < max_length:
                tokens.append(self.pad)
                attention_mask.append(0)
        if output_token_ids:
            return self.convert_tokens_to_ids(tokens), attention_mask
        else:
            return tokens, attention_mask

    def convert_tokens_to_ids(self, tokens):
        return [self.char_to_id[t] for t in tokens]

    def convert_ids_to_tokens(self, ids):
        if isinstance(ids, torch.Tensor):
            ids = ids.cpu().detach().tolist()
        return [self.id_to_char[i] for i in ids]

class SmoothCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(SmoothCrossEntropyLoss, self).__init__()

    def forward(self, logits, target, smoothing=0.0):
        if target.dim() == logits.dim() - 1:
            target = target.unsqueeze(-1)
        lprobs = F.log_softmax(logits, dim=-1)
        nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
        smooth_loss = -lprobs.sum(dim=-1)
        eps_i = smoothing / lprobs.size(-1)
        loss = (1.0 - smoothing) * nll_loss + eps_i * smooth_loss
        return loss, nll_loss

class CrossEntropyLoss(nn.Module):
    def __init__(self):
        super(CrossEntropyLoss, self).__init__()

    def forward(self, logits, targets=None, target_mask=None):
        if targets.dim() == logits.dim() - 1:
            targets = targets.unsqueeze(-1)
        lprobs = F.log_softmax(logits, dim=-1)
        nll_loss = -lprobs.gather(dim=-1, index=targets).squeeze(-1)
        if target_mask is not None:
            loss = ((nll_loss * target_mask).sum(-1) / target_mask.sum(-1)).mean()
        else:
            loss = nll_loss.mean()
        return loss

class Trie(object):
    def __init__(self, char_tokenizer, eos_token_id=2):
        self.char_tokenizer = char_tokenizer
        self.eos_token_id = eos_token_id
        self._trie = {}

    def add_word_ids(self, word_token_ids):
        trie = self._trie
        for token_id in word_token_ids:
            if token_id not in trie:
                trie[token_id] = {}
            trie = trie[token_id]
        trie[self.eos_token_id] = {}

    def add_words(self, words):
        for word in words:
            try:
                word_token_ids = self.char_tokenizer.convert_tokens_to_ids(word)
                self.add_word_ids(word_token_ids)
            except KeyError:
                continue

    def get_candidate_chars(self, prefix_ids):
        trie = self._trie
        for token_id in prefix_ids:
            if token_id not in trie:
                return []
            trie = trie[token_id]
        return trie.keys()


In [21]:
# Example: Initialize the model.
from types import SimpleNamespace

# Define your parameters inline.
args = SimpleNamespace(
    mimic_csv_dir="data/mimic3/split",
    data_dir="data/mimic_synthetic",
    dict_file="data/lexicon/lexicon_en.json",
    bert_dir="bert/ncbi_bert_base",
    output_dir="results/cim_base",
    is_train=True,
    test_file="test.tsv",
    init_ckpt=None,
    init_step=0,
    seed=123,
    decoder_layers=12,
    train_bert=True,
    bert_finetune_factor=1.0,
    dropout=0.1,
    synthetic_min_word_len=3,
    do_substitution=True,
    do_transposition=True,
    max_word_corruptions=2,
    no_corruption_prob=0.0,
    train_with_ed=False,
    batch_size=256,
    num_gpus=4,
    optimizer='adam',
    adam_betas="(0.9, 0.999)",
    adam_eps=1e-08,
    weight_decay=0.01,
    training_step=500000,
    display_iter=100,
    eval_iter=25000,
    lr=[0.0001],
    warmup_updates=10000,
    warmup_init_lr=1e-6,
    num_beams=3,
    edit_distance_weight=5.0,
    edit_distance_extra_len=100,
    length_penalty=1.0,
    beam_sort_linear_ed=False,
    beam_final_score_normalize_ed=False,
    dict_matching=True
)

# Set absolute paths for BERT configuration and vocabulary files
bert_config_file = r"C:\Users\chen5\Desktop\FinalP1\cim-misspelling-main\bert\ncbi_bert_base\bert_config.json"
bert_vocab_file = r"C:\Users\chen5\Desktop\FinalP1\cim-misspelling-main\bert\ncbi_bert_base\vocab.txt"

# Try loading the configuration and initializing the BERT model and its tokenizer
try:
    # Load configuration from the JSON file
    bert_config = CharacterLanguageModel.get_bert_config(bert_config_file)
    print("BERT configuration loaded successfully.")

    # Build the BERT model using the configuration
    bert_model = CharacterLanguageModel.build_bert_model(bert_config)
    print("BERT model built successfully.")

    # Create the tokenizer from the vocabulary file
    bert_tokenizer = CharacterLanguageModel.get_bert_tokenizer(bert_vocab_file)
    print("BERT tokenizer created successfully.")
except Exception as e:
    print("Error initializing BERT components:", e)



BERT configuration loaded successfully.
BERT model built successfully.
BERT tokenizer created successfully.
