# Citation embeddings
Exploring BERT and sciBERT to predict data citations in classification and embedding tasks

Findings: 
- sciBERT determines sentences from scientific articles to be more similar to each other than BERT does
- BERT do a better job of differentiating and clustering similar citances and indicator terms

References:

- [Multi Class Text Classification With Deep Learning Using BERT](https://github.com/susanli2016/NLP-with-Python/blob/master/Text_Classification_With_BERT.ipynb)
- [Domain-Specific BERT Models](https://mccormickml.com/2020/06/22/domain-specific-bert-tutorial/)
- [sciBERT demo](https://colab.research.google.com/drive/19loLGUDjxGKy4ulZJ1m3hALq2ozNyEGe)

In [1]:
import random
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from scipy.spatial.distance import cosine

import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers import AutoTokenizer, AutoModel
from transformers import BertTokenizer, BertModel
from transformers import BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

Set the directory (data, models on Turbo)

In [2]:
DATA_DIR = '/nfs/turbo/hrg/coleridge/'

## Compare sciBERT to BERT
Token overlap between the vocabs ~42%

In [5]:
bert_model = BertModel.from_pretrained('bert-base-uncased',
                                  output_hidden_states = True) 

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# bert_model.eval()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
scibert_model = BertModel.from_pretrained("allenai/scibert_scivocab_uncased",
                                  output_hidden_states=True)

scibert_tokenizer = BertTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")

# scibert_model.eval()

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Example biomedical text and words

In [None]:
text = "hydrogels are hydrophilic polymer networks which may absorb from 10–20% (an arbitrary lower limit) up to thousands of times their dry weight in water."
word = 'hydrogels'
words = ['polymerization', 
         '2,2-azo-isobutyronitrile',
         'multifunctional crosslinkers'
         ]

Split the sentence into tokens with both BERT and SciBERT

In [None]:
bert_tokens = bert_tokenizer.tokenize(text)
scibert_tokens = scibert_tokenizer.tokenize(text)

while len(scibert_tokens) < len(bert_tokens):
    scibert_tokens.append("")

In [9]:
print('{:<12} {:<12}'.format("BERT", "SciBERT"))
print('{:<12} {:<12}'.format("----", "-------"))

for tup in zip(bert_tokens, scibert_tokens):
    print('{:<12} {:<12}'.format(tup[0], tup[1]))

BERT         SciBERT     
----         -------     
hydro        hydrogels   
##gel        are         
##s          hydrophilic 
are          polymer     
hydro        networks    
##phi        which       
##lic        may         
polymer      absorb      
networks     from        
which        10          
may          –           
absorb       20          
from         %           
10           (           
–            an          
20           arbitrary   
%            lower       
(            limit       
an           )           
arbitrary    up          
lower        to          
limit        thousands   
)            of          
up           times       
to           their       
thousands    dry         
of           weight      
times        in          
their        water       
dry          .           
weight                   
in                       
water                    
.                        


## Semantic similarity task

In [11]:
def get_word_indeces(tokenizer, text, word):
    '''
    Determines the index or indeces of the tokens corresponding to `word`
    within `text`. `word` can consist of multiple words, e.g., "cell biology".
    
    Determining the indeces is tricky because words can be broken into multiple
    tokens. I've solved this with a rather roundabout approach--I replace `word`
    with the correct number of `[MASK]` tokens, and then find these in the 
    tokenized result. 
    '''
    # Tokenize the 'word'--it may be broken into multiple tokens or subwords.
    word_tokens = tokenizer.tokenize(word)

    # Create a sequence of `[MASK]` tokens to put in place of `word`.
    masks_str = ' '.join(['[MASK]']*len(word_tokens))

    # Replace the word with mask tokens.
    text_masked = text.replace(word, masks_str)

    # `encode` performs multiple functions:
    #   1. Tokenizes the text
    #   2. Maps the tokens to their IDs
    #   3. Adds the special [CLS] and [SEP] tokens.
    input_ids = tokenizer.encode(text_masked)

    # Use numpy's `where` function to find all indeces of the [MASK] token.
    mask_token_indeces = np.where(np.array(input_ids) == tokenizer.mask_token_id)[0]

    return mask_token_indeces

In [12]:
def get_embedding(b_model, b_tokenizer, text, word=''):
    '''
    Uses the provided model and tokenizer to produce an embedding for the
    provided `text`, and a "contextualized" embedding for `word`, if provided.
    '''

    # If a word is provided, figure out which tokens correspond to it.
    if not word == '':
        word_indeces = get_word_indeces(b_tokenizer, text, word)

    # Encode the text, adding the (required!) special tokens, and converting to
    # PyTorch tensors.
    encoded_dict = b_tokenizer.encode_plus(
                        text,                      # Sentence to encode.
                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                        return_tensors = 'pt',     # Return pytorch tensors.
                )

    input_ids = encoded_dict['input_ids']
    
    b_model.eval()

    # Run the text through the model and get the hidden states.
    bert_outputs = b_model(input_ids)
    
    # Run the text through BERT, and collect all of the hidden states produced
    # from all 12 layers. 
    with torch.no_grad():

        outputs = b_model(input_ids)

        # Evaluating the model will return a different number of objects based on 
        # how it's  configured in the `from_pretrained` call earlier. In this case, 
        # becase we set `output_hidden_states = True`, the third item will be the 
        # hidden states from all layers. See the documentation for more details:
        # https://huggingface.co/transformers/model_doc/bert.html#bertmodel
        hidden_states = outputs[2]

    # `hidden_states` has shape [13 x 1 x <sentence length> x 768]

    # Select the embeddings from the second to last layer.
    # `token_vecs` is a tensor with shape [<sent length> x 768]
    token_vecs = hidden_states[-2][0]

    # Calculate the average of all token vectors.
    sentence_embedding = torch.mean(token_vecs, dim=0)

    # Convert to numpy array.
    sentence_embedding = sentence_embedding.detach().numpy()

    # If `word` was provided, compute an embedding for those tokens.
    if not word == '':
        # Take the average of the embeddings for the tokens in `word`.
        word_embedding = torch.mean(token_vecs[word_indeces], dim=0)

        # Convert to numpy array.
        word_embedding = word_embedding.detach().numpy()
    
        return (sentence_embedding, word_embedding)
    else:
        return sentence_embedding

### Test functions

Get the embedding for the sentence, as well as an embedding for 'hydrogels'

In [13]:
(sen_emb, word_emb) = get_embedding(scibert_model, scibert_tokenizer, text, word)

print('Embedding sizes:')
print(sen_emb.shape)
print(word_emb.shape)

Embedding sizes:
(768,)
(768,)


Cosine similarity of the two embeddings

In [14]:
sim = 1 - cosine(sen_emb, word_emb)

## Sentence comparison examples

### Citance and non-citance similarity (from same paper)
- `text_query` and `text_A` are true citances but refer to different datasets
- `text_B` does not refer to a dataset

In [15]:
# Three sentences; query is more similar to A than B.
text_query = "In addition to HadISST, we have analyzed four other SST datasets: the NOAA Extended Reconstructed Sea Surface Temperature version 3b (ERSST.v3b; Smith et al. 2008) , the NOAA Optimum Interpolation Sea Surface Temperature version 2 (OISSTv2; Reynolds et al. 2002) , and the Japan Meteorological Agency Centennial in situ ObservationBased Estimates (COBE; Ishii et al. 2005) "
text_A = "The SST data in Fig. 15a is sourced from the Hadley Centre Sea Ice and Sea Surface Temperature dataset (HadISST; Rayner et al. 2003) , which incorporates satellite data, float, and ship measurements."
text_B = "The sudden increase of the relaxation temperature by 0.58C at year 0 increases the surface heat flux."

# Get embeddings for each.
emb_query = get_embedding(scibert_model, scibert_tokenizer, text_query)
emb_A = get_embedding(scibert_model, scibert_tokenizer, text_A)
emb_B = get_embedding(scibert_model, scibert_tokenizer, text_B)

# Compare query to A and B with cosine similarity.
sim_query_A = 1 - cosine(emb_query, emb_A)
sim_query_B = 1 - cosine(emb_query, emb_B)

print("'query' should be more similar to 'A' than to 'B'...\n")

print('SciBERT:')
print('  sim(query, A): {:.2}'.format(sim_query_A))
print('  sim(query, B): {:.2}'.format(sim_query_B))

# Repeat with BERT.
emb_query = get_embedding(bert_model, bert_tokenizer, text_query)
emb_A = get_embedding(bert_model, bert_tokenizer, text_A)
emb_B = get_embedding(bert_model, bert_tokenizer, text_B)

# Compare query to A and B with cosine similarity.
sim_query_A = 1 - cosine(emb_query, emb_A)
sim_query_B = 1 - cosine(emb_query, emb_B)

print('')
print('BERT:')
print('  sim(query, A): {:.2}'.format(sim_query_A))
print('  sim(query, B): {:.2}'.format(sim_query_B))

'query' should be more similar to 'A' than to 'B'...

SciBERT:
  sim(query, A): 0.96
  sim(query, B): 0.89

BERT:
  sim(query, A): 0.92
  sim(query, B): 0.73


### Non-bio/CS example
- `text_query` and `text_A` come from the same article and refer to the same dataset
- `text_B` comes from a different article and refers to a different dataset

In [46]:
# Three sentences; query is more similar to A than B.
text_query = "For example, NASS collects data from farm and ranch operations in many surveys and the Census of Agriculture (COA)."
text_A = "Also, the farm-level data collected by the Census of Agriculture is necessary to track farm transitions, such as new farmer entry (Gale, 2002) , beginning development (Ahearn & Newton, 2009) , and switching between marketing channels."
text_B = "We evaluate whether gender differences in the likelihood of obtaining a tenure track job, promotion to tenure, and promotion to full professor explain these facts using the 1973-2001 Survey of Doctorate Recipients."

# Get embeddings for each.
emb_query = get_embedding(scibert_model, scibert_tokenizer, text_query)
emb_A = get_embedding(scibert_model, scibert_tokenizer, text_A)
emb_B = get_embedding(scibert_model, scibert_tokenizer, text_B)

# Compare query to A and B with cosine similarity.
sim_query_A = 1 - cosine(emb_query, emb_A)
sim_query_B = 1 - cosine(emb_query, emb_B)

print("'query' should be more similar to 'A' than to 'B'...\n")

print('SciBERT:')
print('  sim(query, A): {:.2}'.format(sim_query_A))
print('  sim(query, B): {:.2}'.format(sim_query_B))

# Repeat with BERT.
emb_query = get_embedding(bert_model, bert_tokenizer, text_query)
emb_A = get_embedding(bert_model, bert_tokenizer, text_A)
emb_B = get_embedding(bert_model, bert_tokenizer, text_B)

# Compare query to A and B with cosine similarity.
sim_query_A = 1 - cosine(emb_query, emb_A)
sim_query_B = 1 - cosine(emb_query, emb_B)

print('')
print('BERT:')
print('  sim(query, A): {:.2}'.format(sim_query_A))
print('  sim(query, B): {:.2}'.format(sim_query_B))

'query' should be more similar to 'A' than to 'B'...

SciBERT:
  sim(query, A): 0.96
  sim(query, B): 0.93

BERT:
  sim(query, A): 0.89
  sim(query, B): 0.75


### Biomedical example
Sentences containing data citations; we would expect to see that query text and text A are more similar
- `text_query` and `text_A` come from the same article and refer to the same dataset
- `text_B` comes from a different article and refers to a different dataset

In [42]:
text_query = "The primary goal of ADNI has been to test whether serial MRI, positron emission tomography (PET), other biological markers, and clinical and neuropsychological assessment can be combined to measure the progression of MCI and early AD."
text_A = "The primary goal of ADNI has been to test whether serial MRI, PET, other biological markers, and clinical and neuropsychological assessment can be combined to measure the progression of mild cognitive impairment (MCI) and early Alzheimers disease (AD)." 
text_B = "In this study, we quantified individual species of plasma sphingomyelin and dihydrosphingomyelin in 992 individuals, aged 55 and older, enrolled in the Baltimore Longitudinal Study of Aging (BLSA)."

# Get embeddings for each.
emb_query = get_embedding(scibert_model, scibert_tokenizer, text_query)
emb_A = get_embedding(scibert_model, scibert_tokenizer, text_A)
emb_B = get_embedding(scibert_model, scibert_tokenizer, text_B)

# Compare query to A and B with cosine similarity.
sim_query_A = 1 - cosine(emb_query, emb_A)
sim_query_B = 1 - cosine(emb_query, emb_B)

print("'query' should be more similar to 'A' than to 'B'...\n")

print('SciBERT:')
print('  sim(query, A): {:.2}'.format(sim_query_A))
print('  sim(query, B): {:.2}'.format(sim_query_B))

# Repeat with BERT.
emb_query = get_embedding(bert_model, bert_tokenizer, text_query)
emb_A = get_embedding(bert_model, bert_tokenizer, text_A)
emb_B = get_embedding(bert_model, bert_tokenizer, text_B)

# Compare query to A and B with cosine similarity.
sim_query_A = 1 - cosine(emb_query, emb_A)
sim_query_B = 1 - cosine(emb_query, emb_B)

print('')
print('BERT:')
print('  sim(query, A): {:.2}'.format(sim_query_A))
print('  sim(query, B): {:.2}'.format(sim_query_B))

'query' should be more similar to 'A' than to 'B'...

SciBERT:
  sim(query, A): 0.99
  sim(query, B): 0.93

BERT:
  sim(query, A): 0.97
  sim(query, B): 0.84


## Word comparison examples

In [51]:
text = "The survey of agriculture is a large, long running, publically available data source."

print('"' + text + '"\n')

# ======== SciBERT ========

# Get contextualized embeddings for "census", "data", and "agriculture"
(emb_sen, emb_s) = get_embedding(scibert_model, scibert_tokenizer, text, word="survey")
(emb_sen, emb_d) = get_embedding(scibert_model, scibert_tokenizer, text, word="data")
(emb_sen, emb_a) = get_embedding(scibert_model, scibert_tokenizer, text, word="agriculture")

# Compare the embeddings
print('SciBERT:')
print('  sim(survey, agriculture): {:.2}'.format((1 - cosine(emb_s, emb_a))))
print('  sim(survey, data): {:.2}'.format(1 - cosine(emb_s, emb_d)))

print('')

# ======== BERT ========

# Get contextualized embeddings for "census", "data", and "agriculture"
(emb_sen, emb_s) = get_embedding(bert_model, bert_tokenizer, text, word="survey")
(emb_sen, emb_d) = get_embedding(bert_model, bert_tokenizer, text, word="data")
(emb_sen, emb_a) = get_embedding(bert_model, bert_tokenizer, text, word="agriculture")

# Compare the embeddings
print('BERT:')
print('  sim(survey, agriculture): {:.2}'.format((1 - cosine(emb_s, emb_a))))
print('  sim(survey, data): {:.2}'.format(1 - cosine(emb_s, emb_d)))

"The survey of agriculture is a large, long running, publically available data source."

SciBERT:
  sim(survey, agriculture): 0.79
  sim(survey, data): 0.76

BERT:
  sim(survey, agriculture): 0.44
  sim(survey, data): 0.7


## Classification task

Let's first see if a transformer model can predict which dataset a paper cites based on its title alone
- This classification task assumes a 1:1 relationship between papers and datasets so for this trial, let's drop duplicate papers

In [3]:
df = pd.read_csv(DATA_DIR+'train.csv')
df = df.drop_duplicates(subset=['pub_title'])
df.head()

Unnamed: 0,Id,pub_title,dataset_title,dataset_label,cleaned_label
0,d0fa7568-7d8e-4db9-870f-f9c6f668c17b,The Impact of Dual Enrollment on College Degre...,National Education Longitudinal Study,National Education Longitudinal Study,national education longitudinal study
1,2f26f645-3dec-485d-b68d-f013c9e05e60,Educational Attainment of High School Dropouts...,National Education Longitudinal Study,National Education Longitudinal Study,national education longitudinal study
2,c5d5cd2c-59de-4f29-bbb1-6a88c7b52f29,Differences in Outcomes for Female and Male St...,National Education Longitudinal Study,National Education Longitudinal Study,national education longitudinal study
3,5c9a3bc9-41ba-4574-ad71-e25c1442c8af,Stepping Stone and Option Value in a Model of ...,National Education Longitudinal Study,National Education Longitudinal Study,national education longitudinal study
4,c754dec7-c5a3-4337-9892-c02158475064,"Parental Effort, School Resources, and Student...",National Education Longitudinal Study,National Education Longitudinal Study,national education longitudinal study


The classes are highly imbalanced

In [4]:
df.dataset_title.nunique()

44

In [5]:
df['dataset_title'].value_counts()

Alzheimer's Disease Neuroimaging Initiative (ADNI)                                             3797
Baltimore Longitudinal Study of Aging (BLSA)                                                   1157
Trends in International Mathematics and Science Study                                          1113
Early Childhood Longitudinal Study                                                              949
SARS-CoV-2 genome sequence                                                                      735
Agricultural Resource Management Survey                                                         641
Census of Agriculture                                                                           592
Rural-Urban Continuum Codes                                                                     486
Survey of Earned Doctorates                                                                     412
NOAA Tide Gauge                                                                                 398


## Encode labels

In [43]:
possible_labels = df.dataset_title.unique()

label_dict = {}
for index, possible_label in enumerate(possible_labels):
    label_dict[possible_label] = index

label_dict

{'National Education Longitudinal Study': 0,
 'NOAA Tide Gauge': 1,
 'Sea, Lake, and Overland Surges from Hurricanes': 2,
 'Coastal Change Analysis Program': 3,
 'Aging Integrated Database (AGID)': 4,
 "Alzheimer's Disease Neuroimaging Initiative (ADNI)": 5,
 'Baltimore Longitudinal Study of Aging (BLSA)': 6,
 'Agricultural Resource Management Survey': 7,
 'Beginning Postsecondary Student': 8,
 "The National Institute on Aging Genetics of Alzheimer's Disease Data Storage Site (NIAGADS)": 9,
 'Common Core of Data': 10,
 'Survey of Industrial Research and Development': 11,
 'Baccalaureate and Beyond': 12,
 'International Best Track Archive for Climate Stewardship': 13,
 'National Teacher and Principal Survey': 14,
 'Higher Education Research and Development Survey': 15,
 'Survey of Earned Doctorates': 16,
 'School Survey on Crime and Safety': 17,
 'World Ocean Database': 18,
 'Program for the International Assessment of Adult Competencies': 19,
 'Early Childhood Longitudinal Study': 20,


In [44]:
df['label'] = df.dataset_title.replace(label_dict)

Unnamed: 0,Id,pub_title,dataset_title,dataset_label,cleaned_label,label
19654,f89dd9fa-07af-4384-aa0c-0d14602c0cea,Artificial Intelligence of COVID-19 Imaging: A...,RSNA International COVID-19 Open Radiology Dat...,RSNA International COVID-19 Open Radiology Dat...,rsna international covid 19 open radiology dat...,42
19656,b3498176-8832-4033-aea6-b5ea85ea04c4,RSNA International Trends: A Global Perspectiv...,RSNA International COVID-19 Open Radiology Dat...,RSNA International COVID Open Radiology Database,rsna international covid open radiology database,42
19657,f77eb51f-c3ac-420b-9586-cb187849c321,MCCS: a novel recognition pattern-based method...,CAS COVID-19 antiviral candidate compounds dat...,CAS COVID-19 antiviral candidate compounds dat...,cas covid 19 antiviral candidate compounds dat...,43
19658,ab59bcdd-7b7c-4107-93f5-0ccaf749236c,Quantitative Structure–Activity Relationship M...,CAS COVID-19 antiviral candidate compounds dat...,CAS COVID-19 antiviral candidate compounds dat...,cas covid 19 antiviral candidate compounds dat...,43
19659,fd23e7e0-a5d2-4f98-992d-9209c85153bb,A ligand-based computational drug repurposing ...,CAS COVID-19 antiviral candidate compounds dat...,CAS COVID-19 antiviral candidate compounds dat...,cas covid 19 antiviral candidate compounds dat...,43


## Train and validation split
Labels are imbalanced so stratify the split

In [48]:
X_train, X_val, y_train, y_val = train_test_split(df.index.values, 
                                                  df.label.values, 
                                                  test_size=0.15, 
                                                  random_state=42, 
                                                  stratify=df.label.values)

df['data_type'] = ['not_set']*df.shape[0]

df.loc[X_train, 'data_type'] = 'train'
df.loc[X_val, 'data_type'] = 'val'

df.groupby(['dataset_title', 'label', 'data_type']).count()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Id,pub_title,dataset_label,cleaned_label
dataset_title,label,data_type,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Advanced National Seismic System (ANSS) Comprehensive Catalog (ComCat),31,train,20,20,20,20
Advanced National Seismic System (ANSS) Comprehensive Catalog (ComCat),31,val,3,3,3,3
Aging Integrated Database (AGID),4,train,2,2,2,2
Agricultural Resource Management Survey,7,train,545,545,545,545
Agricultural Resource Management Survey,7,val,96,96,96,96
...,...,...,...,...,...,...
The National Institute on Aging Genetics of Alzheimer's Disease Data Storage Site (NIAGADS),9,val,2,2,2,2
Trends in International Mathematics and Science Study,22,train,946,946,946,946
Trends in International Mathematics and Science Study,22,val,167,167,167,167
World Ocean Database,18,train,261,261,261,261


## Tokenization and data encoding
We will use a pre-trained BERT model configuration to encode our data

In [49]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', 
                                          do_lower_case=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




In [52]:
encoded_data_train = tokenizer.batch_encode_plus(
    df[df.data_type=='train'].pub_title.values, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=256, 
#     truncation=True,
    return_tensors='pt'
)

In [53]:
encoded_data_val = tokenizer.batch_encode_plus(
    df[df.data_type=='val'].pub_title.values, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=256, 
#     truncation=True,
    return_tensors='pt'
)

In [54]:
input_ids_train = encoded_data_train['input_ids']
attention_masks_train = encoded_data_train['attention_mask']
labels_train = torch.tensor(df[df.data_type=='train'].label.values)

In [55]:
input_ids_val = encoded_data_val['input_ids']
attention_masks_val = encoded_data_val['attention_mask']
labels_val = torch.tensor(df[df.data_type=='val'].label.values)

In [56]:
dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)

In [None]:
len(dataset_train), len(dataset_val)

## BERT Pre-trained model
Treat each `pub_title` as a unique sequence so one sequence will be classified to one of the 44 `dataset_label`

- bert-base-uncased is a smaller pre-trained model;
- num_labels indicates the number of output labels;
- we don’t care about output_attentions;
- we also don’t need output_hidden_states

About the warning:
https://github.com/huggingface/transformers/issues/5421

In [63]:
# model = BertForSequenceClassification.from_pretrained("bert-base-uncased",
#                                                       num_labels=len(label_dict),
#                                                       output_attentions=False,
#                                                       output_hidden_states=False)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

## Data loader

In [60]:
batch_size = 3

dataloader_train = DataLoader(dataset_train, 
                              sampler=RandomSampler(dataset_train), 
                              batch_size=batch_size)

dataloader_validation = DataLoader(dataset_val, 
                                   sampler=SequentialSampler(dataset_val), 
                                   batch_size=batch_size)

## Optimizer and scheduler
Consider changing the number of epochs

In [64]:
optimizer = AdamW(model.parameters(),
                  lr=1e-5, 
                  eps=1e-8)
                  
epochs = 5

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=0,
                                            num_training_steps=len(dataloader_train)*epochs)

## Performance metrics

In [66]:
def f1_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return f1_score(labels_flat, preds_flat, average='weighted')

def accuracy_per_class(preds, labels):
    label_dict_inverse = {v: k for k, v in label_dict.items()}
    
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: {label_dict_inverse[label]}')
        print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')

## Training

In [72]:
seed_val = 17
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [73]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(device)

cpu


In [74]:
def evaluate(dataloader_val):

    model.eval()
    
    loss_val_total = 0
    predictions, true_vals = [], []
    
    for batch in dataloader_val:
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }

        with torch.no_grad():        
            outputs = model(**inputs)
            
        loss = outputs[0]
        logits = outputs[1]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)
    
    loss_val_avg = loss_val_total/len(dataloader_val) 
    
    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)
            
    return loss_val_avg, predictions, true_vals

