<a href="https://colab.research.google.com/github/suhas-chowdary/20newsgroup/blob/main/distilbert_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [42]:
import os
import re
import collections
import timeit
import torch
import pandas as pd
import pkbar
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
from sklearn.metrics import  f1_score,classification_report


In [None]:
# Uses GPU if available
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

In [None]:
# If running on google colab: Uncomment below code and install transformers and pkbar libraries.

# !pip install transformers
# !pip install pkbar
from transformers import DistilBertConfig,DistilBertTokenizer,DistilBertModel

In [None]:
# If running on google colab:  Upload 'data.zip' file present in git repository to colab, uncomment below code and run it.
# This mounts the data to google colab.
# !unzip data.zip

### Load data

In [4]:
# Path to documents.
dir_path = os.getcwd()
rel_path = "data"
data_path = os.path.join(dir_path, rel_path)
news_groups = [f for f in os.listdir(data_path)]
news_group_idx = {v:i for i,v in enumerate(news_groups)}


In [5]:
# Preprocess data: Remove all special characters, convert to lower case.
def data_preprocess(cur):
        cur = cur.lower()
        cur = re.sub(r'[\w\.-]+@[\w\.-]+',' ',cur)
        cur = re.sub("[^a-zA-Z,.']", ' ', cur)
        cur = re.sub(r'\.{2,}',' ',cur)   
        cur = re.sub('\s+',' ',cur)
        cur = " ".join(cur.split())
        return cur

In [6]:
# Prepare dataset: Read data from docs. 
X = []
y = []

for category in news_groups:
    cls = []
    f_path = os.path.join(data_path,category)
    # Read data from all docs.
    for files in os.listdir(f_path):
        path = os.path.join(f_path,files)
        with open(path,'r',errors='ignore',encoding="utf8") as file:
            cur_doc = data_preprocess(file.read().replace('\n',' '))
            X.append(cur_doc)
        y.append(int(news_group_idx[category]))

In [7]:
# train test split
X_train,X_test,y_train,y_test= train_test_split(X,y,stratify=y,test_size=0.2, 
                                                random_state=9)
print('number of training samples:', len(X_train))
print('number of test samples:', len(X_test))

number of training samples: 16333
number of test samples: 4084


In [8]:
train_df = pd.DataFrame({'doc':X_train,
                         'labels':y_train})
test_df = pd.DataFrame({'doc':X_test,
                         'labels':y_test})

In [34]:
# Distil-bert model parameters
MAX_LEN = 512
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 8
EPOCHS = 8
LEARNING_RATE = 1e-05
num_classes = len(news_groups)
num_of_batches_per_epoch = len(X_train)//TRAIN_BATCH_SIZE
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

### Convert data into bert consumable format.

BERT expects input data in a specific format, with special tokens to mark the beginning ([CLS]) and separation/end of sentences ([SEP]). 

We need to tokenize our text into tokens that correspond to BERT’s vocabulary. Bert uses word piece tokenizer and has a vocabulary size of ~30k words.

For each tokenized sentence, Bert requires: 

a. Input ids, a sequence of integers identifying each input token to its index number in the BERT tokenizer vocabulary.

b. Mask ids, a sequence of integers identifying each masked location as bert uses Masked Language Model(MLM).

c. Corresonding targets.


In [10]:
class BertDataFormat(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __getitem__(self, index):
        cur_doc = str(self.data.doc[index])
        cur_doc = " ".join(cur_doc.split())
        inputs = self.tokenizer.encode_plus(
            cur_doc,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'targets': torch.tensor(self.data.labels[index], dtype=torch.long)
        } 
    
    def __len__(self):
        return self.len

training_set = BertDataFormat(train_df, tokenizer, MAX_LEN)
testing_set = BertDataFormat(test_df, tokenizer, MAX_LEN)

In [11]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': False,
                'num_workers': 0
                }

training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

### Baseline distil-bert model.

For baseline model, I used default pretrained distil-bert model from hugging face and tried to classify the data without any fine tuning.

In [30]:
# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model. 

class DistillBERTClass(torch.nn.Module):
    def __init__(self,num_classes):
        super(DistillBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.classifier = torch.nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        bert_last = hidden_state[:, 0]
        output = self.classifier(bert_last)
        return output

In [None]:
# Copy model to device.
baseline_model = DistillBERTClass(num_classes)
baseline_model.to(device)

In [32]:
# Create the loss function and optimizer
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params =  baseline_model.parameters(), lr=LEARNING_RATE)

In [12]:
# Calcuate accuracy of the model
def acc_cal(big_idx, targets):
    n_correct = (big_idx==targets).sum().item()
    return n_correct

