<a href="https://colab.research.google.com/github/rupav02/AI_Project/blob/main/POS_Tagger.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

```Notes on Data:```
- Format is "word/TAG"
- Single newline indicates new Sentence
- Double newline indicates new Paragraph

```Notes on Model:```
- Must extend or truncate sentences to be of the same size.
- Special [CLS] token used at start of all sentences for classification.
- Special [SEP] token added to end of all sentences.
- Special [PAD] token used to pad sentences.
- Must use model's own Tokenizer


In [1]:
import os
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
import torch.optim as optim
from transformers import BertModel, BertConfig, BertForTokenClassification
from torch.utils.data import DataLoader, Dataset
import ssl
import nltk
from nltk.tree import Tree
from nltk.corpus import treebank
import urllib.request
import requests
from bs4 import BeautifulSoup

#### DO NOT REDISTRIBUTE THE WSJ DATA.
It is licensed by the Linguistic Data Consortium (LDC) and is not free to distribute.

In [None]:
!unzip wsj.zip

Archive:  wsj.zip
   creating: wsj/
   creating: wsj/11/
  inflating: wsj/11/wsj_1138.mrg     
  inflating: wsj/11/wsj_1130.mrg     
  inflating: wsj/11/wsj_1197.mrg     
  inflating: wsj/11/wsj_1178.mrg     
  inflating: wsj/11/wsj_1161.mrg     
  inflating: wsj/11/wsj_1167.mrg     
  inflating: wsj/11/wsj_1175.mrg     
  inflating: wsj/11/wsj_1101.mrg     
  inflating: wsj/11/wsj_1110.mrg     
  inflating: wsj/11/wsj_1119.mrg     
  inflating: wsj/11/wsj_1176.mrg     
  inflating: wsj/11/wsj_1168.mrg     
  inflating: wsj/11/wsj_1174.mrg     
  inflating: wsj/11/wsj_1142.mrg     
  inflating: wsj/11/wsj_1190.mrg     
  inflating: wsj/11/wsj_1170.mrg     
  inflating: wsj/11/wsj_1124.mrg     
  inflating: wsj/11/wsj_1129.mrg     
  inflating: wsj/11/wsj_1154.mrg     
  inflating: wsj/11/wsj_1173.mrg     
  inflating: wsj/11/wsj_1148.mrg     
  inflating: wsj/11/wsj_1135.mrg     
  inflating: wsj/11/wsj_1162.mrg     
  inflating: wsj/11/wsj_1132.mrg     
  inflating: wsj/11/wsj_1181.mr

In [None]:
def parse_trees(file_path):
    with open(file_path, "r") as file:
        data = file.read()

    trees = []
    current_tree = []
    open_parens = 0

    for line in data.splitlines():
        line = line.strip()
        if not line:
            continue

        current_tree.append(line)
        open_parens += line.count("(") - line.count(")")

        # If open_parens reaches zero, we've completed a tree
        if open_parens == 0:
            tree_string = " ".join(current_tree)
            trees.append(tree_string)
            current_tree = []

    return trees

In [None]:
# wsj/00-18 is the train set
# wsj/19-21 is the dev set
# wsj/22-24 is the test set

# separate the trees into train, dev, and test sets

def collect_trees(range_limits, target_list):
    for i in range(*range_limits):
        directory = f"wsj/{i:02}"
        files = os.listdir(directory)
        for file in files:
            target_list.extend(parse_trees(f"{directory}/{file}"))

# Initialize tree lists
train_trees, validation_trees, test_trees = [], [], []

# Collect trees for train, validation, and test sets
collect_trees((0, 19), train_trees)
collect_trees((19, 22), validation_trees)
collect_trees((22, 25), test_trees)

# Print the number of trees in each set
print(f"Number of training trees: {len(train_trees)}")
print(f"Number of validation trees: {len(validation_trees)}")
print(f"Number of test trees: {len(test_trees)}")

Number of training trees: 38219
Number of validation trees: 5527
Number of test trees: 5462


In [None]:
print(Tree.fromstring(train_trees[0]))

(
  (S
    (NP-SBJ
      (NP (JJ Japanese))
      (NN investment)
      (PP-LOC (IN in) (NP (NNP Southeast) (NNP Asia))))
    (VP
      (VBZ is)
      (VP
        (VBG propelling)
        (NP (DT the) (NN region))
        (PP-DIR-CLR (IN toward) (NP (JJ economic) (NN integration)))))
    (. .)))


In [None]:
train = [Tree.fromstring(tree).pos() for tree in train_trees]
val = [Tree.fromstring(tree).pos() for tree in validation_trees]
test = [Tree.fromstring(tree).pos() for tree in test_trees]

In [None]:
# Example for the first tree in the training set
print(train[0])

[('Japanese', 'JJ'), ('investment', 'NN'), ('in', 'IN'), ('Southeast', 'NNP'), ('Asia', 'NNP'), ('is', 'VBZ'), ('propelling', 'VBG'), ('the', 'DT'), ('region', 'NN'), ('toward', 'IN'), ('economic', 'JJ'), ('integration', 'NN'), ('.', '.')]


In [None]:
# Processing Tags
tags = set()
for tree in train:
    for _, tag in tree:
        tags.add(tag)
tags = ",".join(sorted(tags))
tags = tags.split(",")
tags = ["<pad>"] + tags
# tag2idx = {tag: idx for idx, tag in enumerate(tags)}
# idx2tag = {idx: tag for idx, tag in enumerate(tags)}
# Update tag2idx and idx2tag
tag2idx = {tag: idx for idx, tag in enumerate(tags)}
tag2idx['<pad>'] = tag2idx.pop('<pad>')
tag2idx['X'] = len(tag2idx)

idx2tag = {idx: tag for idx, tag in enumerate(tags)}
print(tags)

