# Finetune pretrained model on MasakhaNER using ClearML and Google Colab

Colab notebook that downloads a pretrained SHIBA model from CLearML (stored in an S3 bucket), then finetunes for NER using one of [several datasets based on MasakhaNER](https://github.com/cdleong/masakhane-ner). We've got Kinyarwanda, Swahili, phonemized/not-phonemized, with spaces and without spaces. 


# Install dependencies and imports

In [None]:
# ! pip install shiba-model jsonlines torchmetrics transformers datasets clearml boto3 awscli
! pip uninstall -y folium
! pip install boto3==1.19.6 clearml==1.1.3 datasets==1.14.0 tensorboard==2.6.0 torchmetrics==0.6.0 transformers==4.12.0 seqeval==1.2.2 shiba_model==0.1.0

# google_cloud_storage == 1.18.1
# https://download.pytorch.org/whl/cu111/torch-1.9.0%2Bcu111-cp37-cp37m-linux_x86_64.whl
# matplotlib == 3.2.2
# numpy == 1.19.5
# pathlib == 1.0.1
# scikit_learn == 0.22.2.post1
# seqeval == 1.2.2
# shiba_model == 0.1.0


In [None]:
# ! pip install torchmetrics transformers datasets

In [None]:
# !pip install clearml==1.1.2 boto3 awscli

In [None]:
# !pip install seqeval

In [None]:
! git clone https://github.com/octanove/shiba

In [None]:
# Assuming here that you've uploaded this manually. 
!cat clearml.conf > ~/clearml.conf


#AWS install
!mkdir -p ~/.aws
!cp config ~/.aws
!cp credentials ~/.aws

!ls ~/.aws
!ls ~/clearml.conf

In [None]:
import os
from typing import Optional, Dict, Tuple, List
import inspect
from dataclasses import dataclass, field

import torch
import torchmetrics
import transformers
from datasets import load_dataset, Dataset
from transformers import HfArgumentParser, Trainer, EvalPrediction, BertForSequenceClassification, AutoTokenizer, \
    DataCollatorWithPadding, TrainingArguments, EarlyStoppingCallback
from shiba import ShibaForClassification, ShibaForSequenceLabeling, CodepointTokenizer
from shiba import Shiba, get_pretrained_state_dict
from torch.nn.utils.rnn import pad_sequence
from random import randrange

In [None]:
from clearml import Task

# Setup Task and Download pretrained model

We want to connect a previous Task, and download a particular model from that Task. It needs to be one of the pytorch_model.bin files. 


See https://www.allegro.ai/clearml/docs/rst/references/clearml_python_ref/model_module/model_model.html?highlight=get_local_copy#clearml.model.Model.get_local_copy

And see https://www.allegro.ai/clearml/docs/rst/references/clearml_python_ref/task_module/task_task.html#clearml.task.Task.set_input_model


In [None]:
# copy-pasted from Google Docs spreadsheet
# input task name (note: step counts deceptive, early stopping was on)	input task ID
# text only	
# SHIBA, train='hf_swahili_no_spaces_jsonl', val='hf_swahili_no_spaces_jsonl', steps=30000, seed=420	e598bf905b184e8c92239bd9ffc1ec7c
# text only, but more text (aka if we had ASR perfectly done)	
# SHIBA, train='hf_swahili_plus_alffa_gold_no_word_boundaries_jsonl', val='hf_swahili_no_spaces_jsonl', steps=30000, seed=420	404f726453cf4304a85ea46cb36b0140
# epitran on text only	
# SHIBA, train='hf_swahili_epitran_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=20000, seed=42	66751d7d3aef44cc9ac96003579e6578
# SHIBA, train='hf_swahili_epitran_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=42	bd7ef48002d24058b76cc321d2e12dd2
# SHIBA, train='hf_swahili_epitran_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=420	45620c38bc3f4a14920a47dd83bacf2b
# epitran on text only, with more text	
# SHIBA, train='hf_swahili_epitran_plus_alffa_gold_epitran_no_word_boundaries_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=20000, seed=42	77fa240ccced47879db700a519db372a
# SHIBA, train='hf_swahili_epitran_plus_alffa_gold_epitran_no_word_boundaries_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=42	9b691f596e1141439d29c59763a78272
	
# epitran text and allosaurus	
# SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=20000, seed=42	d75762028d5742dd84ce785aa4e35acf
# SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=42	9f98fe0646a94392a98f4159c8f92f5d
	
	
# allosaurus only on Kinyarwanda	
# SHIBA, train='rw_allosaurus_204_of_258_train_jsonl', val='rw_allosaurus_204_of_258_train_jsonl', steps=30000, seed=42	67efb2b321f147b0908e3a6b82d98d69
# SHIBA, train='rw_allosaurus_204_of_258_train_jsonl', val='rw_allosaurus_204_of_258_train_jsonl', steps=30000, seed=420	f96628e881a442a0a52d9646b2783982

## set input task ID

In [None]:

# clearml_input_task_id="9f98fe0646a94392a98f4159c8f92f5d" # AKA "SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=42"
# clearml_input_task_id="66751d7d3aef44cc9ac96003579e6578" #66751d7d3aef44cc9ac96003579e6578 is SHIBA, train='hf_swahili_epitran_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=20000, seed=42
# clearml_input_task_id="f96628e881a442a0a52d9646b2783982"  #  AKA the second Common Voice experiment
# clearml_input_task_id = "77fa240ccced47879db700a519db372a" # HF + ALFFA Gold, both epitran
# clearml_input_task_id= "e598bf905b184e8c92239bd9ffc1ec7c"  # HF Sw only
# clearml_input_task_id= "404f726453cf4304a85ea46cb36b0140"  # HF + ALFFA
# clearml_input_task_id= "d75762028d5742dd84ce785aa4e35acf"  # SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=20000, seed=42
# clearml_input_task_id= "66751d7d3aef44cc9ac96003579e6578"  # SHIBA, train='hf_swahili_epitran_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=20000, seed=42
# clearml_input_task_id= "9f98fe0646a94392a98f4159c8f92f5d" # SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=42
# clearml_input_task_id= "bd7ef48002d24058b76cc321d2e12dd2" # SHIBA, train='hf_swahili_epitran_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=42
# clearml_input_task_id="45620c38bc3f4a14920a47dd83bacf2b" # SHIBA, train='hf_swahili_epitran_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=420
# clearml_input_task_id="9b691f596e1141439d29c59763a78272" # SHIBA, train='hf_swahili_epitran_plus_alffa_gold_epitran_no_word_boundaries_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=42
# clearml_input_task_id="78dd5e6f88624458975f7cfdcad48acf" #SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=420
# clearml_input_task_id="67efb2b321f147b0908e3a6b82d98d69" # SHIBA, train='rw_allosaurus_204_of_258_train_jsonl', val='rw_allosaurus_204_of_258_train_jsonl', steps=30000, seed=42
# clearml_input_task_id="9473c339194e4d649713278bc9a3f694" # 9473c339194e4d649713278bc9a3f694 is SHIBA, train='ALFFA_allosaurus_transcriptions_with_epitran_inventory_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=314
# clearml_input_task_id="c747034e4dd0439484d6b43750fdd400" # SHIBA, train='ALFFA_allosaurus_transcriptions_with_epitran_inventory_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=42
# clearml_input_task_id="286a17a424f24c428c5c70c74a1e50bd" # SHIBA, train='ALFFA_allosaurus_transcriptions_with_epitran_inventory_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=420
# clearml_input_task_id="8f463a814a464ec882cb12d9e96cfac7" # SHIBA, train='hf_swahili_epitran_plus_alffa_gold_epitran_no_word_boundaries_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=420
# clearml_input_task_id="e3b51b1ac2f140e59a9ee73d506c3de6" # SHIBA, train='hf_swahili_epitran_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=314
# clearml_input_task_id="d3f3be10ceb449d099e12c30ba6a2331" # SHIBA, train='hf_swahili_no_spaces_jsonl', val='hf_swahili_no_spaces_jsonl', steps=30000, seed=314
# clearml_input_task_id="c50e7aaf63d04cb4b5aefbfb9b313eb4" # SHIBA, train='common_voice_rw_epitran_no_spaces_jsonl', val='common_voice_rw_epitran_no_spaces_jsonl', steps=30000, seed=42
# clearml_input_task_id="384ee109dfc54132ba35c93c3653618a" # SHIBA, train='common_voice_rw_epitran_no_spaces_jsonl', val='common_voice_rw_epitran_no_spaces_jsonl', steps=30000, seed=420 
# clearml_input_task_id="f96628e881a442a0a52d9646b2783982" # SHIBA, train='rw_allosaurus_204_of_258_train_jsonl', val='rw_allosaurus_204_of_258_train_jsonl', steps=30000, seed=420
# clearml_input_task_id="3836cfc178e647358a12d1efa29b8b48" # SHIBA, train='hf_swahili_plus_alffa_gold_no_word_boundaries_jsonl', val='hf_swahili_no_spaces_jsonl', steps=30000, seed=314
# clearml_input_task_id="2f5bda07113947c3a32b9ba3b932afe1" # SHIBA, train='hf_swahili_plus_alffa_gold_no_word_boundaries_jsonl', val='hf_swahili_no_spaces_jsonl', steps=30000, seed=666
# clearml_input_task_id="62d1f269d6c7435bb0e4856e7eb2d1d7" # SHIBA, train='common_voice_rw_epitran_no_spaces_jsonl', val='common_voice_rw_epitran_no_spaces_jsonl', steps=30000, seed=7
# clearml_input_task_id="c50e7aaf63d04cb4b5aefbfb9b313eb4" # SHIBA, train='common_voice_rw_epitran_no_spaces_jsonl', val='common_voice_rw_epitran_no_spaces_jsonl', steps=30000, seed=42
# clearml_input_task_id="62d1f269d6c7435bb0e4856e7eb2d1d7" # SHIBA, train='common_voice_rw_epitran_no_spaces_jsonl', val='common_voice_rw_epitran_no_spaces_jsonl', steps=30000, seed=7
# clearml_input_task_id="c3d0cf324d5f4d9581c1bfea73ebebbd" # SHIBA, train='rw_allosaurus_204_of_258_train_jsonl', val='rw_allosaurus_204_of_258_train_jsonl', steps=30000, seed=7
# clearml_input_task_id="623716b9359942439096ecea8c43464e" # SHIBA, train='rw_allosaurus_204_of_258_train_jsonl', val='rw_allosaurus_204_of_258_train_jsonl', steps=30000, seed=666
# clearml_input_task_id="aaae56010e9d4f04897f2f6e25a91774" # SHIBA, train='hf_swahili_no_spaces_jsonl', val='hf_swahili_no_spaces_jsonl', steps=30000, seed=42
# clearml_input_task_id="e11ce10176454cc78faf619cb32343d5" # SHIBA, train='hf_swahili_no_spaces_jsonl', val='hf_swahili_no_spaces_jsonl', steps=30000, seed=7

# SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_spaces_fixed_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=42	17894fdd1abd4fe1a97adb7d34dd73e1
# clearml_input_task_id="17894fdd1abd4fe1a97adb7d34dd73e1"
# SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_spaces_fixed_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=420	89df4ac03d2e40769fec4012befdea0f
# clearml_input_task_id="89df4ac03d2e40769fec4012befdea0f"
# SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_spaces_fixed_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=314	ada77d92ba9c4b9baebbfeea4de45de3
# clearml_input_task_id="ada77d92ba9c4b9baebbfeea4de45de3"
# SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_spaces_fixed_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=7	8293fffb16db41f7bd87aeb6f5f13534
# clearml_input_task_id="8293fffb16db41f7bd87aeb6f5f13534"


# # SHIBA, train='ALFFA_allosaurus_transcriptions_with_epitran_inventory_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=666	90c40ebea196422197d3b9536ddaf96a
# clearml_input_task_id="90c40ebea196422197d3b9536ddaf96a" 
# # SHIBA, train='ALFFA_allosaurus_transcriptions_with_epitran_inventory_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=7	503c512efa5a4488a64f5399991ebd5d
# clearml_input_task_id="503c512efa5a4488a64f5399991ebd5d" 
# SHIBA, train='ALFFA_allosaurus_transcriptions_with_epitran_inventory_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=314	fbbac7dd5c634481be67dd74e97c3803
# clearml_input_task_id="fbbac7dd5c634481be67dd74e97c3803" 
# SHIBA, train='ALFFA_allosaurus_transcriptions_with_epitran_inventory_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=420	eb4ff7996366477797b6a6e598d1eac1
# clearml_input_task_id="eb4ff7996366477797b6a6e598d1eac1" 
# SHIBA, train='ALFFA_allosaurus_transcriptions_with_epitran_inventory_no_spaces_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=42	4219aafc64484c49a3bb0b2831b2fd74
# clearml_input_task_id="4219aafc64484c49a3bb0b2831b2fd74" 



# SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_spaces_fixed_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=666	dbead513ab854ed2aee6ce7875672ae6
# clearml_input_task_id="dbead513ab854ed2aee6ce7875672ae6" 
# SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_spaces_fixed_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=777	336e9095e7c04cca9373abbc375adfdf
# clearml_input_task_id="336e9095e7c04cca9373abbc375adfdf" 
# SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_spaces_fixed_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=300000, seed=365	ce75bdea067f44378e4d0e5a33557ceb
# clearml_input_task_id="ce75bdea067f44378e4d0e5a33557ceb" 


# SHIBA, train='common_voice_rw_epitran_no_spaces_jsonl', val='common_voice_rw_epitran_no_spaces_jsonl', steps=30000, seed=7	562b057944134af19ebfbeb0d535e15b
# (finetune on swa phones)
# clearml_input_task_id="562b057944134af19ebfbeb0d535e15b"

clearml_input_task_id="aaae56010e9d4f04897f2f6e25a91774" # SHIBA, train='hf_swahili_no_spaces_jsonl', val='hf_swahili_no_spaces_jsonl', steps=30000, seed=42

# clearml_input_task_id="" 



specific_model_artifact_name = None
# specific_model_artifact_name = "checkpoint-5000_pytorch_model.bin"


input_task = Task.get_task(clearml_input_task_id)


In [None]:
input_task.name

"SHIBA, train='hf_swahili_no_spaces_jsonl', val='hf_swahili_no_spaces_jsonl', steps=30000, seed=42"

In [None]:
input_task_artifacts = input_task.artifacts

## set task seed and name and which dataset to use

In [None]:
task_seed=42

# dataset_for_finetuning="kin_phonemes_no_word_boundaries"
# dataset_for_finetuning="swa_phonemes_no_word_boundaries"
dataset_for_finetuning="swa_no_word_boundaries"


task = Task.init(
        project_name="", # add your clearml project name here
        output_uri="s3://", # add your s3://bucket here
        task_name=f"[finetune seed {task_seed}] finetune Shiba on {dataset_for_finetuning} using model from Task '{input_task.name}'"
    )

In [None]:
task.name

"[finetune seed 42] finetune Shiba on swa_no_word_boundaries using model from Task 'SHIBA, train='hf_swahili_no_spaces_jsonl', val='hf_swahili_no_spaces_jsonl', steps=30000, seed=42'"

In [None]:
input_task_models = input_task.get_models()["output"]
print(input_task_models)

In [None]:
# apparently you can just directly search the "list-like object by name"
if specific_model_artifact_name is None:
  input_task_pytorch_model = input_task_models["pytorch_model"]
  print(input_task_pytorch_model.name)
  input_task_pytorch_model_local_copy = input_task_pytorch_model.get_local_copy()
else:
  input_task_pytorch_model = input_task_artifacts[specific_model_artifact_name]
  input_task_pytorch_model_local_copy = input_task_pytorch_model.get_local_copy()
  print(input_task_pytorch_model.name)

In [None]:
input_task_pytorch_model

In [None]:
print(input_task_pytorch_model_local_copy)
from pathlib import Path
pretrained_model_folder=Path(input_task_pytorch_model_local_copy).parent

# later code in the notebook expects a string, so...
pretrained_model_folder=str(pretrained_model_folder)


os.environ["pretrained_model_folder"]=pretrained_model_folder
!mkdir -p "$pretrained_model_folder"

/root/.clearml/cache/storage_manager/global/caea772d29cf1edd7b2a19b621d33204.pytorch_model.bin


# Helper functions

In [None]:
@dataclass
class DataArguments:
    data: str = field(
        default=None, metadata={"help": "The location of the Japanese wiki data to use for training."}
    )


@dataclass
class ShibaTrainingArguments(TrainingArguments):
    masking_type: Optional[str] = field(default='rand_span')
    load_only_model: Optional[bool] = field(default=False)

    group_by_length: Optional[bool] = field(default=True)
    logging_first_step: Optional[bool] = field(default=True)
    learning_rate: Optional[float] = 0.001

    logging_steps: Optional[int] = field(default=200)
    report_to: Optional[List[str]] = field(default_factory=lambda: ['wandb'])
    evaluation_strategy: Optional[str] = field(default='steps')
    fp16: Optional[bool] = field(default=torch.cuda.is_available())
    deepspeed: Optional = field(default=None)
    warmup_ratio: Optional[float] = 0.025  # from canine

    per_device_eval_batch_size: Optional[int] = field(default=12)
    per_device_train_batch_size: Optional[int] = field(default=12)
    # max that we can fit on one GPU is 12. 12 * 21 * 8 = 2016
    gradient_accumulation_steps: Optional[int] = field(default=21)

    # model arguments - these have to be in training args for the hyperparam search
    dropout: Optional[float] = field(
        default=0.1
    )
    deep_transformer_stack_layers: Optional[int] = field(
        default=12
    )
    local_attention_window: Optional[int] = field(default=128)


@dataclass
class ShibaWordSegArgs(ShibaTrainingArguments):
    do_predict: Optional[bool] = field(default=True)

    # only used for hyperparameter search
    trials: Optional[int] = field(default=2)
    deepspeed: Optional = field(default=None)
    gradient_accumulation_steps: Optional[int] = field(default=1)
    report_to: Optional[List[str]] = field(default=lambda: ['tensorboard', 'wandb'])
    num_train_epochs: Optional[int] = 6
    save_strategy: Optional[str] = 'no'

    pretrained_bert: Optional[str] = field(default=None)


@dataclass
class ShibaClassificationArgs(ShibaTrainingArguments):
    do_predict: Optional[bool] = field(default=True)
    eval_steps: Optional[int] = field(default=300)
    logging_steps: Optional[int] = field(default=100)
    learning_rate: Optional[float] = 2e-5
    per_device_train_batch_size: Optional[int] = 6
    num_train_epochs: Optional[int] = 6
    save_strategy: Optional[str] = 'no'
    output_dir: Optional[str] = field(default="~/runs/livedoor_classification")

    # only used for hyperparameter search
    trials: Optional[int] = field(default=2)
    deepspeed: Optional = field(default=None)
    gradient_accumulation_steps: Optional[int] = field(default=1)
    report_to: Optional[List[str]] = field(default=lambda: ['tensorboard', 'wandb'])

    pretrained_bert: Optional[str] = field(default=None)


def get_model_hyperparams(input_args):
    if not isinstance(input_args, dict):
        input_args = input_args.__dict__

    shiba_hyperparams = inspect.getfullargspec(Shiba.__init__).args
    return {key: val for key, val in input_args.items() if key in shiba_hyperparams}


def get_base_shiba_state_dict(state_dict: Dict) -> Dict:
    if sum(1 for x in state_dict.keys() if x.startswith('shiba_model')) > 0:
        return {key[12:]: val for key, val in state_dict.items() if key.startswith('shiba_model')}
    else:
        return state_dict


def prepare_data(args: DataArguments) -> Tuple[Dataset, Dataset]:
    all_data = load_dataset('json', data_files=args.data)['train']
    data_dict = all_data.train_test_split(train_size=0.98, seed=42)
    training_data = data_dict['train']
    dev_data = data_dict['test']
    return training_data, dev_data


class SequenceLabelingDataCollator:
    def __init__(self):
        self.tokenizer = CodepointTokenizer()

    def __call__(self, batch) -> Dict[str, torch.Tensor]:
        padded_batch = self.tokenizer.pad([x['input_ids'] for x in batch])
        input_ids = padded_batch['input_ids']
        attention_mask = padded_batch['attention_mask']

        # don't compute loss from padding
        labels = pad_sequence([torch.tensor(x['labels']) for x in batch], batch_first=True, padding_value=-100)
        # also don't compute loss from CLS or SEP tokens
        special_token_mask = (input_ids == self.tokenizer.CLS) | (input_ids == self.tokenizer.SEP)
        labels = labels.where(~special_token_mask, torch.full(labels.shape, -100))


        ########################
        # Some debugging code! 
        # print(input_ids)
        # print(attention_mask)
        # print(labels)
        # print(type(input_ids)) # torch.tensor 
        # print(input_ids.size()) # 8x222, or 8x203, etc. 
        for i, row in enumerate(input_ids):
          # print(row)
          # print(type(row)) # <class 'torch.Tensor'>
          # print(row.size()) # torch.Size([203]) or whatever
          decoded = self.tokenizer.decode(row)
          nonzeroes=torch.count_nonzero(row)
          if nonzeroes<5:
            print(f"row {i} has {nonzeroes} nonzeroes")
            print(decoded)
          #   input_ids[i]=input_ids[i-1]
          #   attention_mask[i]=attention_mask[i-1]
          #   labels[i]=labels[i-1]
        #############################################
            
          
          
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }


