# Inference of PB classifiers

In [1]:
import accelerate
import transformers
import re
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizerFast

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ManageDataset(Dataset):
    def __init__(self, tokenizer, sentences, primary_labels, subcategory_labels, target_char_spans):
        self.tokenizer = tokenizer
        self.sentences = sentences
        self.primary_labels = primary_labels  # List of primary category labels
        self.subcategory_labels = subcategory_labels  # List of subcategory labels
        self.char_spans = target_char_spans  # List of character spans for target words

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        # Tokenize the sentence into BERT tokens with offset mappings
        inputs = self.tokenizer(
            self.sentences[idx],
            return_tensors="pt",
            truncation=True,
            padding='max_length',
            max_length=256,
            return_offsets_mapping=True  # Return offset mappings for sub-token positions
        )

        # Generate the manag_mask
        manag_mask = self._get_manag_mask(
            self.sentences[idx],
            inputs["input_ids"][0],
            inputs["offset_mapping"][0],
            self.char_spans[idx]
        )

        # Return tokens' embeddings and the labels
        return {
            "input_ids": inputs["input_ids"][0],
            "attention_mask": inputs["attention_mask"][0],
            "manag_mask": manag_mask,
            "primary_labels": torch.tensor(self.primary_labels[idx], dtype=torch.long),
            "subcategory_labels": torch.tensor(self.subcategory_labels[idx], dtype=torch.long)
        }

    def _get_manag_mask(self, sentence, input_ids, offset_mapping, target_char_span):
        # Initialize manag_mask
        manag_mask = torch.zeros_like(input_ids, dtype=torch.bool)
        # Iterate over BERT tokens and align with target word's character span
        for i, (start, end) in enumerate(offset_mapping):
            if start == 0 and end == 0:
                continue  # Skip special tokens like [CLS], [SEP], [PAD]
            if (start >= target_char_span[0] and start < target_char_span[1]) or \
               (end > target_char_span[0] and end <= target_char_span[1]) or \
               (start <= target_char_span[0] and end >= target_char_span[1]):
                manag_mask[i] = True
        return manag_mask


