In [1]:
import numpy as np
import pandas as pd
from datasets import load_dataset
from datasets import Dataset, DatasetDict
from transformers.data import datasets
from torch.utils.data import DataLoader
from transformers import EarlyStoppingCallback, AutoTokenizer, DataCollatorWithPadding, AdamW
from transformers import TrainingArguments, Trainer
import torch
from transformers import BertTokenizer, BertModel

from sklearn.model_selection import train_test_split

import os
import sys
sys.path.append("/home/jovyan/20230406_ArticleClassifier/ArticleClassifier")
from src.general.utils import cc_path


In [2]:
 labels = ['human', 'mouse', 'rat', 'nonhuman', 'controlled study',
           'animal experiment', 'animal tissue', 'animal model', 'animal cell',
           'major clinical study', 'clinical article', 'case report',
           'multicenter study', 'systematic review', 'meta analysis',
           'observational study', 'pilot study', 'longitudinal study',
           'retrospective study', 'case control study', 'cohort analysis',
           'cross-sectional study', 'diagnostic test accuracy study',
           'double blind procedure', 'crossover procedure',
           'single blind procedure', 'adult', 'aged', 'middle aged', 'child',
           'adolescent', 'young adult', 'very elderly', 'infant', 'school child',
           'newborn', 'preschool child', 'embryo', 'fetus', 'male', 'female',
           'human cell', 'human tissue', 'normal human', 'human experiment',
           'phase 2 clinical trial', 'randomized controlled trial',
           'clinical trial', 'controlled clinical trial', 'phase 3 clinical trial',
           'phase 1 clinical trial', 'phase 4 clinical trial']

# Preprocessing

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [10]:
def load_whole_ds(DATA_PATH = "../data/"):

        total = pd.read_csv(cc_path(f"data/processed/canary/articles_cleaned.csv"))
        total.pui = total.pui.astype(str)
        
        total[labels] = total[labels].astype(int)
        
        total['str_keywords'] = total['keywords'].str.replace('[', ' ').str.replace(']', ' ').str.replace(', ', ' ').str.replace("'", '')
        total['embedding_text'] = total['title'] + total['str_keywords'] + total['abstract']


        with open(cc_path(f'data/canary_train_indices.txt')) as f:
            train_puis = f.read().splitlines()
            # print(train_puis)
        with open(cc_path(f'data/canary_val_indices.txt')) as f:
            val_puis = f.read().splitlines()
        with open(cc_path(f'data/canary_test_indices.txt')) as f:
            test_puis = f.read().splitlines()

        # Split data into train-validation-test sets
        train = total.loc[total.pui.isin(train_puis), :]
        val = total.loc[total.pui.isin(val_puis), :]
        test = total.loc[total.pui.isin(test_puis), :]
        
        return train, val, test

In [11]:
def load_small_ds(DATA_PATH = "../data/"):

    small_df = pd.read_csv(DATA_PATH + 'small_dataset.csv')
    all_x = small_df.iloc[:,-1]
    all_y = small_df.iloc[:,:-1]
    
    X_train, X_test, y_train, y_test = train_test_split(all_x, all_y,
    test_size=0.2, shuffle = True, random_state = 8)

    # Use the same function above for the validation set
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, 
    test_size=0.25, random_state= 8) # 0.25 x 0.8 = 0.2
    
    train_df = y_train
    train_df['input_raw'] = X_train
    test_df = y_test
    test_df['input_raw'] = X_test
    val_df = y_val
    val_df['input_raw'] = X_val
    
    return train_df, test_df, val_df

In [12]:
%%time
# Choose whether you want to load the whole dataset for the experiment
train_df, test_df, valid_df = load_whole_ds()
# train_df, test_df, valid_df = load_small_ds(DATA_PATH)

trds = Dataset.from_pandas(train_df)
vds = Dataset.from_pandas(valid_df)
teds = Dataset.from_pandas(test_df)

full_ds = DatasetDict()