['<pad>', '#', '$', "''", '', '', '-LRB-', '-NONE-', '-RRB-', '.', ':', 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB', '``']


In [None]:
MAX_LEN = 128

In [None]:
def trees_to_tokens_labels(trees):
    tokens_list = []
    labels_list = []

    for tree in trees:
        tokens = []
        labels = []
        for token, label in tree:
            tokens.append(token)
            labels.append(label)
        tokens_list.append(tokens)
        labels_list.append(labels)

    return tokens_list, labels_list

train_tokens, train_labels = trees_to_tokens_labels(train)
val_tokens, val_labels = trees_to_tokens_labels(val)
test_tokens, test_labels = trees_to_tokens_labels(test)

print(f"Number of training sentences: {len(train_tokens)}")
print(f"Number of validation sentences: {len(val_tokens)}")
print(f"Number of test sentences: {len(test_tokens)}")

Number of training sentences: 38219
Number of validation sentences: 5527
Number of test sentences: 5462


In [None]:
def tokenize_and_align_labels(tokens_list, labels_list, tokenizer, max_len):
    input_ids = []
    attention_masks = []
    label_ids = []

    for tokens, labels in zip(tokens_list, labels_list):
        tokens_enc = []
        label_enc = []

        tokens_enc.append("[CLS]")
        label_enc.append(tag2idx['<pad>'])

        for token, label in zip(tokens, labels):
            sub_tokens = tokenizer.tokenize(token)
            if not sub_tokens:
                sub_tokens = ['[UNK]']
            tokens_enc.extend(sub_tokens)
            label_enc.append(tag2idx.get(label, tag2idx['<pad>']))
            for _ in sub_tokens[1:]:
                label_enc.append(tag2idx['X'])

        tokens_enc.append("[SEP]")
        label_enc.append(tag2idx['<pad>'])

        input_id = tokenizer.convert_tokens_to_ids(tokens_enc)
        attention_mask = [1] * len(input_id)

        # Trunc
        if len(input_id) > max_len:
            input_id = input_id[:max_len-1] + [tokenizer.convert_tokens_to_ids("[SEP]")]
            attention_mask = attention_mask[:max_len-1] + [1]
            label_enc = label_enc[:max_len-1] + [tag2idx['<pad>']]
        # ideally, we can pad the input
        else:
            padding_length = max_len - len(input_id)
            input_id += [tokenizer.pad_token_id] * padding_length
            attention_mask += [0] * padding_length
            label_enc += [tag2idx['<pad>']] * padding_length

        input_ids.append(input_id)
        attention_masks.append(attention_mask)
        label_ids.append(label_enc)

    return torch.tensor(input_ids), torch.tensor(attention_masks), torch.tensor(label_ids)


In [None]:
from transformers import BertTokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

train_input_ids, train_attention_masks, train_label_ids = tokenize_and_align_labels(
    train_tokens, train_labels, tokenizer, MAX_LEN
)

val_input_ids, val_attention_masks, val_label_ids = tokenize_and_align_labels(
    val_tokens, val_labels, tokenizer, MAX_LEN
)

test_input_ids, test_attention_masks, test_label_ids = tokenize_and_align_labels(
    test_tokens, test_labels, tokenizer, MAX_LEN
)

print(f"Training input IDs shape: {train_input_ids.shape}")
print(f"Validation input IDs shape: {val_input_ids.shape}")
print(f"Test input IDs shape: {test_input_ids.shape}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Training input IDs shape: torch.Size([38219, 128])
Validation input IDs shape: torch.Size([5527, 128])
Test input IDs shape: torch.Size([5462, 128])


In [None]:
class POSDataset(Dataset):
    def __init__(self, input_ids, attention_masks, labels):
        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_masks[idx],
            'labels': self.labels[idx]
        }

# Dataset
train_dataset = POSDataset(train_input_ids, train_attention_masks, train_label_ids)
val_dataset = POSDataset(val_input_ids, val_attention_masks, val_label_ids)
test_dataset = POSDataset(test_input_ids, test_attention_masks, test_label_ids)

# DataLoaders
BATCH_SIZE = 16

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
num_labels = len(tag2idx)

# prtrained BERT model
model = BertForTokenClassification.from_pretrained(
    "bert-base-cased",
    num_labels=num_labels,
    output_attentions=False,
    output_hidden_states=False
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(f"Using device: {device}")


model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using device: cuda


In [None]:
from transformers import get_linear_schedule_with_warmup

optimizer = optim.AdamW(model.parameters(), lr=3e-5)

epochs = 3
total_steps = len(train_dataloader) * epochs

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

loss_fn = nn.CrossEntropyLoss(ignore_index=tag2idx['<pad>'])

def train_epoch(model, dataloader, optimizer, scheduler, device, loss_fn):
    model.train()
    total_loss = 0

    for batch in tqdm(dataloader, desc="Training"):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        scheduler.step()

    avg_loss = total_loss / len(dataloader)
    return avg_loss

def eval_model(model, dataloader, device, loss_fn):
    model.eval()
    total_loss = 0
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            total_loss += loss.item()

            logits = outputs.logits
            preds = torch.argmax(logits, dim=2)

            # Move tensors to CPU
            preds = preds.detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()

            predictions.extend(preds)
            true_labels.extend(labels)

    avg_loss = total_loss / len(dataloader)

    return avg_loss, predictions, true_labels


In [None]:
!pip install seqeval

Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16161 sha256=0ebb92a3613a8cb20439e347cdb490669ed04ee0cd3618bd2e49580b7df2bc4e
  Stored in directory: /root/.cache/pip/wheels/1a/67/4a/ad4082dd7dfc30f2abfe4d80a2ed5926a506eb8a972b4767fa
Successfully built seqeval
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2


In [None]:
from seqeval.metrics import classification_report, f1_score


In [None]:
for epoch in range(epochs):
    print(f"\n======== Epoch {epoch + 1} / {epochs} ========")
    print("Training...")

    train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, device, loss_fn)
    print(f"Training loss: {train_loss}")

    print("Evaluating...")
    val_loss, val_preds, val_true = eval_model(model, val_dataloader, device, loss_fn)
    print(f"Validation loss: {val_loss}")

    val_preds_labels = []
    val_true_labels = []

    for pred_seq, true_seq in zip(val_preds, val_true):
        pred_labels = []
        true_label = []
        for p, t in zip(pred_seq, true_seq):
            if t != tag2idx['<pad>']:
                pred_labels.append(idx2tag.get(p, "X"))
                true_label.append(idx2tag.get(t, "X"))
        val_preds_labels.append(pred_labels)
        val_true_labels.append(true_label)

    f1 = f1_score(val_true_labels, val_preds_labels)
    print(f"Validation F1 Score: {f1:.4f}")

    print("Classification Report:")
    print(classification_report(val_true_labels, val_preds_labels))



Training...


Training: 100%|██████████| 2389/2389 [03:14<00:00, 12.27it/s]


Training loss: 0.048294703268932156
Evaluating...


Evaluating: 100%|██████████| 346/346 [00:09<00:00, 37.30it/s]


Validation loss: 0.01387806895062478




Validation F1 Score: 0.9728
Classification Report:


  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           '       1.00      1.00      1.00       980
           B       0.94      0.94      0.94      7019
          BD       0.98      0.98      0.98      4618
          BG       0.95      0.94      0.94      1951
          BN       0.91      0.93      0.92      2746
          BP       0.96      0.97      0.97      1485
          BR       0.69      0.80      0.74       229
          BS       0.89      0.80      0.85        51
          BZ       0.99      0.99      0.99      2814
           C       1.00      0.99      1.00      3380
           D       0.99      0.99      0.99      5889
          DT       0.95      0.97      0.96       609
           H       1.00      0.08      0.14        13
           J       0.91      0.93      0.92      7575
          JR       0.89      0.89      0.89       444
          JS       0.94      0.97      0.95       262
           N       0.96      0.95      0.96     24004
          NP       0.96    

Training: 100%|██████████| 2389/2389 [03:13<00:00, 12.35it/s]


Training loss: 0.011554799458501621
Evaluating...


Evaluating: 100%|██████████| 346/346 [00:09<00:00, 37.42it/s]


Validation loss: 0.013202031628858578
Validation F1 Score: 0.9733
Classification Report:
              precision    recall  f1-score   support

           '       1.00      1.00      1.00       980
           B       0.93      0.94      0.94      7019
          BD       0.98      0.98      0.98      4618
          BG       0.94      0.95      0.94      1951
          BN       0.92      0.92      0.92      2746
          BP       0.97      0.98      0.97      1485
          BR       0.70      0.81      0.75       229
          BS       0.93      0.80      0.86        51
          BZ       0.99      0.99      0.99      2814
           C       1.00      0.99      1.00      3380
           D       0.99      1.00      1.00      5889
          DT       0.95      0.96      0.96       609
           H       0.86      0.46      0.60        13
           J       0.90      0.94      0.92      7575
          JR       0.90      0.91      0.90       444
          JS       0.96      0.98      0.97   

Training: 100%|██████████| 2389/2389 [03:13<00:00, 12.36it/s]


Training loss: 0.008789662173547649
Evaluating...


Evaluating: 100%|██████████| 346/346 [00:09<00:00, 37.54it/s]


Validation loss: 0.012999179710451019
Validation F1 Score: 0.9745
Classification Report:
              precision    recall  f1-score   support

           '       1.00      1.00      1.00       980
           B       0.94      0.94      0.94      7019
          BD       0.98      0.98      0.98      4618
          BG       0.94      0.95      0.94      1951
          BN       0.91      0.93      0.92      2746
          BP       0.97      0.98      0.97      1485
          BR       0.74      0.79      0.76       229
          BS       0.93      0.80      0.86        51
          BZ       0.99      0.99      0.99      2814
           C       1.00      1.00      1.00      3380
           D       0.99      1.00      1.00      5889
          DT       0.95      0.97      0.96       609
           H       0.88      0.54      0.67        13
           J       0.91      0.93      0.92      7575
          JR       0.89      0.93      0.91       444
          JS       0.96      0.98      0.97   

In [None]:
# FINAL EVAL ONLY, TEST SET
from itertools import chain
print("\nEvaluating on Test Set...")
test_loss, test_preds, test_true = eval_model(model, test_dataloader, device, loss_fn)
print(f"Test loss: {test_loss}")

test_preds_labels = []
test_true_labels = []

for pred_seq, true_seq in zip(test_preds, test_true):
    pred_labels = []
    true_label = []
    for p, t in zip(pred_seq, true_seq):
        if t != tag2idx['<pad>']:
            pred_labels.append(idx2tag.get(p, "X"))
            true_label.append(idx2tag.get(t, "X"))
    test_preds_labels.append(pred_labels)
    test_true_labels.append(true_label)

test_f1 = f1_score(test_true_labels, test_preds_labels)
print(f"Test F1 Score: {test_f1:.4f}")

# Print test accuracy

# Flatten the predicted and true label lists
flat_test_preds_labels = list(chain.from_iterable(test_preds_labels))
flat_test_true_labels = list(chain.from_iterable(test_true_labels))

# Compute accuracy
test_acc = (np.array(flat_test_preds_labels) == np.array(flat_test_true_labels)).mean()
print(f"Test Accuracy: {test_acc :.4f}")

print("Test Classification Report:")
print(classification_report(test_true_labels, test_preds_labels))



Evaluating on Test Set...


Evaluating: 100%|██████████| 342/342 [00:09<00:00, 37.27it/s]


Test loss: 0.012540405189648423




Test F1 Score: 0.9738
Test Accuracy: 0.9831
Test Classification Report:
              precision    recall  f1-score   support

           '       1.00      1.00      1.00      1041
           B       0.94      0.94      0.94      7330
          BD       0.98      0.97      0.98      4554
          BG       0.95      0.94      0.94      1930
          BN       0.90      0.93      0.92      2608
          BP       0.97      0.99      0.98      1564
          BR       0.88      0.81      0.84       271
          BS       0.93      0.91      0.92        69
          BZ       0.99      0.99      0.99      2639
           C       1.00      1.00      1.00      3243
           D       0.99      0.99      0.99      5566
          DT       0.94      0.97      0.95       628
           H       0.69      0.73      0.71        15
           J       0.92      0.93      0.93      7721
          JR       0.90      0.96      0.93       422
          JS       0.97      0.97      0.97       267
         



In [None]:
# Save the model
output_dir = './bert-pos-model/'


# Create the directory if it does not exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Save model weights
torch.save(model.state_dict(), os.path.join(output_dir, "model_weights.pth"))

# Save model config
with open(os.path.join(output_dir, "config.json"), 'w') as f:
    f.write(model.config.to_json_string())

# Save tokenizer
tokenizer.save_pretrained(output_dir)


('./bert-pos-model/tokenizer_config.json',
 './bert-pos-model/special_tokens_map.json',
 './bert-pos-model/vocab.txt',
 './bert-pos-model/added_tokens.json')

In [None]:
# TO LOAD THE MODEL LATER!

from transformers import BertForTokenClassification, BertTokenizer, BertConfig
import torch

# Load tokenizer and config
tokenizer = BertTokenizer.from_pretrained('./bert-pos-model/')
config = BertConfig.from_pretrained('./bert-pos-model/config.json')

# Initialize the model
model = BertForTokenClassification(config)

# Load weights
model.load_state_dict(torch.load('./bert-pos-model/model_weights.pth'))
model.eval()  # Set to evaluation mode

print("Model loaded successfully!")


#### BERT with CRF

In [7]:
! pip install pytorch-crf



In [8]:
# INITIAL REFERENCE CODE!
# from torchcrf import CRF


# class BertCRFTagger(nn.Module):

#   def __init__(self, bert, hidden_size, num_tags, dropout):
#     super().__init__()
#     self.bert = bert
#     self.crf = CRF(num_tags, batch_first=True)
#     self.fc = nn.Linear(hidden_size, num_tags)
#     self.dropout = nn.Dropout(dropout)

#   def generate_mask(self, input_temaplte):
#     bs = input_temaplte.size(0)
#     seq_len = torch.max(input_temaplte)
#     mask = torch.ByteTensor(bs, seq_len).fill_(0)
#     for i in range(bs):
#       mask[i, :input_temaplte[i]] = 1
#     return mask

#   def forward(self, input_ids, text_lens, tags=None):
#     bert_output = self.bert(input_ids)
#     last_hidden_state = bert_output['hidden_states'][-1]

#     emission = self.fc(last_hidden_state)
#     mask = self.generate_mask(text_lens).to(device)

#     if tags is not None:
#       loss = -self.crf(torch.log_softmax(emission, dim=2), tags, mask=mask, reduction='mean')
#       return loss
#     else:
#       prediction = self.crf.decode(emission, mask=mask)
#       return prediction

In [9]:
from torchcrf import CRF


class BertCRFTagger(nn.Module):

  def __init__(self, bert_model_name, hidden_size, num_tags, dropout_rate = 0.1):
    super(BertCRFTagger, self).__init__()
    self.bert = BertModel.from_pretrained(bert_model_name)
    self.dropout = nn.Dropout(dropout_rate)
    self.fc = nn.Linear(hidden_size, num_tags)
    self.crf = CRF(num_tags, batch_first=True)

  def forward(self, input_ids, attention_mask, labels = None):
    # Get BERT embeddings
    outputs = self.bert(input_ids, attention_mask=attention_mask)
    sequence_output = self.dropout(outputs.last_hidden_state)

    # Linear layer for emissions
    emissions = self.fc(sequence_output)

    # If labels are provided, calculate the loss
    if labels is not None:
      mask = attention_mask.bool() # CRF requires a mask
      loss = -self.crf(emissions, labels, mask=mask, reduction='mean')
      return loss

    # If labels are not provided, decode the tags
    else:
      mask = attention_mask.bool()
      predictions = self.crf.decode(emissions, mask=mask)
      return predictions

In [6]:
# Update model initializations

hidden_size = 768 # Hidden size for BERT-base
num_labels = len(tag2idx)
dropout_rate = 0.1
bert_model_name = 'bert-base-cased'

model = BertCRFTagger(bert_model_name, hidden_size, num_labels, dropout_rate).to(device)

print(f"Using device:{device}")

# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=3e-5)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

NameError: name 'tag2idx' is not defined

In [10]:
# Update training function - the training function needs to compute loss using the CRF layer now

def train_epoch(model, dataloader, optimizer, scheduler, device):
  model.train()
  total_loss = 0

  for batch in tqdm(dataloader, desc = "Training"):
    optimizer.zero_grad()

    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)

    # Forward pass
    loss = model(input_ids = input_ids, attention_mask = attention_mask, labels = labels)
    total_loss += loss.item()

    # Backward pass and optimization
    loss.backward()
    optimizer.step()
    scheduler.step()

  avg_loss = total_loss / len(dataloader)
  return avg_loss


