In [1]:
"""This File Cannot Import Other Files"""

import os
import sys


def set_cuda_visible_devices(gpu_id: int):
    if is_module_imported("torch") or is_module_imported("tensorflow"):
        raise ValueError

    # https://stackoverflow.com/questions/37893755/tensorflow-set-cuda-visible-devices-within-jupyter
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
    print(f"Set GPU ID to {gpu_id}")


def is_module_imported(module_name: str) -> bool:
    # https://stackoverflow.com/questions/30483246/how-to-check-if-a-python-module-has-been-imported
    return module_name in sys.modules


In [2]:
set_cuda_visible_devices(1)

Set GPU ID to 1


In [8]:
import numpy as np
from experiments import constants
from experiments.misc_utils import sort_dict_keys_by_vals_with_conditions
constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR2 = 

In [9]:
import os
import time
import torch
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import transformers
from tqdm import tqdm
from glob import glob
from copy import deepcopy
from contexttimer import Timer
from collections import defaultdict
from transformers import TrainingArguments
from transformers import default_data_collator
from typing import List, Dict, Tuple, Optional, Union, Any

from experiments import constants
from experiments import mnli_utils
from experiments import misc_utils
from experiments import remote_utils
from experiments import influence_helpers
from experiments import hans
from influence_utils import nn_influence_utils
from experiments.data_utils import (
    glue_output_modes,
    glue_compute_metrics)

MNLI_TRAINING_SCRIPT_NAME = "scripts/run_MNLI.20200913.sh"
NUM_DATAPOINTS_TO_REMOVE_CHOICES = [1, 10, 100, 10000]

CORRECT_INDICES = sorted([
    # e.g., `KNN-recall.only-correct.50.0.pth.g0301.ll.unc.edu`
    int(f.split("/")[-1].split(".")[3])
    for f in glob(os.path.join(
        constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR,
        "*only-correct*")
    )
])
INCORRECT_INDICES = sorted([
    # e.g., `KNN-recall.only-correct.50.0.pth.g0301.ll.unc.edu`
    int(f.split("/")[-1].split(".")[3])
    for f in glob(os.path.join(
        constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR,
        "*only-incorrect*")
    )
])


def run_retraining_main(
        mode: str,
        num_examples_to_test: int):

    if mode not in ["full", "KNN-1000", "KNN-10000", "random"]:
        raise ValueError(f"Unrecognized `mode` {mode}")

    for example_relative_index in range(num_examples_to_test):
        for correct_mode in ["correct", "incorrect"]:
            if correct_mode == "correct":
                example_index = CORRECT_INDICES[example_relative_index]
            if correct_mode == "incorrect":
                example_index = INCORRECT_INDICES[example_relative_index]

            if mode in ["full"]:
                # Load file from local or sync from remote
                file_name = os.path.join(
                    constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR,
                    f"KNN-recall.only-{correct_mode}.50.{example_index}"
                    f".pth.g0301.ll.unc.edu")

                influences_dict = torch.load(file_name)
                if example_index != influences_dict["test_index"]:
                    raise ValueError

                if (correct_mode == "correct" and
                        influences_dict["correct"] is not True or
                        correct_mode == "incorrect" and
                        influences_dict["correct"] is True):
                    raise ValueError

                helpful_indices = misc_utils.sort_dict_keys_by_vals(
                    influences_dict["influences"])
                harmful_indices = helpful_indices[::-1]
                indices_dict = {
                    "helpful": [influences_dict["influences"][k] for k in helpful_indices],
                    "harmful": [influences_dict["influences"][k] for k in harmful_indices]}

            if mode in ["KNN-1000", "KNN-10000"]:
                if mode == "KNN-1000":
                    kNN_k = 1000
                if mode == "KNN-10000":
                    kNN_k = 10000

                file_name = os.path.join(
                    constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR2,
                    f"visualization"
                    f".only-{correct_mode}"
                    f".5.mnli-mnli-None-mnli"
                    f".{kNN_k}.True.pth.g0306.ll.unc.edu")

                influences_dict = torch.load(file_name)[example_relative_index]
                if example_index != influences_dict["index"]:
                    raise ValueError

                helpful_indices = misc_utils.sort_dict_keys_by_vals_with_conditions(
                    influences_dict["influences"],
                    condition_func=lambda k_v: k_v[1] < 0.0
                )
                harmful_indices = misc_utils.sort_dict_keys_by_vals_with_conditions(
                    influences_dict["influences"],
                    condition_func=lambda k_v: k_v[1] > 0.0
                )[::-1]

                indices_dict = {
                    "helpful": [influences_dict["influences"][k] for k in helpful_indices],
                    "harmful": [influences_dict["influences"][k] for k in harmful_indices]}

            if mode == "random":
                # Get indices corresponding to each label
                label_to_indices = mnli_utils.get_label_to_indices_map()
                np.random.shuffle(label_to_indices["neutral"])
                np.random.shuffle(label_to_indices["entailment"])
                np.random.shuffle(label_to_indices["contradiction"])
                indices_dict = {
                    "neutral": label_to_indices["neutral"],
                    "entailment": label_to_indices["entailment"],
                    "contradiction": label_to_indices["contradiction"],
                }

            for tag, indices in indices_dict.items():
                for num_data_points_to_remove in NUM_DATAPOINTS_TO_REMOVE_CHOICES:
                    if len(indices) < num_data_points_to_remove:
                        raise ValueError(f"`indices` have only {len(indices)} elememts "
                                         f"whereas {num_data_points_to_remove} is needed")
                    yield example_relative_index, tag, indices[:num_data_points_to_remove], num_data_points_to_remove
