In [1]:
from datasets import Dataset
import json

# Load the dataset from a JSON file
with open('../data/corpus/train.json', 'r') as f:
    data = json.load(f)

d = {
    'slot': [item['slot'] for item in data],
    'text': [item['text'] for item in data],
    'position': [item['positions'] for item in data]
}


# create dataset from dict (train split)
dataset = Dataset.from_dict(d)
# 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = dataset.train_test_split(test_size=0.025)

In [3]:
from transformers import AutoTokenizer

model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [4]:
def classes(data):
    l = []
    for item in data:
        l.extend(item['slot'].keys())

    list_set = set(l)
    length = len(list_set)
    return list(list_set), length

In [5]:

class_labels, no_classes = classes(data)
print(class_labels)
print(no_classes)

['artist', 'credit', 'bpm_lower_than', 'pump_routine', 'warps', 'fakes', 'pump_halfdouble', 'scrolls', 'pump_single', 'mix', 'stops', 'meter_greater_than', 'tune', 'pump_double', 'speeds', 'bpm_greater_than', 'meter_lower_than', 'pump_couple', 'meter', 'bpm']
20


In [6]:
def class_mapper(class_labels):
    no_classes = len(class_labels)
    d = {}
    # 0 is reserved for no class
    d['O'] = 0
    for i in range(1, no_classes + 1):
        d[class_labels[i-1]] = i
    
    # and another dictionary to map the index to the class label

    d_reverse = {}
    d_reverse[0] = 'O'
    for i in range(1, no_classes + 1):
        d_reverse[i] = class_labels[i-1]

    return d, d_reverse

mapper, unmapper = class_mapper(class_labels)
print(mapper, unmapper)


{'O': 0, 'artist': 1, 'credit': 2, 'bpm_lower_than': 3, 'pump_routine': 4, 'warps': 5, 'fakes': 6, 'pump_halfdouble': 7, 'scrolls': 8, 'pump_single': 9, 'mix': 10, 'stops': 11, 'meter_greater_than': 12, 'tune': 13, 'pump_double': 14, 'speeds': 15, 'bpm_greater_than': 16, 'meter_lower_than': 17, 'pump_couple': 18, 'meter': 19, 'bpm': 20} {0: 'O', 1: 'artist', 2: 'credit', 3: 'bpm_lower_than', 4: 'pump_routine', 5: 'warps', 6: 'fakes', 7: 'pump_halfdouble', 8: 'scrolls', 9: 'pump_single', 10: 'mix', 11: 'stops', 12: 'meter_greater_than', 13: 'tune', 14: 'pump_double', 15: 'speeds', 16: 'bpm_greater_than', 17: 'meter_lower_than', 18: 'pump_couple', 19: 'meter', 20: 'bpm'}


In [7]:
def align_positions_with_tokens(input_ids, offset_mapping, position):
    begin_end_tokens = [(offset_mapping[i][0], offset_mapping[i][1])  for i, token in enumerate(input_ids)]
    labels = [ 0 for i in range(len(input_ids))]
    labels[0] = -100
    # find index of SEP token (102)
    sep_index = input_ids.index(102)
    # from sep_index to the end of the list, set the label to -100
    for i in range(sep_index, len(labels)):
        labels[i] = -100

    
    begin_end_tokens = begin_end_tokens[1:-1]
    for key, val in position.items():
        if (val != None):
            begin_gt = val['begin']
            end_gt = val['end']

            class_label = key


            # find the indices of the tokens that contain the begin and end of the ground truth
            begin_token = [i for i, token in enumerate(begin_end_tokens) if token[0] >= begin_gt and token[0] < end_gt ]
            end_token = [i for i, token in enumerate(begin_end_tokens) if token[1] == end_gt ]


            # create list of indices of tokens that are part of the ground truth
            try: 
                tokens = [i for i in range(begin_token[0], end_token[0]+1)]
            except Exception as e:
                print(begin_end_tokens)
                print(f'begin: {begin_gt}, end: {end_gt}, class: {class_label}')
                print(f'begin_token: {begin_token}, end_token: {end_token}')
                print('---')
                raise e

            # set the label of the tokens that are part of the ground truth
            for token in tokens:
                labels[token + 1] = mapper[class_label]
    
    return labels