In [11]:
# Update the evaluation function
def eval_model(model, dataloader, device):
  model.eval()
  total_loss = 0
  predictions = []
  true_labels = []

  with torch.no_grad():
    for batch in tqdm(dataloader, desc = "Evaluating"):
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      labels = batch['labels'].to(device)

      # Forward pass
      loss = model(input_ids = input_ids, attention_mask = attention_mask, labels = labels)
      total_loss += loss.item()

      # Get predictions
      preds = model(input_ids = input_ids, attention_mask = attention_mask)
      predictions.extend(preds)
      true_labels.extend(labels.detach().cpu().tolist())

  avg_loss = total_loss / len(dataloader)
  return avg_loss, predictions, true_labels

In [None]:
# train the model
for epoch in range(epochs):
    print(f"\n======== Epoch {epoch + 1} / {epochs} ========")
    print("Training...")

    train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, device)
    print(f"Training loss: {train_loss}")

    print("Evaluating...")
    val_loss, val_preds, val_true = eval_model(model, val_dataloader, device)
    print(f"Validation loss: {val_loss}")

    # # Convert true labels and predictions to tag names
    # val_preds_labels = [[idx2tag[p] for p in seq] for seq in val_preds]
    # val_true_labels = [[idx2tag[t] for t in seq if t != tag2idx['<pad>']] for seq in val_true]

    # f1 = f1_score(val_true_labels, val_preds_labels, average='weighted')
    # print(f"Validation F1 Score: {f1:.4f}")

    # print("Classification Report:")
    # print(classification_report(val_true_labels, val_preds_labels))

    val_preds_labels = []
    val_true_labels = []

    for pred_seq, true_seq in zip(val_preds, val_true):
        filtered_preds = [idx2tag[p] for p, t in zip(pred_seq, true_seq) if t != tag2idx['<pad>']]
        filtered_labels = [idx2tag[t] for t in true_seq if t != tag2idx['<pad>']]

        if len(filtered_preds) == len(filtered_labels):  # Ensure alignment
            val_preds_labels.append(filtered_preds)
            val_true_labels.append(filtered_labels)

    # Compute F1 score
    f1 = f1_score(val_true_labels, val_preds_labels, average='weighted')
    print(f"Validation F1 Score: {f1:.4f}")

    # Classification report
    print("Classification Report:")
    print(classification_report(val_true_labels, val_preds_labels))