full_ds['train'] = trds
full_ds['validation'] = vds
full_ds['test'] = teds

CPU times: user 7.12 s, sys: 991 ms, total: 8.11 s
Wall time: 8.11 s


In [13]:
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}

len(labels)

52

In [14]:
%%time
#pretrained_model_name = "microsoft/xtremedistil-l6-h256-uncased"  # for xtremedistil transformer
model_version = '../../data_preparation/scibert_scivocab_uncased'
do_lower_case = True
# model = BertModel.from_pretrained(model_version)
# trained_bert_model = torch.load(cc_path(f'models/embedders/finetuned_bert_56k_20e_3lay_best_iter.pt'), map_location=torch.device('cpu'))
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=do_lower_case)

def preprocess_data(examples):
    # take a batch of texts
    text = examples["embedding_text"]
    # encode them
    encoding = tokenizer(text, padding="max_length", truncation=True, max_length=512)
    # add labels
    labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
    # create numpy array of shape (batch_size, num_labels)
    labels_matrix = np.zeros((len(text), len(labels)))
    # fill numpy array
    for idx, label in enumerate(labels):
        labels_matrix[:, idx] = labels_batch[label]

    encoding["labels"] = labels_matrix.tolist()

    return encoding

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RobertaTokenizer'. 
The class this function is called from is 'BertTokenizer'.


CPU times: user 36.4 ms, sys: 0 ns, total: 36.4 ms
Wall time: 44.7 ms


In [15]:
%%time
try:
    encoded_dataset = full_ds.map(preprocess_data, batched=True, remove_columns=full_ds['train'].column_names)
except:
    encoded_dataset = full_ds.map(preprocess_data, batched=True, remove_columns=full_ds['test'].column_names)

Map:   0%|          | 0/36055 [00:00<?, ? examples/s]

Map:   0%|          | 0/11268 [00:00<?, ? examples/s]

Map:   0%|          | 0/9014 [00:00<?, ? examples/s]

CPU times: user 5min 30s, sys: 639 ms, total: 5min 30s
Wall time: 5min 30s


In [16]:
print(encoded_dataset.keys())
print(encoded_dataset['test'][0].keys())

dict_keys(['train', 'validation', 'test'])
dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])


In [17]:
from transformers import AutoModelForSequenceClassification
pretrained_model_name = "allenai/scibert_scivocab_cased"
model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name, 
                                                           problem_type="multi_label_classification", 
                                                           num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id)

Some weights of the model checkpoint at allenai/scibert_scivocab_cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were

In [65]:
# for name, param in model.named_parameters():
#     print(name, param.requires_grad)

# Loss functions definition

In [18]:
# Use custom loss

class F1Trainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
    
        S=-1
        E=0
        y = labels
        # Sigmoid hyperparams:
        b = torch.tensor(S)
        c = torch.tensor(E)

        # Calculate the sigmoid
        sig = 1 / (1 + torch.exp(b * (logits + c)))
        tp = torch.sum(sig * y, dim=0)
        fp = torch.sum(sig * (1 - y), dim=0)
        fn = torch.sum((1 - sig) * y, dim=0)

        sigmoid_f1 = 2*tp / (2*tp + fn + fp + 1e-16)
        cost = 1 - sigmoid_f1
        macroCost = torch.mean(cost)

        return (macroCost, outputs) if return_outputs else macroCost

In [19]:
gamma = 2
alpha = 0.75

class FLTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        
        y = labels
        y_hat = outputs.get("logits")
        
        BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss(reduction='none')
        bce = BCEWithLogitsLoss(y_hat, y)
        
        pt = torch.exp(-bce)
        
        alpha_factor = y * alpha + (1 - y) * (1 - alpha)
        modulating_factor = torch.pow((1.0 - pt), gamma)

        focal_loss = torch.mean(alpha_factor * modulating_factor * bce)
#         focal_loss = torch.mean(modulating_factor * ce)

        return (focal_loss, outputs) if return_outputs else focal_loss