sample = dataset['train'][12239]
inputs = tokenizer(sample["text"], is_split_into_words=False, return_offsets_mapping=True)
print(sample['text'])
print(inputs.tokens())
labels = align_positions_with_tokens(inputs['input_ids'], inputs['offset_mapping'], sample['position'])
print(labels)


    

album EXCEED~ZE
['[CLS]', 'album', 'exceed', '~', 'ze', '[SEP]']
[-100, 0, 10, 10, 10, -100]


In [8]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["text"], is_split_into_words=False, return_offsets_mapping=True, truncation=True, padding=True, max_length=48
    )
    new_labels = []
    for i, position in enumerate(examples['position']):
        try:
            new_labels.append(align_positions_with_tokens(tokenized_inputs['input_ids'][i], tokenized_inputs['offset_mapping'][i], position))
        except Exception as e:
            print( examples['text'][i])
            print(tokenized_inputs.tokens(i))
            print(f'index: {i}')
            raise e

    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

In [9]:
tokenized_datasets = dataset.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=["slot", "text", "position"],
)

Map: 100%|██████████| 97500/97500 [00:07<00:00, 12942.10 examples/s]
Map: 100%|██████████| 2500/2500 [00:00<00:00, 14671.39 examples/s]


In [10]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [11]:
import evaluate

metric = evaluate.load("seqeval")

In [12]:
def decode_labels(labels):
    return [unmapper[label] if label != -100 else 'O' for label in labels]

In [13]:

decoded_labels = decode_labels(tokenized_datasets['train']["labels"][0])
print(decoded_labels)
pred = decoded_labels.copy()
pred[1] = 'bpm_greater_than'

metric.compute(predictions=[pred], references=[decoded_labels], )

['O', 'stops', 'stops', 'O', 'O', 'meter_lower_than', 'O', 'O', 'O', 'bpm_greater_than', 'tune', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']




{'eter_lower_than': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1},
 'pm_greater_than': {'precision': 0.5,
  'recall': 1.0,
  'f1': 0.6666666666666666,
  'number': 1},
 'tops': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1},
 'une': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1},
 'overall_precision': 0.6,
 'overall_recall': 0.75,
 'overall_f1': 0.6666666666666665,
 'overall_accuracy': 0.9629629629629629}