Training...


Training: 100%|██████████| 2389/2389 [08:19<00:00,  4.78it/s]


Training loss: 3.754462001662017
Evaluating...


Evaluating: 100%|██████████| 346/346 [00:38<00:00,  8.88it/s]


Validation loss: 1.7037329026040313




Validation F1 Score: 0.9731
Classification Report:


  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           '       1.00      1.00      1.00       980
           B       0.94      0.94      0.94      7019
          BD       0.97      0.99      0.98      4618
          BG       0.94      0.95      0.94      1951
          BN       0.91      0.92      0.92      2746
          BP       0.97      0.98      0.97      1485
          BR       0.72      0.76      0.74       229
          BS       0.89      0.80      0.85        51
          BZ       0.99      0.99      0.99      2814
           C       0.99      1.00      1.00      3380
           D       0.99      1.00      1.00      5889
          DT       0.96      0.92      0.94       609
           H       0.50      0.08      0.13        13
           J       0.92      0.92      0.92      7575
          JR       0.88      0.94      0.91       444
          JS       0.96      0.97      0.96       262
           N       0.96      0.95      0.96     24004
          NP       0.96    

Training: 100%|██████████| 2389/2389 [08:23<00:00,  4.75it/s]


Training loss: 1.4000765681017289
Evaluating...