class ClassificationDataCollator:
    def __init__(self):
        self.tokenizer = CodepointTokenizer()

    def __call__(self, batch) -> Dict[str, torch.Tensor]:
        padded_batch = self.tokenizer.pad([x['input_ids'] for x in batch])
        input_ids = padded_batch['input_ids']
        attention_mask = padded_batch['attention_mask']

        labels = torch.tensor([x['labels'] for x in batch])

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

# Load pretrained model
https://github.com/octanove/shiba/issues/2 describes more, but tl;dr, you can torch.load the pytorch_model.bin, but the keys are all prefixed with shiba_model, which Shiba's load_state_dict is NOT expecting. Turns out you can just... edit the state dict!

In [None]:
! ls -alh "$pretrained_model_folder"

total 608M
drwxr-xr-x 2 root root 4.0K Nov  8 22:07 .
drwxr-xr-x 3 root root 4.0K Nov  8 22:07 ..
-rw-r--r-- 1 root root 608M Nov  8 22:07 caea772d29cf1edd7b2a19b621d33204.pytorch_model.bin


In [None]:
def load_and_fix_state_dict(path_to_pytorch_model):
  # we're expecting either a pytorch_model.bin or a whatever.pt
  state_dict = torch.load(path_to_pytorch_model)

  # https://discuss.pytorch.org/t/prefix-parameter-names-in-saved-model-if-trained-by-multi-gpu/494/4 describes a method for pulling the prefixes off
  state_dict_with_fixed_keys = {k.partition("shiba_model.")[2]:state_dict[k] for k in state_dict.keys()}

  # that ends up deleting keys like "autregressive_encoder.norm2.bias" that don't start with "shiba_model."", reducing them to ""
  # fortunately, we don't _want_ those anyway
  _ = state_dict_with_fixed_keys.pop("", None)

  return state_dict_with_fixed_keys

