# Introduction

Transformer based WSD

# Preliminaries
In the following cell we are going to


*   Download the Semcor+OMSTI WSD dataset in XML format
*   Parse the Semcor+OMSTI WSD dataset 
*   Build a suitable JSON file with instances-glosses pairs


In [1]:
import re
import os
import json
from lxml import etree
from tqdm import tqdm
import ijson.backends.yajl2 as ijson
from typing import List, Tuple, Dict, Optional, Any, Callable, Union
from collections.abc import Iterable
from nltk.corpus import wordnet
import gc 
import random
import textwrap
from dataclasses import dataclass
from tqdm import tqdm

import torch
from torch.nn.utils import rnn
from torch.nn import functional as F
import pytorch_lightning as pl
import re

## Run 'wsddataset.py' script to build the WSD dataset

In [7]:
from nltk.corpus import wordnet as wn

MAPPING = {"NOUN": wn.NOUN, "VERB": wn.VERB, "ADJ": wn.ADJ, "ADV": wn.ADV}

SYNSETS = list(wn.all_synsets())

SYNSETS_TO_IDS = { synset.name(): index for index, synset in enumerate(SYNSETS) }

IDS_TO_SYNSETS = { index: synset_name for synset_name, index in SYNSETS_TO_IDS.items() }

print('Number Synsets: ',len(SYNSETS))


Number Synsets:  117659


In [None]:
import torch

indices = []
values = []
shape = (len(SYNSETS),len(SYNSETS))

hypo = lambda s: s.hyponyms()
hyper = lambda s: s.hypernyms()

depth = 1

for index, synset_name in IDS_TO_SYNSETS.items():
     linked_synsets = list(wn.synset(synset_name).closure(hypo, depth=depth)) + \
          list(wn.synset(synset_name).closure(hyper, depth=depth))
     factor = len(linked_synsets)
     for linked in linked_synsets:
         indices.append([index,SYNSETS_TO_IDS[linked.name()]])
         values.append(1/factor)

ADJACENCY_MATRIX = torch.sparse_coo_tensor(list(zip(*indices)),values,shape,dtype=torch.float32).cpu()


In [None]:
def cache(method):
    """
    This decorator caches the return value of a method so that results are not recomputed
    """
    method_name = method.__name__
    def wrapper(self, *args, **kwargs):
        self._cache = getattr(self, '_cache', {})
        if method_name not in self._cache:
            self._cache[method_name] = method(self, *args, **kwargs)
        return self._cache[method_name]
    return wrapper

In [32]:
@dataclass
class Sample:
    """ Sample base class for WSD """
    instance_id: str
    wordnet_key: str
    word_index: int
    lemmas: List[str]
    pos: List[str]
    synsets: List[str]
    glosses: List[str]
    examples: List[List[str]]
    label: str

    @cache
    def process(self):
        """ Convert samples to ids """
        synset = wn.synset(self.label)
        synset_list = wn.synsets(synset.name().split('.')[0],pos=MAPPING[self.pos[self.word_index]])
        synsets = [SYNSETS_TO_IDS[synset.name()] for synset in synset_list]

        return self.lemmas, self.word_index, synsets, SYNSETS_TO_IDS[self.label]


In [33]:
def load(path: str, size: Optional[int] = float('inf')) -> List[Sample]:
    samples = []
    generator = (row for row in ijson.items(open(path),'item'))
    for sample_index, sample_dict in tqdm(enumerate(generator), desc="parsing samples: "):
        if sample_index >= size:
            break         
        samples.append(Sample(**sample_dict))
    print('\nDone.')
    return samples

samples = load('WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti+synsets.json', size = 100)

parsing samples: : 100it [00:00, 4519.87it/s]
Done.



In [35]:
samples[0].process()

(['how',
  'long',
  'have',
  'it',
  'be',
  'since',
  'you',
  'review',
  'the',
  'objective',
  'of',
  'you',
  'benefit',
  'and',
  'service',
  'program',
  '?'],
 1,
 [7858, 7832, 13209, 10980, 7894, 7892, 12584, 10391, 68],
 7858)

In [36]:
n_sample = 33
s = samples[n_sample]

print('Info:\n')
print(s.wordnet_key)

print('\nWord index:\n')
print(s.word_index)

print('\nLabel:\n')
print(s.label)

print('\nGlosses:\n')
print(s.glosses)

print('\nExamples:\n')
print(s.examples)

print('\nProcessed:\n')
print(s.process())

Info:

['output%1:04:00::']

Word index:

21

Label:

output.n.02

Glosses:

['final product; the things produced', 'production of a certain amount', 'signal that comes out of an electronic system', 'the quantity of something (as a commodity) that is created (usually within a given period of time)', 'what is produced in a given time period']

Examples:

[[], [], [], ['production was up in the second quarter'], []]

Processed:

(['do', 'you', 'measure', 'its', 'relation', 'to', 'reduced', 'absenteeism', ',', 'turnover', ',', 'accident', ',', 'and', 'grievance', ',', 'and', 'to', 'improved', 'quality', 'and', 'output', '?'], 21, [39649, 26445, 61085, 95413, 43044], 26445)


Finally we do the usual chore of implementing subclasses of `torch.utils.data.Dataset` and `pl.LightningDataModule`, which will handle the production of batches of samples to be used by the models that we are going to implement. This should all look familiar to you. If it doesn't, please (re)check the previous notebooks!  

In [37]:
PAD_IDX = 0
IGN_IDX = 1

class WSDDataset(torch.utils.data.Dataset):

    def __init__(self, samples):
        self.samples = samples
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, item):
        return self.samples[item]

def collate_fn(samples, device=None):
    tokens, word_indices, synsets, labels = zip(*[s.process() for s in samples])
    batch = dict()
    batch['tokens'] = list(tokens)
    batch['key_indices'] = list(word_indices)
    batch['synsets'] = list(synsets)
    batch['labels'] = list(labels)
    if device is not None:
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
    return batch


In [38]:
collate_fn(samples[:2])

{'tokens': [['how',
   'long',
   'have',
   'it',
   'be',
   'since',
   'you',
   'review',
   'the',
   'objective',
   'of',
   'you',
   'benefit',
   'and',
   'service',
   'program',
   '?'],
  ['how',
   'long',
   'have',
   'it',
   'be',
   'since',
   'you',
   'review',
   'the',
   'objective',
   'of',
   'you',
   'benefit',
   'and',
   'service',
   'program',
   '?']],
 'key_indices': [1, 4],
 'synsets': [[7858, 7832, 13209, 10980, 7894, 7892, 12584, 10391, 68],
  [116827,
   116870,
   117057,
   116822,
   117540,
   117103,
   116891,
   116067,
   117258,
   115191,
   116857,
   117511,
   117281]],
 'labels': [7858, 116827]}

In [39]:
class WSDDataModule(pl.LightningDataModule):

    def __init__(self, samples, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.samples = samples

    def prepare_data(self):
        random.seed(1337)
        samples = self.samples[:]
        random.shuffle(samples)
        i = int(len(samples) * 0.8)
        j = int(len(samples) * 0.9)
        self.train_samples = samples[:i]
        self.valid_samples = samples[i+1:j]
        self.test_samples = samples[j+1:]
        super().prepare_data()

    def setup(self, stage=None):
        self.train_dataset = WSDDataset(self.train_samples)
        self.valid_dataset = WSDDataset(self.valid_samples)
        self.test_dataset = WSDDataset(self.test_samples)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=collate_fn,
            pin_memory=True,
            num_workers=8,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=collate_fn,
            pin_memory=True,
            num_workers=8,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=collate_fn,
            pin_memory=True,
            num_workers=8,
        )


In [40]:
import abc 

class BaseKeyEmbedder(torch.nn.Module, metaclass=abc.ABCMeta):

    embedding_dim: int
    n_hidden_states: int
    retrain_model: bool
    is_split_into_words: bool

    def __init__(self, retrain_model: bool = False):
        super().__init__()
        self.retrain_model = retrain_model

    def forward(
            self,
            key_indices: List[int] or List[Tuple[int,int]] = None, 
            src_tokens_str: Union[None, List[List[str]]] = None,
            batch_major: bool = True,
            **kwargs
    ):
        pass

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def is_cuda(self):
        return next(self.parameters()).is_cuda

    @property
    def embed_dim(self):
        return self.embedding_dim

    @property
    def embedded_dim(self):
        return self.embedding_dim

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        return super().state_dict(destination, prefix, keep_vars)