In [None]:
for epoch in tqdm(range(1, epochs+1)):
    
    model.train()
    
    loss_train_total = 0

    progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    for batch in progress_bar:

        model.zero_grad()
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }       

        outputs = model(**inputs)
        
        loss = outputs[0]
        loss_train_total += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})
         
        
    torch.save(model.state_dict(), f'data_volume/finetuned_BERT_epoch_{epoch}.model')
        
    tqdm.write(f'\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(dataloader_train)            
    tqdm.write(f'Training loss: {loss_train_avg}')
    
    val_loss, predictions, true_vals = evaluate(dataloader_validation)
    val_f1 = f1_score_func(predictions, true_vals)
    tqdm.write(f'Validation loss: {val_loss}')
    tqdm.write(f'F1 Score (Weighted): {val_f1}')

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=4044.0, style=ProgressStyle(description_wid…

In [None]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased",
                                                      num_labels=len(label_dict),
                                                      output_attentions=False,
                                                      output_hidden_states=False)

model.to(device)

In [None]:
model.load_state_dict(torch.load('data_volume/finetuned_BERT_epoch_1.model', map_location=torch.device('cpu')))

In [None]:
_, predictions, true_vals = evaluate(dataloader_validation)

In [None]:
accuracy_per_class(predictions, true_vals)