# Inference of wsd classifiers

In [1]:
import accelerate
import transformers
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizerFast, BertConfig,AdamW
from torch.nn import CrossEntropyLoss
import torch.nn as nn
import torch.nn.functional as F  # Import the functional module to apply softmax

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
import os
import re
from  tqdm import tqdm
import pandas as pd
import numpy as np

In [3]:
# functions
class ManageDataset(Dataset):
    def __init__(self, tokenizer, sentences, labels, target_char_spans):
        self.tokenizer = tokenizer
        self.sentences = sentences
        self.labels = labels
        self.char_spans = target_char_spans

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        # Tokenize the sentence into BERT tokens with offset mappings (fast tokenizer)
        inputs = self.tokenizer(
            self.sentences[idx],
            return_tensors="pt",
            truncation=True,
            padding='max_length',
            max_length=256,
            return_offsets_mapping=True # return tuple indicating the sub-token's start position
        )

        # Generate the manag_mask
        manag_mask = self._get_manag_mask(
            self.sentences[idx],
            inputs["input_ids"][0],
            inputs["offset_mapping"][0],
            self.char_spans[idx]
        )

        # Return tokens' embeddings and the label
        return {
            "input_ids": inputs["input_ids"][0],
            "attention_mask": inputs["attention_mask"][0],
            "manag_mask": manag_mask,
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }

    def _get_manag_mask(self, sentence, input_ids, offset_mapping, target_char_span):
        # Initialize manag_mask
        manag_mask = torch.zeros_like(input_ids, dtype=torch.bool)
        # Iterate over BERT tokens and align with target word's character span
        for i, (start, end) in enumerate(offset_mapping):
            if start == 0 and end == 0:
                continue  # Skip special tokens like [CLS], [SEP], [PAD]
            if (start >= target_char_span[0] and start < target_char_span[1]) or \
               (end > target_char_span[0] and end <= target_char_span[1]) or \
               (start <= target_char_span[0] and end >= target_char_span[1]):
                manag_mask[i] = True
        return manag_mask