In [41]:
from pprint import pprint
class BERTKeyEmbedder(BaseKeyEmbedder):

    DEFAULT_MODEL = 'bert-base-uncased'

    @staticmethod
    def _do_imports():
        import transformers as tnf
        import torchtext.data.functional as ttext
        return tnf, ttext

    def __init__(
        self,
        name: Union[str, None] = None,
        key_token: str = '*',
        weights: str = "",
        is_split_into_words: bool = True,
        key_piece_merging_mode: str = 'MEAN',
        last_hidden_number: int = 4,
        last_hidden_merging_mode: str = 'SUM',
        retrain_model: bool = False,
    ):
        
        assert not retrain_model
        super(BaseKeyEmbedder, self).__init__()

        assert not retrain_model
        assert last_hidden_number > 0 and last_hidden_number < 12

        self.retrain_model = retrain_model
        if not name:
            name = self.DEFAULT_MODEL

        self.name = name
        self.is_split_into_words = is_split_into_words
        self.last_hidden_number = last_hidden_number
        self.key_piece_merging_mode = key_piece_merging_mode
        self.last_hidden_merging_mode = last_hidden_merging_mode

        tnf, ttext = self._do_imports()
        self.bert_tokenizer = tnf.BertTokenizerFast.from_pretrained(name)
        self.bert_model = tnf.BertModel.from_pretrained(name,output_hidden_states= True)
        self.hidden_size = self.bert_model.config.hidden_size

        self.key_token = key_token
        self.key_id = self.bert_tokenizer.convert_tokens_to_ids(self.key_token)
        self.cleaner = ttext.custom_replace([(r"" + "\\" + self.key_token, '')])

        if weights:

            state = torch.load(weights)['state']
            state = {".".join(k.split('.')[1:]): v for k, v in state.items() if k.startswith('bert.')}
            self.bert_model.load_state_dict(state)
            self.name = self.name + "-" + os.path.split(weights)[-1]

        for par in self.parameters():
            par.requires_grad = False

    def forward(
        self, 
        key_indices: List[int] or List[Tuple[int,int]], 
        src_tokens_str: List[List[str]], 
        batch_major=True,
        **kwargs
        )-> torch.Tensor:
        new_src_tokens_str = []

        if self.is_split_into_words:
            assert isinstance(key_indices[0], int)
            
            for src_token_str, key_index in zip(src_tokens_str,key_indices):
                new_src_token_str = list(self.cleaner(src_token_str))
                # surround the query token
                new_src_token_str[key_index] = self.key_token +\
                        new_src_token_str[key_index] + self.key_token
                new_src_tokens_str.append(new_src_token_str)
        else:
            assert isinstance(key_indices[0], tuple) and isinstance(key_indices[0][0], int)

            for src_token_str, key_index in zip(src_tokens_str,key_indices):
                new_src_token_str = list(self.cleaner([src_token_str]))[0]
                # surround the query token
                new_src_token_str[key_index[0]:key_index[1]] = self.key_token +\
                        new_src_token_str + self.key_token
                new_src_tokens_str.append(new_src_token_str)

        input_batch = self.bert_tokenizer(
                new_src_tokens_str,
                return_tensors='pt',
                is_split_into_words=True,
                padding=True
            )

        input_batch['query_token_indices'] = self.get_splitted_input_indices(input_batch['input_ids'])

        with torch.set_grad_enabled(self.retrain_model and not self.training):
            outputs = self.bert_model.eval().forward(
                input_ids=input_batch['input_ids'].cuda(),
                attention_mask=input_batch['attention_mask'].cuda(),
                token_type_ids=input_batch['token_type_ids'].cuda(),
            )
        
        hidden_states = outputs[2]

        stacked_hidden_states = torch.stack([hidden_states[i] for i in \
            torch.arange(start=-1,end=-(self.last_hidden_number+1),step=-1)], dim=-1)
        
        if self.last_hidden_merging_mode == 'SUM':
            merged_hidden_states = stacked_hidden_states.sum(dim=-1)
        elif self.last_hidden_merging_mode == 'MEAN':
            merged_hidden_states = stacked_hidden_states.mean(dim=-1)
        else:
            merged_hidden_states = stacked_hidden_states.mean(dim=-1)
        
        # contextualized words to disambiguate
        key_context_embedding = self.get_word(
                        merged_hidden_states,
                        input_batch['query_token_indices'][...,0].cuda(),
                        input_batch['query_token_indices'][...,1].cuda()
                        )
        return key_context_embedding
    
    def get_splitted_input_indices(
        self,
        batch_input_ids: torch.Tensor,
        ) -> torch.Tensor:

        splitted_input_indices = []
        for input_ids in batch_input_ids:
            indices = (input_ids == self.key_id).nonzero(as_tuple=True)[0]
            indices[0]+=1
            splitted_input_indices.append(indices)
        tensor_indices = torch.stack(splitted_input_indices,dim=0)
        assert tensor_indices.shape[-1] == 2

        return tensor_indices

    def get_word(
        self,
        tensor: torch.Tensor,
        start: torch.Tensor,
        end: torch.Tensor
        ) -> torch.Tensor:

        B,N,C = tensor.shape
        tensor = tensor.view(B*N,C)
        indices = torch.cat([b*N+torch.arange(start[b],end[b],1) for b in torch.arange(B)],dim=0)
        mask = torch.zeros_like(tensor)
        mask[indices,:]=1
        tensor = tensor * mask
        num_elem = (end-start)[...,None].repeat(1, C)
        tensor = tensor.view(B,N,C)

        if self.key_piece_merging_mode == 'SUM':
            tensor = tensor.sum(dim=1)
        elif self.key_piece_merging_mode == 'MEAN':
            tensor = tensor.mean(dim=1)
        else:
            tensor = tensor.mean(dim=1)

        tensor = tensor/num_elem

        return tensor   

