In [50]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
import transformers
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,TensorDataset
# from datasets import Dataset
from torch.utils.data import Dataset 
import torch
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
import evaluate
from torch.nn import CrossEntropyLoss
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import os
import glob
import torch

In [51]:
df = pd.read_csv("data/balancednewcategory.csv")
df.head()

Unnamed: 0,link,headline,category,short_description,authors,date,PESTEL_label
0,https://www.huffingtonpost.com/entry/mortgage-...,Mortgage Deal Reached In 2008 Shows Pitfalls T...,BUSINESS,"The Obama administration, which is pushing sta...",Loren Berlin,2012-02-05,Economic
1,https://www.huffingtonpost.com/entry/women-in-...,"Women in Business: Kate O'Brien Minson, Presid...",BUSINESS,Kate has lived and breathed the therapeutic ap...,"Laura Dunn, ContributorSocial Media and Commun...",2015-04-25,Economic
2,https://www.huffingtonpost.com/entry/like-athl...,"Like Athletes, Business Owners Need to Learn F...",BUSINESS,"Business owners and top executives can also ""w...","Mary Ellen Biery, ContributorResearch Speciali...",2015-01-19,Economic
3,https://www.huffingtonpost.com/entry/donald-tr...,Trump Could Trigger The Longest Recession Sinc...,BUSINESS,Yikes.,Ben Walsh,2016-06-27,Economic
4,https://www.huffingtonpost.com/entry/grocery-c...,Grocery Chains Made A Promise To The First Lad...,BUSINESS,An AP investigation found that major grocers o...,"Mike Schneider, AP",2015-12-07,Economic


In [52]:
df["content"] = df["headline"] + " " + df["short_description"]
df = df[['PESTEL_label', 'content']]
df.head()

Unnamed: 0,PESTEL_label,content
0,Economic,Mortgage Deal Reached In 2008 Shows Pitfalls T...
1,Economic,"Women in Business: Kate O'Brien Minson, Presid..."
2,Economic,"Like Athletes, Business Owners Need to Learn F..."
3,Economic,Trump Could Trigger The Longest Recession Sinc...
4,Economic,Grocery Chains Made A Promise To The First Lad...


In [53]:
# Convert content to string and handle 'NaN' values
df['content'] = df['content'].apply(lambda x: '' if pd.isna(x) else str(x))

In [54]:
def clean_text(text):
    text = text.replace('\n', ' ').strip()
    return text

df['content'] = df['content'].apply(clean_text)

In [55]:
# test set 20%
train_val_df, test_df = train_test_split(
    df, test_size=0.2, stratify=df['PESTEL_label'], random_state=42
)

# train set 70%, val set 10%
train_df, val_df = train_test_split(
    train_val_df, test_size=0.125, stratify=train_val_df['PESTEL_label'], random_state=42
)

In [56]:
class NewsDataset(Dataset):
    def __init__(self, data, tokenizer, max_len, pestel_to_idx):
        self.df = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.pestel_to_idx = pestel_to_idx

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index] 
        content = row['content']
        label = row['PESTEL_label']

        inputs = self.tokenizer.encode_plus(
            content,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True
        )

        return {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'targets': torch.tensor(self.pestel_to_idx[label], dtype=torch.long)
        }
        
pestel_to_idx = {
    "Political": 0,
    "Economic": 1,
    "Social": 2,
    "Technological": 3,
    "Environmental": 4,
    "Legal": 5
}

In [57]:
tokenizer = transformers.DistilBertTokenizer.from_pretrained('distilbert-base-uncased')



In [58]:
distilbert_model = transformers.DistilBertModel.from_pretrained('distilbert-base-uncased')

In [59]:
class PestelClassifier(torch.nn.Module):
    def __init__(self, distilbert, num_classes):
        super(PestelClassifier, self).__init__()
        self.distilbert = distilbert
        self.dropout = torch.nn.Dropout(0.3)
        self.output = torch.nn.Linear(768, num_classes)
        
        # self.classifier = torch.nn.Sequential(
        #     torch.nn.Linear(768, 256),
        #     torch.nn.ReLU(),
        #     torch.nn.Dropout(0.3),
        #     torch.nn.Linear(256, num_classes)
        # )


    def forward(self, ids, mask):
        output = self.distilbert(ids, attention_mask=mask)
        output = self.dropout(output[0][:, 0, :])  # CLS token
        output = self.output(output)
        return output
    
        # x = self.distilbert(ids, attention_mask=mask).last_hidden_state[:, 0, :]  # CLS
        # x = self.classifier(x)
        # return x

In [60]:
num_classes = 6  # For PESTEL
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PestelClassifier(distilbert_model, num_classes)
model.to(device)

PestelClassifier(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): 

In [61]:
EPOCHS = 20
LEARNING_RATE = 1e-5 
MAX_LEN = 128 
BATCH_SIZE = 32 