#                     run_one_retraining(
#                         indices=indices[:num_data_points_to_remove],
#                         dir_name=(
#                             f"./retraining-remove-"
#                             f"{example_index}-"
#                             f"{correct_mode}-"
#                             f"{mode}-"
#                             f"{tag}-"
#                             f"{num_data_points_to_remove}"))


In [10]:
NUM_RETRAINING_EXPERIMENTS = 3
for mode in ["full", "KNN-10000", "random"]:
    indices_generator = run_retraining_main(mode, NUM_RETRAINING_EXPERIMENTS)
    for index, tag, indices, n in indices_generator:
        sign_flipped = bool(np.sign(np.max(indices)) != np.sign(np.min(indices)))
        under_length = len(indices) != n

        if sign_flipped or under_length:
            print(f"{mode:<10}", index, tag, f"{n:<5}", "\t", sign_flipped, under_length)

ValueError: `indices` have only 5517 elememts whereas 10000 is needed

# Which kind of experiments went not expected

In [9]:
NUM_RETRAINING_EXPERIMENTS = 3

In [10]:
for mode in ["full", "KNN-1000", "KNN-10000", "random"]:
    indices_generator = run_retraining_main(mode, NUM_RETRAINING_EXPERIMENTS)
    for index, tag, indices, n in indices_generator:
        sign_flipped = bool(np.sign(np.max(indices)) != np.sign(np.min(indices)))
        under_length = len(indices) != n

        if sign_flipped or under_length:
            print(f"{mode:<10}", index, tag, f"{n:<5}", "\t", sign_flipped, under_length)

KNN-1000   0 helpful 10000 	 True True
KNN-1000   0 harmful 10000 	 True True
KNN-1000   0 helpful 10000 	 True True
KNN-1000   0 harmful 10000 	 True True
KNN-1000   1 helpful 10000 	 True True
KNN-1000   1 harmful 100   	 True False
KNN-1000   1 harmful 10000 	 True True
KNN-1000   1 helpful 10000 	 True True
KNN-1000   1 harmful 10000 	 True True
KNN-1000   2 helpful 10000 	 True True
KNN-1000   2 harmful 10000 	 True True
KNN-1000   2 helpful 100   	 True False
KNN-1000   2 helpful 10000 	 True True
KNN-1000   2 harmful 10000 	 True True
KNN-10000  0 helpful 10000 	 True False
KNN-10000  0 harmful 10000 	 True False
KNN-10000  0 helpful 10000 	 True False
KNN-10000  0 harmful 10000 	 True False
KNN-10000  1 helpful 10000 	 True False
KNN-10000  1 harmful 10000 	 True False
KNN-10000  1 helpful 10000 	 True False
KNN-10000  1 harmful 10000 	 True False
KNN-10000  2 helpful 10000 	 True False
KNN-10000  2 harmful 10000 	 True False
KNN-10000  2 helpful 10000 	 True False
KNN-10000  2