Evaluating: 100%|██████████| 346/346 [00:38<00:00,  8.88it/s]


Validation loss: 1.6344475208679377
Validation F1 Score: 0.9744
Classification Report:
              precision    recall  f1-score   support

           '       1.00      1.00      1.00       980
           B       0.93      0.94      0.94      7019
          BD       0.97      0.98      0.98      4618
          BG       0.94      0.94      0.94      1951
          BN       0.90      0.95      0.92      2746
          BP       0.97      0.98      0.97      1485
          BR       0.79      0.72      0.76       229
          BS       0.91      0.80      0.85        51
          BZ       0.99      0.99      0.99      2814
           C       1.00      1.00      1.00      3380
           D       0.99      1.00      1.00      5889
          DT       0.95      0.96      0.95       609
           H       0.78      0.54      0.64        13
           J       0.93      0.92      0.92      7575
          JR       0.87      0.95      0.91       444
          JS       0.96      0.97      0.97     

Training: 100%|██████████| 2389/2389 [08:21<00:00,  4.77it/s]


Training loss: 1.0683629369201797
Evaluating...


Evaluating: 100%|██████████| 346/346 [00:38<00:00,  8.89it/s]


Validation loss: 1.6206061357707646
Validation F1 Score: 0.9749
Classification Report:
              precision    recall  f1-score   support

           '       1.00      1.00      1.00       980
           B       0.94      0.94      0.94      7019
          BD       0.98      0.98      0.98      4618
          BG       0.94      0.95      0.94      1951
          BN       0.92      0.93      0.92      2746
          BP       0.97      0.98      0.97      1485
          BR       0.73      0.77      0.75       229
          BS       0.91      0.80      0.85        51
          BZ       0.99      0.99      0.99      2814
           C       1.00      1.00      1.00      3380
           D       0.99      1.00      1.00      5889
          DT       0.96      0.97      0.96       609
           H       0.88      0.54      0.67        13
           J       0.92      0.93      0.92      7575
          JR       0.88      0.93      0.90       444
          JS       0.96      0.97      0.97     

In [12]:
def evaluate_test_set_with_accuracy(model, dataloader, device, idx2tag, tag2idx):
    model.eval()
    test_preds = []
    test_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # CRF Decode
            preds = model(input_ids=input_ids, attention_mask=attention_mask)
            test_preds.extend(preds)  # CRF decode returns a list of lists
            test_labels.extend(labels.cpu().numpy().tolist())  # Convert to list

    # Align predictions and labels
    aligned_preds = []
    aligned_labels = []
    for pred_seq, true_seq in zip(test_preds, test_labels):
        filtered_preds = [idx2tag[p] for p, t in zip(pred_seq, true_seq) if t != tag2idx['<pad>']]
        filtered_labels = [idx2tag[t] for t in true_seq if t != tag2idx['<pad>']]

        if len(filtered_preds) == len(filtered_labels):  # Ensure length match
            aligned_preds.append(filtered_preds)
            aligned_labels.append(filtered_labels)

    # Calculate F1 score
    f1 = f1_score(aligned_labels, aligned_preds, average='weighted')
    print(f"Test F1 Score: {f1:.4f}")

    # Calculate Accuracy
    correct = sum(p == l for preds, labels in zip(aligned_preds, aligned_labels) for p, l in zip(preds, labels))
    total = sum(len(labels) for labels in aligned_labels)
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy:.4f}")

    print("Test Classification Report:")
    print(classification_report(aligned_labels, aligned_preds))

    return aligned_preds, aligned_labels, accuracy

test_preds_labels, test_true_labels, test_accuracy = evaluate_test_set_with_accuracy(
    model=model,
    dataloader=test_dataloader,
    device=device,
    idx2tag=idx2tag,
    tag2idx=tag2idx
)


In [None]:
# Saving the model

# Directory to save the model
output_dir = './bert-crf-pos-model/'

# Create the directory if it doesn't exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Save model weights
torch.save(model.state_dict(), os.path.join(output_dir, "model_weights.pth"))

# Save model configuration
with open(os.path.join(output_dir, "config.json"), 'w') as f:
    f.write(model.bert.config.to_json_string())  # Save BERT config

# Save tokenizer
tokenizer.save_pretrained(output_dir)

print(f"Model and related files saved to {output_dir}")

Model and related files saved to ./bert-crf-pos-model/


In [None]:
# Code to load in the model again

from transformers import BertModel, BertTokenizer
import torch

# Load tokenizer and BERT configuration
tokenizer = BertTokenizer.from_pretrained('./bert-crf-pos-model/')
bert_config = BertModel.from_pretrained('./bert-crf-pos-model/config.json').config

