In [1]:
"""This File Cannot Import Other Files"""

import os
import sys


def set_cuda_visible_devices(gpu_id: int):
    if is_module_imported("torch") or is_module_imported("tensorflow"):
        raise ValueError

    # https://stackoverflow.com/questions/37893755/tensorflow-set-cuda-visible-devices-within-jupyter
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
    print(f"Set GPU ID to {gpu_id}")


def is_module_imported(module_name: str) -> bool:
    # https://stackoverflow.com/questions/30483246/how-to-check-if-a-python-module-has-been-imported
    return module_name in sys.modules


In [2]:
set_cuda_visible_devices(1)

Set GPU ID to 1


In [3]:
import numpy as np
from experiments import constants
from experiments.misc_utils import sort_dict_keys_by_vals_with_conditions
constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR2



'/export/home/hguo/Experiments-backup/20201019/visualization_outputs/'

In [4]:
import os
import time
import torch
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import transformers
from tqdm import tqdm
from glob import glob
from copy import deepcopy
from contexttimer import Timer
from collections import defaultdict
from transformers import TrainingArguments
from transformers import default_data_collator
from typing import List, Dict, Tuple, Optional, Union, Any

from experiments import constants
from experiments import mnli_utils
from experiments import misc_utils
from experiments import remote_utils
from experiments import influence_helpers
from experiments import hans
from influence_utils import nn_influence_utils
from experiments.data_utils import (
    glue_output_modes,
    glue_compute_metrics)

MNLI_TRAINING_SCRIPT_NAME = "scripts/run_MNLI.20200913.sh"
NUM_DATAPOINTS_TO_REMOVE_CHOICES = [1, 5, 10, 25, 50, 100, 10000]

CORRECT_INDICES = sorted([
    # e.g., `KNN-recall.only-correct.50.0.pth.g0301.ll.unc.edu`
    int(f.split("/")[-1].split(".")[3])
    for f in glob(os.path.join(
        constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR,
        "*only-correct*")
    )
])
INCORRECT_INDICES = sorted([
    # e.g., `KNN-recall.only-correct.50.0.pth.g0301.ll.unc.edu`
    int(f.split("/")[-1].split(".")[3])
    for f in glob(os.path.join(
        constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR,
        "*only-incorrect*")
    )
])


def run_retraining_main(
        mode: str,
        num_examples_to_test: int):

    if mode not in ["full", "KNN-1000", "KNN-10000", "random"]:
        raise ValueError(f"Unrecognized `mode` {mode}")

    for example_relative_index in range(num_examples_to_test):
        for correct_mode in ["correct", "incorrect"]:
            if correct_mode == "correct":
                example_index = CORRECT_INDICES[example_relative_index]
            if correct_mode == "incorrect":
                example_index = INCORRECT_INDICES[example_relative_index]

            if mode in ["full"]:
                # Load file from local or sync from remote
                file_name = os.path.join(
                    constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR,
                    f"KNN-recall.only-{correct_mode}.50.{example_index}"
                    f".pth.g0301.ll.unc.edu")

                influences_dict = torch.load(file_name)
                if example_index != influences_dict["test_index"]:
                    raise ValueError

                if (correct_mode == "correct" and
                        influences_dict["correct"] is not True or
                        correct_mode == "incorrect" and
                        influences_dict["correct"] is True):
                    raise ValueError

                helpful_indices, harmful_indices = (
                    misc_utils.get_helpful_harmful_indices_from_influences_dict(
                        influences_dict["influences"]))

                indices_dict = {
                    "helpful": [influences_dict["influences"][k] for k in helpful_indices],
                    "harmful": [influences_dict["influences"][k] for k in harmful_indices]}

            if mode in ["KNN-1000", "KNN-10000"]:
                if mode == "KNN-1000":
                    kNN_k = 1000
                if mode == "KNN-10000":
                    kNN_k = 10000

                file_name = os.path.join(
                    constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR2,
                    f"visualization"
                    f".only-{correct_mode}"
                    f".5.mnli-mnli-None-mnli"
                    f".{kNN_k}.True.pth.g0306.ll.unc.edu")

                influences_dict = torch.load(file_name)[example_relative_index]
                if example_index != influences_dict["index"]:
                    raise ValueError

                helpful_indices, harmful_indices = (
                    misc_utils.get_helpful_harmful_indices_from_influences_dict(
                        influences_dict["influences"],
                        # n=kNN_k
                    ))

                indices_dict = {
                    "helpful": [influences_dict["influences"][k] for k in helpful_indices],
                    "harmful": [influences_dict["influences"][k] for k in harmful_indices]}

            if mode == "random":
                # Get indices corresponding to each label
                label_to_indices = mnli_utils.get_label_to_indices_map()
                np.random.shuffle(label_to_indices["neutral"])
                np.random.shuffle(label_to_indices["entailment"])
                np.random.shuffle(label_to_indices["contradiction"])
                indices_dict = {
                    "neutral": label_to_indices["neutral"],
                    "entailment": label_to_indices["entailment"],
                    "contradiction": label_to_indices["contradiction"],
                }

            for tag, indices in indices_dict.items():
                for num_data_points_to_remove in NUM_DATAPOINTS_TO_REMOVE_CHOICES:
                    if len(indices) < num_data_points_to_remove:
#                         pass
                        print(f"\t\t\t\t\t\t\t`indices` have only {len(indices)} elememts "
                              f"whereas {num_data_points_to_remove} is needed")
                    yield example_index, tag, indices[:num_data_points_to_remove], num_data_points_to_remove
