In [10]:
!pip install faiss-cpu




In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import faiss
import time
from transformers import BertTokenizer, BertForSequenceClassification

LIPSCHITZ_CONSTANT = 1.0

In [12]:
class GRACEAdaptor(nn.Module):
    def __init__(self, hidden_dim, epsilon_init=0.5, lipschitz_constant=LIPSCHITZ_CONSTANT):
        super(GRACEAdaptor, self).__init__()
        self.hidden_dim = hidden_dim
        self.epsilon_init = epsilon_init
        self.lipschitz_constant = lipschitz_constant
        self.codebook = []
        self.values = nn.ParameterList([])
        self.epsilon = []
        self.index = faiss.IndexFlatL2(hidden_dim)

    def forward(self, h_l_minus_1):
        if not self.codebook:
            return h_l_minus_1


        hidden_state_np = h_l_minus_1.cpu().detach().numpy()


        D, I = self.index.search(hidden_state_np, 1)


        min_distance = D[0][0]
        min_idx = I[0][0]


        if min_distance < self.epsilon[min_idx]:

            value_update = self.values[min_idx] * self.lipschitz_constant
            h_l = h_l_minus_1 + value_update
            return h_l
        else:
            return h_l_minus_1

    def update_codebook(self, h_l_minus_1, new_value):

        hidden_state_np = h_l_minus_1.cpu().detach().numpy().astype('float32')

        if not self.codebook:
            self.codebook.append(hidden_state_np)
            self.values.append(nn.Parameter(new_value))
            self.epsilon.append(self.epsilon_init)
            self.index.add(hidden_state_np)
        else:
            D, I = self.index.search(hidden_state_np, 1)
            min_distance = D[0][0]
            min_idx = I[0][0]

            if min_distance < self.epsilon[min_idx]:
                self.epsilon[min_idx] += self.epsilon_init
                self.values[min_idx] = nn.Parameter(new_value)
            else:
                self.codebook.append(hidden_state_np)
                self.values.append(nn.Parameter(new_value))
                self.epsilon.append(self.epsilon_init)
                self.index.add(hidden_state_np)


In [13]:
class GRACEModelWrapper(nn.Module):
    def __init__(self, model, grace_layers=None, lipschitz_constant=LIPSCHITZ_CONSTANT):
        super(GRACEModelWrapper, self).__init__()
        self.model = model
        self.grace_layers = grace_layers or []
        self.grace_adaptors = nn.ModuleList([GRACEAdaptor(self.model.config.hidden_size, lipschitz_constant=lipschitz_constant) for _ in self.grace_layers])

    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states = list(outputs.hidden_states)


        for i, grace_adaptor in enumerate(self.grace_adaptors):
            hidden_states[self.grace_layers[i]] = grace_adaptor(hidden_states[self.grace_layers[i]])

        return outputs

    def update_grace(self, input_ids, true_labels, criterion, optimizer, scale_factor=100.0):

        outputs = self.forward(input_ids)
        logits = outputs.logits
        loss = criterion(logits, true_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def measure_performance(model, input_ids, criterion, optimizer, true_labels):
    model.train()
    start_time = time.time()

    outputs = model(input_ids)
    loss = criterion(outputs.logits, true_labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    elapsed_time = time.time() - start_time
    predicted_class = torch.argmax(outputs.logits, dim=-1)

    return loss.item(), predicted_class.item(), elapsed_time


In [15]:
tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-tiny')
model = BertForSequenceClassification.from_pretrained('prajjwal1/bert-tiny')


grace_model_baseline = GRACEModelWrapper(model, grace_layers=[0, 1])


grace_model_improved = GRACEModelWrapper(model, grace_layers=[0, 1])


criterion = nn.CrossEntropyLoss()
optimizer_baseline = optim.Adam(grace_model_baseline.parameters(), lr=1e-5)
optimizer_improved = optim.Adam(grace_model_improved.parameters(), lr=1e-5)

input_text = "This is an example sentence."
inputs = tokenizer(input_text, return_tensors='pt')
input_ids = inputs['input_ids']
true_labels = torch.tensor([1])

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
baseline_loss, baseline_prediction, baseline_time = measure_performance(grace_model_baseline, input_ids, criterion, optimizer_baseline, true_labels)
print(f"Baseline - Loss: {baseline_loss}, Prediction: {baseline_prediction}, Time: {baseline_time:.4f} seconds")

Baseline - Loss: 0.7011982202529907, Prediction: 0, Time: 0.1268 seconds


In [17]:
improved_loss, improved_prediction, improved_time = measure_performance(grace_model_improved, input_ids, criterion, optimizer_improved, true_labels)
print(f"Improved - Loss: {improved_loss}, Prediction: {improved_prediction}, Time: {improved_time:.4f} seconds")

Improved - Loss: 0.5796358585357666, Prediction: 1, Time: 0.0922 seconds
