In [1]:
# Modified from https://towardsdatascience.com/transformers-for-multilabel-classification-71a1a0daf5e1

In [40]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import datetime
import pandas as pd
import numpy as np
import tensorflow as tf
import torch
from torch.nn import BCEWithLogitsLoss, BCELoss
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report, confusion_matrix, multilabel_confusion_matrix, f1_score, accuracy_score
import pickle
from transformers import *
from tqdm import tqdm, trange
from ast import literal_eval

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [4]:
all_df = pd.read_csv("../goog-bks-csv/b-druids-all.csv")

In [6]:
print("About all dataframe")
print(f"\tSize: {len(all_df)}")
print('\tUnique text:', all_df.text.nunique() == all_df.shape[0])
print('\tNull values: ', all_df.isnull().values.any())
print('\taverage page length ', all_df.text.str.split().str.len().mean())
print('\tstandard deviation length', all_df.text.str.split().str.len().std())

About all dataframe
	Size: 34817
	Unique text: False
	Null values:  False
	average page length  411.2618261194244
	standard deviation length 243.75536907224742


In [7]:
all_df.head()
all_df = all_df.drop_duplicates(subset=['text'])

In [8]:
len(all_df)

34758

In [9]:
print('Unique text:', all_df.text.nunique() == all_df.shape[0])
print('Null values: ', all_df.isnull().values.any())

Unique text: True
Null values:  False


In [11]:
cols = all_df.columns
label_cols = list(cols[2:])
num_labels = len(label_cols)
print("Subset of FAST Labels\n", label_cols[0:25])

Subset of FAST Labels
 ['http://id.worldcat.org/fast/870268', 'http://id.worldcat.org/fast/894932', 'http://id.worldcat.org/fast/917265', 'http://id.worldcat.org/fast/887377', 'http://id.worldcat.org/fast/897386', 'http://id.worldcat.org/fast/1204623', 'http://id.worldcat.org/fast/1205427', 'http://id.worldcat.org/fast/1065823', 'http://id.worldcat.org/fast/998323', 'http://id.worldcat.org/fast/1132103', 'http://id.worldcat.org/fast/808830', 'http://id.worldcat.org/fast/847688', 'http://id.worldcat.org/fast/1266344', 'http://id.worldcat.org/fast/1080875', 'http://id.worldcat.org/fast/1748896', 'http://id.worldcat.org/fast/1163863', 'http://id.worldcat.org/fast/901476', 'http://id.worldcat.org/fast/832383', 'http://id.worldcat.org/fast/1148536', 'http://id.worldcat.org/fast/832181', 'http://id.worldcat.org/fast/1052928', 'http://id.worldcat.org/fast/917342', 'http://id.worldcat.org/fast/891773', 'http://id.worldcat.org/fast/836873', 'http://id.worldcat.org/fast/1432076']


In [12]:
print("Count of 1 per label: \n", all_df[label_cols].sum(), "\n")

Count of 1 per label: 
 http://id.worldcat.org/fast/870268      0.0
http://id.worldcat.org/fast/894932      0.0
http://id.worldcat.org/fast/917265    397.0
http://id.worldcat.org/fast/887377      0.0
http://id.worldcat.org/fast/897386      0.0
                                      ...  
http://id.worldcat.org/fast/235269    423.0
http://id.worldcat.org/fast/185500    285.0
http://id.worldcat.org/fast/675822    285.0
http://id.worldcat.org/fast/444817    243.0
http://id.worldcat.org/fast/43064     315.0
Length: 1793, dtype: float64 



In [13]:
# Shuffles rows
all_df = all_df.sample(frac=1).reset_index(drop=True)

In [14]:
all_df.head()

Unnamed: 0,druid,text,http://id.worldcat.org/fast/870268,http://id.worldcat.org/fast/894932,http://id.worldcat.org/fast/917265,http://id.worldcat.org/fast/887377,http://id.worldcat.org/fast/897386,http://id.worldcat.org/fast/1204623,http://id.worldcat.org/fast/1205427,http://id.worldcat.org/fast/1065823,...,http://id.worldcat.org/fast/530762,http://id.worldcat.org/fast/179858,http://id.worldcat.org/fast/434309,http://id.worldcat.org/fast/1423826,http://id.worldcat.org/fast/1423871,http://id.worldcat.org/fast/235269,http://id.worldcat.org/fast/185500,http://id.worldcat.org/fast/675822,http://id.worldcat.org/fast/444817,http://id.worldcat.org/fast/43064
0,bj902ps6664,ix LIMITS OF NATURAL SELECTION IN MAN 205\npro...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,bv003kx8777,Effect of Different Food Conditions on Surviva...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,bt415kh5496,41 l 'A THEORY BY WHICH TO WORK'\nimpossible t...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,bc672mk0535,"THE JOURNAL OF BIOLOGICAL CHEMISTRY\nVol. 279,...",0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,bg681dz7650,Coastal Taipan\nThe Coastal Taipan is Australi...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [15]:
all_df['one_hot_labels'] = list(all_df[label_cols].values)

In [17]:
all_df.head()

