In [1]:
import torch
from nlp_fast_unlearning.utils import prepare_dbpedia, build_noisy_dl, ensure_deterministic, DEVICE, BATCH_SIZE
from nlp_fast_unlearning.baseline_model import TextClassificationModel

_, _, _, dbpedia_vocab = prepare_dbpedia(for_baseline_only=True)

vocab_size = dbpedia_vocab.vocab_size

baseline_name = "DBpedia_baseline.pt"
unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
unlearning_model.load_state_dict(torch.load(baseline_name))

<All keys matched successfully>

In [2]:
%%time

(
    retain_samples,
    noisy_data,
    retain_valid_dl,
    forget_valid_dl,
    retain_test_dl,
    forget_test_dl,
    dbpedia_vocab,
) = prepare_dbpedia(
    for_baseline_only=False,
    classes_to_forget=[5,6,7,8,9],
    model=unlearning_model,
    retain_percentage=0.01,
    vocab_class=dbpedia_vocab,
)

Searching for error maximizing noise for class  5
Got loss 195.9832305908203 for tensor([844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844,
        844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844,
        844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844,
        844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844,
        844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844,
        844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844,
        844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844, 844,
        844, 844], device='cuda:0')
Searching for error maximizing noise for class  6
Got loss 213.3338623046875 for tensor([101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101,

In [3]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


def validation_step(model, batch):
    labels, text, offsets = batch
    out = model(text,offsets)
    loss = F.cross_entropy(out, labels)   
    acc = accuracy(out, labels)
    return {'Loss': loss.detach(), 'Acc': acc}


def validation_epoch_end(model, outputs):
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()   
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()      
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

In [4]:
import torch.nn.functional as F

@torch.no_grad()
def evaluate_after_unlearning(model, val_loader):
    model.eval()
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch_end(model, outputs)

In [5]:
%%time
from tqdm import tqdm


# lrs = [12]
# clips = [22]
# ratios = [1.5]

# The following search candidates take a few minutes to run through
lrs = [7,8,9,10,11,12,13,14]
clips = [17,18,19,20,21,22,23,24]
ratios = [4,8,10,12,15,18]


best_retain_acc = 0

unlearned_model_name = "DBpedia_fast_unlearned.pt"

for lr in tqdm(lrs):
    for clip in clips:
        for ratio in ratios:
            ensure_deterministic()
            
            noisy_loader = build_noisy_dl(
                retain_samples,
                noisy_data,
                dbpedia_vocab,
                retain_to_forget_ratio=ratio,
            )
            unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
            unlearning_model.load_state_dict(torch.load(baseline_name))

            optimizer = torch.optim.SGD(unlearning_model.parameters(), lr = lr)


            unlearning_model.train(True)
            for epoch in range(1):
                running_loss = 0.0
                running_acc = 0
                num_batches = len(noisy_loader)
                
                for i, data in enumerate(noisy_loader):
                    labels, inputs, offsets = data

                    optimizer.zero_grad()
                    outputs = unlearning_model(inputs,offsets)
                    loss = unlearning_model.criterion(outputs, labels)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(unlearning_model.parameters(), clip)
                    optimizer.step()

                    running_loss += loss.item()
                    out = torch.argmax(outputs.detach(),dim=1)
                    running_acc += (labels==out).sum().item()/labels.size(0)
                print(f"Train loss {epoch+1}: {running_loss/num_batches},Train Acc:{running_acc*100/num_batches}%")
                print("Performance of Standard Forget Model on Forget Class")
                history = [evaluate_after_unlearning(unlearning_model, forget_valid_dl)]
                print("Accuracy: {}".format(history[0]["Acc"]*100))
                print("Loss: {}".format(history[0]["Loss"]))

                print("Performance of Standard Forget Model on Retain Class")
                history = [evaluate_after_unlearning(unlearning_model, retain_valid_dl)]
                print("Accuracy: {}".format(history[0]["Acc"]*100))
                print("Loss: {}".format(history[0]["Loss"]))
                forget_acc = evaluate_after_unlearning(unlearning_model, forget_valid_dl)["Acc"]*100
                if forget_acc == 0.0:
                    retain_acc = evaluate_after_unlearning(unlearning_model, retain_valid_dl)["Acc"]*100
                    if retain_acc > best_retain_acc:
                        best_retain_acc = retain_acc
                        best_lr = lr
                        best_clip = clip
                        best_ratio = ratio
                        torch.save(unlearning_model.state_dict(), unlearned_model_name)

print(
    f"Best hyperparams: LR={best_lr}, grad clip={best_clip}, "
    f"ratio={best_ratio} | Best retain acc: {best_retain_acc}"
)
unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
unlearning_model.load_state_dict(torch.load(unlearned_model_name))

  0%|                                                                             | 0/8 [00:00<?, ?it/s]

Train loss 1: 813.4583576571557,Train Acc:64.78914650537635%
Performance of Standard Forget Model on Forget Class
Accuracy: 20.236817002296448
Loss: 68.75955200195312
Performance of Standard Forget Model on Retain Class
Accuracy: 6.58646821975708
Loss: 23.246158599853516
Train loss 1: 302.0453813726252,Train Acc:71.59689529220779%
Performance of Standard Forget Model on Forget Class
Accuracy: 10.185955464839935
Loss: 31.381567001342773
Performance of Standard Forget Model on Retain Class
Accuracy: 95.0734555721283
Loss: 0.414170503616333
Train loss 1: 93.13356750179636,Train Acc:78.21800595238095%
Performance of Standard Forget Model on Forget Class
Accuracy: 25.031903386116028
Loss: 9.26378345489502
Performance of Standard Forget Model on Retain Class
Accuracy: 86.93691492080688
Loss: 0.5897656083106995
Train loss 1: 62.59503196415148,Train Acc:80.61735341643582%
Performance of Standard Forget Model on Forget Class
Accuracy: 20.509839057922363
Loss: 3.2676475048065186
Performance of S

 12%|████████▌                                                           | 1/8 [02:11<15:20, 131.52s/it]

Train loss 1: 968.9072431748913,Train Acc:66.07022849462365%
Performance of Standard Forget Model on Forget Class
Accuracy: 1.2645110487937927
Loss: 41.328617095947266
Performance of Standard Forget Model on Retain Class
Accuracy: 95.05856037139893
Loss: 0.8412427306175232
Train loss 1: 511.0372900529341,Train Acc:70.49898538961038%
Performance of Standard Forget Model on Forget Class
Accuracy: 2.193460799753666
Loss: 32.08176803588867
Performance of Standard Forget Model on Retain Class
Accuracy: 95.91948986053467
Loss: 0.21027107536792755
Train loss 1: 438.7252793194327,Train Acc:51.58110119047619%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 132.43077087402344
Performance of Standard Forget Model on Retain Class
Accuracy: 73.72954487800598
Loss: 29.098426818847656
Train loss 1: 84.36116173392848,Train Acc:81.50572772391506%
Performance of Standard Forget Model on Forget Class
Accuracy: 17.742924392223358
Loss: 10.286556243896484
Performance of Standard Fo

 25%|█████████████████                                                   | 2/8 [04:27<13:25, 134.19s/it]

Train loss 1: 1358.38572742093,Train Acc:60.63088037634409%
Performance of Standard Forget Model on Forget Class
Accuracy: 22.027096152305603
Loss: 136.39959716796875
Performance of Standard Forget Model on Retain Class
Accuracy: 1.44819812849164
Loss: 89.01139068603516
Train loss 1: 810.8203209963712,Train Acc:48.58360389610389%
Performance of Standard Forget Model on Forget Class
Accuracy: 20.236817002296448
Loss: 116.09325408935547
Performance of Standard Forget Model on Retain Class
Accuracy: 39.66488540172577
Loss: 45.44211196899414
Train loss 1: 653.6389972365079,Train Acc:63.70907738095238%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.30297255143523216
Loss: 135.77581787109375
Performance of Standard Forget Model on Retain Class
Accuracy: 77.7990460395813
Loss: 1.4372508525848389
Train loss 1: 169.4982651158383,Train Acc:75.94536299630654%
Performance of Standard Forget Model on Forget Class
Accuracy: 2.5465134531259537
Loss: 48.3371696472168
Performance of S

 38%|█████████████████████████▌                                          | 3/8 [06:43<11:15, 135.10s/it]

Train loss 1: 1763.4451542515908,Train Acc:56.783434139784944%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.04006410308647901
Loss: 59.373287200927734
Performance of Standard Forget Model on Retain Class
Accuracy: 93.31733584403992
Loss: 1.8753191232681274
Train loss 1: 1047.1856464039195,Train Acc:53.465807629870135%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.08012820617295802
Loss: 107.4648208618164
Performance of Standard Forget Model on Retain Class
Accuracy: 74.80238676071167
Loss: 4.931105613708496
Train loss 1: 858.8114992777506,Train Acc:44.661458333333336%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.08012820617295802
Loss: 323.2149963378906
Performance of Standard Forget Model on Retain Class
Accuracy: 10.947567969560623
Loss: 169.2346954345703
Train loss 1: 636.6547472100509,Train Acc:56.6968923130194%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 136.69639587402344
Performance of Standa

 50%|██████████████████████████████████                                  | 4/8 [09:06<09:11, 137.94s/it]

Train loss 1: 2039.3323639900455,Train Acc:53.146001344086024%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 89.15209197998047
Performance of Standard Forget Model on Retain Class
Accuracy: 89.76197242736816
Loss: 3.4734389781951904
Train loss 1: 1249.9367242292924,Train Acc:40.789265422077925%
Performance of Standard Forget Model on Forget Class
Accuracy: 20.226801931858063
Loss: 370.4637451171875
Performance of Standard Forget Model on Retain Class
Accuracy: 0.011003520921804011
Loss: 229.09217834472656
Train loss 1: 868.2285716650631,Train Acc:70.77752976190476%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.20032052416354418
Loss: 64.73258209228516
Performance of Standard Forget Model on Retain Class
Accuracy: 93.78907084465027
Loss: 0.5414455533027649
Train loss 1: 778.9587267825478,Train Acc:40.24880251615882%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.15024038730189204
Loss: 226.44354248046875
Performance of Stan

 62%|██████████████████████████████████████████▌                         | 5/8 [11:32<07:03, 141.08s/it]

Train loss 1: 2211.6891713296213,Train Acc:56.95144489247312%
Performance of Standard Forget Model on Forget Class
Accuracy: 15.406845510005951
Loss: 136.88412475585938
Performance of Standard Forget Model on Retain Class
Accuracy: 25.406494736671448
Loss: 25.140731811523438
Train loss 1: 1423.9229826493697,Train Acc:41.7270698051948%
Performance of Standard Forget Model on Forget Class
Accuracy: 19.906288385391235
Loss: 335.09381103515625
Performance of Standard Forget Model on Retain Class
Accuracy: 21.86800241470337
Loss: 34.93797302246094
Train loss 1: 1251.7572246588757,Train Acc:47.56324404761905%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.08261999464593828
Loss: 187.99696350097656
Performance of Standard Forget Model on Retain Class
Accuracy: 88.10204863548279
Loss: 1.0304754972457886
Train loss 1: 990.4831165514494,Train Acc:43.25008656509695%
Performance of Standard Forget Model on Forget Class
Accuracy: 2.789340913295746
Loss: 198.59725952148438
Performa

 75%|███████████████████████████████████████████████████                 | 6/8 [14:09<04:52, 146.43s/it]

Train loss 1: 2888.3034955916864,Train Acc:48.98773521505376%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.030048078042455018
Loss: 186.08316040039062
Performance of Standard Forget Model on Retain Class
Accuracy: 80.8553695678711
Loss: 9.3108549118042
Train loss 1: 1657.331828030673,Train Acc:47.95454545454545%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.010016025771619752
Loss: 203.9792022705078
Performance of Standard Forget Model on Retain Class
Accuracy: 89.58448767662048
Loss: 3.6518683433532715
Train loss 1: 1552.2508635748,Train Acc:47.28422619047619%
Performance of Standard Forget Model on Forget Class
Accuracy: 15.081202983856201
Loss: 728.0101318359375
Performance of Standard Forget Model on Retain Class
Accuracy: 8.531060069799423
Loss: 244.34396362304688
Train loss 1: 1167.9630052666914,Train Acc:35.7470567867036%
Performance of Standard Forget Model on Forget Class
Accuracy: 5.002882704138756
Loss: 285.1639404296875
Performance of 

 88%|███████████████████████████████████████████████████████████▌        | 7/8 [16:30<02:24, 144.63s/it]

Train loss 1: 3112.6291727865896,Train Acc:48.214885752688176%
Performance of Standard Forget Model on Forget Class
Accuracy: 22.1073716878891
Loss: 210.71705627441406
Performance of Standard Forget Model on Retain Class
Accuracy: 4.82182539999485
Loss: 66.48725891113281
Train loss 1: 1866.4179874766958,Train Acc:44.51887175324675%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.020032051543239504
Loss: 281.2443542480469
Performance of Standard Forget Model on Retain Class
Accuracy: 75.77904462814331
Loss: 6.745492935180664
Train loss 1: 1740.966783796038,Train Acc:33.61235119047619%
Performance of Standard Forget Model on Forget Class
Accuracy: 21.45378887653351
Loss: 83.48030853271484
Performance of Standard Forget Model on Retain Class
Accuracy: 6.418757885694504
Loss: 91.76441192626953
Train loss 1: 1359.4770628276624,Train Acc:30.70716470452447%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.030048078042455018
Loss: 240.08441162109375
Performance

100%|████████████████████████████████████████████████████████████████████| 8/8 [18:52<00:00, 141.53s/it]


Best hyperparams: LR=8, grad clip=19, ratio=10 | Best retain acc: 95.72798013687134
CPU times: user 18min 5s, sys: 45.2 s, total: 18min 50s
Wall time: 18min 52s


<All keys matched successfully>

In [6]:
print("Performance of Standard Forget Model on Forget Class")
history = [evaluate_after_unlearning(unlearning_model, forget_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Standard Forget Model on Retain Class")
history = [evaluate_after_unlearning(unlearning_model, retain_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 132.2335968017578
Performance of Standard Forget Model on Retain Class
Accuracy: 95.72389721870422
Loss: 0.4859219193458557


In [7]:
print("Test on Forget Class")
history = [evaluate_after_unlearning(unlearning_model, forget_test_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Test on Retain Class")
history = [evaluate_after_unlearning(unlearning_model, retain_test_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Test on Forget Class
Accuracy: 0.0
Loss: 131.6812744140625
Test on Retain Class
Accuracy: 95.88307738304138
Loss: 0.3123016953468323
