In [17]:
from __future__ import absolute_import, division, print_function

import argparse
import csv
import logging
import os
import random
import sys
import pickle
import time
import math

import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam 

# from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
# from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel, BertConfig, WEIGHTS_NAME, CONFIG_NAME
# from pytorch_pretrained_bert.tokenization import BertTokenizer
# from pytorch_pretrained_bert.optimization import BertAdam

import torch.autograd as autograd

from bert_util import *

In [2]:
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

def gather_flat_grad(grads):
    views = []
    for p in grads:
        if p.data.is_sparse:
            view = p.data.to_dense().view(-1)
        else:
            view = p.data.view(-1)
        views.append(view)
    return torch.cat(views, 0)

def hv(loss, model_params, v):
    grad = autograd.grad(loss, model_params, create_graph=True, retain_graph=True)
    Hv = autograd.grad(grad, model_params, grad_outputs=v)
    return Hv
#https://github.com/kohpangwei/influence-release/issues/4  -- lissa paper
def get_inverse_hvp_lissa(v, model, param_influence, train_lissa_loader, args):
    ihvp = None
    for i in range(args.lissa_repeat):
        cur_estimate = v
        lissa_data_iterator = iter(train_lissa_loader)
        for j in range(args.lissa_depth):
            try:
                tmp_elem = next(lissa_data_iterator)
                input_ids, input_mask, segment_ids, label_ids, guids = tmp_elem
            except StopIteration:
                lissa_data_iterator = iter(train_lissa_loader)
                tmp_elem = next(lissa_data_iterator)
                input_ids, input_mask, segment_ids, label_ids, guids = tmp_elem
            input_ids = input_ids.to(args.device)
            input_mask = input_mask.to(args.device)
            segment_ids = segment_ids.to(args.device)
            label_ids = label_ids.to(args.device)
            
            model.zero_grad()
            train_loss = model(input_ids = input_ids, token_type_ids = segment_ids, attention_mask =input_mask, labels = label_ids)
            hvp = hv(train_loss, param_influence, cur_estimate)
            cur_estimate = [_a + (1 - args.damping) * _b - _c / args.scale
                            for _a, _b, _c in zip(v, cur_estimate, hvp)]
            
            if (j % args.logging_steps == 0) or (j == args.lissa_depth - 1):
                logger.info(" Recursion at depth %d: norm is %f", j,
                            np.linalg.norm(gather_flat_grad(cur_estimate).cpu().numpy()))
        if ihvp == None:
            ihvp = cur_estimate
        else:
            ihvp = [_a + _b for _a, _b in zip(ihvp, cur_estimate)]
    
    return_ihvp = gather_flat_grad(ihvp)
    return_ihvp /= args.lissa_repeat
    return return_ihvp

# def main():
# parser = argparse.ArgumentParser()

# ## Required parameters
# parser.add_argument("--data_dir",
#                     default=None,
#                     type=str,
#                     required=True,
#                     help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
# parser.add_argument("--xlnet_model", default=None, type=str, required=True,
#                     help="Bert pre-trained model selected in the list: bert-base-uncased, "
#                     "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
#                     "bert-base-multilingual-cased, bert-base-chinese.")
# parser.add_argument("--output_dir",
#                     default=None,
#                     type=str,
#                     required=True,
#                     help="The output directory where the model predictions and checkpoints will be written.")