In [13]:
# train model
def train(epoch,model):
    tr_loss = 0
    n_correct = 0
    nb_tr_steps = 0
    nb_tr_examples = 0
    model.train()

    # progress bar
    train_per_epoch = num_of_batches_per_epoch
    kbar = pkbar.Kbar(target=train_per_epoch, epoch=epoch, 
                      num_epochs=EPOCHS, width=8, 
                      always_stateful=False)

    for idx,data in enumerate(training_loader, 0):

        # copy tensors to gpu
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.long)

        # get output and calculate loss.
        outputs = model(ids, mask)
        loss = loss_function(outputs, targets)
        tr_loss += loss.item()
        big_val, big_idx = torch.max(outputs.data, dim=1)
        n_correct += acc_cal(big_idx, targets)

        nb_tr_steps += 1
        nb_tr_examples+=targets.size(0)
      
        optimizer.zero_grad()
        loss.backward()
        # # When using GPU
        optimizer.step()
        kbar.update(idx, values=[("train_loss", tr_loss/(idx+1))])


    epoch_loss = tr_loss/nb_tr_steps
    epoch_accu = (n_correct*100)/nb_tr_examples
    print(f"Training Loss Epoch: {epoch_loss}")
    print(f"Training Accuracy Epoch: {epoch_accu}")

    return

In [None]:
# function to predict output.
def valid(model, testing_loader):
    predicted_labels = []
    true_labels = []
    nb_tr_steps = 0
    tr_loss =0
    nb_tr_examples=0
    model.eval()
    n_correct = 0; n_wrong = 0; total = 0
    with torch.no_grad():
        for _, data in enumerate(testing_loader, 0):

            # copy tensors to gpu.
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.long)
            outputs = model(ids, mask).squeeze()

            # calculate loss
            loss = loss_function(outputs, targets)
            tr_loss += loss.item()
            big_val, big_idx = torch.max(outputs.data, dim=1)
            predicted_labels += big_idx
            true_labels += targets
            
            n_correct += acc_cal(big_idx, targets)

            nb_tr_steps += 1
            nb_tr_examples+=targets.size(0)
            
    
    epoch_loss = tr_loss/nb_tr_steps
    epoch_accu = (n_correct*100)/nb_tr_examples
    return epoch_accu,epoch_loss,predicted_labels,true_labels

In [35]:
for epoch in range(EPOCHS):
    train(epoch,baseline_model)
    print('\n')

Epoch: 1/8
Training Loss Epoch: 1.151628428075853
Training Accuracy Epoch: 65.61562480866957


Epoch: 2/8
Training Loss Epoch: 0.6041948671665236
Training Accuracy Epoch: 80.38327312802302


Epoch: 3/8
Training Loss Epoch: 0.4213431019967205
Training Accuracy Epoch: 86.64666625849507


Epoch: 4/8
Training Loss Epoch: 0.3093200043631931
Training Accuracy Epoch: 90.06918508540991


Epoch: 5/8
Training Loss Epoch: 0.2327240092121612
Training Accuracy Epoch: 92.38351803098023


Epoch: 6/8
Training Loss Epoch: 0.18626880633346482
Training Accuracy Epoch: 93.61415539092634


Epoch: 7/8
Training Loss Epoch: 0.15773281787798524
Training Accuracy Epoch: 94.36110941039612


Epoch: 8/8
Training Loss Epoch: 0.1369942877634112
Training Accuracy Epoch: 94.93663135982366




In [37]:
acc,loss,predicted_labels,true_labels = valid(baseline_model, testing_loader)
print("test accuracy on baseline distilbert model =",round(acc,2))

test accuracy on baseline distilbert model = 83.15


In [40]:
predicted_labels = [i.item() for i in predicted_labels]
true_labels = [i.item() for i in true_labels]
baseline_f1 = f1_score(true_labels,predicted_labels,average='macro')
print("F1 score on baseline model = ",baseline_f1)

F1 score on baseline model =  0.8346094307162266


Classification report of baseline bert model.

In [41]:
print(classification_report(true_labels,predicted_labels,target_names=news_groups))

                          precision    recall  f1-score   support

             alt.atheism       0.72      0.64      0.68       203
            misc.forsale       0.87      0.77      0.82       202
      talk.religion.misc       0.53      0.55      0.54       225
                 sci.med       0.88      0.94      0.91       200
           comp.graphics       0.88      0.70      0.78       202
      talk.politics.guns       0.83      0.81      0.82       211
      talk.politics.misc       0.66      0.68      0.67       234
        rec.sport.hockey       0.98      0.98      0.98       200
          comp.windows.x       0.86      0.91      0.89       200
               sci.crypt       0.92      0.92      0.92       200
  soc.religion.christian       0.89      0.86      0.88       199
   talk.politics.mideast       0.85      0.86      0.85       201
 comp.os.ms-windows.misc       0.72      0.87      0.79       200
