In [None]:
import breaching
    
import torch
%load_ext autoreload
%autoreload 2

# Redirects logs directly into the jupyter notebook
import logging, sys
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()
from rouge_score import rouge_scorer

from tqdm import tqdm
import numpy as np

In [None]:
dataset = "cola-hash"
shared_model = "BERT" 
model = "NLPMLP-600-1000"
seed = 0
epochs = 100
batch_size = 16
save_dir = "."
leak_mode ="None"
lr_list = 1e-3

In [None]:
if "cola" in dataset:
    cfg = breaching.get_config(overrides=["case=9_bert_training", "case/data=cola", "case.data.task=classification",
                                          "attack=tag"])
    device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
    torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
    setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))

    cfg.case.user.num_data_points = 1
    cfg.case.user.user_idx = 1
    cfg.case.data.shape = [16]

    cfg.case.model="bert-sanity-check"

    cfg.attack.optim.max_iterations = 6000
    cfg.attack.optim.step_size = 0.05
    cfg.case.data.examples_from_split = "validation"
elif "wikitext" in dataset:
    cfg = breaching.get_config(overrides=["case=10_causal_lang_training",  "attack=tag"])
          
    device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
    torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
    setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
    
    cfg.case.user.num_data_points = 1 
    cfg.case.user.user_idx = 1
    cfg.case.data.shape = [16]
    cfg.case.data.examples_from_split = "validation"

user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, setup)

metric_list = ["rouge1", "rouge2", "rougeL"]

def get_results(dataset, shared_model, model, seed, epochs, batch_size, save_dir, leak_mode, lr, print_example=False, version="new"):
    print_res = []
    seed_str = f"_{seed}" #if leak_mode.startswith("prune") else  ""
    save_file_name = f"{save_dir}/checkpoint/{dataset}_{shared_model}_{model}_{leak_mode}_{lr}_{epochs}_{batch_size}{seed_str}"

    if version == "new":
        checkpoint_name = f"{save_file_name}_version1.pt"
        checkpoint = torch.load(checkpoint_name)
        if "epoch" in checkpoint.keys():
            epoch = checkpoint["epoch"]
            print(f"epoch: {epoch}")
    else:
        for epoch in range(100):
            checkpoint_name = f"{save_file_name}_{epoch}.pt"
            if os.path.exists(checkpoint_name):
                continue
            else:
                epoch = epoch - 1
                break
        print(f"epoch: {epoch}")
        
    evaluation = []

    acc = checkpoint["val_acc"]
    print_res.append(acc)

    reconstructed_data_all = checkpoint["val_reconstructed_imgs"]
    gt_data_all = user.dataloader.dataset
    del checkpoint
    for metric in metric_list:
        evaluation = []
        randperm = torch.arange(100).long()
        for rand_i in tqdm(range(100)):
            i = randperm[rand_i]
            reconstructed_data = {}
            gt_data = {}
            reconstructed_data["data"] = reconstructed_data_all[i:i+1]
            reconstructed_data["labels"] = reconstructed_data["data"]
            gt_data["data"] = gt_data_all[i:i+1]["input_ids"]
            
            if print_example and rand_i < 19:
                user.print(gt_data)
                user.print(reconstructed_data)
            reconstructed_data = user.print(reconstructed_data, print_out=False)[0]
            gt_data = user.print(gt_data, print_out=False)[0]
            scorer = rouge_scorer.RougeScorer([metric], use_stemmer=True)
            scores = scorer.score(reconstructed_data,
                                  gt_data )
            evaluation.append(scores)
        res_list = [metrics[metric][2] for metrics in evaluation]
        print_res.append(np.asarray(res_list).mean())
    print(leak_mode, "&".join(["${:10.2f}$".format(x*100) for x in print_res]) + "&")
    print(leak_mode, "".join(["{:10.2f}".format(x*100) for x in print_res]))

In [None]:
get_results(dataset, shared_model, model, seed, epochs, batch_size, save_dir, leak_mode, lr, print_example=True, version="new")