train_set = NewsDataset(train_df, tokenizer, MAX_LEN, pestel_to_idx)
val_set = NewsDataset(val_df, tokenizer, MAX_LEN, pestel_to_idx)
test_set = NewsDataset(test_df, tokenizer, MAX_LEN, pestel_to_idx)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
loss_function = torch.nn.CrossEntropyLoss()

In [62]:
def train_bert_model(model, train_loader, val_loader, device, epochs, optimizer, loss_function, patience=3, checkpoint_dir=None, resume=False):
    start_epoch = 0
    best_val_loss = float('inf')
    patience_counter = 0

    # Resume from latest checkpoint if specified
    if resume and checkpoint_dir:
        latest_ckpt = find_latest_checkpoint(checkpoint_dir)
        if latest_ckpt:
            print(f"Resuming from checkpoint: {latest_ckpt}")
            model, optimizer, start_epoch = load_checkpoint(model, optimizer, latest_ckpt, device)
            start_epoch += 1

    for epoch in range(start_epoch, epochs):
        model.train()
        train_loss = 0
        train_preds = []
        train_labels = []

        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} - Training")
        for batch in train_bar:
            ids = batch['ids'].to(device)
            mask = batch['mask'].to(device)
            targets = batch['targets'].to(device)

            optimizer.zero_grad()
            outputs = model(ids, mask)
            loss = loss_function(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_preds.extend(predicted.cpu().tolist())
            train_labels.extend(targets.cpu().tolist())

            train_bar.set_postfix(loss=loss.item())

        avg_train_loss = train_loss / len(train_loader)
        train_acc = accuracy_score(train_labels, train_preds)
        train_precision = precision_score(train_labels, train_preds, average='macro', zero_division=0)
        train_recall = recall_score(train_labels, train_preds, average='macro', zero_division=0)
        train_f1 = f1_score(train_labels, train_preds, average='macro', zero_division=0)

        val_loss, val_acc, val_precision, val_recall, val_f1 = evaluate_bert_model(
            model, val_loader, loss_function, device, set_name="Validation"
        )

        print(f"Epoch {epoch+1}")
        print(f"\tTrain Loss: {avg_train_loss:.4f} | Acc: {train_acc*100:.2f}% | P: {train_precision:.4f} | R: {train_recall:.4f} | F1: {train_f1:.4f}")
        print(f"\tVal   Loss: {val_loss:.4f} | Acc: {val_acc*100:.2f}% | P: {val_precision:.4f} | R: {val_recall:.4f} | F1: {val_f1:.4f}")

        # Save checkpoint every 2 epochs
        if checkpoint_dir and (epoch + 1) % 4 == 0:
            path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pt")
            save_checkpoint(model, optimizer, epoch, path)
            print(f"Checkpoint saved to {path}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break


In [63]:
def evaluate_bert_model(model, data_loader, loss_function, device, set_name="Test"):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []
    loop = tqdm(data_loader, desc=f"{set_name} Evaluation")

    with torch.no_grad():
        for batch in loop:
            ids = batch['ids'].to(device)
            mask = batch['mask'].to(device)
            targets = batch['targets'].to(device)

            outputs = model(ids, mask)
            loss = loss_function(outputs, targets)
            total_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().tolist())
            all_targets.extend(targets.cpu().tolist())

            loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(data_loader)
    accuracy = accuracy_score(all_targets, all_preds)
    precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
    recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
    f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

    print(f"{set_name} | Loss: {avg_loss:.4f} | Acc: {accuracy*100:.2f}% | P: {precision:.4f} | R: {recall:.4f} | F1: {f1:.4f}")

    return avg_loss, accuracy, precision, recall, f1

In [64]:
def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)

def load_checkpoint(model, optimizer, path, device):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return model, optimizer, checkpoint['epoch']

def find_latest_checkpoint(checkpoint_dir):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))
    if not checkpoints:
        return None
    return max(checkpoints, key=os.path.getctime)

In [65]:
train_bert_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    epochs=EPOCHS,
    optimizer=optimizer,
    loss_function=loss_function,
    patience=3,
    checkpoint_dir="./checkpoints",
    resume=True
)

Epoch 1 - Training:   1%|          | 2/368 [00:04<14:38,  2.40s/it, loss=1.82]


KeyboardInterrupt: 

In [None]:
test_loss, test_accuracy = evaluate_bert_model(
    model=model,
    data_loader=test_loader,
    loss_function=loss_function,
    device=device,
)

print(f"\nFinal Test Accuracy: {test_accuracy * 100:.2f}% | Test Loss: {test_loss:.4f}")

Test Evaluation:   3%|▎         | 7/210 [00:03<01:54,  1.77it/s, loss=1.77]


KeyboardInterrupt: 