#                     run_one_retraining(
#                         indices=indices[:num_data_points_to_remove],
#                         dir_name=(
#                             f"./retraining-remove-"
#                             f"{example_index}-"
#                             f"{correct_mode}-"
#                             f"{mode}-"
#                             f"{tag}-"
#                             f"{num_data_points_to_remove}"))


In [6]:
NUM_RETRAINING_EXPERIMENTS = 3
for mode in ["KNN-10000"]:
    indices_generator = run_retraining_main(mode, NUM_RETRAINING_EXPERIMENTS)
    for index, tag, indices, n in indices_generator:
        sign_flipped = bool(np.sign(np.max(indices)) != np.sign(np.min(indices)))
        under_length = len(indices) != n

        if sign_flipped or under_length:
            print(f"{mode:<10}", f"{index:<3}", tag, f"{n:<5}", "\t", sign_flipped, under_length)

ValueError: `helpful_indices` have only 5517 elememts whereas 10000 is needed

# Which kind of experiments went not expected

In [5]:
NUM_RETRAINING_EXPERIMENTS = 3

In [6]:
for mode in ["full", "KNN-1000", "KNN-10000", "random"]:
    indices_generator = run_retraining_main(mode, NUM_RETRAINING_EXPERIMENTS)
    for index, tag, indices, n in indices_generator:
        sign_flipped = bool(np.sign(np.max(indices)) != np.sign(np.min(indices)))
        under_length = len(indices) != n

        if sign_flipped or under_length:
            print(f"{mode:<10}", f"{index:<3}", tag, f"{n:<5}", "\t", sign_flipped, under_length)

							`indices` have only 632 elememts whereas 10000 is needed
KNN-1000   0   helpful 10000 	 False True
							`indices` have only 368 elememts whereas 10000 is needed
KNN-1000   0   harmful 10000 	 False True
							`indices` have only 529 elememts whereas 10000 is needed
KNN-1000   3   helpful 10000 	 False True
							`indices` have only 471 elememts whereas 10000 is needed
KNN-1000   3   harmful 10000 	 False True
							`indices` have only 975 elememts whereas 10000 is needed
KNN-1000   1   helpful 10000 	 False True
							`indices` have only 25 elememts whereas 50 is needed
KNN-1000   1   harmful 50    	 False True
							`indices` have only 25 elememts whereas 100 is needed
KNN-1000   1   harmful 100   	 False True
							`indices` have only 25 elememts whereas 10000 is needed
KNN-1000   1   harmful 10000 	 False True
							`indices` have only 236 elememts whereas 10000 is needed
KNN-1000   14  helpful 10000 	 False True
							`indices` have only 764 elememts whereas 10000 

# KNN-Recall, Simulator, and Augmentation Experiments

In [1]:
cd /workspace/fast-influence-functions/

/workspace/fast-influence-functions


In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [3]:
import torch
from tqdm import tqdm
from glob import glob
from experiments.misc_utils import get_helpful_harmful_indices_from_influences_dict



In [4]:
# full_influences_collections = [
#     torch.load(file_name)
#     for file_name in tqdm(glob(f"/export/home/hguo/Experiments/20200922/KNN-recall.only-*.50.*.pth.g0301.ll.unc.edu"))]

# N = 1000

100%|██████████| 100/100 [00:24<00:00,  4.04it/s]


In [4]:
# full_influences_collections = (
#     torch.load("/export/home/hguo/Experiments/20201119/imiator_experiments.only-correct.10.pt") +
#     torch.load("/export/home/hguo/Experiments/20201119/imiator_experiments.only-incorrect.10.pt")
# )

# N = 10

In [7]:
# output_collections = torch.load("./figures/another-hans-augentation-experimental.pth")
BASE_PATH = "/export/home/hguo/Experiments/20201118/"
output_collections = {
    "lexical_overlap": torch.load(
#         os.path.join(BASE_PATH, "hans-augmentation-new.hans.lexical_overlap.3.False.pth")),
        os.path.join(BASE_PATH, "hans-augmentation-new.mnli-2.lexical_overlap.3.False.pth")),
    "constituent": torch.load(
#         os.path.join(BASE_PATH, "hans-augmentation-new.hans.constituent.3.False.pth")),
        os.path.join(BASE_PATH, "hans-augmentation-new.mnli-2.constituent.3.False.pth")),
    "subsequence": torch.load(
#         os.path.join(BASE_PATH, "hans-augmentation-new.hans.subsequence.3.False.pth")),
        os.path.join(BASE_PATH, "hans-augmentation-new.mnli-2.subsequence.3.False.pth")),
}

In [8]:
full_influences_collections = []
for subset_1 in output_collections.keys():
    for subset_2 in output_collections[subset_1].keys():
        for index in range(len(output_collections[subset_1][subset_2])):
            if (output_collections[subset_1][subset_2][index]["num_datapoints"] != 1):
                raise ValueError
            if len(output_collections[subset_1][subset_2][index]["datapoint_indices"]) <= 1:
                raise ValueError
            full_influences_collections.append({
                "influences": output_collections[subset_1][subset_2][index]["influences"],
            })

N = 1

In [9]:
for collection in tqdm(full_influences_collections):
    if collection["influences"] is None:
        # print("Influence is None")
        continue

    try:
        helpful_indices, harmful_indices = (
            get_helpful_harmful_indices_from_influences_dict(
                collection["influences"], n=N)
        )
    except Exception as e:
        print(e)
        helpful_indices, harmful_indices = (
            get_helpful_harmful_indices_from_influences_dict(
                collection["influences"])
        )

    if not (
            (len(helpful_indices) >= N) and
            (len(harmful_indices) >= N)
    ):
        print(len(helpful_indices),
              len(harmful_indices))

100%|██████████| 270/270 [00:00<00:00, 1287.04it/s]