# Reinitialize the model
loaded_model = BertCRFTagger(
    bert_model_name="bert-base-cased",  # Ensure it matches your pretrained BERT
    hidden_size=768,
    num_labels=len(tag2idx),
    dropout_rate=0.1
)

# Load weights
loaded_model.load_state_dict(torch.load('./bert-crf-pos-model/model_weights.pth'))
loaded_model.to(device)

print("Model reloaded successfully!")

In [None]:
# Example input
input_ids = torch.tensor([[101, 2023, 2003, 1037, 2742, 102]]).to(device)  # Replace with your actual test input
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]]).to(device)

# Decode predictions
loaded_model.eval()
with torch.no_grad():
    preds = loaded_model(input_ids=input_ids, attention_mask=attention_mask)

print("Predicted labels:", preds)

#### BERT + CRF on augmented data

In [13]:
data_augmentation = False

In [14]:
!unzip wsj.zip

Archive:  wsj.zip
   creating: wsj/
   creating: wsj/11/
  inflating: wsj/11/wsj_1138.mrg     
  inflating: wsj/11/wsj_1130.mrg     
  inflating: wsj/11/wsj_1197.mrg     
  inflating: wsj/11/wsj_1178.mrg     
  inflating: wsj/11/wsj_1161.mrg     
  inflating: wsj/11/wsj_1167.mrg     
  inflating: wsj/11/wsj_1175.mrg     
  inflating: wsj/11/wsj_1101.mrg     
  inflating: wsj/11/wsj_1110.mrg     
  inflating: wsj/11/wsj_1119.mrg     
  inflating: wsj/11/wsj_1176.mrg     
  inflating: wsj/11/wsj_1168.mrg     
  inflating: wsj/11/wsj_1174.mrg     
  inflating: wsj/11/wsj_1142.mrg     
  inflating: wsj/11/wsj_1190.mrg     
  inflating: wsj/11/wsj_1170.mrg     
  inflating: wsj/11/wsj_1124.mrg     
  inflating: wsj/11/wsj_1129.mrg     
  inflating: wsj/11/wsj_1154.mrg     
  inflating: wsj/11/wsj_1173.mrg     
  inflating: wsj/11/wsj_1148.mrg     
  inflating: wsj/11/wsj_1135.mrg     
  inflating: wsj/11/wsj_1162.mrg     
  inflating: wsj/11/wsj_1132.mrg     
  inflating: wsj/11/wsj_1181.mr

In [15]:
def parse_trees(file_path):
    with open(file_path, "r") as file:
        data = file.read()

    trees = []
    current_tree = []
    open_parens = 0

    for line in data.splitlines():
        line = line.strip()
        if not line:
            continue

        current_tree.append(line)
        open_parens += line.count("(") - line.count(")")

        # If open_parens reaches zero, we've completed a tree
        if open_parens == 0:
            tree_string = " ".join(current_tree)
            trees.append(tree_string)
            current_tree = []

    return trees

In [16]:
# wsj/00-18 is the train set
# wsj/19-21 is the dev set
# wsj/22-24 is the test set

# separate the trees into train, dev, and test sets

def collect_trees(range_limits, target_list):
    for i in range(*range_limits):
        directory = f"wsj/{i:02}"
        files = os.listdir(directory)
        for file in files:
            target_list.extend(parse_trees(f"{directory}/{file}"))

# Initialize tree lists
train_trees, validation_trees, test_trees = [], [], []

# Collect trees for train, validation, and test sets
collect_trees((0, 19), train_trees)
collect_trees((19, 22), validation_trees)
collect_trees((22, 25), test_trees)

# Print the number of trees in each set
print(f"Number of training trees: {len(train_trees)}")
print(f"Number of validation trees: {len(validation_trees)}")
print(f"Number of test trees: {len(test_trees)}")

Number of training trees: 38219
Number of validation trees: 5527
Number of test trees: 5462


In [17]:
print(Tree.fromstring(train_trees[0]))

(
  (S
    (NP-SBJ
      (NP (JJ Japanese))
      (NN investment)
      (PP-LOC (IN in) (NP (NNP Southeast) (NNP Asia))))
    (VP
      (VBZ is)
      (VP
        (VBG propelling)
        (NP (DT the) (NN region))
        (PP-DIR-CLR (IN toward) (NP (JJ economic) (NN integration)))))
    (. .)))


In [18]:
train = [Tree.fromstring(tree).pos() for tree in train_trees]
val = [Tree.fromstring(tree).pos() for tree in validation_trees]
test = [Tree.fromstring(tree).pos() for tree in test_trees]

In [19]:
# Example for the first tree in the training set
print(train[0])

[('Japanese', 'JJ'), ('investment', 'NN'), ('in', 'IN'), ('Southeast', 'NNP'), ('Asia', 'NNP'), ('is', 'VBZ'), ('propelling', 'VBG'), ('the', 'DT'), ('region', 'NN'), ('toward', 'IN'), ('economic', 'JJ'), ('integration', 'NN'), ('.', '.')]


In [20]:
# Processing Tags
tags = set()
for tree in train:
    for _, tag in tree:
        tags.add(tag)
tags = ",".join(sorted(tags))
tags = tags.split(",")
tags = ["<pad>"] + tags
# tag2idx = {tag: idx for idx, tag in enumerate(tags)}
# idx2tag = {idx: tag for idx, tag in enumerate(tags)}
# Update tag2idx and idx2tag
tag2idx = {tag: idx for idx, tag in enumerate(tags)}
tag2idx['<pad>'] = tag2idx.pop('<pad>')
tag2idx['X'] = len(tag2idx)

idx2tag = {idx: tag for idx, tag in enumerate(tags)}
print(tags)

