<a href="https://colab.research.google.com/github/sreejithvn/zero-shot-classification-for-long-text/blob/main/1_3_Zero_shot_BART.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers

In [None]:
import pandas as pd
import numpy as np

from sklearn.metrics import accuracy_score, f1_score

In [None]:
# mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
df = pd.read_json('/content/gdrive/MyDrive/Colab Notebooks/MSC_Project/Jan2020Frontiers_20_labels.jsonl',lines=True)

In [None]:
df['label'].value_counts()

Physiology                             105
Genetics                                99
Neuroscience                            89
Psychiatry                              86
Neurology                               76
Chemistry                               69
Marine Science                          64
Bioengineering and Biotechnology        56
Endocrinology                           53
Cell and Developmental Biology          47
Cellular and Infection Microbiology     46
Veterinary Science                      45
Medicine                                44
Pediatrics                              43
Physics                                 35
Ecology and Evolution                   34
Public Health                           29
Aging Neuroscience                      29
Earth Science                           26
Cellular Neuroscience                   26
Name: label, dtype: int64

In [None]:
candidate_labels = list(df.label.unique())

In [None]:
len(df), len(candidate_labels)

(1101, 20)

# SPLITTING DATA into TRAIN, VALIDATION and TEST sets

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
train_texts, temp_texts, train_labels, temp_labels = train_test_split(df.text, df.label, test_size=0.2, 
                                                                      random_state=42, stratify=df.label, shuffle=True)

In [None]:
val_texts, test_texts, val_labels, test_labels = train_test_split(temp_texts, temp_labels, test_size=0.5, 
                                                                  random_state=42, stratify=temp_labels, shuffle=True)

In [None]:
# reset_indices
train_texts.reset_index(drop=True, inplace=True), train_labels.reset_index(drop=True, inplace=True)
val_texts.reset_index(drop=True, inplace=True), val_labels.reset_index(drop=True, inplace=True)
test_texts.reset_index(drop=True, inplace=True), test_labels.reset_index(drop=True, inplace=True)

(None, None)

In [None]:
test_texts

0       respiratory morbidity and lung function analy...
1        flavor techniques for lfv processes: higgs d...
2       corrigendum: human milk oligosaccharide compo...
3      obsessive–compulsive personality symptoms pred...
4       blood-brain barrier and delivery of protein a...
                             ...                        
106      synergies between division of labor and gut ...
107     efficient and stable photocatalytic hydrogen ...
108     the δ-opioid receptor differentially regulate...
109      thalidomide in the treatment of sweet's synd...
110     investigating gray and white matter structura...
Name: text, Length: 111, dtype: object

In [None]:
train_texts.shape, val_texts.shape, test_texts.shape

((880,), (110,), (111,))

In [None]:
# Each set contains samples from all classes
len(test_labels.unique()), len(train_labels.unique()), len(train_labels.unique())

(20, 20, 20)

Each set is a representative sample with equal distribution for all classes

In [None]:
pd.DataFrame([train_labels.value_counts(), val_labels.value_counts(), test_labels.value_counts()], 
             index=['Train', 'Val', 'Test']).T

Unnamed: 0,Train,Val,Test
Physiology,84,10,11
Genetics,79,10,10
Neuroscience,71,9,9
Psychiatry,69,8,9
Neurology,61,7,8
Chemistry,55,7,7
Marine Science,51,6,7
Bioengineering and Biotechnology,45,5,6
Endocrinology,42,5,6
Cell and Developmental Biology,38,5,4


In [None]:
from transformers import pipeline

classifier_gpu = pipeline("zero-shot-classification", device=0, model="facebook/bart-large-mnli") # to utilize GPU

Downloading config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading vocab.json:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')

In [None]:
def split_sequence(sequence):
  tokens = tokenizer.tokenize(sequence)
  chunks = [tokens[x:x+256] for x in range(0, len(tokens), 256)]
  return [' '.join(x).replace('##', '') for x in chunks]

In [None]:
test_data_split = test_texts.apply(split_sequence)

In [None]:
pd.DataFrame(test_data_split)

In [None]:
import time

# Classifier only considers FIRST 512 tokens (sentence) of each TEST sample

In [None]:
# Classifier only considers FIRST 512 (sentence) of each TEST sample (not split)  (2 min runtime)

test_sequences = list(test_texts)    # NOT SPLIT TEST DATA

candidate_labels = list(df.label.unique())

results = classifier_gpu(test_sequences, candidate_labels, batch_size=1) # Sequence is a list of list of sentences, but classifier only takes first sentence from inner list
# print(results)

scores_df = pd.DataFrame(results, columns=['labels', 'scores'])

In [None]:
pred_labels_512 = scores_df['labels'].apply(lambda x: x[0])

pred_labels_512

0                                Neurology
1      Cellular and Infection Microbiology
2                               Pediatrics
3                    Ecology and Evolution
4                                Neurology
                      ...                 
106    Cellular and Infection Microbiology
107    Cellular and Infection Microbiology
108                  Cellular Neuroscience
109                             Physiology
110    Cellular and Infection Microbiology
Name: labels, Length: 111, dtype: object

### Accuracy score for Test Data, with default truncation to 512

In [None]:
# accuracy_score_test_data_512_old = np.sum(pred_labels_512 == test_labels) / len(test_labels)

accuracy_test_data_512 = accuracy_score(test_labels, pred_labels_512)
f1_score_test_data_512 = f1_score(test_labels, pred_labels_512, average='macro')
print(f'Test data first 512 -> Accuracy: {accuracy_test_data_512*100:.2f}, F1_score: {f1_score_test_data_512*100:.2f}')

Test data first 512 -> Accuracy: 9.01, F1_score: 7.76


# For Entire Test Dataset

