In [1]:
import torch
from transformers import BertTokenizer, BertForTokenClassification, Trainer, TrainingArguments
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
import wandb

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


## 1. Load your dataset


In [2]:
def process_sample(batch):
    tokens_list = []
    tags_list = []
    tokens = []
    tags = []
    
    for line in batch['text']:
        if line:  # non-empty line means we have a token-tag pair
            token, tag = line.split()  # assuming space is the delimiter
            tokens.append(token)
            tags.append(tag)
        else:  # empty line means end of sentence
            tokens_list.append(tokens)
            tags_list.append(tags)
            tokens = []
            tags = []
    
    # Add remaining tokens and tags if there's any
    if tokens:
        tokens_list.append(tokens)
        tags_list.append(tags)
    
    return {'tokens': tokens_list, 'tags': tags_list}
    

data_files = {
    'train': 'data/filtered_train.txt',
    'validation': 'data/filtered_val.txt',
    'test': 'data/filtered_test.txt'
}

# Load the dataset from local files without specifying a script
dataset = load_dataset('text', data_files=data_files)
pdataset = dataset.map(process_sample, batched=True, remove_columns=['text'])


Found cached dataset text (/home/shemmati/.cache/huggingface/datasets/text/default-406edee888e87088/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 111.24it/s]
Loading cached processed dataset at /home/shemmati/.cache/huggingface/datasets/text/default-406edee888e87088/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2/cache-c5b419e1e4e695c4.arrow
Loading cached processed dataset at /home/shemmati/.cache/huggingface/datasets/text/default-406edee888e87088/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2/cache-0d798682e4b5b2f2.arrow
Loading cached processed dataset at /home/shemmati/.cache/huggingface/datasets/text/default-406edee888e87088/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2/cache-a2063a6600fe02

## 2. Load the tokenizer and model


In [16]:
id2label = {
    0: "O",
    1: "B-ap_name1",
    2: "I-ap_name1",
    3: "B-vz1",
    4: "I-vz1",
    5: "B-coordx1",
    6: "I-coordx1",
    7: "B-coordy1",
    8: "I-coordy1",
    9: "B-type1",
    10: "I-type1",
}
label2id = {"O": 0,
          "B-ap_name1": 1,
          "I-ap_name1": 2,
          "B-vz1": 3,
          "I-vz1": 4,
          "B-coordx1": 5,
          "I-coordx1": 6,
          "B-coordy1": 7,
          "I-coordy1": 8,
          "B-type1": 9,
          "I-type1": 10,
         }

In [18]:
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER", add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER", num_labels=11, id2label=id2label, label2id=label2id,ignore_mismatched_sizes=True)
#### this ignore size mismatch might cause major issues ...


# Define a data collator to handle token-level tasks (like NER)
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer)


Some weights of BertForTokenClassification were not initialized from the model checkpoint at dslim/bert-base-NER and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([9]) in the checkpoint and torch.Size([11]) in the model instantiated
- classifier.weight: found shape torch.Size([9, 768]) in the checkpoint and torch.Size([11, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 3. Tokenize the dataset


In [19]:
def recursive_label2id_conversion(label, label2id):
    if isinstance(label, str):
        return label2id[label]
    elif isinstance(label, list):
        return [recursive_label2id_conversion(l, label2id) for l in label]
    else:
        raise ValueError("Unsupported label type")
        
def tokenize_and_align_labels2(examples, label2id):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"tags"]):
        converted_label = recursive_label2id_conversion(label, label2id)

        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(converted_label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

#tokenized_datasets = pdataset.map(tokenize_function, batched=True, num_proc=4)
tokenized_datap = pdataset.map(tokenize_and_align_labels2, batched=True, fn_kwargs={"label2id": label2id},num_proc=4)


                                                                                                                                                                                                                 

## 4. Train


In [21]:
wandb.init(project='NEDAI',name='try1')

model.to(device)

training_args = TrainingArguments(
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    logging_dir="./logs",
    report_to="wandb",  # Log to wandb
    logging_steps=100,
    do_train=True,
    do_eval=True,
    output_dir="./results",
)

# Define the Trainer
trainer = Trainer(
    model = model,
    args = training_args,
    data_collator = data_collator,
    train_dataset = tokenized_datap["train"],
    eval_dataset = tokenized_datap["validation"],
    tokenizer = tokenizer,
)

# Train the model
trainer.train()
wandb.finish()

# Save the model
model.save_pretrained("./ner_model")
tokenizer.save_pretrained("./ner_model")

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss
1,0.0216,0.026405
2,0.0148,0.028592
3,0.0064,0.039019


0,1
eval/loss,▁▂█
eval/runtime,█▂▁
eval/samples_per_second,▁▇█
eval/steps_per_second,▁▇█
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁▁
train/loss,█▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁
train/train_loss,▁

0,1
eval/loss,0.03902
eval/runtime,7.5016
eval/samples_per_second,356.326
eval/steps_per_second,44.657
train/epoch,3.0
train/global_step,9129.0
train/learning_rate,0.0
train/loss,0.0064
train/total_flos,4959303043863270.0
train/train_loss,0.01782


('./ner_model/tokenizer_config.json',
 './ner_model/special_tokens_map.json',
 './ner_model/vocab.txt',
 './ner_model/added_tokens.json',
 './ner_model/tokenizer.json')

In [23]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained("./ner_model/")
model = AutoModelForTokenClassification.from_pretrained("./ner_model/")

nlp = pipeline("ner", model=model, tokenizer=tokenizer)
example = "In this work we studied in detail galaxies M45 and NGC 2132 which are at z=0.1 and they appeared to be red."

ner_results = nlp(example)
print(ner_results)

[{'entity': 'B-ap_name1', 'score': 0.9913223, 'index': 12, 'word': 'N', 'start': 51, 'end': 52}, {'entity': 'B-ap_name1', 'score': 0.90762424, 'index': 13, 'word': '##GC', 'start': 52, 'end': 54}, {'entity': 'I-ap_name1', 'score': 0.9939843, 'index': 14, 'word': '213', 'start': 55, 'end': 58}]