In [20]:
gamma_neg=4
gamma_pos=1
clip=0.05
eps=1e-8
disable_torch_grad_focal_loss = True

class ASLTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        y = labels
        y_hat = outputs.get("logits")
        
        # Calculating Probabilities
        x_sigmoid = torch.sigmoid(y_hat)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid

        # Asymmetric Clipping
        if clip is not None and clip > 0:
            xs_neg = (xs_neg + clip).clamp(max=1)

        # Basic CE calculation
        los_pos = y * torch.log(xs_pos.clamp(min=eps))
        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=eps))
        loss = los_pos + los_neg

        # Asymmetric Focusing
        if gamma_neg > 0 or gamma_pos > 0:
            if disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            pt0 = xs_pos * y
            pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
            pt = pt0 + pt1
            one_sided_gamma = gamma_pos * y + gamma_neg * (1 - y)
            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
            if disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            loss *= one_sided_w
        
        asl_loss = -loss.mean()
        
        return (asl_loss, outputs) if return_outputs else asl_loss

In [21]:
# Use custom loss for HAMMING LOSS

class HLTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
    
        S=-1
        E=0
        y = labels
        # Sigmoid hyperparams:
        b = torch.tensor(S)
        c = torch.tensor(E)

        # Calculate the sigmoid
        sig = 1 / (1 + torch.exp(b * (logits + c)))

        fp = torch.sum(sig * (1 - y), dim=0)
        fn = torch.sum((1 - sig) * y, dim=0)
        
        hamm_loss = (fp + fn) / torch.sum(labels, dim=0).clamp(min=0.5)  #avoid dividing by 0 if there is no label for the class
        
        macroCost = torch.mean(hamm_loss)

        return (macroCost, outputs) if return_outputs else macroCost

In [22]:
# Use custom loss
def get_classWeights():
    train_labels = train_df[labels]
    tot = sum(train_labels.sum(axis=0))
    weight = 1 - (train_labels.sum(axis=0) / tot)
    
    return torch.tensor(weight)

weights = get_classWeights().to(device)
weights = weights**2

class F1weightTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
    
        S=-1
        E=0
        y = labels
        # Sigmoid hyperparams:
        b = torch.tensor(S)
        c = torch.tensor(E)

        # Calculate the sigmoid
        sig = 1 / (1 + torch.exp(b * (logits + c)))
        tp = torch.sum(sig * y, dim=0)
        fp = torch.sum(sig * (1 - y), dim=0)
        fn = torch.sum((1 - sig) * y, dim=0)

        sigmoid_f1 = 2*tp / (2*tp + fn + fp + 1e-16)
        cost = 1 - sigmoid_f1
        weighted_cost = torch.mul(cost, weights)
        macroCost = torch.mean(weighted_cost)

        return (macroCost, outputs) if return_outputs else macroCost

In [23]:
class F1learnTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")

        S=-2
        E=0
        y = labels
        # Sigmoid hyperparams:
        b = torch.tensor(S)
        c = self.model.F1thr # Learnable parameter!!!!!!!!!!!!!!!! :)
#         c = torch.tensor(E)

        # Calculate the sigmoid
        sig = 1 / (1 + torch.exp(b * (logits + c)))
        tp = torch.sum(sig * y, dim=0)
        fp = torch.sum(sig * (1 - y), dim=0)
        fn = torch.sum((1 - sig) * y, dim=0)

        sigmoid_f1 = 2*tp / (2*tp + fn + fp + 1e-16)
        cost = 1 - sigmoid_f1
        macroCost = torch.mean(cost)

        return (macroCost, outputs) if return_outputs else macroCost

# Trainer

In [24]:
args = TrainingArguments(
    f"scibert-cased",
    evaluation_strategy = 'steps',
#     eval_steps = 100,
    eval_steps = 500,
    save_strategy = "steps",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='f1_micro',
    report_to="none",
    fp16=True,
    gradient_checkpointing=True
    #gradient_accumulation_steps=256
)