In [None]:
# For Entire long_text samples with default batch_size=64

candidate_labels = list(df.label.unique())

pred_labels_count_list = []
pred_labels_prob_list = []


start = time.perf_counter()

for ix, long_text in enumerate(test_data_split):

  start_time = time.perf_counter()

  print(f'Sample {ix}: True Label: {test_labels[ix]}')

  prob_score_dict = dict.fromkeys(candidate_labels, 0)

  result = classifier_gpu(long_text, candidate_labels, batch_size=8)
  
  for index in range(len(long_text)):
    # Create a temporary dict for storing probability scores corresponding to each label, for each chunk
    prob_score_chunk = dict(zip(result[index]['labels'], result[index]['scores']))
    for label in prob_score_dict:
      # Add and update probability score received for each chunk, to get overall score for the entire text sample
      prob_score_dict[label] += prob_score_chunk[label]

  max_prob_label = max(prob_score_dict, key=prob_score_dict.get)

  print('Predicted label based on highest probabilty score:', max_prob_label)
  
  score_df = pd.DataFrame(result, columns=['labels', 'scores'])

  # Get the labels with highest score(one at index '0') (for each sub-sentence)
  # Then get the label occuring first, the most time (chosen as the label for the main sentence)
  most_count_label = score_df['labels'].apply(lambda x: x[0]).value_counts().index[0] 
  
  print('Predicted label based on most first occurence count:', most_count_label)
  
  pred_labels_count_list.append(most_count_label)
  pred_labels_prob_list.append(max_prob_label)

  stop_time = time.perf_counter() 
  # print(f'Run time for sample {ix}: {(stop_time - start_time):.2f}')

stop = time.perf_counter()
runtime = stop-start
# print(f'Total run time: {runtime:.2f}')

Sample 0: True Label: Pediatrics
Predicted label based on highest probabilty score: Pediatrics
Predicted label based on most first occurence count: Pediatrics
Sample 1: True Label: Physics
Predicted label based on highest probabilty score: Physics
Predicted label based on most first occurence count: Physics
Sample 2: True Label: Pediatrics
Predicted label based on highest probabilty score: Pediatrics
Predicted label based on most first occurence count: Pediatrics
Sample 3: True Label: Psychiatry
Predicted label based on highest probabilty score: Psychiatry
Predicted label based on most first occurence count: Psychiatry
Sample 4: True Label: Aging Neuroscience
Predicted label based on highest probabilty score: Bioengineering and Biotechnology
Predicted label based on most first occurence count: Bioengineering and Biotechnology
Sample 5: True Label: Pediatrics
Predicted label based on highest probabilty score: Genetics
Predicted label based on most first occurence count: Genetics
Sample 



Predicted label based on highest probabilty score: Psychiatry
Predicted label based on most first occurence count: Pediatrics
Sample 10: True Label: Genetics
Predicted label based on highest probabilty score: Genetics
Predicted label based on most first occurence count: Genetics
Sample 11: True Label: Marine Science
Predicted label based on highest probabilty score: Chemistry
Predicted label based on most first occurence count: Chemistry
Sample 12: True Label: Neuroscience
Predicted label based on highest probabilty score: Neuroscience
Predicted label based on most first occurence count: Neuroscience
Sample 13: True Label: Physiology
Predicted label based on highest probabilty score: Physics
Predicted label based on most first occurence count: Public Health
Sample 14: True Label: Public Health
Predicted label based on highest probabilty score: Public Health
Predicted label based on most first occurence count: Public Health
Sample 15: True Label: Marine Science
Predicted label based on 

# Accuracy and F1 Score results for TEST Data

In [None]:
# Results for TEST Dataset with only first 512
accuracy_test_data_512 = accuracy_score(test_labels, pred_labels_512)
f1_score_test_data_512 = f1_score(test_labels, pred_labels_512, average='macro')
print(f'Test data only first 512 -> Accuracy: {accuracy_test_data_512*100:.2f}, F1_score: {f1_score_test_data_512*100:.2f}')


# pred_labels_count_list  # For ENTIRE long_text DATA
accuracy_count = accuracy_score(test_labels, pred_labels_count_list)
f1_score_count = f1_score(test_labels, pred_labels_count_list, average='macro')
print(f'Test data long text (most first position count) -> Accuracy: {accuracy_count*100:.2f}, F1_score: {f1_score_count*100:.2f}')


# pred_labels_count_list  # For ENTIRE long_text DATA

accuracy_prob = accuracy_score(test_labels, pred_labels_prob_list)
f1_score_prob = f1_score(test_labels, pred_labels_prob_list, average='macro')
print(f'Test data long text (highest probability sum) -> Accuracy: {accuracy_prob*100:.2f}, F1_score: {f1_score_prob*100:.2f}')


Test data only first 512 -> Accuracy: 9.01, F1_score: 7.76
Test data long text (most first position count) -> Accuracy: 38.74, F1_score: 31.55
Test data long text (highest probability sum) -> Accuracy: 46.85, F1_score: 39.42


In [None]:
BART_zero_shot_metrics_table_df = pd.DataFrame(data=([np.round(accuracy_test_data_512*100,2),
                    np.round(f1_score_test_data_512*100,2)],
                   [np.round(accuracy_count*100,2),
                    np.round(f1_score_count*100,2)],
                   [np.round(accuracy_prob*100,2),
                    np.round(f1_score_prob*100,2)]
                   ), columns=['Accuracy', 'F1_score'], 
                   index=['Only first 512', 'Long text (most first position count)', 'Long text (highest probability sum)'])

BART_zero_shot_metrics_table_df

Unnamed: 0,Accuracy,F1_score
Only first 512,9.01,7.76
Long text (most first position count),38.74,31.55
Long text (highest probability sum),46.85,39.42