class BERTWSDModel(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased', num_labels=2):
        super(BERTWSDModel, self).__init__()
        # Load pre-trained BERT model
        self.bert = BertModel.from_pretrained(bert_model_name)
        # Classification head
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        # Dropout layer for regularization
        self.dropout = nn.Dropout(p=0.3)
        # Save the configuration
        self.config = self.bert.config
        self.num_labels = num_labels

    def forward(self, input_ids, attention_mask, manag_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state  # (batch_size, seq_length, hidden_size)

        # Apply manag_mask to get embeddings of target tokens
        manag_mask_expanded = manag_mask.unsqueeze(-1).expand(last_hidden_state.size())
        target_embeddings = last_hidden_state * manag_mask_expanded.float()

        # Compute average embeddings for each sample in the batch
        token_counts = manag_mask.sum(dim=1).unsqueeze(-1)  # (batch_size, 1)
        # Avoid division by zero
        token_counts[token_counts == 0] = 1
        avg_embeddings = target_embeddings.sum(dim=1) / token_counts  # (batch_size, hidden_size)

        # Apply dropout
        pooled_output = self.dropout(avg_embeddings)

        # Get logits from classifier
        logits = self.classifier(pooled_output)  # (batch_size, num_labels)

        return logits

    def save_pretrained(self, save_directory):
        if not os.path.exists(save_directory):
            os.makedirs(save_directory)
        torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
        self.config.save_pretrained(save_directory)
        print(f"Model saved to {save_directory}")

    @classmethod
    def from_pretrained(cls, load_directory):
        # Load the model configuration
        config = BertModel.from_pretrained(load_directory).config
        # Initialize the model
        model = cls(bert_model_name=load_directory)
        # Load the model state dict
        model_load_path = os.path.join(load_directory, 'pytorch_model.bin')
        if torch.cuda.is_available():
            model.load_state_dict(torch.load(model_load_path))
            model = model.to('cuda')
        else:
            model.load_state_dict(torch.load(model_load_path, map_location=torch.device('cpu')))
        return model


In [4]:
# Load your model and tokenizer
save_directory = "/zfs/projects/faculty/amirgo-management/BERT/WSD_Oct21/"
loaded_tokenizer = BertTokenizerFast.from_pretrained(save_directory)
loaded_model = BERTWSDModel.from_pretrained(save_directory)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model.to(device)

  model.load_state_dict(torch.load(model_load_path, map_location=torch.device('cpu')))


BERTWSDModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [5]:
def infer(sentences, char_spans, model, tokenizer, batch_size=10):
    dataset = ManageDataset(tokenizer, sentences, [0]*len(sentences), char_spans)  # Dummy labels just for data processing
    loader = DataLoader(dataset, batch_size)  # Set batch size according to your needs

    model.eval()
    pred_labels = []
    confidences = []  # To store prediction confidences

    with torch.no_grad():
        for batch in tqdm(loader, desc="Inferencing", unit="batch"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            manag_mask = batch["manag_mask"].to(device)
            labels = batch["labels"].to(device)

            logits = model(input_ids, attention_mask, manag_mask)

            # Convert logits to probabilities using softmax
            probs = F.softmax(logits, dim=1)

            # Get the predicted labels and their corresponding confidences
            preds = torch.argmax(logits, dim=1)
            conf = probs[range(probs.shape[0]), preds].tolist()  # Get the confidence of the predicted class for each sample

            pred_labels.extend(preds.tolist())
            confidences.extend(conf)

    return pred_labels, confidences  # Return both predicted labels and their confidences

In [15]:
label_dict = {0: "Intransitive", 1: "Transitive"}

def get_word_char_spans(sentence, words):
    char_spans = []
    current_pos = 0
    for word in words:
        pattern = re.escape(word)
        match = re.search(pattern, sentence[current_pos:])
        if match is None:
            raise ValueError(f"Word '{word}' not found in sentence.")
        start_idx = current_pos + match.start()
        end_idx = current_pos + match.end()
        char_spans.append((start_idx, end_idx))
        current_pos = end_idx
    return char_spans

def infer_individual_sentence(sentence, target_word, model=loaded_model, tokenizer=loaded_tokenizer):
    char_span = get_word_char_spans(sentence, [target_word])[0]
    pred_labels,confidences = infer([sentence],[char_span], model, tokenizer)
    print(label_dict[pred_labels[0]], confidences[0])
    return char_span

In [14]:
# individual prediction
test = "I don't think I could really manage to do that."
infer_individual_sentence(test, 'manage')

Inferencing: 100%|██████████| 1/1 [00:00<00:00,  9.08batch/s]

Intransitive 0.9905189871788025





array([29, 35])

In [15]:
# individual prediction
test = "We barely manage this year."
infer_individual_sentence(test, 'manage')

Inferencing: 100%|██████████| 1/1 [00:00<00:00,  5.49batch/s]

Intransitive 0.9639004468917847





In [9]:
# individual prediction
test = "I don't think I could really manage being married to your mother anymore."
infer_individual_sentence(test, 'manage')

Inferencing: 100%|██████████| 1/1 [00:00<00:00,  5.12batch/s]

Intransitive 0.9039086699485779





In [10]:
test = "For one year that remained they could manage; if george wasn't willing to try, it wasn't money that was stopping him., it was the idea of marriage itself."
infer_individual_sentence(test, 'manage')

Inferencing: 100%|██████████| 1/1 [00:00<00:00,  5.21batch/s]


Intransitive 0.5622024536132812


In [12]:
test = "A wealthy widow and a unmanageable daughter."
infer_individual_sentence(test, 'unmanageable')

Inferencing: 100%|██████████| 1/1 [00:00<00:00,  4.99batch/s]

Transitive 0.8727024793624878





In [13]:
# not the perfect kind of training dataset, but I don't think it's a big problem as the major trend should be captured
test = "The doctor know how to manage patients with mental health issues."
infer_individual_sentence(test, 'manage')

Inferencing: 100%|██████████| 1/1 [00:00<00:00,  6.78batch/s]

Transitive 0.996315062046051





# file load examples