# Protonet-LayoutLM with CORD

References:

https://github.com/cnielly/prototypical-networks-omniglot/blob/master/prototypical_networks_pytorch_omniglot.ipynb

https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py

In [1]:
import pandas as pd
import numpy as np
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm_notebook
from tqdm import trange
from transformers import set_seed

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
set_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [2]:
import warnings
warnings.filterwarnings("ignore")

## Load data

In [3]:
from datasets import load_dataset_builder, load_dataset, concatenate_datasets
from pprint import pprint
import datasets

datasets.utils.logging.set_verbosity_error()

DATASET_URI = 'katanaml/cord'

dsbuild = load_dataset_builder(DATASET_URI)
pprint(dsbuild.info.features.keys())

print("No. of labels: ", len(dsbuild.info.features['ner_tags'].feature.names))

data = load_dataset(DATASET_URI)
print(data)

dict_keys(['id', 'words', 'bboxes', 'ner_tags', 'image_path'])
No. of labels:  23


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 800
    })
    test: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 100
    })
})


## Data sampling function

Pick random sample of given support set size, query set size, no of classes and the dataset size. 

The function will be used to sample data every episode.

In [4]:
LABELS = dsbuild.info.features['ner_tags'].feature.names

# select 15 for training, 8 for testing
TRAINLABELS = np.random.choice(LABELS, 13, replace=False)
TESTLABELS = np.random.choice(LABELS, 10, replace=False)

HDIM = 768
# TODO: remove 'O' label?

Sampler below will be used for every episode to generate a sample of data to use in the episode.

This will by default return all entities in a sample eventhough they are selected for a particular entity. We need additional preprocessing before the modeling stage to only train with the tokens that correspond to our target entity.

In [5]:
def sample_data(dataset, n_way, n_support, n_query, labels):
    # randomly select B labels to sample in the
    B = np.random.choice(labels, n_way, replace=False)
    B = [LABELS.index(c) for c in B]
    cls_sample = {}  # class -> sample
    
    # select support and query samples for each of these labels
    for cls in B:
        # select data containing the tag
        sample = dataset.filter(lambda x: cls in x,
                                input_columns=['ner_tags'],
                               keep_in_memory=True)
        # sample a small subset of the result
        sample = sample.shuffle(keep_in_memory=True) \
                       .select(np.arange(n_support+n_query), 
                                keep_in_memory=True)
        cls_sample[cls] = sample
    
    return {
        'samples': cls_sample,
        'n_way': n_way,
        'n_support': n_support,
        'n_query': n_query
    }

In [6]:
id2label = data['test'].features['ner_tags'].feature.int2str

def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else id2label(labels[word_id])
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            label = id2label(label)
            new_labels.append(label)

    return new_labels

def align_all_labels(encoded, orig_data): 
    all_labels = []
    for i in range(encoded['input_ids'].shape[0]): # for each instance
        labels = align_labels_with_tokens(orig_data['ner_tags'][i], encoded.word_ids(i))
        all_labels.append(labels)
    return all_labels

def getwords(wordids, words):
    return [words[i]  if i is not None else -100 for i in wordids]

def decode_tokens(input_ids):
    decoded = []
    for inp in input_ids:
        decoded.append([processor.decode(i) for i in inp])
    return decoded

def load_image(data):
    # update dataset with image
    data['image'] = Image.open(data['image_path']).convert('RGB')
    return data

In [7]:
def encode(sample, processor, model):
    enc = processor(sample['image'], sample['words'],
                   boxes = sample['bboxes'],
                    return_tensors='pt',
                   truncation=True, padding=True)
    with torch.no_grad():
        enc = enc.to(DEVICE)
        out = model(**enc)
    out = out.last_hidden_state.detach()
    out = out[:,:enc['input_ids'].shape[1],:]  # select only the word tokens
    return enc, out