class BERTClassificationModel(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased', num_primary_labels=3, num_subcategory_labels=10):
        super(BERTClassificationModel, self).__init__()
        # Load pre-trained BERT model
        self.bert = BertModel.from_pretrained(bert_model_name)
        # Classification heads for primary and subcategories
        self.primary_classifier = nn.Linear(self.bert.config.hidden_size, num_primary_labels)
        self.subcategory_classifier = nn.Linear(self.bert.config.hidden_size, num_subcategory_labels)
        # Dropout layer for regularization
        self.dropout = nn.Dropout(p=0.3)
        # Save the configuration
        self.config = self.bert.config
        self.num_primary_labels = num_primary_labels
        self.num_subcategory_labels = num_subcategory_labels

    def forward(self, input_ids, attention_mask, manag_mask):
        # Pass inputs through BERT model
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state  # (batch_size, seq_length, hidden_size)

        # Apply manag_mask to get embeddings of target tokens
        manag_mask_expanded = manag_mask.unsqueeze(-1).expand(last_hidden_state.size())
        target_embeddings = last_hidden_state * manag_mask_expanded.float()

        # Compute average embeddings for each sample in the batch
        token_counts = manag_mask.sum(dim=1).unsqueeze(-1)  # (batch_size, 1)
        # Avoid division by zero
        token_counts[token_counts == 0] = 1
        avg_embeddings = target_embeddings.sum(dim=1) / token_counts  # (batch_size, hidden_size)

        # Apply dropout
        pooled_output = self.dropout(avg_embeddings)

        # Get logits from classifiers
        primary_logits = self.primary_classifier(pooled_output)  # (batch_size, num_primary_labels)
        subcategory_logits = self.subcategory_classifier(pooled_output)  # (batch_size, num_subcategory_labels)

        return primary_logits, subcategory_logits

    def save_pretrained(self, save_directory):
        if not os.path.exists(save_directory):
            os.makedirs(save_directory)
        # Save model state dict
        torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
        # Save configuration with label information
        self.config.num_primary_labels = self.num_primary_labels
        self.config.num_subcategory_labels = self.num_subcategory_labels
        self.config.save_pretrained(save_directory)
        print(f"Model saved to {save_directory}")

    @classmethod
    def from_pretrained(cls, load_directory):
        # Load the model configuration
        config = BertModel.from_pretrained(load_directory).config
        # Get the number of labels from the saved config
        num_primary_labels = config.num_primary_labels
        num_subcategory_labels = config.num_subcategory_labels
        # Initialize the model
        model = cls(
            bert_model_name=load_directory,
            num_primary_labels=num_primary_labels,
            num_subcategory_labels=num_subcategory_labels
        )
        # Load the model state dict
        model_load_path = os.path.join(load_directory, 'pytorch_model.bin')
        if torch.cuda.is_available():
            model.load_state_dict(torch.load(model_load_path))
            model = model.to('cuda')
        else:
            model.load_state_dict(torch.load(model_load_path, map_location=torch.device('cpu')))
        return model

In [3]:
save_directory = "/zfs/projects/faculty/amirgo-management/BERT/PB_MultiClass_Full_Oct30/"
model = BERTClassificationModel.from_pretrained(save_directory)
tokenizer = BertTokenizerFast.from_pretrained(save_directory)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

  model.load_state_dict(torch.load(model_load_path, map_location=torch.device('cpu')))


BERTClassificationModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, e

In [4]:
def infer(sentences, char_spans, model, tokenizer, batch_size=10):
    dataset = ManageDataset(tokenizer, sentences, [0]*len(sentences), [0]*len(sentences), char_spans)
    loader = DataLoader(dataset, batch_size)  # Set batch size according to your needs

    model.eval()
    pred_primary_labels = []
    pred_subcategory_labels = []
    primary_confidences = []
    subcategory_confidences = []

    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            manag_mask = batch['manag_mask'].to(device)

            # Forward pass
            primary_logits, subcategory_logits = model(input_ids, attention_mask, manag_mask)

            # Convert logits to probabilities using softmax
            primary_probs =torch.softmax(primary_logits, dim=1)
            subcategory_probs = torch.softmax(subcategory_logits, dim=1)

            # Get the predicted labels (indices of max probabilities)
            primary_preds = torch.argmax(primary_probs, dim=1)
            subcategory_preds = torch.argmax(subcategory_probs, dim=1)

            primary_confidence = torch.max(primary_probs, dim=1).values
            subcategory_confidence = torch.max(subcategory_probs, dim=1).values

            # Append predictions and confidences to the lists
            pred_primary_labels.extend(primary_preds.cpu().numpy())
            pred_subcategory_labels.extend(subcategory_preds.cpu().numpy())
            primary_confidences.extend(primary_confidence.cpu().numpy())
            subcategory_confidences.extend(subcategory_confidence.cpu().numpy())

    return pred_primary_labels, pred_subcategory_labels, primary_confidences, subcategory_confidences

In [5]:
# label cleaning
secondary_map_num = {"Others": 0,
                     "Financials": 1,
                     "Emotion and subjective experiences": 2,
                     "Human body": 3,
                     "Household": 4,
                     "Family": 5,
                     "Time": 6,
                     "Romantic relationships": 7,
                     "Friendship": 8,
                     "Business Operations": 9}

primary_map_num = {'Personal': 0, 'Business and Professional': 1, 'Others': 2}
reverse_primary_map = {v: k for k, v in primary_map_num.items()}
reverse_secondary_map_num = {v: k for k, v in secondary_map_num.items()}


def get_word_char_spans(sentence, words):
    char_spans = []
    current_pos = 0
    for word in words:
        pattern = re.escape(word)
        match = re.search(pattern, sentence[current_pos:])
        if match is None:
            raise ValueError(f"Word '{word}' not found in sentence.")
        start_idx = current_pos + match.start()
        end_idx = current_pos + match.end()
        char_spans.append((start_idx, end_idx))
        current_pos = end_idx
    return char_spans

def infer_individual_sentence(sentence, target_word):
    char_span = get_word_char_spans(sentence, [target_word])[0]
    pred_primary_labels, pred_subcategory_labels, primary_confidences, subcategory_confidences = infer([sentence],[char_span],model, tokenizer)
    print(reverse_primary_map[pred_primary_labels[0]], primary_confidences[0])
    print(reverse_secondary_map_num[pred_subcategory_labels[0]], subcategory_confidences[0])
    return

In [8]:
# individual prediction
test = "the manager is good at managing his children, but he doesn't know how to manage his employee."
infer_individual_sentence(test, 'manager')

Business and Professional 0.9943504
Business Operations 0.9914539


In [9]:
test = "Joe is also a good manager."
infer_individual_sentence(test, 'manager')

Business and Professional 0.9757144
Business Operations 0.9939633


In [10]:
test = "I manage, with all my effort, to smile at their customers."
infer_individual_sentence(test, 'manage')

Others 0.99932766
Others 0.9985071


In [12]:
test = "She still need to work on management skills."
infer_individual_sentence(test, 'management')

Business and Professional 0.9210463
Business Operations 0.85461885


In [13]:
test = "The manager, foreseeing a thining theatre, gave us free admission."
infer_individual_sentence(test, 'manager')

Business and Professional 0.999788
Business Operations 0.9992181


In [14]:
test = "Managing sexuallity is a difficult task for many people."
infer_individual_sentence(test, 'Managing')

Personal 0.9624546
Romantic relationships 0.94145226


In [15]:
test = "She manages a weak smile, but her eyes are full of tears."
infer_individual_sentence(test, 'manages')

Personal 0.99833554
Emotion and subjective experiences 0.99834096


In [27]:
# not the perfect kind of training dataset, but I don't think it's a big problem as the major trend should be captured
test = "The doctor know how to manage patients with mental health issues."
infer_individual_sentence(test, 'manage')

Personal 0.9860299
Human body 0.98886174
