In [1]:
import numpy as np
import pandas as pd

## Multi-label Fine-tuning with BART on Drug Reviews

### Data Overview

In [2]:
df = pd.read_csv("../Data/preprocessed_osteoporosis_11labels.csv")#load labeled data
df = df[:99]#only labeled 99 records

In [3]:
#replace NaN in label columns with 0
columns = ['limb pain','gastrointestinal', 'dental', 'cardiac', 'dermatological',
       'respiratory', 'weight gain and loss', 'headache', 'menstrual',
       'fatigue', 'body temperature']

for column in columns:
    df[column] = df[column].replace(np.nan, int(0))

df

Unnamed: 0,Age,Condition,Date,Drug,DrugId,EaseofUse,Effectiveness,Reviews,Satisfaction,Sex,...,gastrointestinal,dental,cardiac,dermatological,respiratory,weight gain and loss,headache,menstrual,fatigue,body temperature
0,45-54,Post-Menopausal Osteoporosis Prevention,7/31/2017,lopreeza,167327,5,2,After taking this drug for approx. 21 days I s...,3,Male,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0
1,45-54,Post-Menopausal Osteoporosis Prevention,12/29/2016,lopreeza,167327,5,5,I have taken this drug for almost 7 years with...,5,Female,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,45-54,Osteoporosis,1/19/2012,oyster shell + d,94390,1,1,I have severe pain in my hand and muscle joint...,1,Female,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,75+,Osteoporosis,2/23/2015,os-cal 500-vit d3,16527,1,3,Food dyes and talc...large pill for a newly re...,1,Female,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
4,45-54,Osteoporosis,8/27/2012,os-cal 500-vit d3,16527,5,3,I have taken it for 6 months and it did not in...,1,Male,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
94,75+,Osteoporosis,10/24/2016,prolia syringe,154218,1,1,"Two days after receiving shot of prolia, I beg...",1,Female,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
95,45-54,Osteoporosis,10/7/2016,prolia syringe,154218,1,1,I thought I was going crazy...but reading thes...,1,,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
96,55-64,Osteoporosis,10/6/2016,prolia syringe,154218,1,1,First shot I got diarrhea about ten days later...,1,Female,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
97,65-74,Aromatase Inhibitor Drug-Induced Osteoporosis,9/27/2016,prolia syringe,154218,5,4,I haven taken Prolia every 6 months for 4 year...,4,Female,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [4]:
#extract column "Reviews" and label columns for tuning
df = df[['Reviews','limb pain','gastrointestinal', 'dental', 'cardiac', 'dermatological',
       'respiratory', 'weight gain and loss', 'headache', 'menstrual',
       'fatigue', 'body temperature' ]]

In [5]:
df

Unnamed: 0,Reviews,limb pain,gastrointestinal,dental,cardiac,dermatological,respiratory,weight gain and loss,headache,menstrual,fatigue,body temperature
0,After taking this drug for approx. 21 days I s...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0
1,I have taken this drug for almost 7 years with...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,I have severe pain in my hand and muscle joint...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,Food dyes and talc...large pill for a newly re...,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
4,I have taken it for 6 months and it did not in...,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...
94,"Two days after receiving shot of prolia, I beg...",1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
95,I thought I was going crazy...but reading thes...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
96,First shot I got diarrhea about ten days later...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
97,I haven taken Prolia every 6 months for 4 year...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [6]:
#train/val/test split
df_train = df.iloc[:60]
df_val = df.iloc[60:80]
df_test = df.iloc[80:]

df_train.to_csv("../Data/train_11labels.csv", index=False)
df_val.to_csv("../Data/val_11labels.csv", index=False)
df_test.to_csv("../Data/test_11labels.csv", index=False)

### Create Data Loader

In [7]:
from datasets import load_dataset

dataset = load_dataset("csv", data_files={"train": "../Data/train_11labels.csv", "val": "../Data/val_11labels.csv", "test": "../Data/test_11labels.csv"})

Using custom data configuration default-7155160b68322eb6