Unnamed: 0,druid,text,http://id.worldcat.org/fast/870268,http://id.worldcat.org/fast/894932,http://id.worldcat.org/fast/917265,http://id.worldcat.org/fast/887377,http://id.worldcat.org/fast/897386,http://id.worldcat.org/fast/1204623,http://id.worldcat.org/fast/1205427,http://id.worldcat.org/fast/1065823,...,http://id.worldcat.org/fast/179858,http://id.worldcat.org/fast/434309,http://id.worldcat.org/fast/1423826,http://id.worldcat.org/fast/1423871,http://id.worldcat.org/fast/235269,http://id.worldcat.org/fast/185500,http://id.worldcat.org/fast/675822,http://id.worldcat.org/fast/444817,http://id.worldcat.org/fast/43064,one_hot_labels
0,bj902ps6664,ix LIMITS OF NATURAL SELECTION IN MAN 205\npro...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,bv003kx8777,Effect of Different Food Conditions on Surviva...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,bt415kh5496,41 l 'A THEORY BY WHICH TO WORK'\nimpossible t...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,bc672mk0535,"THE JOURNAL OF BIOLOGICAL CHEMISTRY\nVol. 279,...",0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
4,bg681dz7650,Coastal Taipan\nThe Coastal Taipan is Australi...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [20]:
labels = list(all_df.one_hot_labels.values)
pages = list(all_df.text.values)

In [22]:
# Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

# Tokenizer's Encoding Method
encodings = tokenizer.batch_encode_plus(pages, max_length=512, pad_to_max_length=True, truncation=True)

In [23]:
print("tokenizer outputs", encodings.keys())

tokenizer outputs dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])


In [24]:
# tokenized and encoded sentences
input_ids = encodings['input_ids']

# token type ids
token_type_ids = encodings['token_type_ids']

# attention masks
attention_masks = encodings['attention_mask']

In [25]:
label_counts = all_df.one_hot_labels.astype(str).value_counts()

In [26]:
one_freq = label_counts[label_counts==1].keys()
one_freq_idxs = sorted(list(all_df[all_df.one_hot_labels.astype(str).isin(one_freq)].index), reverse=True)

In [27]:
one_freq_idxs

[]

In [28]:
train_inputs, validation_inputs, train_labels, validation_labels, train_token_types, validation_token_types, train_masks, validation_masks = train_test_split(input_ids, labels, token_type_ids, attention_masks, random_state=2020, test_size=0.10, stratify=labels)

In [29]:
# Convert all training data into PyTorch tensors
train_inputs = torch.tensor(train_inputs)
train_labels = torch.tensor(train_labels)
train_masks = torch.tensor(train_masks)
train_token_types = torch.tensor(train_token_types)

In [30]:
# Convert all validation data into PyTorch tensors
validation_inputs = torch.tensor(validation_inputs)
validation_labels = torch.tensor(validation_labels)
validation_masks = torch.tensor(validation_masks)
validation_token_types = torch.tensor(validation_token_types)

In [32]:
batch_size = 32

# Creates a torch DataLoader iterator
train_data = TensorDataset(train_inputs, train_masks, train_labels, train_token_types)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels, validation_token_types)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)

In [34]:
torch.save(validation_dataloader,'data/validation_data_loader')
torch.save(train_dataloader,'data/train_data_loader')

In [35]:
# Load BERT Model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [36]:
optimizer = AdamW(model.parameters(), lr=2e-5)

In [42]:
train_loss_set = []
epochs = 3
start = datetime.datetime.utcnow()

print(f"Started at {start} for {epochs} epochs")
for _ in trange(epochs, desc="Epoch"):
    model.train()
    
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    
    for step, batch in enumerate(train_dataloader):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels, b_token_types = batch
        optimizer.zero_grad()
        
        outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
        logits = outputs[0]
        loss_func = BCEWithLogitsLoss()
        
        loss = loss_func(logits.view(-1,num_labels),b_labels.type_as(logits).view(-1,num_labels))
        
        train_loss_set.append(loss.item())
        
        loss.backward()
        
        optimizer.step()
        
        tr_loss += loss.item()
        nb_tr_examples += b_input_ids.size(0)
        nb_tr_steps += 1
        
    finished_training = datetime.datetime.utcnow()
    print(f"Train loss: {(tr_loss/nb_tr_steps)} finished: {finished_training}, elapsed time {(finished_training-start).seconds /60.} mins")
    
    model.eval()
    
    logit_preds,true_labels,pred_labels,tokenized_texts = [],[],[],[]

    for i, batch in enumerate(validation_dataloader):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels, b_token_types = batch
        
    with torch.no_grad():
        outs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
        b_logit_pred = outs[0]
        pred_label = torch.sigmoid(b_logit_pred)
        b_logit_pred = b_logit_pred.detach().cpu().numpy()
        pred_label = pred_label.to('cpu').numpy()
        b_labels = b_labels.to('cpu').numpy()
        tokenized_texts.append(b_input_ids)
        logit_preds.append(b_logit_pred)
        true_labels.append(b_labels)
        pred_labels.append(pred_label)
        
    pred_labels = [item for sublist in pred_labels for item in sublist]
    true_labels = [item for sublist in true_labels for item in sublist]
    
    threshold = 0.50
    pred_bools = [pl>threshold for pl in pred_labels]
    true_bools = [tl==1 for tl in true_labels]
    val_f1_accuracy = f1_score(true_bools,pred_bools,average='micro')*100
    val_flat_accuracy = accuracy_score(true_bools, pred_bools)*100
    
    print('F1 Validation Accuracy: ', val_f1_accuracy)
    print('Flat Validation Accuracy: ', val_flat_accuracy)
    finished_validation = datetime.datetime.utcnow()
    print(f"Finished validation at {finished_validation} elapsed time {(finished_validation-start).seconds / 60.} mins")
    
end = datetime.datetime.utcnow()
print(f"Finished all training at {end}, total time {(end-start).seconds / 60.} mins")

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Started at 2020-10-27 20:15:27.741590 for 3 epochs


Epoch:   0%|          | 0/3 [20:17:23<?, ?it/s]


KeyboardInterrupt: 