def encode_and_filter(sample, processor, model, load_image_on_sample=False):
    # output dimension: n_way, n_support+n_query, hiddendim
    output_tensors = []
    
    for cls in sample['samples'].keys():
        instance = sample['samples'][cls]
        if load_image_on_sample:
            instance = instance.map(load_image,
                                    keep_in_memory=True,
                                   num_proc=4
                                   )
        enc, out = encode(instance, processor, model)
        aligned_labels = align_all_labels(enc, instance)
        
        # output dimension: n_support+n_query, hiddendim
        selected = []
        for i, tag in enumerate(aligned_labels):
            targetclass = LABELS[cls]
            # ids of tokens of the target class
            idxs = [idx                
                    for idx, tok in enumerate(tag)
                    if tok == targetclass ]
            # randomly select 1 token of the class
            selidx = np.random.choice(idxs, 1)
            selected.append(out[i,selidx,:])
        selected = torch.cat(selected)
        output_tensors.append(selected)
        
    output_tensors = torch.stack(output_tensors)
    
    assert output_tensors.shape == (sample['n_way'], sample['n_support']+sample['n_query'], HDIM)
    return output_tensors

## Implement ProtoNet

In [8]:
class ProtoNet(nn.Module):
    def __init__(self, encoder, processor, nnet):
        super(ProtoNet, self).__init__()
        self.encoder = encoder
        self.processor = processor
        self.nnet = nnet
        
        # exclude encoder params
        for p in self.encoder.parameters():
            p.requires_grad = False

    def set_forward_loss(self, sample, load_image_on_sample=False):
        n_way = sample['n_way']
        n_support = sample['n_support']
        n_query = sample['n_query']
        
        sample_enc = encode_and_filter(sample, self.processor, self.encoder, load_image_on_sample)

        x_support = sample_enc[:, :n_support]
        x_query = sample_enc[:, n_support:]

        #target indices are 0 ... n_way-1
        target_inds = torch.arange(0, n_way) \
                      .view(n_way, 1, 1)  \
                      .expand(n_way, n_query, 1).long()
        target_inds = target_inds.to(DEVICE)

        # encode images of the support and the query set
        x = torch.cat([x_support.contiguous().view(n_way * n_support, *x_support.size()[2:]),
                       x_query.contiguous().view(n_way * n_query, *x_query.size()[2:])], 0)

        z = self.nnet.forward(x)
        z_dim = z.size(-1)
        z_proto = z[:n_way*n_support].view(n_way, n_support, z_dim).mean(1)
        z_query = z[n_way*n_support:]

        #compute distances
        dists = self.euclidean_dist(z_query, z_proto)

        #compute probabilities
        log_p_y = F.log_softmax(-dists, dim=1).view(n_way, n_query, -1)

        loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
        _, y_hat = log_p_y.max(2)
        acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()

        return loss_val, {
            'loss': loss_val.item(),
            'acc': acc_val.item(),
            'y_hat': y_hat
            }
    
    def euclidean_dist(self, x, y):
        # x: N x D
        # y: M x D
        n = x.size(0)
        m = y.size(0)
        d = x.size(1)
        assert d == y.size(1)

        x = x.unsqueeze(1).expand(n, m, d)
        y = y.unsqueeze(0).expand(n, m, d)

        return torch.pow(x - y, 2).sum(2)

## Training and Validation

In [9]:
def train(model, optimizer, scheduler,
          dataset, n_way, n_support, n_query, 
          max_epoch, epoch_size,
          load_image=False):
    epoch = 0
    
    while epoch < max_epoch:
        running_loss = 0
        running_acc = 0
        
        for episode in trange(epoch_size, desc="Epoch {:d}/{:d} train".format(epoch+1, max_epoch)):
            sample = sample_data(dataset,
                                 n_way, n_support, n_query,
                                 TRAINLABELS)
            optimizer.zero_grad()
            loss, output = model.set_forward_loss(sample, load_image)
            running_loss += output['loss']
            running_acc += output['acc']
            loss.backward()
            optimizer.step()
            
        epoch_loss = running_loss / epoch_size
        epoch_acc = running_acc / epoch_size
        print('Epoch {:d} -- Loss: {:.4f} Acc: {:.4f}'.format(epoch+1,epoch_loss, epoch_acc))
        epoch += 1
        scheduler.step()