comp.sys.ibm.pc.hardware       0.76      0.77      0.76       203
         

### Model 2: 

I am inspired by DocBERT: BERT for Document Classification architecture developed by [Adhikari et al., 2019](https://arxiv.org/pdf/1904.08398.pdf) which tries to distil knowledge from bert to LSTMs which are proven to be effective due to long text format of documents. DocBert is current state of the art solution for document classification on Multiple datasets like Reuters dataset. So I implemented DocBert architecture and tweaked it on 20 Newsgroup dataset.

In [15]:

class DocBERTClass(torch.nn.Module):
    def __init__(self):
        super(DocBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.dimension = 64
        #self.lstm = nn.LSTM(input_size=768, hidden_size=dimension,num_layers=1,
        #                    bidirectional=True)
        hidden_dim = 64 
        embed_dim = 768

        self.lstm  = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.pre_classifier = torch.nn.Linear(64, 64)
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(64, 20)

    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        bert_final = hidden_state[:, 0]
        bert_final = bert_final.squeeze().unsqueeze(dim=0)
        lstm1, (h, c) = self.lstm(bert_final)
        #print(lstm1.shape)
        lstm1 = lstm1.view((lstm1.shape)[1],64)
        linear1 = self.pre_classifier(lstm1)
        linear1 = torch.nn.ReLU()(linear1)
        linear1 = self.dropout(linear1)
        #print(linear1.shape)
        output = self.classifier(linear1)
        #print(output.shape)
        return output

In [16]:
# Copy model to device.
docbert_model = DocBERTClass()
docbert_model.to(device)

DocBERTClass(
  (l1): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_features

In [17]:
# Create the loss function and optimizer
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params =  docbert_model.parameters(), lr=LEARNING_RATE)

In [21]:
print('DocBert Model')
EPOCHS = 13
for epoch in range(EPOCHS):
    train(epoch,docbert_model)
    print('\n')


DocBert Model
Epoch: 1/13
Training Loss Epoch: 2.024186410352836
Training Accuracy Epoch: 68.86671156554215


Epoch: 2/13
Training Loss Epoch: 1.7390003384152075
Training Accuracy Epoch: 76.64238045674402


Epoch: 3/13
Training Loss Epoch: 1.495943967205537
Training Accuracy Epoch: 81.17308516500337


Epoch: 4/13
Training Loss Epoch: 1.2848184296129732
Training Accuracy Epoch: 84.26498499969387


Epoch: 5/13
Training Loss Epoch: 1.098104508044554
Training Accuracy Epoch: 86.71401457172595


Epoch: 6/13
Training Loss Epoch: 0.945668080994945
Training Accuracy Epoch: 88.13445172350455


Epoch: 7/13
Training Loss Epoch: 0.816197077803
Training Accuracy Epoch: 89.67734035388477


Epoch: 8/13
Training Loss Epoch: 0.7088179806946073
Training Accuracy Epoch: 90.96308087920161


Epoch: 9/13
Training Loss Epoch: 0.6139628174664575
Training Accuracy Epoch: 91.7712606379722


Epoch: 10/13
Training Loss Epoch: 0.5396569252276864
Training Accuracy Epoch: 92.51821465744199


Epoch: 11/13
Training Lo

In [24]:
acc,loss,predicted_labels,true_labels = valid(docbert_model, testing_loader)
print("test accuracy on DocBert model =",acc, '%')

test accuracy on DocBert model = 88.54064642507346 %


In [26]:
predicted_labels = [i.item() for i in predicted_labels]
true_labels = [i.item() for i in true_labels]
docbert_f1 = f1_score(true_labels,predicted_labels,average='macro')
print("F1 score on docbert model = ",docbert_f1)

F1 score on docbert model =  0.8970506268868808


### Classification report on DocBert architecture.

In [25]:
print(classification_report(true_labels,predicted_labels,target_names=news_groups))

                          precision    recall  f1-score   support

             alt.atheism       0.84      0.63      0.72       203
            misc.forsale       0.88      0.97      0.92       202
      talk.religion.misc       0.66      0.60      0.62       225
                 sci.med       0.95      0.97      0.96       200
           comp.graphics       0.91      0.85      0.88       202
      talk.politics.guns       0.81      0.90      0.85       211
      talk.politics.misc       0.74      0.75      0.75       234
        rec.sport.hockey       1.00      0.99      1.00       200
          comp.windows.x       0.95      0.94      0.95       200
               sci.crypt       0.98      0.95      0.97       200
  soc.religion.christian       0.86      0.98      0.92       199
   talk.politics.mideast       0.87      0.95      0.91       201
 comp.os.ms-windows.misc       0.89      0.87      0.88       200
comp.sys.ibm.pc.hardware       0.87      0.78      0.82       203
         