In [52]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput


In [53]:
class GRACEAdaptor(nn.Module):
    def __init__(self, hidden_dim, epsilon_init=0.5):
        super(GRACEAdaptor, self).__init__()
        self.codebook = []
        self.hidden_dim = hidden_dim
        self.epsilon_init = epsilon_init
        self.epsilon = []
        self.values = nn.ParameterList([])

    def forward(self, h_l_minus_1):
        if not self.codebook:
            return h_l_minus_1
        distances = [torch.dist(h_l_minus_1, k) for k in self.codebook]
        min_distance = min(distances)
        min_idx = distances.index(min_distance)
        if min_distance < self.epsilon[min_idx]:
            h_l = h_l_minus_1 + self.values[min_idx]
            return h_l
        else:
            return h_l_minus_1

    def update_codebook(self, h_l_minus_1, new_value):
        if not self.codebook:
            self.codebook.append(h_l_minus_1.clone().detach())
            self.values.append(nn.Parameter(new_value))
            self.epsilon.append(self.epsilon_init)
        else:
            distances = [torch.dist(h_l_minus_1, k) for k in self.codebook]
            min_distance = min(distances)
            min_idx = distances.index(min_distance)
            if min_distance < self.epsilon[min_idx]:
                self.epsilon[min_idx] += self.epsilon_init
                self.values[min_idx] = nn.Parameter(new_value)
            else:
                self.codebook.append(h_l_minus_1.clone().detach())
                self.values.append(nn.Parameter(new_value))
                self.epsilon.append(self.epsilon_init)


In [54]:
class GRACEModelWrapper(nn.Module):
    def __init__(self, model, grace_layers=None):
        super(GRACEModelWrapper, self).__init__()
        self.model = model
        self.grace_layers = grace_layers or []

    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        return outputs

    def update_grace(self, input_ids, true_labels, criterion, optimizer, scale_factor=100.0):
        outputs = self.forward(input_ids)
        logits = outputs.logits
        loss = criterion(logits, true_labels)
        optimizer.zero_grad()
        loss.backward()
        logits = outputs.logits.clone().detach()
        true_class_logits = logits[:, true_labels.item()]
        modified_logits = logits + (scale_factor * (1 - true_class_logits)).unsqueeze(1)
        outputs.logits = modified_logits
        print(f"Modified logits: {modified_logits}")
        optimizer.step()


In [55]:
tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-tiny')
model = BertForSequenceClassification.from_pretrained('prajjwal1/bert-tiny')

grace_model = GRACEModelWrapper(model)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [56]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(grace_model.parameters(), lr=1e-5)


In [57]:
input_text = "This is an example sentence."
inputs = tokenizer(input_text, return_tensors='pt')
input_ids = inputs['input_ids']


In [58]:
outputs = grace_model(input_ids)
predicted_class = torch.argmax(outputs.logits, dim=-1)
print(f"Predicted class before correction: {predicted_class.item()}")


Predicted class before correction: 1


In [61]:
true_labels = torch.tensor([0])

grace_model.update_grace(input_ids, true_labels, criterion, optimizer, scale_factor=500.0)


Modified logits: tensor([[698.4131, 698.6824]])


In [63]:
outputs_after_correction = grace_model(input_ids)
predicted_class_after_correction = torch.argmax(outputs_after_correction.logits, dim=-1)
print(f"Predicted class after correction: {predicted_class_after_correction.item()}")


Predicted class after correction: 1