In [10]:
def test(model, 
         dataset, n_way, n_support, n_query, 
         test_episode,
            load_image_on_sample=False):
    running_loss = 0.0
    running_acc = 0.0
    for episode in trange(test_episode):
        sample = sample_data(dataset,
                             n_way, n_support, n_query, 
                             TESTLABELS)
        loss, output = model.set_forward_loss(sample, load_image_on_sample)
        running_loss += output['loss']
        running_acc += output['acc']
    avg_loss = running_loss / test_episode
    avg_acc = running_acc / test_episode
    print('Test results -- Loss: {:.4f} Acc: {:.4f}'.format(avg_loss, avg_acc))
    return avg_acc

## Train the model

In [11]:
N_WAY = 5
N_SUPPORT = 5
N_QUERY = 5

MAX_EPOCHS = 10
EPOCH_SIZE = 100

In [13]:
from transformers import AutoProcessor, AutoModelForTokenClassification, AutoModel
processor = AutoProcessor.from_pretrained('microsoft/layoutlmv3-base', apply_ocr=False)
model = AutoModel.from_pretrained('microsoft/layoutlmv3-base')
model.eval()
model.to(DEVICE);

In [14]:
nnet = nn.Sequential(
    nn.Linear(HDIM, 256),
    nn.Linear(256, 64),
).to(DEVICE)

In [15]:
pnet = ProtoNet(model, processor, nnet)

In [16]:
optimizer = optim.Adam(nnet.parameters(), lr=0.005)

In [17]:
scheduler = optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.5, last_epoch=-1)

In [18]:
dataset = concatenate_datasets([data['train'], data['test']])

In [19]:
dataset = dataset.map(load_image, keep_in_memory=True, num_proc=3)  
load_image_on_sample = False

In [20]:
train(pnet, optimizer, scheduler,
      dataset,
      N_WAY, N_SUPPORT, N_QUERY,
      MAX_EPOCHS, EPOCH_SIZE,
      load_image_on_sample)

Epoch 1/2 train:   0%|          | 0/2 [00:00<?, ?it/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Epoch 1/2 train:  50%|█████     | 1/2 [00:19<00:19, 19.45s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Epoch 1/2 train: 100%|██████████| 2/2 [00:36<00:00, 18.13s/it]


Epoch 1 -- Loss: 4.4302 Acc: 0.5600


Epoch 2/2 train:   0%|          | 0/2 [00:00<?, ?it/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Epoch 2/2 train:  50%|█████     | 1/2 [00:19<00:19, 19.06s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Epoch 2/2 train: 100%|██████████| 2/2 [00:40<00:00, 20.06s/it]

Epoch 2 -- Loss: 8.5733 Acc: 0.4800





In [21]:
acc = test(pnet,
      dataset,
      N_WAY, N_SUPPORT, N_QUERY,
      EPOCH_SIZE,
      load_image_on_sample)

  0%|          | 0/2 [00:00<?, ?it/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

 50%|█████     | 1/2 [00:19<00:19, 19.26s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

100%|██████████| 2/2 [00:38<00:00, 19.17s/it]

Test results -- Loss: 9.6178 Acc: 0.6400





In [22]:
import os
from datetime import datetime

In [23]:
OUTPATH = "outputs"
filename = f"pnet_{N_SUPPORT}_{N_QUERY}_{N_WAY}_{round(acc, 4)*100}_{datetime.now().strftime('%m-%d-%H:%M')}.pt"
torch.save(pnet, os.path.join(OUTPATH, filename))