Downloading and preparing dataset csv/default to C:\Users\Opal.Yang\.cache\huggingface\datasets\csv\default-7155160b68322eb6\0.0.0\652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset csv downloaded and prepared to C:\Users\Opal.Yang\.cache\huggingface\datasets\csv\default-7155160b68322eb6\0.0.0\652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
#an example of reviews
dataset["train"]

Dataset({
    features: ['Reviews', 'limb pain', 'gastrointestinal', 'dental', 'cardiac', 'dermatological', 'respiratory', 'weight gain and loss', 'headache', 'menstrual', 'fatigue', 'body temperature'],
    num_rows: 60
})

### Load Tokenizer 

Process the reviews and include a padding and truncation strategy to handle any variable sequence length.

In [9]:
from transformers import BartTokenizer

labels = ['limb pain', 'gastrointestinal', 'dental', 'cardiac', 'dermatological',
       'respiratory', 'weight gain and loss', 'headache', 'menstrual',
       'fatigue', 'body temperature']
       
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli")

def preprocess_data(examples):
  # take a batch of texts
  text = examples["Reviews"]
  # encode them
  encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128)
  # add labels
  labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
  # create numpy array of shape (batch_size, num_labels)
  labels_matrix = np.zeros((len(text), len(labels)))
  # fill numpy array
  for idx, label in enumerate(labels):
    labels_matrix[:, idx] = labels_batch[label]

  encoding["labels"] = labels_matrix.tolist()
  
  return encoding

In [10]:
#encode dataset
encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset['train'].column_names)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [11]:
example = encoded_dataset['train'][3]
print(example.keys())

dict_keys(['input_ids', 'attention_mask', 'labels'])


In [12]:
tokenizer.decode(example['input_ids'])

'<s>Food dyes and talc...large pill for a newly removed thyroid throat! So I have been dissolving them in a bit of hot water, add a little fruit juice and drink it! HAVE A BAD COUGH SINCE STARTING THIS CALCIUM. COULD IT BE AN ALLERGIC REACTION?</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'

In [13]:
example["labels"]

[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]

In [14]:
encoded_dataset.set_format("torch")#set the format of our data as Pytorch tensor

### Load and Train the Model

Discard the head of the BART model and replace with a randomly initialized classification head. Fine-tune this new model head on my sequence classification task, transferring the knowledge of the pre-trained model to it.

In [15]:
from transformers import BartForSequenceClassification

#define the model
model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli",problem_type="multi_label_classification", num_labels=11, ignore_mismatched_sizes=True)

Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-large-mnli and are newly initialized because the shapes did not match:
- classification_head.out_proj.weight: found shape torch.Size([3, 1024]) in the checkpoint and torch.Size([11, 1024]) in the model instantiated
- classification_head.out_proj.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([11]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
#declare training hyperparameters
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="../Output/Model",
                                learning_rate=1e-4,
                                evaluation_strategy="epoch",
                                per_device_train_batch_size=8,
                                per_device_eval_batch_size=8,
                                num_train_epochs=30,
                                weight_decay=0.01,
                                metric_for_best_model="accuracy",
                                # save_strategy="epoch",
                                # load_best_model_at_end=True

)

In [17]:
#define a compute_metrics function that returns a dictionary with the desired metric values
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch

def multi_label_metrics(predictions, labels, threshold=0.5):
  sigmoid = torch.nn.Sigmoid()
  probs = sigmoid(torch.Tensor(predictions))
  y_pred = np.zeros(probs.shape)
  y_pred[np.where(probs >= threshold)] = 1
  y_true = labels
  f1_micro_average = f1_score(y_true, y_pred, average="weighted") 
  roc_auc = roc_auc_score(y_true, y_pred, average="micro") 
  accuracy = accuracy_score(y_true, y_pred)

  metrics = {"f1": f1_micro_average,
             "roc_auc": roc_auc,
             "accuracy": accuracy}
  return metrics

def compute_metrics(p: EvalPrediction):
  preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
  result = multi_label_metrics(predictions=preds, labels=p.label_ids)
  return result