# ## Other parameters
# parser.add_argument("--cache_dir",
#                     default="",
#                     type=str,
#                     help="Where do you want to store the pre-trained models downloaded from s3")
# parser.add_argument("--trained_model_dir",
#                     default="",
#                     type=str,
#                     help="Where is the fine-tuned (with the cloze-style LM objective) xlnet model?")
# parser.add_argument("--max_seq_length",
#                     default=128,
#                     type=int,
#                     help="The maximum total input sequence length after WordPiece tokenization. \n"
#                          "Sequences longer than this will be truncated, and sequences shorter \n"
#                          "than this will be padded.")
# parser.add_argument("--do_lower_case",
#                     action='store_true',
#                     help="Set this flag if you are using an uncased model.")
# parser.add_argument("--train_batch_size",
#                     default=32,
#                     type=int,
#                     help="Total batch size for training.")
# parser.add_argument("--eval_batch_size",
#                     default=8,
#                     type=int,
#                     help="Total batch size for eval.")
# parser.add_argument("--no_cuda",
#                     action='store_true',
#                     help="Whether not to use CUDA when available")
# parser.add_argument('--seed',
#                     type=int,
#                     default=42,
#                     help="random seed for initialization")
# parser.add_argument('--freeze_bert',
#                     action='store_true',
#                     help="Whether to freeze BERT")
# parser.add_argument('--full_bert',
#                     action='store_true',
#                     help="Whether to use full BERT")
# parser.add_argument('--num_train_samples',
#                     type=int,
#                     default=-1,
#                     help="-1 for full train set, otherwise please specify")
# parser.add_argument('--damping',
#                     type=float,
#                     default=0.0,
#                     help="probably need damping for deep models")
# parser.add_argument('--test_idx',
#                     type=int,
#                     default=1,
#                     help="test index we want to examine")
# parser.add_argument('--start_test_idx',
#                     type=int,
#                     default=-1,
#                     help="when not -1, --test_idx will be disabled")
# parser.add_argument('--end_test_idx',
#                     type=int,
#                     default=-1,
#                     help="when not -1, --test_idx will be disabled")
# parser.add_argument("--lissa_repeat",
#                     default=1,
#                     type=int)
# parser.add_argument("--lissa_depth_pct",
#                     default=1.0,
#                     type=float)
# parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
# parser.add_argument('--scale',
#                     type=float,
#                     default=1e4,
#                     help="probably need scaling for deep models")
# parser.add_argument("--alt_mode",
#                     default="",
#                     type=str,
#                     help="whether to use extended data split (ext) or only control data split (ctr)")


# args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# args.device = device

# random.seed(args.seed)
# np.random.seed(args.seed)
# torch.manual_seed(args.seed)

# if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
#     #raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
#     logger.info("WARNING: Output directory already exists and is not empty.")
# if not os.path.exists(args.output_dir):
#     os.makedirs(args.output_dir)

In [18]:
ma_processor = MAProcessor()
label_list = ma_processor.get_labels()
num_labels = len(label_list)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

# Prepare model
model = MyBertForSequenceClassification.from_pretrained('/project/my_model/model-runner-output', num_labels=num_labels)
model.to(device)

# Prepare optimizer
param_optimizer = list(model.named_parameters())
#     if freeze_bert:
#         frozen = ['bert']
#     elif args.full_bert:
#         frozen = []
#     else:
frozen = ['bert.embeddings.',
          'bert.encoder.layer.0.',
          'bert.encoder.layer.1.',
          'bert.encoder.layer.2.',
          'bert.encoder.layer.3.',
          'bert.encoder.layer.4.',
          'bert.encoder.layer.5.',
          'bert.encoder.layer.6.',
          'bert.encoder.layer.7.',
         ] # *** change here to filter out params we don't want to track ***

param_influence = []
for n, p in param_optimizer:
    if (not any(fr in n for fr in frozen)):
        param_influence.append(p)
    elif 'bert.embeddings.word_embeddings.' in n:
        pass # need gradients through embedding layer for computing saliency map
    else:
        p.requires_grad = False


08/13/2021 19:24:01 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/faculty/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
08/13/2021 19:24:01 - INFO - pytorch_pretrained_bert.modeling -   loading archive file /project/my_model/model-runner-output
08/13/2021 19:24:01 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}



In [19]:
param_shape_tensor = []
param_size = 0
for p in param_influence:
    tmp_p = p.clone().detach()
    param_shape_tensor.append(tmp_p)
    param_size += torch.numel(tmp_p)
logger.info("  Parameter size = %d", param_size)


08/13/2021 19:24:05 - INFO - __main__ -     Parameter size = 28943618


In [20]:
train_examples = ma_processor.get_direct_control_train_examples("/project/my_data")


In [21]:
train_features = convert_examples_to_features(
    train_examples, label_list, 200, tokenizer)
logger.info("***** Train set *****")
logger.info("  Num examples = %d", len(train_examples))
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_id = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
all_guids = torch.tensor([f.guid for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_id, all_guids)
train_dataloader = DataLoader(train_data, sampler=SequentialSampler(train_data), batch_size=1)

test_examples = ma_processor.get_adv_examples('/project/my_data')

test_features = convert_examples_to_features(
    test_examples, label_list,200, tokenizer)
logger.info("***** Test set *****")
logger.info("  Num examples = %d", len(test_examples))
all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long)
all_label_id = torch.tensor([f.label_id for f in test_features], dtype=torch.long)
all_guids = torch.tensor([f.guid for f in test_features], dtype=torch.long)
test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_id, all_guids)
test_dataloader = DataLoader(test_data, sampler=SequentialSampler(test_data), batch_size=1)