In [14]:
import numpy as np
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[ unmapper[label] for label in sample if label != -100] for sample in labels]
    true_predictions = [
        [unmapper[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

In [15]:
from transformers import DistilBertForTokenClassification 

model = DistilBertForTokenClassification.from_pretrained(
    model_checkpoint,
    id2label=unmapper,
    label2id=mapper,
)

model.config.num_labels = no_classes + 1


Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
from transformers import TrainingArguments

args = TrainingArguments(
    "distilbert-piu-search",
    evaluation_strategy='steps',
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    eval_steps=500,
    per_device_train_batch_size=16,
)

In [17]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)
trainer.train()

  0%|          | 0/18282 [00:00<?, ?it/s]You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  3%|▎         | 500/18282 [00:14<07:59, 37.08it/s]

{'loss': 0.4065, 'learning_rate': 1.9453013893447107e-05, 'epoch': 0.08}


                                                   
  3%|▎         | 505/18282 [00:15<31:01,  9.55it/s]

{'eval_loss': 0.05714792385697365, 'eval_precision': 0.9706809151938621, 'eval_recall': 0.976433296582139, 'eval_f1': 0.9735486087255237, 'eval_accuracy': 0.9834633092173259, 'eval_runtime': 1.4685, 'eval_samples_per_second': 1702.375, 'eval_steps_per_second': 213.137, 'epoch': 0.08}


  5%|▌         | 1000/18282 [00:30<08:08, 35.35it/s]

{'loss': 0.0414, 'learning_rate': 1.8906027786894213e-05, 'epoch': 0.16}


                                                    
  5%|▌         | 1005/18282 [00:31<30:46,  9.36it/s]

{'eval_loss': 0.0379437692463398, 'eval_precision': 0.9803679297089511, 'eval_recall': 0.9841510474090408, 'eval_f1': 0.9822558459422285, 'eval_accuracy': 0.9887249835572677, 'eval_runtime': 1.5014, 'eval_samples_per_second': 1665.168, 'eval_steps_per_second': 208.479, 'epoch': 0.16}


  8%|▊         | 1500/18282 [00:44<07:12, 38.77it/s]

{'loss': 0.0326, 'learning_rate': 1.8359041680341322e-05, 'epoch': 0.25}


                                                    
  8%|▊         | 1507/18282 [00:46<28:23,  9.85it/s]

{'eval_loss': 0.03233273699879646, 'eval_precision': 0.982147761603955, 'eval_recall': 0.9856670341786108, 'eval_f1': 0.983904250928601, 'eval_accuracy': 0.9901343606126092, 'eval_runtime': 1.4467, 'eval_samples_per_second': 1728.09, 'eval_steps_per_second': 216.357, 'epoch': 0.25}


 11%|█         | 2000/18282 [00:59<06:51, 39.60it/s]

{'loss': 0.0279, 'learning_rate': 1.7812055573788428e-05, 'epoch': 0.33}


                                                    
 11%|█         | 2006/18282 [01:00<26:54, 10.08it/s]

{'eval_loss': 0.025209957733750343, 'eval_precision': 0.9848797250859107, 'eval_recall': 0.9874586549062845, 'eval_f1': 0.9861675039570573, 'eval_accuracy': 0.9916376961383069, 'eval_runtime': 1.456, 'eval_samples_per_second': 1716.99, 'eval_steps_per_second': 214.967, 'epoch': 0.33}


 14%|█▎        | 2500/18282 [01:14<06:57, 37.78it/s]

{'loss': 0.0245, 'learning_rate': 1.7265069467235534e-05, 'epoch': 0.41}


                                                    
 14%|█▎        | 2505/18282 [01:15<26:57,  9.75it/s]

{'eval_loss': 0.025493081659078598, 'eval_precision': 0.9854215376151836, 'eval_recall': 0.9874586549062845, 'eval_f1': 0.9864390445377573, 'eval_accuracy': 0.991731654608663, 'eval_runtime': 1.4382, 'eval_samples_per_second': 1738.32, 'eval_steps_per_second': 217.638, 'epoch': 0.41}


 16%|█▋        | 3000/18282 [01:29<07:15, 35.11it/s]

{'loss': 0.0206, 'learning_rate': 1.671808336068264e-05, 'epoch': 0.49}


                                                    
 16%|█▋        | 3005/18282 [01:30<27:40,  9.20it/s]

{'eval_loss': 0.021747741848230362, 'eval_precision': 0.9857142857142858, 'eval_recall': 0.9889746416758545, 'eval_f1': 0.9873417721518987, 'eval_accuracy': 0.9928591562529362, 'eval_runtime': 1.5272, 'eval_samples_per_second': 1636.943, 'eval_steps_per_second': 204.945, 'epoch': 0.49}


 19%|█▉        | 3500/18282 [01:44<06:34, 37.45it/s]

{'loss': 0.0201, 'learning_rate': 1.6171097254129746e-05, 'epoch': 0.57}


                                                    
 19%|█▉        | 3505/18282 [01:46<26:40,  9.23it/s]

{'eval_loss': 0.0211471114307642, 'eval_precision': 0.9887161139397276, 'eval_recall': 0.9902149944873209, 'eval_f1': 0.9894649865730221, 'eval_accuracy': 0.9934229070750729, 'eval_runtime': 1.5386, 'eval_samples_per_second': 1624.888, 'eval_steps_per_second': 203.436, 'epoch': 0.57}


 22%|██▏       | 4000/18282 [02:00<06:27, 36.87it/s]

{'loss': 0.0185, 'learning_rate': 1.562411114757685e-05, 'epoch': 0.66}


                                                    
 22%|██▏       | 4005/18282 [02:01<25:29,  9.33it/s]

{'eval_loss': 0.018788378685712814, 'eval_precision': 0.989946288390029, 'eval_recall': 0.9906284454244763, 'eval_f1': 0.9902872494317008, 'eval_accuracy': 0.9942215540730996, 'eval_runtime': 1.5042, 'eval_samples_per_second': 1662.05, 'eval_steps_per_second': 208.089, 'epoch': 0.66}


 25%|██▍       | 4500/18282 [02:15<06:07, 37.55it/s]

{'loss': 0.0171, 'learning_rate': 1.5077125041023959e-05, 'epoch': 0.74}


                                                    
 25%|██▍       | 4505/18282 [02:17<25:15,  9.09it/s]

{'eval_loss': 0.016791068017482758, 'eval_precision': 0.9920165175498967, 'eval_recall': 0.9932469680264608, 'eval_f1': 0.9926313614764823, 'eval_accuracy': 0.9953490557173729, 'eval_runtime': 1.5678, 'eval_samples_per_second': 1594.578, 'eval_steps_per_second': 199.641, 'epoch': 0.74}


 27%|██▋       | 5000/18282 [02:31<05:50, 37.94it/s]

{'loss': 0.0153, 'learning_rate': 1.4530138934471065e-05, 'epoch': 0.82}


                                                    
 27%|██▋       | 5005/18282 [02:32<23:57,  9.23it/s]

{'eval_loss': 0.01705148071050644, 'eval_precision': 0.9929752066115702, 'eval_recall': 0.9935226019845645, 'eval_f1': 0.9932488288784789, 'eval_accuracy': 0.995396034952551, 'eval_runtime': 1.5385, 'eval_samples_per_second': 1624.914, 'eval_steps_per_second': 203.439, 'epoch': 0.82}


 30%|███       | 5500/18282 [02:46<05:34, 38.18it/s]

{'loss': 0.015, 'learning_rate': 1.3983152827918172e-05, 'epoch': 0.9}


                                                    
 30%|███       | 5505/18282 [02:48<22:46,  9.35it/s]

{'eval_loss': 0.012150906957685947, 'eval_precision': 0.9936665289825141, 'eval_recall': 0.994625137816979, 'eval_f1': 0.9941456023142088, 'eval_accuracy': 0.9967114535375364, 'eval_runtime': 1.5147, 'eval_samples_per_second': 1650.456, 'eval_steps_per_second': 206.637, 'epoch': 0.9}


 33%|███▎      | 6000/18282 [03:02<05:25, 37.76it/s]

{'loss': 0.0149, 'learning_rate': 1.3436166721365278e-05, 'epoch': 0.98}


                                                    
 33%|███▎      | 6005/18282 [03:04<23:50,  8.58it/s]

{'eval_loss': 0.011320582590997219, 'eval_precision': 0.9924304982108451, 'eval_recall': 0.9937982359426681, 'eval_f1': 0.9931138961575541, 'eval_accuracy': 0.9959597857746876, 'eval_runtime': 1.702, 'eval_samples_per_second': 1468.899, 'eval_steps_per_second': 183.906, 'epoch': 0.98}


 36%|███▌      | 6500/18282 [03:18<05:15, 37.36it/s]

{'loss': 0.0097, 'learning_rate': 1.2889180614812384e-05, 'epoch': 1.07}


                                                    
 36%|███▌      | 6505/18282 [03:20<21:31,  9.12it/s]

{'eval_loss': 0.012142508290708065, 'eval_precision': 0.9946280991735538, 'eval_recall': 0.9951764057331863, 'eval_f1': 0.9949021769082392, 'eval_accuracy': 0.9967584327727145, 'eval_runtime': 1.5616, 'eval_samples_per_second': 1600.903, 'eval_steps_per_second': 200.433, 'epoch': 1.07}


 38%|███▊      | 6913/18282 [03:31<05:03, 37.47it/s]

KeyboardInterrupt: 

In [18]:
# load model from disk checkpoint
model_cpu = DistilBertForTokenClassification.from_pretrained(
    "distilbert-piu-search/checkpoint-6094",
)

In [19]:
from transformers import pipeline
token_classifier = pipeline(
    "token-classification", model=model_cpu, tokenizer=tokenizer, aggregation_strategy="simple"
)


In [20]:

pred = token_classifier("EXC d 20", )
print(pred)


[{'entity_group': 'mix', 'score': 0.9676077, 'word': 'exc', 'start': 0, 'end': 3}, {'entity_group': 'pump_double', 'score': 0.9950138, 'word': 'd', 'start': 4, 'end': 5}, {'entity_group': 'meter', 'score': 0.9997477, 'word': '20', 'start': 6, 'end': 8}]


 38%|███▊      | 6914/18282 [03:45<05:03, 37.47it/s]

In [None]:
[ { "class": p['entity_group'], "word": p['word'], "start": p['start'], "end": p['end']} for p in pred]