In [18]:
#trainer
from transformers import Trainer

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = encoded_dataset["train"],
    eval_dataset = encoded_dataset["val"],
    tokenizer = tokenizer,
    compute_metrics = compute_metrics
)

In [19]:
#fine-tune BART model by calling train()
trainer.train()

***** Running training *****
  Num examples = 60
  Num Epochs = 30
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 240


  0%|          | 0/240 [00:00<?, ?it/s]

***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.37884193658828735, 'eval_f1': 0.1243604879968516, 'eval_roc_auc': 0.5463458110516934, 'eval_accuracy': 0.15, 'eval_runtime': 7.3817, 'eval_samples_per_second': 2.709, 'eval_steps_per_second': 0.406, 'epoch': 1.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.4006187915802002, 'eval_f1': 0.0, 'eval_roc_auc': 0.5, 'eval_accuracy': 0.25, 'eval_runtime': 7.5007, 'eval_samples_per_second': 2.666, 'eval_steps_per_second': 0.4, 'epoch': 2.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.3849164843559265, 'eval_f1': 0.3063241106719367, 'eval_roc_auc': 0.6319073083778965, 'eval_accuracy': 0.15, 'eval_runtime': 7.1282, 'eval_samples_per_second': 2.806, 'eval_steps_per_second': 0.421, 'epoch': 3.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.39471691846847534, 'eval_f1': 0.3306693306693307, 'eval_roc_auc': 0.6319073083778965, 'eval_accuracy': 0.25, 'eval_runtime': 7.3421, 'eval_samples_per_second': 2.724, 'eval_steps_per_second': 0.409, 'epoch': 4.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.42028871178627014, 'eval_f1': 0.3482929491935965, 'eval_roc_auc': 0.6149732620320857, 'eval_accuracy': 0.15, 'eval_runtime': 7.3511, 'eval_samples_per_second': 2.721, 'eval_steps_per_second': 0.408, 'epoch': 5.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.38244184851646423, 'eval_f1': 0.31773618538324416, 'eval_roc_auc': 0.624777183600713, 'eval_accuracy': 0.2, 'eval_runtime': 7.4899, 'eval_samples_per_second': 2.67, 'eval_steps_per_second': 0.401, 'epoch': 6.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.4563431143760681, 'eval_f1': 0.3134380239643398, 'eval_roc_auc': 0.5864527629233511, 'eval_accuracy': 0.1, 'eval_runtime': 7.2864, 'eval_samples_per_second': 2.745, 'eval_steps_per_second': 0.412, 'epoch': 7.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.39093250036239624, 'eval_f1': 0.2798573975044563, 'eval_roc_auc': 0.6024955436720142, 'eval_accuracy': 0.2, 'eval_runtime': 7.4538, 'eval_samples_per_second': 2.683, 'eval_steps_per_second': 0.402, 'epoch': 8.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.4455749988555908, 'eval_f1': 0.2943722943722944, 'eval_roc_auc': 0.5962566844919786, 'eval_accuracy': 0.2, 'eval_runtime': 7.6379, 'eval_samples_per_second': 2.619, 'eval_steps_per_second': 0.393, 'epoch': 9.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.3841358423233032, 'eval_f1': 0.39457735247208936, 'eval_roc_auc': 0.6524064171122995, 'eval_accuracy': 0.35, 'eval_runtime': 7.4722, 'eval_samples_per_second': 2.677, 'eval_steps_per_second': 0.401, 'epoch': 10.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