lissa_depth = int(0 * len(train_examples))

# test_idx = args.test_idx
# start_test_idx = args.start_test_idx
# end_test_idx = args.end_test_idx

influence_dict = dict()
ihvp_dict = dict()


08/13/2021 19:24:11 - INFO - __main__ -   ***** Train set *****
08/13/2021 19:24:11 - INFO - __main__ -     Num examples = 13000
08/13/2021 19:24:11 - INFO - __main__ -   ***** Test set *****
08/13/2021 19:24:11 - INFO - __main__ -     Num examples = 100


In [22]:
 for tmp_idx, (input_ids, input_mask, segment_ids, label_ids, guids) in enumerate(test_dataloader):
#         if args.start_test_idx != -1 and args.end_test_idx != -1:
#             if tmp_idx < args.start_test_idx:
#                 continue
#             if tmp_idx > args.end_test_idx:
#                 break
#         else:
#             if tmp_idx < args.test_idx:
#                 continue
#             if tmp_idx > args.test_idx:
#                 break
                
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)
        label_ids = label_ids.to(device)
        
        influence_dict[tmp_idx] = np.zeros(len(train_examples))

In [23]:
model.eval()
model.zero_grad()
test_loss = model(input_ids, segment_ids, input_mask, label_ids)
test_grads = autograd.grad(test_loss, param_influence)

In [32]:
test_grads[25]

tensor([ 1.5184e-06, -1.9397e-06, -4.7402e-06, -3.6225e-06, -4.2587e-06,
         6.6472e-06,  3.6700e-06,  4.0629e-06, -4.2408e-06, -4.5758e-06,
         1.3419e-06,  1.8024e-06, -1.7919e-06,  2.6335e-06, -3.1649e-06,
         4.2624e-06,  2.1541e-06,  1.1529e-06, -2.6763e-06, -3.1226e-06,
         1.5486e-06, -1.6042e-06, -4.1684e-06,  4.7140e-06,  4.2723e-06,
         6.6284e-07,  3.1581e-06,  2.3107e-07,  2.1825e-06, -1.4779e-06,
         5.8748e-06,  3.4218e-07, -3.9005e-07, -1.8554e-06,  7.1721e-07,
         1.9981e-06, -6.5760e-06,  2.1120e-06,  2.5114e-06, -3.0164e-06,
        -1.0389e-07,  2.3667e-06,  2.6859e-07,  2.4271e-06, -2.4554e-06,
         8.5510e-07,  1.0009e-06,  1.9137e-06, -2.7216e-06, -4.8227e-06,
        -2.0230e-06, -1.5983e-06,  4.3830e-07, -2.8043e-06, -3.8221e-07,
         1.0770e-07,  3.7807e-07,  7.6936e-07,  1.8155e-06, -1.8223e-07,
        -3.0247e-06,  5.5795e-06, -6.3071e-07, -5.8459e-06, -5.1287e-06,
         3.0692e-06, -1.8823e-06,  5.4689e-07, -4.9

In [25]:
param_influence


[Parameter containing:
 tensor([[ 0.0437,  0.0610,  0.0168,  ..., -0.0464, -0.0252,  0.0074],
         [ 0.0233, -0.0206,  0.0179,  ...,  0.0304, -0.0035,  0.0485],
         [ 0.0311, -0.1272,  0.0061,  ..., -0.0477,  0.0382, -0.0359],
         ...,
         [-0.0033, -0.0347, -0.0326,  ..., -0.0210, -0.0028, -0.0318],
         [ 0.0071, -0.0319, -0.0551,  ...,  0.0889, -0.0564,  0.0020],
         [-0.0211,  0.0544,  0.0503,  ...,  0.0030,  0.0134,  0.0266]],
        requires_grad=True),
 Parameter containing:
 tensor([-2.5058e-01, -2.3638e-01,  2.4868e-02,  9.3544e-03,  1.8485e-01,
         -1.0751e-01,  4.3065e-01,  1.4016e-01,  4.9633e-01, -1.9919e-01,
         -5.2133e-02,  5.5362e-02, -1.1410e-01, -1.1738e-01, -1.8928e-01,
         -9.0274e-02, -1.2046e-01, -6.4587e-02,  7.4744e-02,  7.4048e-01,
          1.1913e-01, -3.9970e-01,  9.3598e-02,  1.4659e-01, -9.4422e-03,
          1.4193e-01,  1.7886e-01,  1.4586e-01, -1.6771e-03, -2.3344e-03,
         -2.5971e-01,  2.4574e-01,  1.74

False