# load MasakhaNER data


In [None]:
from pathlib import Path

In [None]:
# !rm -r masakhane-ner && 
!git clone https://github.com/cdleong/masakhane-ner.git


In [None]:
!ls "masakhane-ner/data/"

amh				 luo
amh.zip				 luo.zip
hau				 pcm
hau.zip				 pcm.zip
ibo				 swa
ibo.zip				 swa_no_word_boundaries
kin				 swa_phonemes
kin_no_word_boundaries		 swa_phonemes_no_word_boundaries
kin_phonemes			 swa.zip
kin_phonemes_no_word_boundaries  wol
kin.zip				 wol.zip
lug				 yor
lug.zip				 yor.zip


In [None]:
input_task.name

"SHIBA, train='hf_swahili_no_spaces_jsonl', val='hf_swahili_no_spaces_jsonl', steps=30000, seed=42"

In [None]:
# example input model name: 
# SHIBA, train='hf_swahili_epitran_plus_alffa_allosaurus_no_word_boundaries_jsonl', val='hf_swahili_epitran_no_spaces_jsonl', steps=30000, seed=42 - pytorch_model


masakhaner_dataset = load_dataset("/content/masakhane-ner/custom_huggingface_loading_script.py", dataset_for_finetuning)

# Set based on input_model
# if "rw_allosaurus" in input_task.name or "common_voice_rw_epitran_no_spaces_jsonl" in input_task.name:
#   masakhaner_dataset = load_dataset("/content/masakhane-ner/custom_huggingface_loading_script.py", "kin_phonemes_no_word_boundaries") 
# elif "epitran" in input_task.name:
#   # that means we're in phoneme space
#   # e.g. hf_swahili_epitran_no_spaces_jsonl
#   masakhaner_dataset = load_dataset("/content/masakhane-ner/custom_huggingface_loading_script.py", "swa_phonemes_no_word_boundaries") 
# else:
#   masakhaner_dataset = load_dataset("/content/masakhane-ner/custom_huggingface_loading_script.py", "swa_no_word_boundaries") 