{'eval_loss': 0.42975449562072754, 'eval_f1': 0.3232323232323232, 'eval_roc_auc': 0.6122994652406417, 'eval_accuracy': 0.3, 'eval_runtime': 7.3538, 'eval_samples_per_second': 2.72, 'eval_steps_per_second': 0.408, 'epoch': 11.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.3916727900505066, 'eval_f1': 0.3913570090040678, 'eval_roc_auc': 0.642602495543672, 'eval_accuracy': 0.35, 'eval_runtime': 7.3133, 'eval_samples_per_second': 2.735, 'eval_steps_per_second': 0.41, 'epoch': 12.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.3898325562477112, 'eval_f1': 0.38396082032445666, 'eval_roc_auc': 0.6506238859180036, 'eval_accuracy': 0.4, 'eval_runtime': 8.1559, 'eval_samples_per_second': 2.452, 'eval_steps_per_second': 0.368, 'epoch': 13.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.42833980917930603, 'eval_f1': 0.3759612936083524, 'eval_roc_auc': 0.6452762923351159, 'eval_accuracy': 0.3, 'eval_runtime': 8.1285, 'eval_samples_per_second': 2.46, 'eval_steps_per_second': 0.369, 'epoch': 14.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.4109654426574707, 'eval_f1': 0.3539832716303304, 'eval_roc_auc': 0.6327985739750445, 'eval_accuracy': 0.3, 'eval_runtime': 8.1947, 'eval_samples_per_second': 2.441, 'eval_steps_per_second': 0.366, 'epoch': 15.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.41140222549438477, 'eval_f1': 0.3641549181656133, 'eval_roc_auc': 0.6381461675579322, 'eval_accuracy': 0.35, 'eval_runtime': 8.2186, 'eval_samples_per_second': 2.433, 'eval_steps_per_second': 0.365, 'epoch': 16.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.418842613697052, 'eval_f1': 0.3641549181656133, 'eval_roc_auc': 0.6354723707664884, 'eval_accuracy': 0.35, 'eval_runtime': 7.9307, 'eval_samples_per_second': 2.522, 'eval_steps_per_second': 0.378, 'epoch': 17.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.4206825792789459, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6479500891265596, 'eval_accuracy': 0.35, 'eval_runtime': 7.6111, 'eval_samples_per_second': 2.628, 'eval_steps_per_second': 0.394, 'epoch': 18.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.42395925521850586, 'eval_f1': 0.38396082032445666, 'eval_roc_auc': 0.6506238859180036, 'eval_accuracy': 0.4, 'eval_runtime': 7.5438, 'eval_samples_per_second': 2.651, 'eval_steps_per_second': 0.398, 'epoch': 19.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.4247366487979889, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6479500891265596, 'eval_accuracy': 0.35, 'eval_runtime': 7.4982, 'eval_samples_per_second': 2.667, 'eval_steps_per_second': 0.4, 'epoch': 20.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.4244500994682312, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6479500891265596, 'eval_accuracy': 0.35, 'eval_runtime': 7.532, 'eval_samples_per_second': 2.655, 'eval_steps_per_second': 0.398, 'epoch': 21.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.42786312103271484, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6479500891265596, 'eval_accuracy': 0.35, 'eval_runtime': 8.0792, 'eval_samples_per_second': 2.475, 'eval_steps_per_second': 0.371, 'epoch': 22.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.43019142746925354, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6479500891265596, 'eval_accuracy': 0.35, 'eval_runtime': 8.1294, 'eval_samples_per_second': 2.46, 'eval_steps_per_second': 0.369, 'epoch': 23.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.4331286549568176, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6479500891265596, 'eval_accuracy': 0.35, 'eval_runtime': 8.0094, 'eval_samples_per_second': 2.497, 'eval_steps_per_second': 0.375, 'epoch': 24.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.43553757667541504, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6506238859180036, 'eval_accuracy': 0.35, 'eval_runtime': 7.5382, 'eval_samples_per_second': 2.653, 'eval_steps_per_second': 0.398, 'epoch': 25.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.4354124069213867, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6506238859180036, 'eval_accuracy': 0.35, 'eval_runtime': 8.0856, 'eval_samples_per_second': 2.474, 'eval_steps_per_second': 0.371, 'epoch': 26.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.43470054864883423, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6506238859180036, 'eval_accuracy': 0.35, 'eval_runtime': 7.4859, 'eval_samples_per_second': 2.672, 'eval_steps_per_second': 0.401, 'epoch': 27.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.43469470739364624, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6506238859180036, 'eval_accuracy': 0.35, 'eval_runtime': 7.4634, 'eval_samples_per_second': 2.68, 'eval_steps_per_second': 0.402, 'epoch': 28.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.43491458892822266, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6506238859180036, 'eval_accuracy': 0.35, 'eval_runtime': 7.4189, 'eval_samples_per_second': 2.696, 'eval_steps_per_second': 0.404, 'epoch': 29.0}