['<pad>', '#', '$', "''", '', '', '-LRB-', '-NONE-', '-RRB-', '.', ':', 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB', '``']


In [21]:
MAX_LEN = 128

In [22]:
def trees_to_tokens_labels(trees):
    tokens_list = []
    labels_list = []

    for tree in trees:
        tokens = []
        labels = []
        for token, label in tree:
            tokens.append(token)
            labels.append(label)
        tokens_list.append(tokens)
        labels_list.append(labels)

    return tokens_list, labels_list

train_tokens, train_labels = trees_to_tokens_labels(train)
val_tokens, val_labels = trees_to_tokens_labels(val)
test_tokens, test_labels = trees_to_tokens_labels(test)

print(f"Number of training sentences: {len(train_tokens)}")
print(f"Number of validation sentences: {len(val_tokens)}")
print(f"Number of test sentences: {len(test_tokens)}")

Number of training sentences: 38219
Number of validation sentences: 5527
Number of test sentences: 5462


In [23]:
import random
import string
from nltk.corpus import wordnet

# Probability settings
SYNONYM_PROB = 0.15      # Now also used to determine how many sentences to augment
NOISE_PROB = 0.05
POS_CONFUSION_PROB = 0.1

def penn_to_wn_pos(tag):
    tag = tag.upper()
    return (wordnet.NOUN if tag.startswith('NN') else
            wordnet.VERB if tag.startswith('VB') else
            wordnet.ADJ  if tag.startswith('JJ') else
            wordnet.ADV  if tag.startswith('RB') else None)

def wn_pos_to_penn(wn_pos):
    """Map a WordNet POS back to a representative Penn Treebank tag."""
    if wn_pos == wordnet.NOUN:
        return "NN"
    elif wn_pos == wordnet.VERB:
        return "VB"
    elif wn_pos == wordnet.ADJ:
        return "JJ"
    elif wn_pos == wordnet.ADV:
        return "RB"
    else:
        return None

def get_pos_synonyms(word, wn_pos):
    """Get a list of synonyms for 'word' that share the given WordNet POS."""
    synsets = wordnet.synsets(word, pos=wn_pos)
    synonyms = set()
    for syn in synsets:
        for lemma in syn.lemmas():
            lemma_name = lemma.name().replace('_', ' ')
            # Exclude the original word itself
            if lemma_name.lower() != word.lower():
                synonyms.add(lemma_name)
    return list(synonyms)

def introduce_typo(word):
    """Introduce a small typo by swapping two adjacent characters."""
    if len(word) < 2:
        return word
    idx = random.randint(0, len(word)-2)
    word_as_list = list(word)
    word_as_list[idx], word_as_list[idx+1] = word_as_list[idx+1], word_as_list[idx]
    return "".join(word_as_list)

def find_alternative_pos(word, original_wn_pos):
    """Find alternative POS categories the given word can appear in, based on all its synsets."""
    all_synsets = wordnet.synsets(word)
    # Collect all possible POS forms for this word
    possible_pos = {s.pos() for s in all_synsets}
    # Remove the original POS from the set
    alternative_pos = possible_pos - {original_wn_pos}
    return alternative_pos

def augment_tokens_for_pos(tokens, labels,
                           synonym_prob=SYNONYM_PROB,
                           noise_prob=NOISE_PROB,
                           pos_confusion_prob=POS_CONFUSION_PROB):
    augmented_tokens = []
    augmented_labels = []

    for (word, label) in zip(tokens, labels):
        wn_pos = penn_to_wn_pos(label)
        original_word = word
        new_label = label  # Default to original label unless we change it

        # Check if we apply synonym replacement to this token
        if wn_pos and random.random() < synonym_prob:
            # Get synonyms for the same POS
            same_pos_synonyms = get_pos_synonyms(word, wn_pos)

            # Check if we should attempt POS confusion
            if random.random() < pos_confusion_prob:
                # Find alternative POS forms for the same lemma
                alternative_pos_list = find_alternative_pos(original_word, wn_pos)

                pos_confused = False
                for alt_pos in alternative_pos_list:
                    alt_synonyms = get_pos_synonyms(original_word, alt_pos)
                    if alt_synonyms:
                        word = random.choice(alt_synonyms)
                        # Update label to reflect the new POS
                        alt_label = wn_pos_to_penn(alt_pos)
                        if alt_label:
                            new_label = alt_label
                        pos_confused = True
                        break

                # If no alternative POS synonyms found, fallback to same POS synonyms
                if not pos_confused:
                    if same_pos_synonyms:
                        word = random.choice(same_pos_synonyms)
                        # Same POS, so keep original label
                        new_label = label
                    else:
                        # No synonyms found, keep original word and label
                        word = original_word
                        new_label = label
            else:
                # No POS confusion, just pick from same POS synonyms if available
                if same_pos_synonyms:
                    word = random.choice(same_pos_synonyms)
                    # Same POS, keep original label
                    new_label = label
                else:
                    # No synonyms found, keep original word and label
                    word = original_word
                    new_label = label

        # Introduce noise (typo) if applicable
        if random.random() < noise_prob:
            word = introduce_typo(word)
            # Noise doesn't change the POS, so no label change

        augmented_tokens.append(word)
        augmented_labels.append(new_label)

    return augmented_tokens, augmented_labels


# Example usage to increase dataset size by ~15% (based on SYNONYM_PROB)
if data_augmentation:
    num_original = len(train_tokens)
    # Determine how many sentences to augment (about synonym_prob * 100%)
    num_to_augment = int(num_original * SYNONYM_PROB)

    # Randomly sample which sentences to augment
    indices_to_augment = random.sample(range(num_original), num_to_augment)

    augmented_tokens_list = []
    augmented_labels_list = []

    for i in indices_to_augment:
        tokens, labels = train_tokens[i], train_labels[i]
        aug_tokens, aug_labels = augment_tokens_for_pos(tokens, labels)
        augmented_tokens_list.append(aug_tokens)
        augmented_labels_list.append(aug_labels)

    # Append augmented data to the original dataset
    train_tokens += augmented_tokens_list
    train_labels += augmented_labels_list

    print(f"Number of original training sentences: {num_original}")
    print(f"Number of augmented training sentences: {len(augmented_tokens_list)}")
    print(f"Total number of training sentences after augmentation: {len(train_tokens)}")

In [24]:
def tokenize_and_align_labels(tokens_list, labels_list, tokenizer, max_len):
    input_ids = []
    attention_masks = []
    label_ids = []

    for tokens, labels in zip(tokens_list, labels_list):
        tokens_enc = []
        label_enc = []

        tokens_enc.append("[CLS]")
        label_enc.append(tag2idx['<pad>'])

        for token, label in zip(tokens, labels):
            sub_tokens = tokenizer.tokenize(token)
            if not sub_tokens:
                sub_tokens = ['[UNK]']
            tokens_enc.extend(sub_tokens)
            label_enc.append(tag2idx.get(label, tag2idx['<pad>']))
            for _ in sub_tokens[1:]:
                label_enc.append(tag2idx['X'])

        tokens_enc.append("[SEP]")
        label_enc.append(tag2idx['<pad>'])

        input_id = tokenizer.convert_tokens_to_ids(tokens_enc)
        attention_mask = [1] * len(input_id)

        # Trunc
        if len(input_id) > max_len:
            input_id = input_id[:max_len-1] + [tokenizer.convert_tokens_to_ids("[SEP]")]
            attention_mask = attention_mask[:max_len-1] + [1]
            label_enc = label_enc[:max_len-1] + [tag2idx['<pad>']]
        # ideally, we can pad the input
        else:
            padding_length = max_len - len(input_id)
            input_id += [tokenizer.pad_token_id] * padding_length
            attention_mask += [0] * padding_length
            label_enc += [tag2idx['<pad>']] * padding_length

        input_ids.append(input_id)
        attention_masks.append(attention_mask)
        label_ids.append(label_enc)

    return torch.tensor(input_ids), torch.tensor(attention_masks), torch.tensor(label_ids)


In [25]:
from transformers import BertTokenizer

In [26]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

train_input_ids, train_attention_masks, train_label_ids = tokenize_and_align_labels(
    train_tokens, train_labels, tokenizer, MAX_LEN
)

val_input_ids, val_attention_masks, val_label_ids = tokenize_and_align_labels(
    val_tokens, val_labels, tokenizer, MAX_LEN
)

test_input_ids, test_attention_masks, test_label_ids = tokenize_and_align_labels(
    test_tokens, test_labels, tokenizer, MAX_LEN
)

print(f"Training input IDs shape: {train_input_ids.shape}")
print(f"Validation input IDs shape: {val_input_ids.shape}")
print(f"Test input IDs shape: {test_input_ids.shape}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Training input IDs shape: torch.Size([38219, 128])
Validation input IDs shape: torch.Size([5527, 128])
Test input IDs shape: torch.Size([5462, 128])


In [27]:
class POSDataset(Dataset):
    def __init__(self, input_ids, attention_masks, labels):
        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_masks[idx],
            'labels': self.labels[idx]
        }

# Dataset
train_dataset = POSDataset(train_input_ids, train_attention_masks, train_label_ids)
val_dataset = POSDataset(val_input_ids, val_attention_masks, val_label_ids)
test_dataset = POSDataset(test_input_ids, test_attention_masks, test_label_ids)

# DataLoaders
BATCH_SIZE = 16

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [31]:
# Update model initializations
import torch
import torch.nn as nn
from transformers import BertModel
from torchcrf import CRF
from transformers import get_linear_schedule_with_warmup
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 3
hidden_size = 768 # Hidden size for BERT-base
num_labels = len(tag2idx)
dropout_rate = 0.1
bert_model_name = 'bert-base-cased'

model = BertCRFTagger(bert_model_name, hidden_size, num_labels, dropout_rate).to(device)

print(f"Using device:{device}")

# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=3e-5)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

Using device:cuda


In [None]:
from sklearn.metrics import f1_score, classification_report, accuracy_score

# train the model
for epoch in range(epochs):
    print(f"\n======== Epoch {epoch + 1} / {epochs} ========")
    print("Training...")

    train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, device)
    print(f"Training loss: {train_loss}")

    print("Evaluating...")
    val_loss, val_preds, val_true = eval_model(model, val_dataloader, device)
    print(f"Validation loss: {val_loss}")

    # # Convert true labels and predictions to tag names
    # val_preds_labels = [[idx2tag[p] for p in seq] for seq in val_preds]
    # val_true_labels = [[idx2tag[t] for t in seq if t != tag2idx['<pad>']] for seq in val_true]

    # f1 = f1_score(val_true_labels, val_preds_labels, average='weighted')
    # print(f"Validation F1 Score: {f1:.4f}")

    # print("Classification Report:")
    # print(classification_report(val_true_labels, val_preds_labels))

    val_preds_labels = []
    val_true_labels = []

    for pred_seq, true_seq in zip(val_preds, val_true):
        filtered_preds = [idx2tag[p] for p, t in zip(pred_seq, true_seq) if t != tag2idx['<pad>']]
        filtered_labels = [idx2tag[t] for t in true_seq if t != tag2idx['<pad>']]

        if len(filtered_preds) == len(filtered_labels):  # Ensure alignment
            val_preds_labels.append(filtered_preds)
            val_true_labels.append(filtered_labels)


    # Flatten the nested lists into one-dimensional lists
    flattened_preds = [label for seq in val_preds_labels for label in seq]
    flattened_true = [label for seq in val_true_labels for label in seq]

    # Compute F1 score
    f1 = f1_score(flattened_true, flattened_preds, average='weighted')
    print(f"Validation F1 Score: {f1:.4f}")


    # Compute accuracy
    accuracy = accuracy_score(flattened_true, flattened_preds)
    print(f"Validation Accuracy: {accuracy:.4f}")


    # Classification report
    print("Classification Report:")
    print(classification_report(flattened_true, flattened_preds))


Training...


Training:   2%|▏         | 53/2389 [00:10<08:11,  4.75it/s]

In [None]:
test_preds_labels, test_true_labels, test_accuracy = evaluate_test_set_with_accuracy(
    model=model,
    dataloader=test_dataloader,
    device=device,
    idx2tag=idx2tag,
    tag2idx=tag2idx
)

In [None]:
# Saving the model

# Directory to save the model
output_dir = './bert-crf-pos-model/'

# Create the directory if it doesn't exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Save model weights
torch.save(model.state_dict(), os.path.join(output_dir, "model_weights.pth"))

# Save model configuration
with open(os.path.join(output_dir, "config.json"), 'w') as f:
    f.write(model.bert.config.to_json_string())  # Save BERT config

# Save tokenizer
tokenizer.save_pretrained(output_dir)

print(f"Model and related files saved to {output_dir}")