# HW4: Deep Learning on NER

### Setup environment

In [1]:
!pip install -q datasets accelerate

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/493.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.9/493.7 kB[0m [31m3.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m493.7/493.7 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m30.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.2/311.2 kB[0m [31m33.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
!wget http://nlp.stanford.edu/data/glove.6B.zip

--2023-11-11 02:48:44--  http://nlp.stanford.edu/data/glove.6B.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://nlp.stanford.edu/data/glove.6B.zip [following]
--2023-11-11 02:48:44--  https://nlp.stanford.edu/data/glove.6B.zip
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following]
--2023-11-11 02:48:44--  https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 862182613 (822M) [application/zip]
Saving to: ‘glove.6B.zip’


202

In [3]:
!unzip glove.6B.zip

Archive:  glove.6B.zip
  inflating: glove.6B.50d.txt        
  inflating: glove.6B.100d.txt       
  inflating: glove.6B.200d.txt       
  inflating: glove.6B.300d.txt       


In [4]:
!wget https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py

--2023-11-11 02:51:56--  https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7502 (7.3K) [text/plain]
Saving to: ‘conlleval.py’


2023-11-11 02:51:56 (31.4 MB/s) - ‘conlleval.py’ saved [7502/7502]



In [5]:
!ls # Sanity check

conlleval.py	   glove.6B.200d.txt  glove.6B.50d.txt	sample_data
glove.6B.100d.txt  glove.6B.300d.txt  glove.6B.zip


## Task 0: Prepare Data

### Load dataset

In [6]:
import datasets

dataset = datasets.load_dataset("conll2003")
dataset

Downloading builder script:   0%|          | 0.00/9.57k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/3.73k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/12.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/983k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/14041 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3250 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3453 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

### Create vocabulary

In [7]:
import itertools
from collections import Counter

word_freq = Counter(itertools.chain(*dataset['train']['tokens'])) # type: ignore

word_freq = {
    word: freq for word, freq in word_freq.items() if freq >= 3
}

word2idx = {
    word: idx for idx, word in enumerate(word_freq.keys(), start=2)
}

word2idx['[PAD]'] = 0
word2idx['[UNK]'] = 1

### Tokenize to ids

In [8]:
dataset = (
    dataset.map(lambda x: {
        'input_ids': [
            word2idx.get(word, word2idx['[UNK]']) for word in x['tokens']
        ]
    })
)

dataset['train']['input_ids'][:3]

Map:   0%|          | 0/14041 [00:00<?, ? examples/s]

Map:   0%|          | 0/3250 [00:00<?, ? examples/s]

Map:   0%|          | 0/3453 [00:00<?, ? examples/s]

[[2, 1, 3, 4, 5, 6, 7, 8, 9], [10, 11], [12, 13]]

In [9]:
dataset = dataset.rename_column('ner_tags', 'labels')
dataset = dataset.remove_columns(['pos_tags', 'chunk_tags'])
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'labels', 'input_ids'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'labels', 'input_ids'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'labels', 'input_ids'],
        num_rows: 3453
    })
})

In [10]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Task 1: Bidirectional LSTM Model

### Define class

In [11]:
import torch.nn as nn

embedding_dim = 100
num_lstm_layers = 1
lstm_hidden_dim = 256
lstm_dropout = 0.33
linear_output_dim = 128