masakhaner_dataset_orig = masakhaner_dataset



In [None]:
print(masakhaner_dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 2109
    })
    validation: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 300
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 604
    })
})


In [None]:
print(masakhaner_dataset["train"])

Dataset({
    features: ['id', 'tokens', 'ner_tags'],
    num_rows: 2109
})


In [None]:
print(masakhaner_dataset["train"][0])

{'ner_tags': [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 8, 8, 8, 8, 8, 8, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'tokens': ['W', 'i', 'z', 'a', 'r', 'a', 'y', 'a', 'a', 'f', 'y', 'a', 'y', 'a', 'T', 'a', 'n', 'z', 'a', 'n', 'i', 'a', 'i', 'm', 'e', 'r', 'i', 'p', 'o', 't', 'i', 'J', 'u', 'm', 'a', 't', 'a', 't', 'u', 'k', 'u', 'w', 'a', ',', 'w', 'a', 't', 'u', 't', 'a', 'k', 'r', 'i', 'b', 'a', 'n', '1', '4', 'z', 'a', 'i', 'd', 'i', 'w', 'a', 'm', 'e', 'p', 'a', 't', 'a', 'm', 'a', 'a', 'm', 'b', 'u', 'k', 'i', 'z', 'i', 'y', 'a', 'C', 'o', 'v', 'i', 'd', '-', '1', '9', '.'], 'id': '0'}


In [None]:
def keep_examples_by_length(example):
  target_length = 4
  if len(example["tokens"]) >= target_length:
    return True
  else: 
    print("found a short example:")
    print(example["tokens"])
    return False


masakhaner_dataset = masakhaner_dataset_orig.filter(keep_examples_by_length)   

  0%|          | 0/3 [00:00<?, ?ba/s]

found a short example:
['6', '2', '.']
found a short example:
['w', 'a', '.']
found a short example:
['.']
found a short example:
['5', '4', '.']
found a short example:
['7', '.']


  0%|          | 0/1 [00:00<?, ?ba/s]

found a short example:
['4', '.']
found a short example:
['4', '.']
found a short example:
['S', '.']
found a short example:
['S', '.']
found a short example:
['S', '.']
found a short example:
['g', '.']
found a short example:
['S', '.']
found a short example:
['S', '.']


  0%|          | 0/1 [00:00<?, ?ba/s]

found a short example:
['5', '.']
found a short example:
['7', '.']
found a short example:
['1', '.']
found a short example:
['D', 'r', '.']
found a short example:
['L', 'i', '.']


In [None]:
print(masakhaner_dataset_orig)
print(masakhaner_dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 2109
    })
    validation: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 300
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 604
    })
})
DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 2104
    })
    validation: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 292
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 599
    })
})


In [None]:
# Define GPU device
device = torch.device('cuda:0')

# Tokenizer from the SHIBA package
tokenizer = CodepointTokenizer()

# TODO - To "simulate" the NER scenario, I've changed this process example 
# function. For word segmentation, SHIBA labeled each character as 0 or 1 to 
# indicate word boundaries (see here: 
# https://github.com/octanove/shiba/blob/main/training/finetune_word_segmentation.py#L27). 
# In our case, we will have more than 2 possible labels for each character (one 
# for each possible NER class). To simulate this I'm just assigning a random
# label 0-9 (simulating 10 NER classes). But we would of course want to make
# these labels real for our data.
def process_example(example: Dict) -> Dict:
  original_data = example['tokens']
  text = ''.join(original_data)  # join all the character together 

  # 
  
  tokenized_text = tokenizer.encode(text) # returns a dict with keys "input_ids" and "attention_mask"... I think
  input_ids = tokenized_text['input_ids']  # pull out the input_ids, which should be ints

  outside_label = masakhaner_dataset["train"].features["ner_tags"].feature.str2int("O")
  labels = [outside_label] # make a list with "O" (outside) at the start, because tokenizer adds a CLS token at the beginning

  masakhaNERlabels=example["ner_tags"]  # masakhaNER loading script gives us ints, we just rename them
  labels.extend(masakhaNERlabels) # now the input_ids and labels are the same length!


  return {
      'input_ids': input_ids,
      'labels': labels
  }

masakhaner_class_count = masakhaner_dataset["train"].features["ner_tags"].feature.num_classes # 4 Entity types, B/I of each, plus "O"
def load_model():
  model_hyperparams = get_model_hyperparams(training_args)

  # to whatever corresponds to our actual number of NER classes.
  
  model = ShibaForSequenceLabeling(masakhaner_class_count, **model_hyperparams)




  # if training_args.resume_from_checkpoint:
  #   print('Loading and using base shiba states from', training_args.resume_from_checkpoint)
  #   checkpoint_state_dict = torch.load(training_args.resume_from_checkpoint)
  #   model.shiba_model.load_state_dict(get_base_shiba_state_dict(checkpoint_state_dict))


  # load in the pretrained model
  # state_dict_path = pretrained_model_folder + "/pytorch_model.bin"
  state_dict_path = input_task_pytorch_model_local_copy # should be a something.pytorch.bin
  shiba_model = Shiba()


  shiba_model.load_state_dict(load_and_fix_state_dict(state_dict_path))
  # use our pretrained model, which we loaded in previously. 
  # we simply set the internal Shiba() object to the one we had created above. 
  model.shiba_model=shiba_model

  data_collator = SequenceLabelingDataCollator()
  return model, data_collator

In [None]:
dep = masakhaner_dataset
# use our custom process_example to convert the characters to shiba input IDs
# and also rename the columns
dep = dep.map(process_example, remove_columns=list(dep['train'][0].keys()))

In [None]:
first_train_element = dep["train"][0]["input_ids"]

In [None]:
tokenizer.encode("wizara")

{'attention_mask': tensor([False, False, False, False, False, False, False]),
 'input_ids': tensor([57344,   119,   105,   122,    97,   114,    97])}

In [None]:
tokenizer.decode([57344,   119,   105,   122,    97,   114,    97])

'[CLS]wizara'

In [None]:
tokenizer.decode(first_train_element)

'[CLS]WizarayaafyayaTanzaniaimeripotiJumatatukuwa,watutakriban14zaidiwamepatamaambukiziyaCovid-19.'

In [None]:
print(masakhaner_class_count)

9


In [None]:
print(dep)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 2104
    })
    validation: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 292
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 599
    })
})


In [None]:
for split in ["train", "validation", "test"]:
  print("Original (without the [CLS]): ", masakhaner_dataset[split]['tokens'][0])
  print('Input Ids (first [CLS]):', dep[split]['input_ids'][0])
  print(len(dep[split]['input_ids'][0]))
  print('Labels:', dep[split]['labels'][0])
  print(len(dep[split]['labels'][0]))
  
  

