In [98]:
from datasets import load_dataset,concatenate_datasets, load_from_disk

In [93]:
feature_column = ["tokens","ner_tags"]
split_list = ["train","validation","test"]

def remove_columns_from_dataset_dict(dataset_dict,feature_columns):
    assert sorted(split_list) == sorted(list(dataset_dict.keys())), "Dataset is not containing all splits for train,test,val"
    for split in split_list:
        remove_column_list = [col for col in list(dataset_dict[split].features) if col not in feature_column ]
        dataset_dict[split] = dataset_dict[split].remove_columns(remove_column_list)
    return dataset_dict


def merging_all_splits_from_dataset_dict(dataset1,dataset2):
    for split in split_list:
        assert dataset1[split].features.type == dataset2[split].features.type
        dataset1[split] = concatenate_datasets([dataset1[split],dataset2[split]])
    return dataset1    
            

# Preprocessing `wikiann`

In [None]:
wikiann= load_dataset("wikiann","en")

In [80]:
additional_selected_validation_wikiann = wikiann["validation"].train_test_split(test_size=0.5)
additional_selected_test_wikiann = wikiann["test"].train_test_split(test_size=0.5)

In [81]:
assert wikiann["train"].features.type == additional_selected_validation_wikiann["train"].features.type

In [82]:
wikiann["train"] = concatenate_datasets([additional_selected_test_wikiann["train"],wikiann["train"]])
wikiann["validation"] = additional_selected_validation_wikiann["test"]
wikiann["test"] = additional_selected_test_wikiann["test"]

In [85]:
wikiann_cleaned = remove_columns_from_dataset_dict(wikiann,feature_column)

In [86]:
wikiann_cleaned

DatasetDict({
    validation: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 5000
    })
    test: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 5000
    })
    train: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 25000
    })
})

In [96]:
wikiann_cleaned.save_to_disk("../data/wikiann")

# Preprocessing `conll2003`

In [87]:
conll = load_dataset("conll2003")

Reusing dataset conll2003 (/Users/philipp/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6)


In [88]:
conll_cleaned = remove_columns_from_dataset_dict(conll,feature_column)

In [89]:
conll_cleaned

DatasetDict({
    train: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 3453
    })
})

In [97]:
conll_cleaned.save_to_disk("../data/conll")

# Merging the datasets

In [99]:
loaded_conll = load_from_disk("../data/conll")
wikiann_cleaned = load_from_disk("../data/conll")

In [100]:
merged_dataset = merging_all_splits_from_dataset_dict(wikiann_cleaned,loaded_conll)

# Filter `ner_tags` to 3 or 4 classes

In [149]:
def change_label_to_zero(example):
    example["ner_tags"] = [0 if label==7 or label==8 else label for label in example["ner_tags"]]
    return example

In [150]:
conll_cleaned["train"] = conll_cleaned["train"].map(change_label_to_zero,batched=True)

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))




In [151]:
conll_cleaned["train"][0]

{'ner_tags': [3, 0, 0, 0, 0, 0, 0, 0, 0],
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.']}