In [25]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch
    
# source: https://jesuslea l.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    f1_macro_average = f1_score(y_true=y_true, y_pred=y_pred, average='macro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1_micro': f1_micro_average,
               'f1_macro': f1_macro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

In [26]:
# Set the loss function
loss_fn = 'F1learn'
# ---!!!---!!! SELECT !!!---!!!---

trainers_dict = {'BCE': Trainer, 'F1': F1Trainer, 'FL': FLTrainer, 'ASL': ASLTrainer, 'HL': HLTrainer, 'F1weight': F1weightTrainer, 'F1learn': F1learnTrainer} 

if loss_fn == 'F1learn':
#     model.register_parameter(name='weights', param=torch.nn.Parameter(torch.ones(52)))
    model.register_parameter(name='F1thr', param=torch.nn.Parameter(torch.zeros(52)))
    print(model.F1thr)
    
trainer_class = trainers_dict[loss_fn]

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.], requires_grad=True)


In [27]:
trainer = trainer_class(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics, 
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3000000)]
)

In [29]:
%%time
trainer.train(resume_from_checkpoint=False)
# trainer.train()



Step,Training Loss,Validation Loss,F1 Micro,F1 Macro,Roc Auc,Accuracy
500,0.8025,0.743536,0.32843,0.281726,0.720116,0.0
1000,0.7235,0.711917,0.408753,0.321105,0.77785,0.0
1500,0.6905,0.677341,0.549061,0.388048,0.837475,0.031683
2000,0.6619,0.655611,0.610596,0.445799,0.853294,0.036297
2500,0.643,0.644278,0.623795,0.448831,0.861813,0.035233
3000,0.6336,0.639261,0.631977,0.459034,0.863695,0.033635
3500,0.6236,0.63238,0.658487,0.473253,0.8624,0.043575
4000,0.6149,0.628281,0.662336,0.488808,0.861989,0.043663
4500,0.614,0.625852,0.658324,0.491851,0.87199,0.0418
5000,0.607,0.624351,0.675905,0.501907,0.869343,0.046947


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [30]:
trainer.evaluate()

{'eval_loss': 0.6085918545722961,
 'eval_f1_micro': 0.7144662336154314,
 'eval_f1_macro': 0.5349275228983671,
 'eval_roc_auc': 0.8774336814126337,
 'eval_accuracy': 0.07570110046148384,
 'eval_runtime': 50.8285,
 'eval_samples_per_second': 221.686,
 'eval_steps_per_second': 6.945,
 'epoch': 10.0}

In [138]:
torch.save(model, cc_path(f'models/baselines/paula_finetuned_bert_56k_10e_shifted_f1.pt'))

# # Predict:
# predictions = trainer.predict(encoded_dataset['test'])
# print(predictions.predictions.shape, predictions.label_ids.shape)

torch.save(model, cc_path(f'models/baselines/paula_finetuned_bert_56k_10e_tka.pt'))
# Inference

In [11]:
model = torch.load(cc_path(f'models/baselines/paula_finetuned_bert_56k_10e_tka.pt'))


In [12]:
from tqdm import tqdm

with torch.no_grad():
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    model.to(device)
    test_dataloader = DataLoader(encoded_dataset["test"], shuffle=False, batch_size=8, collate_fn=data_collator)
    outputs = torch.Tensor()
    labels = torch.Tensor()
    for bi, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)
        if len(outputs) == 0:
            outputs = out.logits
            labels = batch['labels']
        else:
            outputs = torch.cat((outputs, out.logits), 0)
            labels = torch.cat((labels, batch['labels']), 0)
            
print('Done')

100%|██████████| 1127/1127 [01:26<00:00, 13.06it/s]

Done





In [140]:
out = pd.DataFrame(outputs.cpu().numpy())
out.to_csv('out.csv')

In [13]:

