In [1]:
###reproducing the author's model###
# Building models
import torch
import torch.nn as nn
import torch.nn.functional as F
!pip install transformers
from transformers import BertModel, BertTokenizer 
import tqdm

# Building datasets
# import src.preprocess
# import os
# import configs
from torch.utils.data import DataLoader

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"



In [2]:
class BertFC(nn.Module):
    def __init__(self):
        """Downloads a BERT base uncased model and adds a linear layer on top of it"""
        super(BertFC, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.classifier = nn.Linear(768, 1)

    def forward(self, ids, token_ids, mask):
        """
        inputs: ids, token_ids, and mask each of dim = [bsz x seqlen]
        returns: probability of a positive label, dim = [bsz]
        """
        sequence_output = self.bert(input_ids = ids, token_type_ids = token_ids, attention_mask = mask)[0]
        pooled_output = self.bert(input_ids = ids, token_type_ids = token_ids, attention_mask = mask)[1]
        # sequence_output has the following shape: (batch_size, sequence_length, 768)
        # sequence_output = nn.ReLU()(sequence_output)
        # sequence_output = torch.tanh(sequence_output)
        # linear_output = self.classifier(sequence_output[:, 0, :])
        output = self.classifier(pooled_output)
        return output.squeeze(0)

# Build the model
bert_fc_model = BertFC()


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
class hypernetwork(nn.Module):                                                                       
    def __init__(self,                                                                               
                 n,                                                                                  
                 m,                                                                                  
                 input_size=768,  # verified                          
                 hidden_size=128,  # verified                         
                 linear_out=1024,  # verified                         
                 num_layers=1):  # verified                           
        """                                                                                          
        Args:                                                                                        
            n:                                                                                       
            m:                                                                                       
            input_size:                                                                              
            hidden_size:                                                                             
            linear_out:                                                                              
            num_layers:                                                                              
        """                                                                                          
        super(hypernetwork, self).__init__()                          
                                                                                                     
        self.bilstm = nn.LSTM(input_size=input_size,                  
                              hidden_size=hidden_size,                
                              num_layers=num_layers,                  
                              bidirectional=True,
                              batch_first = True)     
        self.num_layers = num_layers
        self.hidden_size = hidden_size           
        self.linear = nn.Linear(2*hidden_size, linear_out)            
        self.alpha_linear = nn.Linear(linear_out, m)                 
        self.beta_linear = nn.Linear(linear_out, m)                  
        self.gamma_linear = nn.Linear(linear_out, n)                 
        self.delta_linear = nn.Linear(linear_out, n)                 
                                                                                                     
        # TODO(rajiv): Maybe an intermediate layer?                   
        self.eta_linear = nn.Linear(linear_out, 1)                   
                                                                                                     
    def forward(self, X, gradW):                                                                     
        """                                                                                          
                                                                                                     
        Args:                                                                                        
            X: Vector(input [SEP] SUPPORTS [SEP] REFUTES)             
            gradW: gradients of `finetuned_bert`.                     
        Returns:                                                                                     
                                                                                                     
        """                                                                                          
        # TODO(rajiv): Make sure X is input + y +  a and not just input.                                                                                                     
        # setting hidden states to allow lstm to run
        hidden = torch.zeros((2*self.num_layers, 1, self.hidden_size))
        cell   = torch.zeros((2*self.num_layers, 1, self.hidden_size))
        # TODO: With L of 512, is the BiLSTM model going to forget parameters 
        # through time? Can we mask the bilstm input?
        _, (hidden, _) = self.bilstm(X, (hidden, cell)) #hidden dim [2 x 1 x 128]
        output = torch.tanh(self.linear(hidden.flatten())) #[1024]

        alpha = self.alpha_linear(output) #[m]                          
        beta = self.beta_linear(output)   #[m]                         
        gamma = self.gamma_linear(output) #[n]                         
        delta = self.delta_linear(output) #[n]                            
        eta = self.eta_linear(output)     #[1]
        print("gradW dim: ", gradW.size())
        print("alpha dim: ", alpha.size())                               
        print("beta dim: ", beta.size())
        print("gamma dim: ", gamma.size())
        print("delta dim: ", delta.size())
        print("eta dim: ", eta.size())
                                                                                                     
        # TODO(rajiv): While computing *_hat, we are assuming that the first
        # dimension would be the batch dimension. So we transpose the last two
        # dimensions.                                                                                
        alpha_hat = torch.outer(gamma, F.softmax(alpha))
        beta_hat = torch.outer(delta, F.softmax(beta))
        print("alpha_hat size: ", alpha_hat.size())
        print("beta_hat size: ", beta_hat.size())
                                                                                                     
        delW = torch.sigmoid(eta) * ((alpha_hat * gradW) + beta_hat)
        return delW
                                                                        

In [4]:
def get_attributes(module, names):
  """"
  inputs: Base Module and a list of module names
  returns: the corresponding module
  """
  if len(names) == 1:
    return getattr(module, names[0])
  else:
    return get_attributes(getattr(module, names[0]), names[1:])

class KnowledgeEditor(nn.Module):                                     
  def __init__(self, BERT_model):                                   
    """                                                           
    given a bert model, set up a hypernetwork                     
    for every non-bias, embedding, or layer-norm                  
    """                                                           
    super(KnowledgeEditor, self).__init__()                       
    self.bert = BERT_model                                        
    self.hyper_network_dict = nn.ModuleDict()                                  
    for name, layer in BERT_model.named_parameters():            
      if ("LayerNorm" not in name and 
          "bias" not in name and 
          "embed" not in name):
      # Layers of size NxM 
        h = hypernetwork(layer.size()[0], layer.size()[1])
        self.hyper_network_dict[str(name).replace(".", "-")] = h
                                                                      
  def forward(self, X_ids, X_type_ids, X_mask, A, 
              X_Y_A_ids, X_Y_A_type_ids, X_Y_A_mask):
    """takes tokenized input X, alternative answer A, and 
    hypernetwork specialized input X-<SEP>-Y-<SEP>-A                        
    returns updated parameters in a dictionary of {Layer_name: delta_params}
    """                                 
    # run the BERT model forward                                  
    self.bert.zero_grad()
    outputs = self.bert(X_ids, X_type_ids, X_mask)       
    bert_loss = torch.nn.BCEWithLogitsLoss()
    bert_losses = bert_loss(outputs, A)
    bert_losses.backward() #get the gradients of BERT
    #get pooled encoding from BERT
    encoded_input = self.bert.bert(X_Y_A_ids, X_Y_A_type_ids, X_Y_A_mask)[0] 

    # Run the hypernetwork
    h_outputs = {}
    for layer, network in self.hyper_network_dict.items():
      #get the grad for the layer in question
      gradW = get_attributes(self.bert, layer.split("-") + ['grad'])
      #puts the new parameters in a dictionary of layer names 
      h_outputs[layer] = network(encoded_input, gradW)

    return h_outputs

ke = KnowledgeEditor(bert_fc_model)

In [5]:
from transformers import BertTokenizer
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
token = bert_tokenizer("this is a test", max_length = 512, 
                       padding='max_length',
                       truncation=True)
x_id = torch.tensor(token.input_ids).unsqueeze(0)
x_type_id = torch.tensor(token.token_type_ids).unsqueeze(0)
x_mask = torch.tensor(token.attention_mask).unsqueeze(0)
a = torch.FloatTensor([1])

token = bert_tokenizer(["this is a test [SEP] 1 [SEP] 0"], max_length = 512, 
                       padding='max_length',
                       truncation=True)

xya_id = torch.tensor(token.input_ids)
xya_type_id = torch.tensor(token.token_type_ids)
xya_mask = torch.tensor(token.attention_mask)


In [6]:
ke(x_id, x_type_id, x_mask, a, xya_id, xya_type_id, xya_mask)

gradW dim:  torch.Size([768, 768])
alpha dim:  torch.Size([768])
beta dim:  torch.Size([768])
gamma dim:  torch.Size([768])
delta dim:  torch.Size([768])
eta dim:  torch.Size([1])
alpha_hat size:  torch.Size([768, 768])
beta_hat size:  torch.Size([768, 768])
gradW dim:  torch.Size([768, 768])
alpha dim:  torch.Size([768])
beta dim:  torch.Size([768])
gamma dim:  torch.Size([768])
delta dim:  torch.Size([768])
eta dim:  torch.Size([1])
alpha_hat size:  torch.Size([768, 768])
beta_hat size:  torch.Size([768, 768])




gradW dim:  torch.Size([768, 768])
alpha dim:  torch.Size([768])
beta dim:  torch.Size([768])
gamma dim:  torch.Size([768])
delta dim:  torch.Size([768])
eta dim:  torch.Size([1])
alpha_hat size:  torch.Size([768, 768])
beta_hat size:  torch.Size([768, 768])
gradW dim:  torch.Size([768, 768])
alpha dim:  torch.Size([768])
beta dim:  torch.Size([768])
gamma dim:  torch.Size([768])
delta dim:  torch.Size([768])
eta dim:  torch.Size([1])
alpha_hat size:  torch.Size([768, 768])
beta_hat size:  torch.Size([768, 768])
gradW dim:  torch.Size([3072, 768])
alpha dim:  torch.Size([768])
beta dim:  torch.Size([768])
gamma dim:  torch.Size([3072])
delta dim:  torch.Size([3072])
eta dim:  torch.Size([1])
alpha_hat size:  torch.Size([3072, 768])
beta_hat size:  torch.Size([3072, 768])
gradW dim:  torch.Size([768, 3072])
alpha dim:  torch.Size([3072])
beta dim:  torch.Size([3072])
gamma dim:  torch.Size([768])
delta dim:  torch.Size([768])
eta dim:  torch.Size([1])
alpha_hat size:  torch.Size([768, 3

{'bert-encoder-layer-0-attention-output-dense-weight': tensor([[ 2.7598e-05,  2.8345e-05,  2.2954e-05,  ...,  2.5395e-05,
           2.3192e-05,  2.4585e-05],
         [ 1.0541e-04,  1.0833e-04,  8.7673e-05,  ...,  9.7038e-05,
           8.8658e-05,  9.3853e-05],
         [ 3.8623e-05,  3.9696e-05,  3.2131e-05,  ...,  3.5562e-05,
           3.2487e-05,  3.4395e-05],
         ...,
         [-5.2005e-05, -5.3439e-05, -4.3249e-05,  ..., -4.7870e-05,
          -4.3737e-05, -4.6297e-05],
         [ 3.6744e-05,  3.7766e-05,  3.0567e-05,  ...,  3.3832e-05,
           3.0908e-05,  3.2721e-05],
         [-6.3287e-06, -6.5037e-06, -5.2650e-06,  ..., -5.8265e-06,
          -5.3221e-06, -5.6371e-06]], grad_fn=<MulBackward0>),
 'bert-encoder-layer-0-attention-self-key-weight': tensor([[ 1.3809e-05,  1.1305e-05,  1.1670e-05,  ...,  9.1390e-06,
           1.2878e-05,  1.2148e-05],
         [ 5.2235e-06,  4.2732e-06,  4.4141e-06,  ...,  3.4518e-06,
           4.8685e-06,  4.5886e-06],
         [ 1.362

In [7]:
x = torch.randn(2, 1, 4)
print(x[0])
print(x[1])
print(x.flatten().size())

tensor([[-0.9363, -0.4891, -0.2635, -0.6071]])
tensor([[ 1.2396, -1.1095, -1.3711,  0.7535]])
torch.Size([8])


In [None]:
# output = bert_fc_model(id.unsqueeze(0), type_id.unsqueeze(0), mask.unsqueeze(0)).unsqueeze(0)
# y = torch.FloatTensor(1)
# bert_loss = torch.nn.BCELoss()
# loss = bert_loss(output, y)
# loss.backward()
# bert_fc_model.zero_grad()
# for name, param in bert_fc_model.named_parameters():
  # print(name)
  # print(param.grad)

In [None]:
# # authors_model['state_dict']
# authors_model = torch.load('data/FC_model.ckpt')

# def modify_authors_state_dict(state_dict):
#     """The state dicts prefixes don't match (ours is bert.xyz, 
#     their's is model.model.xyz. This function alters the 
#     naming in their state_dict to match"""
#     from collections import OrderedDict
#     new_state_dict = OrderedDict()
#     for x in state_dict.items():
#         # print(x)
#         name = x[0]
#         vals = x[1]
#         if name[:11] == "model.model":
#             new_name = "bert" + name[11:]
#         else:
#             new_name = name[6:]
#         new_state_dict[new_name] = vals
#     return new_state_dict

# modified_state_dict = modify_authors_state_dict(authors_model['state_dict'])
# # bert_fc_model.load_state_dict(authors_model['state_dict'])
# bert_fc_model.load_state_dict(modified_state_dict)

<All keys matched successfully>

In [None]:
# # Build the datasets
# batch_size = 32
# train_dataset = src.preprocess.FeverDataset(os.path.join(configs.DATA_DIR, 'fever', 'train.jsonl'))
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# dev_dataset = src.preprocess.FeverDataset(os.path.join(configs.DATA_DIR, 'fever', 'dev.jsonl'))
# dev_loader = DataLoader(dev_dataset, batch_size=32, shuffle=True)

In [None]:
# bert_fc_model = bert_fc_model.to(device)

# print(len(train_loader))
# counter = 0
# running_acc = 0

# with torch.no_grad():
#     for x,y in dev_loader:
#         counter += 1
#         print(counter)
#         ids, type_ids, mask = x["input_ids"], x["token_type_ids"], x["attention_mask"]
#         ids, type_ids, mask = ids.squeeze(1), type_ids.squeeze(1), mask.squeeze(1)
#         ids, type_ids, mask, y = ids.to(device), type_ids.to(device), mask.to(device), y.to(device)
#         prob_pos = bert_fc_model(ids, type_ids, mask)
#         preds = torch.round(prob_pos)
#         accuracy = torch.sum(y == preds).item()
#         running_acc += accuracy
        
#         print("Latest accuracy: ", accuracy/batch_size)
#         print("Running accuracy: ", running_acc/(batch_size * counter))

3281
1
Latest accuracy:  0.84375
Running accuracy:  0.84375
2
Latest accuracy:  0.90625
Running accuracy:  0.875
3
Latest accuracy:  0.75
Running accuracy:  0.8333333333333334
4
Latest accuracy:  0.625
Running accuracy:  0.78125
5
Latest accuracy:  0.75
Running accuracy:  0.775
6
Latest accuracy:  0.8125
Running accuracy:  0.78125
7
Latest accuracy:  0.71875
Running accuracy:  0.7723214285714286
8
Latest accuracy:  0.8125
Running accuracy:  0.77734375
9
Latest accuracy:  0.78125
Running accuracy:  0.7777777777777778
10


KeyboardInterrupt: 