class BiLSTM(nn.Module):
  def __init__(self, vocab_size, num_classes):
    super(BiLSTM, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.bi_lstm = nn.LSTM(embedding_dim, lstm_hidden_dim, num_lstm_layers, bidirectional=True, batch_first=True)
    self.linear = nn.Linear(lstm_hidden_dim * 2, linear_output_dim)
    self.elu = nn.ELU()
    self.classifier = nn.Linear(linear_output_dim, num_classes)

  def forward(self, x):
    embedding = self.embedding(x)
    output, _ = self.bi_lstm(embedding)
    output = self.linear(output)
    output = self.elu(output)
    output = self.classifier(output)

    return output


In [26]:
import os

vocab_size, num_classes = len(word2idx), 9
model = BiLSTM(vocab_size, num_classes)
model.to(device)

using_loaded_weights = False

model_path = './task1.pt'
if os.path.exists(model_path):
  using_loaded_weights = True
  model.load_state_dict(torch.load(model_path))
  print(f'Model loaded from {model_path}')

model

Model loaded from ./task1.pt


BiLSTM(
  (embedding): Embedding(8128, 100)
  (bi_lstm): LSTM(100, 256, batch_first=True, bidirectional=True)
  (linear): Linear(in_features=512, out_features=128, bias=True)
  (elu): ELU(alpha=1.0)
  (classifier): Linear(in_features=128, out_features=9, bias=True)
)

### Build train set

In [27]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  batch_first = True
  input_ids = pad_sequence([torch.tensor(seq) for seq in input_ids], batch_first, padding_value=0)
  labels = pad_sequence([torch.tensor(seq) for seq in labels], batch_first, padding_value=9)

  return {
      'input_ids': input_ids,
      'labels': labels
  }

In [28]:
from torch.utils.data import DataLoader

batch_size = 32
shuffle = True

train_loader = DataLoader(dataset['train'], batch_size, shuffle, collate_fn=collate_fn)
dev_loader = DataLoader(dataset['validation'], batch_size, collate_fn=collate_fn)
test_loader = DataLoader(dataset['test'], batch_size, collate_fn=collate_fn)

In [29]:
# Helper function to print green text
def print_green(text):
  print(f'\033[92m{text}\033[0m')

### Train model

In [30]:
import torch.optim as optim
from conlleval import evaluate

def train_model(model):
  print('Begin training BiLSTM')

  lr = 1e-3
  loss_fn = nn.CrossEntropyLoss(ignore_index=9)
  optimizer = optim.Adam(model.parameters(), lr=lr)

  tag_to_index = {
      'O': 0,
      'B-PER': 1,
      'I-PER': 2,
      'B-ORG': 3,
      'I-ORG': 4,
      'B-LOC': 5,
      'I-LOC': 6,
      'B-MISC': 7,
      'I-MISC': 8
  }
  index_to_tag = {index: tag for tag, index in tag_to_index.items()}

  num_epochs = 20
  for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss_total = 0
    for batch in train_loader:
      inputs = batch['input_ids'].to(device)
      labels = batch['labels'].to(device)

      optimizer.zero_grad()
      outputs = model(inputs)
      loss = loss_fn(outputs.permute(0,2,1), labels.long())
      loss.backward()
      optimizer.step()

      train_loss_total += loss.item()

    train_loss_ave = train_loss_total / len(train_loader)
    print(f'Epoch {epoch+1}/{num_epochs}, train loss: {train_loss_ave:.4f}')

    # Evaluation phase
    model.eval()
    dev_loss_total = 0
    pred_tags = []
    true_tags = []
    with torch.no_grad():
      for batch in dev_loader:
        inputs = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(inputs)
        loss = loss_fn(outputs.permute(0,2,1), labels.long())
        dev_loss_total += loss.item()

        preds = torch.argmax(outputs, dim=2)
        for i in range(labels.size(0)):
          pred_seq = preds[i].cpu().numpy()
          true_seq = labels[i].cpu().numpy()

          indices_valid = true_seq != 9
          valid_pred_tags = [index_to_tag[idx] for idx in pred_seq[indices_valid]]
          valid_true_tags = [index_to_tag[idx] for idx in true_seq[indices_valid]]

          pred_tags.append(valid_pred_tags)
          true_tags.append(valid_true_tags)

    dev_loss_ave = dev_loss_total / len(dev_loader)
    print(f'Epoch {epoch+1}/{num_epochs}, dev loss: {dev_loss_ave:.4f}')

    # Calculate metrics
    pred_tags_flattened = []
    for valid_pred_tag in pred_tags:
      for tag in valid_pred_tag:
        pred_tags_flattened.append(tag)

    true_tags_flattened = []
    for valid_true_tag in true_tags:
      for tag in valid_true_tag:
        true_tags_flattened.append(tag)

    precision, recall, f1 = evaluate(true_tags_flattened, pred_tags_flattened)
    print(f'Epoch {epoch+1}/{num_epochs}, Precision: {precision}, Recall: {recall}, F1: {f1}')

    early_stopping_epoch, min_f1 = 10, 77
    if epoch >= early_stopping_epoch and f1 >= min_f1:
      print_green('Expected F1 reached! 🚀🚀'
            f'Epoch: {epoch+1}, F1: {f1}')
      break

In [31]:
if not using_loaded_weights:
  print('Training model...')
  train_model(model)
  torch.save(model.state_dict(), 'task1.pt')
else:
  print('Using loaded model wieghts')

Using loaded model wieghts


### Evaluate model

In [32]:
def test_model(model, loader, desc):
  tag_to_index = {
      'O': 0,
      'B-PER': 1,
      'I-PER': 2,
      'B-ORG': 3,
      'I-ORG': 4,
      'B-LOC': 5,
      'I-LOC': 6,
      'B-MISC': 7,
      'I-MISC': 8
  }
  index_to_tag = {index: tag for tag, index in tag_to_index.items()}

  # Testing phase
  model.eval()
  pred_tags = []
  true_tags = []
  with torch.no_grad():
    for batch in loader:
      inputs = batch['input_ids'].to(device)
      labels = batch['labels'].to(device)

      outputs = model(inputs)
      preds = torch.argmax(outputs, dim=2)
      for i in range(labels.size(0)):
        pred_seq = preds[i].cpu().numpy()
        true_seq = labels[i].cpu().numpy()

        indices_valid = true_seq != 9
        valid_pred_tags = [index_to_tag[idx] for idx in pred_seq[indices_valid]]
        valid_true_tags = [index_to_tag[idx] for idx in true_seq[indices_valid]]

        pred_tags.append(valid_pred_tags)
        true_tags.append(valid_true_tags)

  # Calculate metrics
  pred_tags_flattened = []
  for valid_pred_tag in pred_tags:
    for tag in valid_pred_tag:
      pred_tags_flattened.append(tag)

  true_tags_flattened = []
  for valid_true_tag in true_tags:
    for tag in valid_true_tag:
      true_tags_flattened.append(tag)

  precision, recall, f1 = evaluate(true_tags_flattened, pred_tags_flattened)
  print_green(f'{desc} Data:\n'
        f'Precision: {precision}, Recall: {recall}, F1: {f1}')

test_model(model, train_loader, 'Train')
test_model(model, dev_loader, 'Validation')
test_model(model, test_loader, 'Test')

processed 203621 tokens with 23499 phrases; found: 23527 phrases; correct: 23360.
accuracy:  99.63%; (non-O)
accuracy:  99.88%; precision:  99.29%; recall:  99.41%; FB1:  99.35
              LOC: precision:  99.18%; recall:  99.51%; FB1:  99.34  7164
             MISC: precision:  99.27%; recall:  99.39%; FB1:  99.33  3442
              ORG: precision:  99.33%; recall:  98.96%; FB1:  99.14  6297
              PER: precision:  99.38%; recall:  99.74%; FB1:  99.56  6624
[92mTrain Data:
Precision: 99.29017724316742, Recall: 99.4084854674667, F1: 99.3492961340535[0m
processed 51362 tokens with 5942 phrases; found: 5881 phrases; correct: 4563.
accuracy:  79.70%; (non-O)
accuracy:  95.52%; precision:  77.59%; recall:  76.79%; FB1:  77.19
              LOC: precision:  85.66%; recall:  83.23%; FB1:  84.43  1785
             MISC: precision:  76.54%; recall:  73.97%; FB1:  75.23  891
              ORG: precision:  69.61%; recall:  69.35%; FB1:  69.48  1336
              PER: precision:  76.0