threshold = 0.5
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(outputs.to('cpu'))
probs
# next, use threshold to turn them into integer predictions
y_pred = np.zeros(probs.shape)
y_pred[np.where(probs >= threshold)] = 1

# finally, compute metrics
y_true = labels.cpu().numpy()
f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
f1_macro_average = f1_score(y_true=y_true, y_pred=y_pred, average='macro')
roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
accuracy = accuracy_score(y_true, y_pred)
# return as dictionary
metrics = {'f1_micro': f1_micro_average,
           'f1_macro': f1_macro_average,
           'roc_auc': roc_auc,
           'accuracy': accuracy}
metrics

NameError: name 'f1_score' is not defined

In [15]:
def get_metrics(preds, labels, thr=0.5):
    '''
    Create some metrics: precison, recall, F1...
    
    A macro-average will compute the metric independently for each class and then take the average hence 
    treating all classes equally, whereas a micro-average will aggregate the contributions of all classes
    to compute the average metric.
    '''
    
    # Convert the lists to dataframes
    lab_df = pd.DataFrame(labels.cpu().numpy())
    y_pred = np.zeros(preds.shape)
    y_pred[np.where(preds >= threshold)] = 1
    pred_df = pd.DataFrame(y_pred)
    
    # Calculate tp/fp/fn/tn per class:
    tp = (pred_df + lab_df).eq(2).sum()
    fp = (pred_df - lab_df).eq(1).sum()
    fn = (pred_df - lab_df).eq(-1).sum()
    tn = (pred_df + lab_df).eq(0).sum()
    
    # Calculate precision and recall:
    prec = [tp[i] / (tp[i] + fp[i]) if tp[i] + fp[i] != 0 else 0.0 for i in range(len(tp))]
    rec = [tp[i] / (tp[i] + fn[i]) if tp[i] + fn[i] != 0 else 0.0 for i in range(len(tp))]
    
    # Calculate F1 score:
    f1_score = [2 * prec[i] * rec[i] / (prec[i] + rec[i]) if tp[i] > 0 else 0.0 for i in range(len(tp))]
    
    # Weighted F1 score:
    weight = lab_df.sum() / sum(lab_df.sum())
    f1_wght = [weight[i] * 2 * prec[i] * rec[i] / (prec[i] + rec[i]) if tp[i] > 0 else 0.0 for i in range(len(tp))]
    
    # Macro average (average over classes):
    prec_avg = sum(prec) / len(prec)
    rec_avg = sum(rec) / len(rec)
    f1_avg = sum(f1_score) / len(f1_score)
    f1wgt_avg = sum(f1_wght)
    
    # Micro scores (treat all samples together):
    tp_mic = sum(tp)
    tn_mic = sum(tn)
    fp_mic = sum(fp)
    fn_mic = sum(fn)
    prec_mic = tp_mic / (tp_mic+fp_mic)
    rec_mic = tp_mic / (tp_mic+fn_mic)
    f1_mic = (2*prec_mic*rec_mic) / (prec_mic+rec_mic)
    
    return {
        'Precision': prec, 'Recall': rec, 'F1 score': f1_score,
        'weights': weight, 'Weighted F1 score': f1_wght,
        'Macro precision': prec_avg.round(4), 'Macro recall': rec_avg.round(4), 'Macro F1 score': f1_avg.round(4),
        'Weighted F1 score': f1wgt_avg.round(4),
        'CM TP': tp, 'CM FP': fp,'CM FN': fn, 'CM TN': tn,
        'Micro Precision': round(prec_mic, 4), 'Micro Recall': round(rec_mic, 4), 'Micro F1 score': round(f1_mic, 4),
    }



In [21]:
all_metrics = get_metrics(probs, labels)

for metr, val in all_metrics.items():
    if 'Micro' in metr or 'Macro' in metr:
        print(metr, val)
        
        
