In [1]:
!pip install -q transformers datasets

[K     |████████████████████████████████| 2.5MB 7.5MB/s 
[K     |████████████████████████████████| 266kB 50.5MB/s 
[K     |████████████████████████████████| 901kB 54.0MB/s 
[K     |████████████████████████████████| 3.3MB 50.2MB/s 
[K     |████████████████████████████████| 122kB 57.9MB/s 
[K     |████████████████████████████████| 245kB 49.8MB/s 
[?25h

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from torch.optim.lr_scheduler import StepLR
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm

In [3]:
transformers_model = "bert-base-multilingual-cased"
class NERNet(nn.Module):
    def __init__(self, num_labels):
        super(NERNet, self).__init__()
        self.bert = AutoModel.from_pretrained(transformers_model)
        self.mlp = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size//2),
            nn.ReLU(),
            nn.Linear(self.bert.config.hidden_size//2, num_labels)
        )
    
    def forward(self, tokens, mask):
        encoding = self.bert(input_ids=tokens, attention_mask=mask)["last_hidden_state"]
        assert encoding.shape[0] == tokens.shape[0] and encoding.shape[1] == tokens.shape[1] and encoding.shape[2] == self.bert.config.hidden_size and len(encoding.shape) == 3
        logits = self.mlp(encoding)
        return logits

In [4]:
from datasets import load_dataset
data = load_dataset('conll2003')
label_map = {"O": 0, "B-PER": 1, "I-PER": 2, "B-ORG": 3, "I-ORG": 4, "B-LOC": 5, "I-LOC": 6, "B-MISC": 7, "I-MISC": 8}
inv_label_map = {v: k for k, v in label_map.items()}

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2603.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1781.0, style=ProgressStyle(description…


Downloading and preparing dataset conll2003/conll2003 (download: 4.63 MiB, generated: 9.78 MiB, post-processed: Unknown size, total: 14.41 MiB) to /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=649539.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=162714.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=145897.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset conll2003 downloaded and prepared to /root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/40e7cb6bcc374f7c349c83acd1e9352a4f09474eb691f64f364ee62eb65d0ca6. Subsequent calls will reuse this data.


In [5]:
import random
seed = 20
random.seed(seed)
class NERDataset(torch.utils.data.Dataset):
    def __init__(self, split, sample_size=-1):
        self.data = data[split]
        self.tokenizer = AutoTokenizer.from_pretrained(transformers_model)
        self.max_length = 171
        # for datum in self.data:
        #     self.max_length = max(self.max_length, len(self.tokenizer.tokenize(" ".join(datum["tokens"])))+2)
        # print('max length', self.max_length)
        self.pad_token = "[PAD]"
        if sample_size > 0:
          self.data = [self.data[idx] for idx in range(len(self.data))]
          self.data = random.sample(self.data, sample_size)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        input_ids = ['[CLS]']
        mask = [1]
        labels = [label_map['O']]
        for i, token in enumerate(self.data[idx]["tokens"]):
            wordpieces = self.tokenizer.tokenize(token)
            input_ids += wordpieces
            mask += [1 for _ in wordpieces]
            if inv_label_map[self.data[idx]['ner_tags'][i]][:2] == 'B-':
                labels.append(self.data[idx]['ner_tags'][i])
                for _ in range(len(wordpieces)-1):
                    labels.append(label_map['I-'+inv_label_map[self.data[idx]['ner_tags'][i]][2:]])
            else:
                labels += [self.data[idx]['ner_tags'][i] for _ in wordpieces]
        input_ids.append('[SEP]')
        mask.append(1)
        labels.append(label_map['O'])
        assert len(input_ids) == len(mask) == len(labels)
        for _ in range(self.max_length-len(input_ids)):
            input_ids.append(self.pad_token)
            mask.append(0)
            labels.append(label_map["O"])
        input_ids = self.tokenizer.convert_tokens_to_ids(input_ids)
        return torch.LongTensor(input_ids), torch.BoolTensor(mask), torch.LongTensor(labels)

In [6]:
train_dataset = NERDataset("train", sample_size=500)
dev_dataset = NERDataset("validation")
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=4, shuffle=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=29.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=625.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=995526.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1961828.0, style=ProgressStyle(descript…




In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
%cd /content/drive/MyDrive/pytorch-constraints
!python setup.py install

/content/drive/MyDrive/pytorch-constraints
running install
running bdist_egg
running egg_info
writing pytorch_constraints.egg-info/PKG-INFO
writing dependency_links to pytorch_constraints.egg-info/dependency_links.txt
writing top-level names to pytorch_constraints.egg-info/top_level.txt
adding license file 'LICENSE'
writing manifest file 'pytorch_constraints.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build/bdist.linux-x86_64/egg
creating build/bdist.linux-x86_64/egg/pytorch_constraints
copying build/lib/pytorch_constraints/solver.py -> build/bdist.linux-x86_64/egg/pytorch_constraints
copying build/lib/pytorch_constraints/utils.py -> build/bdist.linux-x86_64/egg/pytorch_constraints
copying build/lib/pytorch_constraints/ast_visitor.py -> build/bdist.linux-x86_64/egg/pytorch_constraints
copying build/lib/pytorch_constraints/brute_force_solver.py -> build/bdist.linux-x86_64/egg/pytorch_constraints
copying buil

In [9]:
import pytorch_constraints

In [10]:
from pytorch_constraints.constraint import constraint
from pytorch_constraints.sampling_solver import *
from pytorch_constraints.brute_force_solver import *
from pytorch_constraints.tnorm_solver import ProductTNormLogicSolver
def bi_constraint(index, tags):
  # if len(tags.shape) == 1:
  #   tags = tags.unsqueeze(0)
  # assert len(tags.shape) == 2
  tags_shifted = tags[:,index:index+1]
  tags_truncated = tags[:,index-1:index]
  overall = torch.ones_like(tags[:,0]).bool().to(tags.device)
  if index > 0:
    overall = overall & ((tags_shifted != label_map["I-"+"PER"]) | (tags_truncated == label_map["B-"+"PER"])).all(dim=-1)
    overall = overall & ((tags_shifted != label_map["I-"+"ORG"]) | (tags_truncated == label_map["B-"+"ORG"])).all(dim=-1)
    overall = overall & ((tags_shifted != label_map["I-"+"LOC"]) | (tags_truncated == label_map["B-"+"LOC"])).all(dim=-1)
    overall = overall & ((tags_shifted != label_map["I-"+"MISC"]) | (tags_truncated == label_map["B-"+"MISC"])).all(dim=-1)
  else:
    overall = overall & (tags[:,0] != label_map["I-"+"PER"]).all(dim=-1)
    overall = overall & (tags[:,0] != label_map["I-"+"ORG"]).all(dim=-1)
    overall = overall & (tags[:,0] != label_map["I-"+"LOC"]).all(dim=-1)
    overall = overall & (tags[:,0] != label_map["I-"+"MISC"]).all(dim=-1)
  """for label_type in ["PER", "ORG", "LOC", "MISC"]:
    if index > 0:
      overall = overall & ((tags_shifted != label_map["I-"+label_type]) | (tags_truncated == label_map["B-"+label_type])).all(dim=-1)
    else:
      assert index == 0
      overall = overall & (tags[:,0] != label_map["I-"+label_type]).all(dim=-1)"""
  # print(overall.any())
  # print([inv_label_map[num] for num in tags[0,:].tolist()])
  return overall
def bi_constraint_full(tags):
  tags_shifted = tags[:,1:]
  tags_truncated = tags[:,:-1]
  # overall = (torch.index_select(tags, 1, torch.tensor([0]).to(tags.device)) == 0 | torch.index_select(tags, 1, torch.tensor([0]).to(tags.device)) != 0)
  # overall = (tags[:,0] == 0 | tags[:,0] % 2 != 0) # & ((tags_shifted == 0 | tags_shifted % 2 != 0) | (tags_shifted-tags_truncated == 1 & tags_shifted % 2 == 0)).all(dim=-1)
  overall = tags[:,0] != 2 and tags[:,0] != 4 and tags[:,0] != 6 and tags[:,0] != 8 and ((tags_shifted == 0 or tags_shifted == 1 or tags_shifted == 3 or tags_shifted == 5 or tags_shifted == 7) or (tags_shifted-tags_truncated == 1 and (tags_shifted == 2 or tags_shifted == 4 or tags_shifted == 6 or tags_shifted == 8))).all(dim=-1)
  return overall
use_constraint = True
constraint_multiplier = 0.05

In [11]:
import random
random.seed(seed)
model = NERNet(len(label_map))
if torch.cuda.is_available():
  model = model.cuda()
  print("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
NUM_EPOCHS = 10
max_sequence_length = None
for epoch in range(NUM_EPOCHS):
    count = 0
    model.train()
    print('Epoch:', epoch)
    constraint_violations = 0
    for i, batch in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()
        input_ids, mask, labels = batch
        if torch.cuda.is_available():
          input_ids = input_ids.to("cuda")
          mask = mask.to("cuda")
          labels = labels.to("cuda")
        if max_sequence_length is not None and input_ids.shape[1] > max_sequence_length:
          input_ids = input_ids[:,:max_sequence_length].contiguous()
          mask = mask[:,:max_sequence_length].contiguous()
          labels = labels[:,:max_sequence_length].contiguous()
        logits = model(input_ids, mask)
        loss = (nn.CrossEntropyLoss(reduction='none')(logits.view(-1, len(label_map)), labels.view(-1))*mask.view(-1).float()).sum()/mask.float().sum()
        # print(logits.shape)
        if use_constraint:
          for index in range(1):
            # cons = constraint(lambda x: bi_constraint(index, x), WeightedSamplingSolver(num_samples=100))
            cons = constraint(bi_constraint_full, ProductTNormLogicSolver())
            closs = cons(logits)
            # print(closs)
            if not torch.isinf(closs).item():
              loss += constraint_multiplier*closs
          # for i in range(logits.shape[0]):
          #   closs = cons(logits[i,:,:].cpu()).to(logits.device)
          #   loss += constraint_multiplier*closs
        loss.backward()
        optimizer.step()
        predicted_tags = logits.argmax(-1)
        for j in range(input_ids.shape[0]):
          if inv_label_map[predicted_tags[j,0].item()][:2] == "I-":
            constraint_violations += 1
        if i % 10 == 0:
          print("Train violations:", constraint_violations)
    num_correct_entities = 0
    num_pred_entities = 0
    num_gold_entities = 0
    constraint_violations = 0
    for batch in tqdm(dev_dataloader):
        input_ids, mask, labels = batch
        if torch.cuda.is_available():
          input_ids = input_ids.to("cuda")
          mask = mask.to("cuda")
          labels = labels.to("cuda")
        logits = model(input_ids, mask)
        pred = logits.argmax(-1)
        pred_entities = set()
        gold_entities = set()
        for i in range(input_ids.shape[0]):
            pred_current = [-1]
            gold_current = [-1]
            if inv_label_map[pred[i,0].item()][:2] == "I-":
              constraint_violations += 1
            for j in range(input_ids.shape[1]):
                if not mask[i,j].item():
                    continue
                pred_tag = inv_label_map[pred[i,j].item()]
                if pred_tag[:2] == 'B-':
                    if pred_current[0] > -1:
                        pred_entities.add((i, pred_current[0], j-1, pred_current[1]))
                    pred_current = [j, pred_tag[2:]]
                if pred_tag == "O":
                    if pred_current[0] > -1:
                        pred_entities.add((i, pred_current[0], j-1, pred_current[1]))
                    pred_current = [-1]
                if pred_tag == "I-":
                    if pred_current[0] > -1 and pred_current[1] != pred_tag[2:]:
                        pred_entities.add((i, pred_current[0], j-1, pred_current[1]))
                        pred_current = [j, pred_tag[2:]]
                gold_tag = inv_label_map[labels[i,j].item()]
                if gold_tag[:2] == "B-":
                    if gold_current[0] > -1:
                        gold_entities.add((i, gold_current[0], j-1, gold_current[1]))
                    gold_current = [j, gold_tag[2:]]
                if gold_tag == "O":
                    if gold_current[0] > -1:
                        gold_entities.add((i, gold_current[0], j-1, gold_current[1]))
                    gold_current = [-1]
                if gold_tag == "I-":
                    if gold_current[0] > -1 and gold_current[1] != gold_tag[2:]:
                        assert False
        num_correct_entities += len(pred_entities.intersection(gold_entities))
        num_pred_entities += len(pred_entities)
        num_gold_entities += len(gold_entities)
    precision = 0
    recall = 0
    if num_pred_entities > 0:
      precision = num_correct_entities/num_pred_entities
    if num_gold_entities > 0:
      recall = num_correct_entities/num_gold_entities
    if precision+recall == 0:
      f1 = 0
    else:
      f1 = 2*precision*recall/(precision+recall)
    print('Dev F1:', f1)
    print('Dev precision:', precision)
    print('Dev recall:', recall)
    print('Dev Violations: ', constraint_violations)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=714314041.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  0%|          | 0/125 [00:00<?, ?it/s]

cuda
Epoch: 0


  0%|          | 0/125 [00:00<?, ?it/s]

BinOp(left=Name(id='tags_shifted', ctx=Load()), op=Sub(), right=Name(id='tags_truncated', ctx=Load()))





NotImplementedError: ignored

In [None]:
del model
torch.cuda.empty_cache()

In [None]:
model