***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


Training completed. Do not forget to share your model on huggingface.co/models =)




{'eval_loss': 0.4349861145019531, 'eval_f1': 0.3784511784511784, 'eval_roc_auc': 0.6506238859180036, 'eval_accuracy': 0.35, 'eval_runtime': 7.4696, 'eval_samples_per_second': 2.678, 'eval_steps_per_second': 0.402, 'epoch': 30.0}
{'train_runtime': 2500.1789, 'train_samples_per_second': 0.72, 'train_steps_per_second': 0.096, 'train_loss': 0.09198493162790934, 'epoch': 30.0}


TrainOutput(global_step=240, training_loss=0.09198493162790934, metrics={'train_runtime': 2500.1789, 'train_samples_per_second': 0.72, 'train_steps_per_second': 0.096, 'train_loss': 0.09198493162790934, 'epoch': 30.0})

### Evaluation

In [29]:
#evaluate on val set
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 20
  Batch size = 8


  0%|          | 0/3 [00:00<?, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


{'eval_loss': 0.4349861145019531,
 'eval_f1': 0.3784511784511784,
 'eval_roc_auc': 0.6506238859180036,
 'eval_accuracy': 0.35,
 'eval_runtime': 6.5573,
 'eval_samples_per_second': 3.05,
 'eval_steps_per_second': 0.458,
 'epoch': 30.0}

### Save the model

In [30]:
trainer.save_model("../Output/Model/Best_Model")#save the model

Saving model checkpoint to ../Output/Model/Best_Model
Configuration saved in ../Output/Model/Best_Model\config.json
Model weights saved in ../Output/Model/Best_Model\pytorch_model.bin
tokenizer config file saved in ../Output/Model/Best_Model\tokenizer_config.json
Special tokens file saved in ../Output/Model/Best_Model\special_tokens_map.json


### Review Classification on Original Data

In [2]:
df_test = pd.read_csv("../Data/test_11labels.csv")
df_test = df_test[["Reviews"]]

In [3]:
df_test

Unnamed: 0,Reviews
0,I am a fit 63 year old male who was surprised ...
1,I had my first injection in January. No side ...
2,"At the age of 62, I discovered I had severe os..."
3,"About a week after first injection, developed ..."
4,I want to thank everyone here that rated this ...
5,"5th injection, no side effects,now normal bone..."
6,This is the worse medication my husband has ev...
7,Call FDA to report negative side effects 1 888...
8,This month I have received my 4th injection of...
9,This drug has ruined my mother's daily life. P...


In [4]:
#Load the fine-tuned BART model
from transformers import BartForSequenceClassification

model = BartForSequenceClassification.from_pretrained("../Output/Model/Best_Model", problem_type="multi_label_classification", num_labels=11)

You passed along `num_labels=3` with an incompatible id to label map: {'0': 'LABEL_0', '1': 'LABEL_1', '2': 'LABEL_2', '3': 'LABEL_3', '4': 'LABEL_4', '5': 'LABEL_5', '6': 'LABEL_6', '7': 'LABEL_7', '8': 'LABEL_8', '9': 'LABEL_9', '10': 'LABEL_10'}. The number of labels wil be overwritten to 11.


In [5]:
#Load the tokenizer
from transformers import BartTokenizer

tokenizer = BartTokenizer.from_pretrained("../Output/Model/Best_Model")

In [6]:
import torch

limb_pain = []
gastrointestinal = []
dental = []
cardiac = []
dermatological = []
respiratory = []
weight_gain_and_loss = []
headache = []
menstrual = []
fatigue = []
body_temperature = []

for i in range(len(df_test["Reviews"])):
  text = df_test["Reviews"][i]
  encoding = tokenizer(text, return_tensors="pt")
  encoding = {k: v.to(model.device) for k,v in encoding.items()}
  outputs = model(**encoding)

  #apply sigmoid and threshold
  sigmoid = torch.nn.Sigmoid()
  logits = outputs.logits
  probs = sigmoid(logits.squeeze().cpu())

  limb_pain.append(probs[0].detach().numpy())
  gastrointestinal.append(probs[1].detach().numpy())
  dental.append(probs[2].detach().numpy())
  cardiac.append(probs[3].detach().numpy())
  dermatological.append(probs[4].detach().numpy())
  respiratory.append(probs[5].detach().numpy())
  weight_gain_and_loss.append(probs[6].detach().numpy())
  headache.append(probs[7].detach().numpy())
  menstrual.append(probs[8].detach().numpy())
  fatigue.append(probs[9].detach().numpy())
  body_temperature.append(probs[10].detach().numpy())

In [7]:
df_test["limb_pain"] = limb_pain
df_test["gastrointestinal"] = gastrointestinal
df_test["dental"] = dental
df_test["cardiac"] = cardiac
df_test["dermatological"] = dermatological
df_test["respiratory"] = respiratory
df_test["weight_gain_and_loss"] = weight_gain_and_loss
df_test["headache"] = headache
df_test["menstrual"] = menstrual
df_test["fatigue"] = fatigue
df_test["body_temperature"] = body_temperature

In [8]:
df_test

Unnamed: 0,Reviews,limb_pain,gastrointestinal,dental,cardiac,dermatological,respiratory,weight_gain_and_loss,headache,menstrual,fatigue,body_temperature
0,I am a fit 63 year old male who was surprised ...,0.87346125,0.014791279,0.005041794,0.0072805667,0.0031952278,0.005019809,0.005157012,0.004917573,0.0114999665,0.002084134,0.0034286182
1,I had my first injection in January. No side ...,0.78821015,0.017522097,0.047899466,0.010138917,0.005762535,0.0029755388,0.002707698,0.0072094025,0.0066541205,0.0040339604,0.0021250097
2,"At the age of 62, I discovered I had severe os...",0.0052116048,0.040672902,0.0060933647,0.003509785,0.0033169547,0.005264402,0.0038471883,0.004853187,0.032130107,0.20140901,0.0028962877
3,"About a week after first injection, developed ...",0.44608274,0.15374866,0.012810566,0.0047481125,0.011469235,0.0028344982,0.0044184434,0.022913624,0.039060514,0.87535805,0.009981869
4,I want to thank everyone here that rated this ...,0.015231547,0.008141291,0.0028759514,0.003921434,0.0034522698,0.0034405636,0.0028801074,0.0027956045,0.008328738,0.0035875633,0.0042090714
5,"5th injection, no side effects,now normal bone...",0.0060847104,0.005623187,0.0052013774,0.004691545,0.0050250464,0.0056666937,0.005384205,0.009908317,0.0050761653,0.012118012,0.011281631
6,This is the worse medication my husband has ev...,0.9872578,0.01142529,0.015266546,0.014431485,0.01166296,0.01800879,0.005758682,0.007977222,0.0025931848,0.9864871,0.008574757
7,Call FDA to report negative side effects 1 888...,0.01007461,0.016177002,0.005448156,0.0035813998,0.0066903564,0.035818312,0.001419228,0.0077077644,0.0137136495,0.33798015,0.057396676
8,This month I have received my 4th injection of...,0.008940042,0.0029049953,0.0031237362,0.010312338,0.0031979503,0.009327203,0.0033492548,0.0038484603,0.0065683336,0.89161366,0.021210818
9,This drug has ruined my mother's daily life. P...,0.13072574,0.0068879803,0.002961098,0.011723445,0.52697027,0.0059988596,0.0017055868,0.0044649704,0.0020288283,0.7531725,0.0052424264


In [None]:
df_test.to_csv("temp.csv")