label_names = ['human', 'mouse', 'rat', 'nonhuman', 'controlled study',
           'animal experiment', 'animal tissue', 'animal model', 'animal cell',
           'major clinical study', 'clinical article', 'case report',
           'multicenter study', 'systematic review', 'meta analysis',
           'observational study', 'pilot study', 'longitudinal study',
           'retrospective study', 'case control study', 'cohort analysis',
           'cross-sectional study', 'diagnostic test accuracy study',
           'double blind procedure', 'crossover procedure',
           'single blind procedure', 'adult', 'aged', 'middle aged', 'child',
           'adolescent', 'young adult', 'very elderly', 'infant', 'school child',
           'newborn', 'preschool child', 'embryo', 'fetus', 'male', 'female',
           'human cell', 'human tissue', 'normal human', 'human experiment',
           'phase 2 clinical trial', 'randomized controlled trial',
           'clinical trial', 'controlled clinical trial', 'phase 3 clinical trial',
           'phase 1 clinical trial', 'phase 4 clinical trial']

Macro precision 0.6003
Macro recall 0.5858
Macro F1 score 0.5774
Micro Precision 0.7165
Micro Recall 0.8314
Micro F1 score 0.7697


In [24]:
dict(zip(label_names, all_metrics['Recall']))

{'human': 0.9602497715504112,
 'mouse': 0.8479467258601554,
 'rat': 0.8104395604395604,
 'nonhuman': 0.8655400440852314,
 'controlled study': 0.9622364802933089,
 'animal experiment': 0.8684603886397608,
 'animal tissue': 0.8333333333333334,
 'animal model': 0.8568464730290456,
 'animal cell': 0.7227722772277227,
 'major clinical study': 0.8945001729505362,
 'clinical article': 0.7278645833333334,
 'case report': 0.8568773234200744,
 'multicenter study': 0.5167464114832536,
 'systematic review': 0.7831325301204819,
 'meta analysis': 0.9056603773584906,
 'observational study': 0.5108108108108108,
 'pilot study': 0.6875,
 'longitudinal study': 0.6118421052631579,
 'retrospective study': 0.8083538083538083,
 'case control study': 0.639751552795031,
 'cohort analysis': 0.6593886462882096,
 'cross-sectional study': 0.7593984962406015,
 'diagnostic test accuracy study': 0.4873096446700508,
 'double blind procedure': 0.7402597402597403,
 'crossover procedure': 0.5238095238095238,
 'single bli

## train metrics

In [144]:
from tqdm import tqdm

with torch.no_grad():
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    model.to(device)
    train_dataloader = DataLoader(encoded_dataset["train"], shuffle=False, batch_size=8, collate_fn=data_collator)
    outputs = torch.Tensor()
    labels = torch.Tensor()
    for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)
        if len(outputs) == 0:
            outputs = out.logits
            labels = batch['labels']
        else:
            outputs = torch.cat((outputs, out.logits), 0)
            labels = torch.cat((labels, batch['labels']), 0)
print('Done')

threshold = 0.5
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(outputs.to('cpu'))
probs
# next, use threshold to turn them into integer predictions
y_pred = np.zeros(probs.shape)
y_pred[np.where(probs >= threshold)] = 1

# finally, compute metrics
y_true = labels.cpu().numpy()
f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
f1_macro_average = f1_score(y_true=y_true, y_pred=y_pred, average='macro')
roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
accuracy = accuracy_score(y_true, y_pred)
# return as dictionary
metrics = {'f1_micro': f1_micro_average,
           'f1_macro': f1_macro_average,
           'roc_auc': roc_auc,
           'accuracy': accuracy}

all_metrics = get_metrics(probs, labels)

for metr, val in all_metrics.items():
    if 'Micro' in metr or 'Macro' in metr:
        print(metr, val)

100%|██████████| 4507/4507 [05:39<00:00, 13.26it/s]


Done
Macro precision 0.7533
Macro recall 0.809
Macro F1 score 0.7584
Micro Precision 0.794
Micro Recall 0.9475
Micro F1 score 0.864