Original (without the [CLS]):  ['W', 'i', 'z', 'a', 'r', 'a', 'y', 'a', 'a', 'f', 'y', 'a', 'y', 'a', 'T', 'a', 'n', 'z', 'a', 'n', 'i', 'a', 'i', 'm', 'e', 'r', 'i', 'p', 'o', 't', 'i', 'J', 'u', 'm', 'a', 't', 'a', 't', 'u', 'k', 'u', 'w', 'a', ',', 'w', 'a', 't', 'u', 't', 'a', 'k', 'r', 'i', 'b', 'a', 'n', '1', '4', 'z', 'a', 'i', 'd', 'i', 'w', 'a', 'm', 'e', 'p', 'a', 't', 'a', 'm', 'a', 'a', 'm', 'b', 'u', 'k', 'i', 'z', 'i', 'y', 'a', 'C', 'o', 'v', 'i', 'd', '-', '1', '9', '.']
Input Ids (first [CLS]): [57344, 87, 105, 122, 97, 114, 97, 121, 97, 97, 102, 121, 97, 121, 97, 84, 97, 110, 122, 97, 110, 105, 97, 105, 109, 101, 114, 105, 112, 111, 116, 105, 74, 117, 109, 97, 116, 97, 116, 117, 107, 117, 119, 97, 44, 119, 97, 116, 117, 116, 97, 107, 114, 105, 98, 97, 110, 49, 52, 122, 97, 105, 100, 105, 119, 97, 109, 101, 112, 97, 116, 97, 109, 97, 97, 109, 98, 117, 107, 105, 122, 105, 121, 97, 67, 111, 118, 105, 100, 45, 49, 57, 46]
93
Labels: [0, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,

# Training

In [None]:
!nvidia-smi

Mon Nov  8 22:07:44 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    23W / 300W |      2MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Metrics for NER


## converting back and forth from class name to NER tag
Let's make some helper functions to ease the process

In [None]:
def classname_to_int(class_name:str)->int:
  ner_tag = masakhaner_dataset["train"].features["ner_tags"].feature.str2int(class_name)
  return ner_tag

def int_to_class_name(ner_tag:int)->str:
  ner_tag=int(ner_tag)
  class_name = masakhaner_dataset["train"].features["ner_tags"].feature.int2str(ner_tag)
  return class_name

In [None]:
# print(masakhaner_dataset["train"].features["ner_tags"])
class_names = masakhaner_dataset["train"].features["ner_tags"].feature.names
print(f"class names in train set are: {class_names}")
for class_name in class_names:
  class_ner_tag = classname_to_int(class_name=class_name)
  print(f"class with name {class_name} has ner_tag of {class_ner_tag}")

for i in range(masakhaner_dataset["train"].features["ner_tags"].feature.num_classes):
  class_name=int_to_class_name(i)
  print(f"class with tag of {i} has name of {class_name}")

class names in train set are: ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-DATE', 'I-DATE']
class with name O has ner_tag of 0
class with name B-PER has ner_tag of 1
class with name I-PER has ner_tag of 2
class with name B-ORG has ner_tag of 3
class with name I-ORG has ner_tag of 4
class with name B-LOC has ner_tag of 5
class with name I-LOC has ner_tag of 6
class with name B-DATE has ner_tag of 7
class with name I-DATE has ner_tag of 8
class with tag of 0 has name of O
class with tag of 1 has name of B-PER
class with tag of 2 has name of I-PER
class with tag of 3 has name of B-ORG
class with tag of 4 has name of I-ORG
class with tag of 5 has name of B-LOC
class with tag of 6 has name of I-LOC
class with tag of 7 has name of B-DATE
class with tag of 8 has name of I-DATE


In [None]:
# foo_string_label = "B-PER"
# foopred = ["B-PER", "I-ORG"]
# # barpred=foopred
# barpred = ["B-PER", "I-PER", "I-PER", "I-ORG"]
# foo_class_name = foo_string_label.split("-")[1]

# #https://stackoverflow.com/questions/6890170/how-to-find-the-last-occurrence-of-an-item-in-a-python-list
# print(barpred)
# index_of_last_item_matching=max(loc for loc, val in enumerate(barpred) if foo_class_name in val)
# print(index_of_last_item_matching)
# print(barpred[index_of_last_item_matching])

# bar_consistent = barpred[:index_of_last_item_matching+1] # string slicing is noninclusive
# print(bar_consistent)

# print(len(bar_consistent))


# def test_thing():
#   class_name = "PER"
#   preds = ["B-PER", "I-PER", "I-PER", "I-ORG", "I-PER"]
#   index_of_last_item_matching=0
#   for i, val in enumerate(preds):
#     print(val)
#     if class_name not in val:
#       print(i)
#       index_of_last_item_matching= i-1
#       break
#   else:
#     index_of_last_item_matching=0

#   print(index_of_last_item_matching)
# test_thing()

## Per-entity metrics

In [None]:
import numpy as np
# based on https://vitalflux.com/accuracy-precision-recall-f1-score-python-example/
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt

#https://docs.python.org/3/library/dataclasses.html
@dataclass
class NEREntity():  
  string_label: str
  labels:list
  preds:list
  start_index : int # the index of the starting B within labels

  def get_length(self):
    return len(self.labels)

  def increment(self, lbl, pred):
    self.labels.append(lbl)
    self.preds.append(pred)

  def totally_correct(self):
    return self.labels==self.preds

  def first_element_correct(self):
    return self.labels[0]==self.preds[0]

  def proportion_correct(self):
    length = self.get_length()
    correct_preds = 0
    for lbl, pred in zip(self.labels, self.preds):
      if lbl == pred:         
        correct_preds = correct_preds+1
      else:
        # print("wrong character")
        pass
    return correct_preds/length

  def length_of_predicted_class_consistency(self):
    # whether true or fake, once the class changes it can't possibly be right. 
    # Or at least, you have no way of knowing if the class is correct. 
    # If you see B-PER, I-ORG, that I-ORG is definitely wrong.
    # print(self.string_label)
    class_name = self.string_label.split("-")[1] # PER, ORG, etc


    #https://stackoverflow.com/questions/6890170/how-to-find-the-last-occurrence-of-an-item-in-a-python-list
    # is not quite right. We want the first one that does not match the class label
    # index_of_last_item_matching=max(loc for loc, val in enumerate(self.preds) if class_name in int_to_class_name(val))
    # above could
    index_of_last_item_matching = 0
    
    # https://code-maven.com/python-find-first-element-in-list-matching-condition
    # class_name = "PER"
    # preds = ["B-PER", "I-PER", "I-PER", "I-ORG", "I-PER"]
    index_of_last_item_matching=0
    for i, val in enumerate(self.preds):
      # print(val)
      if class_name not in int_to_class_name(val):
        index_of_last_item_matching= i-1
        break
    else:
      # finished normally
      index_of_last_item_matching=i
      


    length_of_predicted_class_consistency = index_of_last_item_matching+1
    return length_of_predicted_class_consistency


  def proportion_of_entity_that_is_consistent(self):
    # how far can you get into the class before it veers off?
    consistent_run_length = self.length_of_predicted_class_consistency()
    proportion = consistent_run_length/self.get_length()
    # print(f"in entity with label {self.string_label}, the consistent_run_length is {consistent_run_length}, and length is {self.get_length()} proportion of the entity that is consistent {proportion}")
    return proportion

  def length_of_correctness(self):
    
    if self.totally_correct():
      # all the preds = all the labels. 
      # It might be too short, but we can't check that here
      return self.get_length()

    if not self.first_element_correct():
      # If even the first one is wrong, the rest doesn't matter. 
      # 0% of it is consistent and correct. 
      return 0

    first_nonmatch_index = None
    for i, tupleval in enumerate(zip(self.labels, self.preds)):
      label, pred = tupleval
      if label != pred:
        
        first_nonmatch_index = i
        break
    else:
      # they match all the way through. Shouldn't be possible but OK
      first_nonmatch_index=None

    # if we've gotten here, it isn't totally correct, it isn't totally wrong, 
    # and we have the index of the first nonmatch
    if first_nonmatch_index is not None:
      # first nonmatch index should be correct. 
      # Alternatively we could do last match index, then add one, 
      # which gives the same thing
      length_of_correctness = first_nonmatch_index 
    return length_of_correctness

  def proportion_of_entity_that_is_consistent_and_correct(self):
    length_of_predicted_class_consistency = self.length_of_predicted_class_consistency()
    length_of_correctness = 0

    length_of_correctness = self.length_of_correctness()
       

    # it could be consistent but too long, in which case correctness limits us
    # it could be inconsistent, in which case it would, we think, be incorrect
    # but anyway, this should catch what we want
    limiting_length = min(length_of_correctness, length_of_predicted_class_consistency)
    proportion = limiting_length/self.get_length()
    

    return proportion

  


def per_entity_metric(label_row, prediction_row):
  # for each item in label_row, check where the entity begins and ends using
  # B - something not I
  # print(f"label row:   {label_row}")
  # print(f"pred row:    {prediction_row}")
  
  true_entities = []
  predicted_entities = []

  last_string_prediction = ""
  for i, lbl_predtuple in enumerate(zip(label_row, prediction_row)):
    lbl, pred = lbl_predtuple
    
    if lbl != -100: # not sure why this is here -Colin
      lbl=int(lbl)
      # print(lbl)
      # print(type(lbl))
      string_label = int_to_class_name(lbl) # O, B-something, I-something
      string_pred = int_to_class_name(pred)
      # print(f"previous was {last_string_prediction} and this one is {string_pred}")

      
      # print(f"string_label is {string_label}")
      # print(f"string_pred is {string_pred}")

      # We track true entities based on label. 
      if string_label.startswith("O"):
        # not actually an entity, please don't increment
        pass
      elif string_label.startswith("B"):
        # print(f"new true entity with tag {string_label}")
        # start tracking a new real entity        
        entity = NEREntity(string_label=string_label, labels=[lbl], preds=[pred], start_index=i)                 
        true_entities.append(entity)
      elif string_label.startswith("I"): # it must be an I label
        # print(f"appending to entity {true_entities[-1]}")
        if true_entities:
          true_entities[-1].increment(lbl, pred)
        # print(f"appended to entity {true_entities[-1]}")
      else: 
        print("not appending or passing or making a new entity") # should be impossible
      

      
      # We track predicted entities based on predictions
      if string_pred.startswith("B"):
        # new predicted entity! Start tracking it!
        predicted_entity = NEREntity(string_label=string_pred, labels=[lbl], preds=[pred], start_index=i)
        predicted_entities.append(predicted_entity)
        # print(f"new predicted entity with tag {string_pred}")
      elif string_pred.startswith("I"):

        
        # aka a beginning, then skip one, then a continuation
        if predicted_entities: # there might not be any yet!

          # what if there's a pred like 3, 2, 0, 2, 2, 0, 
          # then on the second 2, the last prediction is 2, and the entity is 2 
          # but it shouldn't be added. 
          expected_index = predicted_entities[-1].start_index + predicted_entities[-1].get_length()
          if i == expected_index:
            predicted_entities[-1].increment(lbl, pred)
        # Is this a plausible continuation of a predicted entity? 
        # Tried to weed them out here, but: 
        # Case 1: 
        # if the previous prediction was B and the class labels match the last prediction    
          # add the thing! 
        # Case 2: no way to know from just the last label, and the B-label
          # Consider: 
          # actual is B-PER, I-PER, I-PER, I-PER, I-PER, I-PER
          # pred is   B-PER, I-PER, I-ORG, I-ORG, I-PER, I-PER
        
        
      
      # update this at the END! 
      last_string_prediction = string_pred

  return true_entities, predicted_entities


def calculate_score_for_entities(true_entities:list, predicted_entities:list):     

  # out of the true entities, what proportion were perfect? 
  
  true_entity_count = len(true_entities)  
  predicted_entity_count = len(predicted_entities)
  
  
  # How many of the true entities did we label every character correctly?
  # This doesn't catch it when the label goes too LONG. 
  # e.g. if the correct answer is "B-PER", "I-PER", "O"
  # but the prediction was "B-PER", "I-PER", "I-PER"
  totally_correct_true_entities = 0  
  for entity in true_entities:
      # print(f"entity was {entity.length} long")
      # print(f"Proportion of accurately-corrected characters in the entity: {entity.proportion_correct()}")
      if entity.totally_correct(): 
        totally_correct_true_entities = totally_correct_true_entities + 1
  proportion_of_true_entities_totally_correct = totally_correct_true_entities/true_entity_count
      
  
  # Within the true entities, what proportion of the characters were correct? 
  char_proportions = [entity.proportion_correct() for entity in true_entities]  
  true_entities_mean_proportion_of_correct_chars = np.mean(char_proportions)
  


  # what proportion of the predicted_entities is consistent and correct?
  proportions_of_predicted_entities_which_are_consistent_and_correct = []
  proportions_of_predicted_entities_which_are_consistent_and_correct_given_first_character_is_correct = []
  for predicted_entity in predicted_entities:

    # Starts at 0%, e.g. if the first element is wrong
    # proportion_of_predicted_entity_that_is_consistent_and_correct = 0.0 

    proportion_of_predicted_entity_that_is_consistent_and_correct = predicted_entity.proportion_of_entity_that_is_consistent_and_correct()

    proportions_of_predicted_entities_which_are_consistent_and_correct.append(proportion_of_predicted_entity_that_is_consistent_and_correct)

    if predicted_entity.first_element_correct():
      proportions_of_predicted_entities_which_are_consistent_and_correct_given_first_character_is_correct.append(proportion_of_predicted_entity_that_is_consistent_and_correct)
    
    proportions_of_predicted_entities_which_are_consistent_and_correct.append(proportion_of_predicted_entity_that_is_consistent_and_correct)

  mean_proportion_of_each_predicted_entity_which_is_consistent_and_correct = np.mean(proportions_of_predicted_entities_which_are_consistent_and_correct)
  if proportions_of_predicted_entities_which_are_consistent_and_correct_given_first_character_is_correct:
    mean_proportion_of_each_predicted_entity_which_stays_consistent_given_it_starts_correct = np.mean(proportions_of_predicted_entities_which_are_consistent_and_correct_given_first_character_is_correct)
  else: 
    mean_proportion_of_each_predicted_entity_which_stays_consistent_given_it_starts_correct = 0.0
  
  print(f"mean_proportion_of_each_predicted_entity_which_stays_consistent_given_it_starts_correct", mean_proportion_of_each_predicted_entity_which_stays_consistent_given_it_starts_correct)
  
  # within the true entities, what proportion got at least the first character?
  true_entities_first_char_correct_count = 0
  for entity in true_entities: 
    if entity.first_element_correct():
      true_entities_first_char_correct_count = true_entities_first_char_correct_count+1
  true_entities_proportion_with_first_char_correct = true_entities_first_char_correct_count/true_entity_count


  count_of_predicted_entities_that_got_the_first_char_correct = 0
  for predicted_entity in predicted_entities:
    if predicted_entity.first_element_correct():
      count_of_predicted_entities_that_got_the_first_char_correct
  






  # CALCULATE SCORES IN ENTITY SPACE BY FIRST CHAR 
  # https://www.numpyninja.com/post/recall-specificity-precison-f1-scores-and-accuracy

  # https://www.geeksforgeeks.org/python-count-of-elements-matching-particular-condition/
  # 1 for i in test_list if i % 2 != 0
  true_positives_by_first_char = sum(1 for i in true_entities if i.first_element_correct()) # true entities with the first character correct
  false_positives_by_first_char = sum(1 for i in predicted_entities if not i.first_element_correct()) # predicted entities with the first character wrong
  false_negatives_by_first_char = sum(1 for i in true_entities if not i.first_element_correct()) # true entities with the first character wrong
  #true_negatives_by_first_char # basically wherever we had "O"

  # generate a preds and actuals list for scikit-learn that is not overlapping
  sklearn_true = []
  sklearn_preds = []

  # first add in the true entities
  for true_entity in true_entities:
    sklearn_true.append(True) 
    if true_entity.first_element_correct():
      sklearn_preds.append(True) # true positive. Predicted yes when answer was yes
    else:
      sklearn_preds.append(False) # false negative. Predicted no when answer was yes

  # then add in the predicted entities that don't overlap. 
  for predicted_entity in predicted_entities:
    if predicted_entity.first_element_correct():
      # it's a true entity, we covered this. 
      pass
    else: 
      sklearn_true.append(False) # actual answer was no entitye
      sklearn_preds.append(True) # but we predicted there was one
  
  conf_matrix_by_first_char = confusion_matrix(y_true=sklearn_true, y_pred=sklearn_preds)
  #
  # Print the confusion matrix using Matplotlib
  #
  fig, ax = plt.subplots(figsize=(5, 5))
  ax.matshow(conf_matrix_by_first_char, cmap=plt.cm.Oranges, alpha=0.3)
  for i in range(conf_matrix_by_first_char.shape[0]):
      for j in range(conf_matrix_by_first_char.shape[1]):
          ax.text(x=j, y=i,s=conf_matrix_by_first_char[i, j], va='center', ha='center', size='xx-large')
  
  plt.xlabel('B Predictions', fontsize=18)
  plt.ylabel('B Actuals', fontsize=18)
  plt.title('Confusion Matrix by first character', fontsize=18)
  plt.show()
  # print(f"sklearn true:  {sklearn_true}")
  # print(f"sklearn_preds: {sklearn_preds}")

  print(conf_matrix_by_first_char)
  precision_score_by_first_char = precision_score(y_true=sklearn_true, y_pred=sklearn_preds)
  recall_score_by_first_char = recall_score(y_true=sklearn_true, y_pred=sklearn_preds)
  accuracy_score_by_first_char = accuracy_score(y_true=sklearn_true, y_pred=sklearn_preds)
  f1_score_by_first_char = f1_score(y_true=sklearn_true, y_pred=sklearn_preds)


  #################################################
  # Scores by 

  return {
          "proportion_of_true_entities_totally_correct" : proportion_of_true_entities_totally_correct, 
          "true_entities_mean_proportion_of_correct_chars" : true_entities_mean_proportion_of_correct_chars,
          "true_entities_proportion_with_first_char_correct": true_entities_proportion_with_first_char_correct,
          "mean_proportion_of_each_predicted_entity_which_is_consistent_and_correct": mean_proportion_of_each_predicted_entity_which_is_consistent_and_correct,
          # "mean_proportion_of_each_predicted_entity_which_stays_consistent_given_it_starts_correct": mean_proportion_of_each_predicted_entity_which_stays_consistent_given_it_starts_correct,
          "true_positives_by_first_char": true_positives_by_first_char,
          "false_positives_by_first_char": false_positives_by_first_char,
          "false_negatives_by_first_char" : false_negatives_by_first_char,
          "true_entity_count": true_entity_count,
          "predicted_entity_count": predicted_entity_count,
          "precision_score_by_first_char": precision_score_by_first_char,
          "recall_score_by_first_char": recall_score_by_first_char,
          "accuracy_score_by_first_char": accuracy_score_by_first_char,
          "f1_score_by_first_char": f1_score_by_first_char,
          }
  


### Testing per-entity metrics

In [None]:
# Testing it. 
def test_per_entity_metric():
  inputs = ['w', 'i', 'z', 'a', 'ɾ', 'a', 'j', 'a', 'a', 'f', 'j', 'a', 'j', 'a', 't', 'a', 'n', 'z', 'a', 'n', 'i', 'a', 
  'i', 'm', 'e', 'ɾ', 'i',  # the last 5 are not in the entity. Apologies to Swahili speakers for incoherence
  't', 'r', 'u', 'm', 'p'  #adding another random entity
  ] 




  label_row = [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 
               0, 0, 0, 0, 0,
               1, 2, 2, 2, 2, 
               ]







  # literally the same
  # predi_row = [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 
  #              0, 0, 0, 0, 0,
  #              1, 2, 2, 2, 2, 
  #              ]
  # print(f"SCORES for the scenario where predictions are perfect:")  
  


  # hallucinated a new entity
  # predi_row = [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 
  #              0, 0, 3, 4, 0,
  #              1, 2, 2, 2, 2, 
  #              ]
  # print(f"SCORES for the scenario where it made up a new entity")  
  
  


  # # the label switches from I-ORG to I-PER halfway thorough
  # predi_row = [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
  #              0, 0, 0, 0, 0,
  #              1, 2, 2, 2, 2, 
  #              ]
  # print(f"SCORES for the scenario where the label switches from I-ORG to I-PER halfway through")  


  # # the label switches from I-org to I-per and back
  # predi_row = [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 
  #              0, 0, 0, 0, 0,
  #              1, 2, 2, 2, 2, 
  #              ]
  # print(f"SCORES for the scenario where the label switches from I-ORG to I-PER and back")  



  # # the label switches from I-org to I-per and back, and there's a spurious one
  # predi_row = [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 
  #              0, 1, 2, 4, 0,
  #              1, 2, 2, 2, 2, 
  #              ]
  # print(f"SCORES for the scenario where the label switches from I-ORG to I-PER and back, plus there's an extra")  

  # # just two wrong............\/..........\/
  # predi_row = [3, 4, 4, 4, 4, 3, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 
  #              0, 0, 0, 0, 0,
  #              1, 2, 2, 2, 2, 
  #              ]  
  # print(f"SCORES for the scenario where a long entity got broken up into a few")
  
  


  # # Got the first one right, but continued onwards too long
  # predi_row = [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 
  #              4, 4, 4, 0, 0,
  #              1, 2, 2, 2, 2, 
  #              ]  
  # print(f"SCORES for where the prediction went on too long")
  



  # # Got the first one right, but didn't go long enough
  # predi_row = [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 
  #              0, 0, 0, 0, 0,
  #              1, 2, 2, 2, 2, 
  #              ]
  # print(f"SCORES for where the prediction didn't go on long enough")  
  

  

  # two wrong, and missed an entity entirely
  # predi_row = [3, 4, 4, 4, 4, 3, 4, 4, 4, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 
  #              0, 0, 0, 0, 0,
  #              0, 0, 0, 0, 0, 
  #              ]  
  # print(f"SCORES for where a long entity got broken up, AND one entity was missed")  
  

  


  # # missed one entirely, but got the first one right
  # predi_row = [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 
  #              0, 0, 0, 0, 0,
  #              0, 0, 0, 0, 0, 
  #              ]  
  # print(f"SCORES for where one entity was missed but the other was caught")  



  # predicting I without correspeonding B
  predi_row = [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, # will it crash?
               3, 2, 0, 2, 2, # <--- right here: will it add the 2 to the 3 entity?
               0, 1, 2, 2, 2, 
               ]  
  print(f"SCORES for where predicting I without corresponding B")  
  
  
  true_entities, predicted_entities = per_entity_metric(label_row=label_row, prediction_row=predi_row)
  print("true entities",true_entities)
  print("predicted entities",predicted_entities)

  scores = calculate_score_for_entities(true_entities, predicted_entities)
  for key in scores.keys():
      print(key, scores[key])
  
# test_per_entity_metric()

## add in seqeval because why not?

In [None]:
# https://github.com/chakki-works/seqeval
from seqeval.metrics import accuracy_score as seqeval_accuracy_score
from seqeval.metrics import classification_report as seqeval_classification_report
from seqeval.metrics import f1_score as seqeval_f1_score
from seqeval.scheme import IOB2 # pretty sure MasakhaNER is this
def calculate_seqeval_scores(labels, predictions):
  # print(type(labels))
  # print(type(predictions))
  # print(labels.size())
  # print(predictions.size())
  # labels = labels.flatten()
  # predictions = predictions.flatten()

  # print(type(labels))
  # print(type(predictions))
  # print(labels.size())
  # print(predictions.size())
  # <class 'torch.Tensor'>
  # <class 'torch.Tensor'>
  # torch.Size([292, 434])
  # torch.Size([292, 434])
  # <class 'torch.Tensor'>
  # <class 'torch.Tensor'>
  # torch.Size([126728])
  # torch.Size([126728])


  #seqeval expects a list of lists of strings
  #y_true: List[List[str]], y_pred: List[List[str]]
  string_labels = []
  string_predictions = []


  for label_row, prediction_row in zip(labels, predictions):
    row_labels = []
    row_predictions = []
    for lbl, pred in zip(label_row, prediction_row):
        if lbl != -100: # still not sure what this is!
          row_labels.append(lbl)
          row_predictions.append(pred)
    row_string_labels = [int_to_class_name(label) for label in row_labels]
    row_string_preds = [int_to_class_name(prediction) for prediction in row_predictions]
    string_labels.append(row_string_labels)
    string_predictions.append(row_string_preds)

  labels = string_labels
  predictions= string_predictions
  
  strict_report = seqeval_classification_report(labels, predictions, mode='strict', scheme=IOB2)
  print("strict seqeval report, scheme=IOB2")
  print(strict_report)

  

  seqeval_accuracy = seqeval_accuracy_score(labels, predictions)

  seqeval_f1 = seqeval_f1_score(labels, predictions)
  return {
      # "strict_report":strict_report,
      "seqeval_accuracy": seqeval_accuracy,
      "seqeval_f1" : seqeval_f1,
  }


## setup compute_metrics for both char and entity cases

In [None]:
# 
# Here's where we define our evaluation metric
def compute_metrics(pred: EvalPrediction) -> Dict:
    label_probs, embeddings = pred.predictions
    labels = torch.tensor(pred.label_ids)
    predictions = torch.max(torch.exp(torch.tensor(label_probs)), dim=2)[1]


    metric = torchmetrics.F1(multiclass=True)
    for label_row, prediction_row in zip(labels, predictions):
        row_labels = []
        row_predictions = []
        for lbl, pred in zip(label_row, prediction_row):
            if lbl != -100:
                row_labels.append(lbl)
                row_predictions.append(pred)

        row_labels = torch.tensor(row_labels)
        row_predictions = torch.tensor(row_predictions)
        assert row_labels.shape == row_predictions.shape
        metric.update(row_predictions, row_labels)
    f1 = metric.compute()


    metric_without_o = torchmetrics.F1(multiclass=True, ignore_index=0)
    for label_row, prediction_row in zip(labels, predictions):
        row_labels = []
        row_predictions = []
        for lbl, pred in zip(label_row, prediction_row):
            if lbl != -100:
                row_labels.append(lbl)
                row_predictions.append(pred)

        row_labels = torch.tensor(row_labels)
        row_predictions = torch.tensor(row_predictions)
        assert row_labels.shape == row_predictions.shape
        metric_without_o.update(row_predictions, row_labels)   
    f1_without_o = metric_without_o.compute()


    true_entities= []
    predicted_entities = []
    for label_row, prediction_row in zip(labels, predictions):
      row_true_entities, row_predicted_entities = per_entity_metric(label_row=label_row, prediction_row=prediction_row)
      true_entities.extend(row_true_entities)
      predicted_entities.extend(row_predicted_entities)
    
    scores = {}
    char_scores = {
                    'f1': f1.item(),
                    'f1_without_o': f1_without_o.item(),
                  }

    
    entity_scores = calculate_score_for_entities(true_entities=true_entities, predicted_entities=predicted_entities)
    seqeval_scores = calculate_seqeval_scores(labels, predictions)

    scores.update(entity_scores)  # update works in-place
    scores.update(char_scores) 
    scores.update(seqeval_scores)
    # print("**************************************")
    # print("SCORES")
    # print(scores)

    return scores

## training args

In [None]:
# Transformers setup
transformers.logging.set_verbosity_info()

# training_args = ShibaWordSegArgs(
#     "shiba_ner_trainer",
#     do_predict=True,
#     save_strategy='epoch',
# )

parser = HfArgumentParser((ShibaWordSegArgs,))

# TODO - Not sure if we need to modify any of the other training parameters, but
# I just used what SHIBA used for the word segmentation. Might need to modify 
# this for early stopping or lr changes.

# Training arguments
training_args = TrainingArguments(
    "shiba_ner_trainer",
    do_train=True,
    do_eval=True,
    do_predict=True,
    save_strategy='epoch',
    evaluation_strategy='epoch',
    debug="underflow_overflow",
    learning_rate=0.0004, # 0.0004 is the default from shiba Training.md 

    num_train_epochs=5000,
    # num_train_epochs=5,
    load_best_model_at_end=True,
    per_device_eval_batch_size=16, # P100 GPU on Colab Pro only 50-75% used at 8, but 14787MiB / 16280MiB at 16
    per_device_train_batch_size=16,
    logging_steps=100,
    seed = task_seed
)

# training_args= parser.parse_args_into_dataclasses()[0]



training_args.logging_dir = training_args.output_dir
print(training_args)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_find_unused_parameters=None,
debug=[<DebugOption.UNDERFLOW_OVERFLOW: 'underflow_overflow'>],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=True,
do_train=True,
eval_accumulation_steps=None,
eval_steps=None,
evaluation_strategy=IntervalStrategy.EPOCH,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greater_is_better=False,
group_by_length=False,
hub_model_id=None,
hub_strategy=HubStrategy.EVERY_SAVE,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=0.0004,
length_column_name=length,
load_best_model_at_end=True,
local_rank=-1,
log_level=-1,
log_level_replica=-1,
log_on_each_node=True,
logging_dir=shiba_ner_trainer,
logging_first_step=False,
logging_n

In [None]:
#https://colab.research.google.com/github/tensorflow/tensorboard/blob/master/docs/get_started.ipynb#scrollTo=6B95Hb6YVgPZ
!ls shiba_ner_trainer/
!rm -r shiba_ner_trainer/*  # delete previous runs


ls: cannot access 'shiba_ner_trainer/': No such file or directory
rm: cannot remove 'shiba_ner_trainer/*': No such file or directory


In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
%tensorboard --logdir shiba_ner_trainer/

## actually run the trainer


In [None]:
model, data_collator = load_model()

# Task.current_task().set_parameters_as_dict(training_args.to_dict())
Task.current_task().connect(training_args, name='training args')
# Let's do some training
trainer = Trainer(model=model,
                  args=training_args,
                  data_collator=data_collator,
                  train_dataset=dep['train'],
                  eval_dataset=dep['validation'],
                  
                  compute_metrics=compute_metrics,
                  callbacks=[
                          EarlyStoppingCallback(early_stopping_patience=20),
                      ]
                  # pretrained_bert=None,
                  )

trainer.train()

In [None]:
print(model)

# Post-train metrics

In [None]:

posttrain_tuple = trainer.predict(dep['test'])
posttrain_metrics = posttrain_tuple.metrics
posttrain_predictions = posttrain_tuple.predictions
posttrain_label_ids = posttrain_tuple.label_ids
print(posttrain_metrics)


In [None]:
from clearml import Logger
for test_metric in posttrain_metrics:
  print(test_metric, posttrain_metrics[test_metric])
  
  #report_scalar(title, series, value, iteration)
  Logger.current_logger().report_scalar("test", test_metric, posttrain_metrics[test_metric], 0)

In [None]:
from pathlib import Path


# predict defined here https://github.com/huggingface/transformers/blob/a13c8145bc2810e3f0a52da22ae6a6366587a41b/src/transformers/trainer.py#L2076

# PredictionOutput defined here
# https://github.com/huggingface/transformers/blob/a13c8145bc2810e3f0a52da22ae6a6366587a41b/src/transformers/trainer_utils.py#L88 


print(type(posttrain_predictions)) # tuple
print(type(posttrain_label_ids)) # numpy.ndarray
print(len(posttrain_predictions)) # 2
print(posttrain_label_ids.shape) # (599, 462)
for i, thing in enumerate(posttrain_predictions):
  print(f"Thing {i} is type() of {type(thing)} with shape {thing.shape}")
  np.save(f"posttrain_predictions_{i}", posttrain_predictions[i])  

#   print(f"element [0] shape is: {thing[0].shape}")
#   print(f"element [0][0] shape is: {thing[0][0].shape}")
#   print(f"element [0][0] is: {thing[0][0]}")
  # print(f"element [0][0][0] is: {thing[0][0][0].shape}")

np.save("posttrain_label_ids", posttrain_label_ids)
# np.save("posttrain_predictions", posttrain_predictions) # ValueError: could not broadcast input array from shape (599,462,9) into shape (599,462)





In [None]:
# check we can reload the saved files exactly:
# loaded_ndarray_0 = np.load("posttrain_predictions_0.npy")
# print(np.all(loaded_ndarray_0==posttrain_predictions[0]))


reloaded_predictions = []
for i, thing in enumerate(posttrain_predictions):
  reloaded_thing = np.load(f"posttrain_predictions_{i}.npy")
  print(np.all(reloaded_thing==thing))
  reloaded_predictions.append(reloaded_thing)

reloaded_predictions= tuple(reloaded_predictions)

reloaded_label_ids=np.load("posttrain_label_ids.npy")  
print(np.all(posttrain_label_ids==reloaded_label_ids))

In [None]:
import json
with open("posttrain_metrics.json", "w") as f:
  json.dump(posttrain_metrics, f)

In [None]:
posttrain_files = Path.cwd().glob("posttrain*")
for posttrain_file in posttrain_files:
  Task.current_task().upload_artifact(posttrain_file.name, artifact_object=posttrain_file)

In [None]:
Task.current_task().completed()