## Fine-tuning mamba for Named Entry Recognition (NER)

In [1]:
import torch
print(torch.cuda.is_available())

True


In [2]:
from datasets import load_dataset
dataset = load_dataset("wikiann", "en")

model_ckpt = "state-spaces/mamba-790m-hf"

  from .autonotebook import tqdm as notebook_tqdm


In [85]:
dataset['train'][5]

{'tokens': ['St.',
  'Mary',
  "'s",
  'Catholic',
  'Church',
  '(',
  'Sandusky',
  ',',
  'Ohio',
  ')'],
 'ner_tags': [3, 4, 4, 4, 4, 4, 4, 4, 4, 4],
 'langs': ['en', 'en', 'en', 'en', 'en', 'en', 'en', 'en', 'en', 'en'],
 'spans': ["ORG: St. Mary 's Catholic Church ( Sandusky , Ohio )"]}

In [3]:
import torch.nn as nn
from transformers import MambaConfig
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.mamba.modeling_mamba import MambaModel, MambaPreTrainedModel



class MambaForTokenClassification(MambaPreTrainedModel):
    config_class = MambaConfig
    model_ckpt = "state-spaces/mamba-790m-hf"
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.mamba = MambaModel.from_pretrained(model_ckpt, config=config)
        self.dropout = nn.Dropout(p=0.2)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
        outputs = self.mamba(input_ids, attention_mask=attention_mask, **kwargs) # hidden state(latent expression)는 [0]으로 접근 가능
        sequence_output = self.dropout(outputs['last_hidden_state'])
        logits = self.classifier(sequence_output)
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        return TokenClassifierOutput(loss=loss, logits=logits,
                                     hidden_states=outputs.last_hidden_state)     # named tuple로 결과를 참조할 수 있도록

In [None]:
from transformers import AutoConfig

tag2index = {
    "O": 0, "B-PER": 1, "I-PER": 2, "B-ORG": 3, "I-ORG": 4, "B-LOC": 5, "I-LOC": 6
}
index2tag = {tag2index[tag]:tag for tag in tag2index}
mamba_config = AutoConfig.from_pretrained(model_ckpt, num_labels=len(tag2index), id2label=index2tag, label2id=tag2index)
mamba_model = (MambaForTokenClassification
               .from_pretrained(model_ckpt, config=mamba_config)
               .to("cuda"))
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
tokenizer_output = tokenizer("Hello, this is an example text for you", return_tensors="pt").to("cuda")
mamba_model.forward(**tokenizer_output)

In [5]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for idx, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=idx)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None or word_idx == previous_word_idx:
                label_ids.append(-100)
            else:
                label_ids.append(label[word_idx])
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

def encode_panx_dataset(corpus):
    return corpus.map(tokenize_and_align_labels, batched=True, remove_columns=['langs', 'ner_tags', 'tokens'])

dataset_decoded = dataset.map(tokenize_and_align_labels, batched=True, remove_columns=['langs', 'ner_tags', 'tokens'])

In [6]:
import numpy as np
def align_predictions(predictions, label_ids):
    preds = np.argmax(predictions, axis=2)
    batch_size, seq_len = preds.shape
    labels_list, preds_list = [], []

    for batch_idx in range(batch_size):
        example_labels, example_preds = [], []
        for seq_idx in range(seq_len):
            if label_ids[batch_idx, seq_idx] != -100:
                example_labels.append(index2tag[label_ids[batch_idx][seq_idx]])
                example_preds.append(index2tag[preds[batch_idx][seq_idx]])

        labels_list.append(example_labels)
        preds_list.append(example_preds)
    return preds_list, labels_list

In [7]:
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorForTokenClassification

training_args = TrainingArguments(
    output_dir="./result",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=32,
    save_steps=100,
    logging_steps=100,
    push_to_hub=False,
    evaluation_strategy="steps",
    eval_steps=100,
    weight_decay=0.01,
    gradient_accumulation_steps=4,
)

data_collator= DataCollatorForTokenClassification(tokenizer)  

trainer = Trainer(
    model=mamba_model,
    data_collator=data_collator,
    args=training_args,
    train_dataset=dataset_decoded['train'],
    eval_dataset=dataset_decoded['test'].shuffle().select(range(64)),
)


trainer.train()




Step,Training Loss,Validation Loss
100,4.1143,0.439956
200,2.1269,0.379877
300,1.8978,0.299387
400,1.8537,0.306148
500,1.6485,0.309016
600,1.7569,0.265519
700,1.6369,0.270684
800,1.6069,0.266185
900,1.5865,0.254824
1000,1.6823,0.253208


TrainOutput(global_step=1250, training_loss=1.9025521545410156, metrics={'train_runtime': 4400.3376, 'train_samples_per_second': 4.545, 'train_steps_per_second': 0.284, 'total_flos': 1804900566516192.0, 'train_loss': 1.9025521545410156, 'epoch': 1.0})

In [132]:
import pandas as pd

data_id = 50

tag2index = {
    "O": 0, "B-PER": 1, "I-PER": 2, "B-ORG": 3, "I-ORG": 4, "B-LOC": 5, "I-LOC": 6
}
tokens = dataset['validation'][data_id]['tokens']
labels = dataset['validation'][data_id]['ner_tags']
token_ids = torch.IntTensor(tokenizer.convert_tokens_to_ids(tokens)).to("cuda").unsqueeze(0)
output = mamba_model.forward(input_ids=token_ids)
output = output['logits'].squeeze().cpu().detach().numpy().argmax(axis=-1).tolist()

# tokenizer.decode()
pd.DataFrame([tokens, [index2tag[e] for e in output], [index2tag[e] for e in labels]], index=["Tokens", "result", "label"])

Unnamed: 0,0,1,2,3
Tokens,Carthay,",",Los,Angeles
result,O,O,B-LOC,O
label,B-ORG,I-ORG,I-ORG,I-ORG