In [47]:
from pprint import pprint
from transformers import BertModel
from torchmetrics import MetricCollection, Accuracy, F1, Recall, Precision
from sparselinear import SparseLinear

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class WSDModel(pl.LightningModule):

    def __init__(self, hparams, *args, **kwargs):
        super(WSDModel, self).__init__(*args, **kwargs)
        self.save_hyperparameters(hparams.__dict__)

        self.bert_key_embedder = BERTKeyEmbedder(
            is_split_into_words=hparams.is_split_into_words)
        # assert Bert if freeze
        self.bert_key_embedder.training = False
        self.vocab_size = hparams.vocab_size

        self.adjacency = SparseLinear(hparams.vocab_size,hparams.vocab_size,\
            connectivity = hparams.coalesce_adjacency_matrix.indices(),bias=False)
        self.adjacency.weights = torch.nn.Parameter(
            hparams.coalesce_adjacency_matrix.values())
        self.adjacency.requires_grad = hparams.train_adjacency

        self.dropout = torch.nn.Dropout(hparams.dropout)
        self.bnorm = torch.nn.BatchNorm1d(self.bert_key_embedder.hidden_size)
        self.swish = torch.nn.SiLU()
        self.classifier1 = torch.nn.Linear(self.bert_key_embedder.hidden_size, hparams.hidden_size)
        self.classifier2 = torch.nn.Linear(hparams.hidden_size, hparams.vocab_size, bias=False)

        torch.nn.init.xavier_uniform_(self.classifier1.weight)
        torch.nn.init.xavier_uniform_(self.classifier2.weight)

        self.softmax = torch.nn.Softmax(dim=-1)
        self.loss = torch.nn.CrossEntropyLoss(reduction='mean')
        self.lr = hparams.lr

        # Measures
        metrics = MetricCollection([
            Accuracy(num_classes=hparams.vocab_size),
            F1(num_classes=hparams.vocab_size,average='micro'),
            Recall(num_classes=hparams.vocab_size,average='micro'),
            Precision(num_classes=hparams.vocab_size,average='micro')])
        self.train_metrics = metrics.clone(prefix='train_')
        self.val_metrics = metrics.clone(prefix='val_')
        self.test_metrics =  metrics.clone(prefix='test_')

    def forward(
        self,
        batch: Union[Dict[str,List], torch.Tensor],
        ) -> torch.Tensor:
        """
        Computes the forward pass, returning unnormalized log probabilities (the logits)
        """
        if isinstance(batch, dict):
            tokens = batch['tokens']
            key_indices = batch['key_indices']
            emb = self.bert_key_embedder.forward(key_indices=key_indices,src_tokens_str=tokens)
        elif isinstance(batch, torch.Tensor):
            emb = batch
        else:
            emb = batch
        out = self.bnorm(emb)
        out = self.classifier1(out)
        out = self.swish(out)
        out = self.classifier2(out)
        logits = self.adjacency(out) + out

        return logits

    def embed(
        self,
        batch: Dict[str,List]
        )-> torch.Tensor:

        tokens = batch['tokens']
        key_indices = batch['key_indices']
        emb = self.bert_key_embedder.forward(key_indices=key_indices,src_tokens_str=tokens)

        return emb

    @torch.no_grad()
    def predict(
        self,
        batch: Dict[str,List],
        ) -> Dict[str, torch.Tensor]:
        """
        Computes a batch of predictions (as list of int) from logits
        """
        mask = self.mask_synsets(batch['synsets']).cuda()
        logits = self(batch)
        pred = torch.argmax(self.softmax(logits).sparse_mask(mask).to_dense(),dim=-1)
        return {'logits': logits, 'pred': pred }

    def mask_synsets(
        self,
        batch_indices: List[List[int]],
        ) -> torch.Tensor:
        indices = []
        for row, elem in enumerate(batch_indices):
            for col in elem:
                indices.append([row,col])
        values = torch.ones((len(indices,)))
        sparse_mask = torch.sparse_coo_tensor(
            list(zip(*indices)),
            values,
            (len(batch_indices),self.vocab_size)
            )
        return sparse_mask.coalesce()

    def basic_step(
        self,
        batch: Dict[str,List],
        ) -> Dict[str,torch.Tensor]:
        """
        Evaluates performance on ground truth in terms of both loss (returned)
        and metrics are update
        """
        mask = self.mask_synsets(batch['synsets']).cuda()
        gold = torch.tensor(batch['labels']).cuda()
        logits = self(batch)
        pred = torch.argmax(self.softmax(logits).sparse_mask(mask).to_dense(),dim=-1)
        loss = self.loss(logits,gold)
        return {'loss': loss, 'pred': pred, 'gold': gold}

    def training_step(
        self,
        batch: Dict[str,List],
        batch_idx: int
        ) -> Dict[str,torch.Tensor]:
        """
        [Required by lightning]
        Computes loss to be used for .backward()
        """
        result = self.basic_step(batch)
        return result

    def write_metrics_end(
        self,
        batch_parts: Dict[str,torch.Tensor],
        metrics: MetricCollection,
        ):
        """
        Write metrics at end on multi GPUs
        """
        output = metrics(batch_parts['pred'],batch_parts['gold'])
        self.log_dict(output, on_step=True, on_epoch=False, prog_bar=True)

    def training_step_end(
        self,
        batch_parts: Dict[str,torch.Tensor],
        ):
        """
        [Required by lightning]
        Computes loss to be used for .backward() on multi GPUs
        """
        self.write_metrics_end(batch_parts,self.train_metrics)
        return batch_parts['loss']

    @torch.no_grad()
    def validation_step(
        self,
        batch: Dict[str,List],
        batch_idx: int
        ) -> Dict[str,torch.Tensor]:
        """
        [Required by lightning]
        Evaluates on batch of validation samples
        """
        result = self.basic_step(batch)
        return result

    def validation_step_end(
        self,
        batch_parts: List[Dict[str,torch.Tensor]],
        ):
        """
        [Required by lightning]
        Computes loss to be used for .backward()
        """
        self.write_metrics_end(batch_parts,self.val_metrics)

    @torch.no_grad()
    def test_step(
        self,
        batch: Dict[str,List],
        batch_idx: int
        ) -> Dict[str,torch.Tensor]:
        """
        [Required by lightning]
        Evaluates on batch of test samples
        """
        result = self.basic_step(batch)
        return result

    def test_step_end(
        self,
        batch_parts: List[Dict[str,torch.Tensor]],
        ):
        """
        [Required by lightning]
        Computes loss to be used for .backward()
        """
        self.write_metrics_end(batch_parts,self.test_metrics)

    def configure_optimizers(self):
        """
        [Required by lightning]
        Initializes the optimizer
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lf)
        return optimizer

    def basic_on_epoch_end(
        self,
        metrics: MetricCollection,
        ):
        """
        Log reduction of metrics
        """
        output = metrics.compute()
        self.log_dict(output, on_step=False, on_epoch=True, prog_bar=True)

    def on_train_epoch_end(self):
        """
        [lightning]
        Logging and EM reset (validation)
        """
        self.basic_on_epoch_end(self.train_metrics)

    def on_validation_epoch_end(self):
        """
        [lightning]
        Logging and EM reset (validation)
        """
        self.basic_on_epoch_end(self.val_metrics)

    def on_test_epoch_end(self):
        """
        [lightning]
        Logging and EM reset (test)
        """
        self.basic_on_epoch_end(self.test_metrics)


In [52]:
samples = load('WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti+synsets.json', size = 100_000)

parsing samples: : 100000it [00:26, 3845.75it/s]
Done.



In [44]:
from pytorch_lightning.callbacks import ModelCheckpoint

data = WSDDataModule(samples, batch_size=200)

pl.seed_everything(41296)

class HParams():
    hidden_size = 512
    dropout = 0.0
    vocab_size = len(SYNSETS)
    coalesce_adjacency_matrix = ADJACENCY_MATRIX.coalesce()
    is_split_into_words = True
    train_adjacency = False
    lr = 5e-4

hparams = HParams()

model = WSDModel(hparams)

checkpoint_callback = ModelCheckpoint(
    monitor='val_Accuracy',
    filename='wsd-key-{epoch:02d}-{val_Accuracy:.2f}',
    save_top_k=1,
    mode='max',
    )

trainer = pl.Trainer(
    max_epochs=6,
    gpus=(1 if torch.cuda.is_available() else 0),
    callbacks=[checkpoint_callback]
)
trainer.fit(model, data)
trainer.test(model, data.test_dataloader())

Global seed set to 41296
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
