In [1]:
import torch
from nlp_fast_unlearning.utils import prepare_dbpedia, build_noisy_dl, ensure_deterministic, DEVICE, BATCH_SIZE
from nlp_fast_unlearning.baseline_model import TextClassificationModel

_, _, _, dbpedia_vocab = prepare_dbpedia(for_baseline_only=True)

vocab_size = dbpedia_vocab.vocab_size

baseline_name = "DBpedia_baseline.pt"
unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
unlearning_model.load_state_dict(torch.load(baseline_name))

<All keys matched successfully>

In [2]:
%%time

(
    retain_samples,
    noisy_data,
    retain_valid_dl,
    forget_valid_dl,
    retain_test_dl,
    forget_test_dl,
    dbpedia_vocab,
) = prepare_dbpedia(
    for_baseline_only=False,
    classes_to_forget=[2, 4],
    model=unlearning_model,
    retain_percentage=0.01,
    vocab_class=dbpedia_vocab,
)

Searching for error maximizing noise for class  2
Got loss 203.99356079101562 for tensor([101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101], device='cuda:0')
Searching for error maximizing noise for class  4
Got loss 236.83079528808594 for tensor([101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 10

In [3]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


def validation_step(model, batch):
    labels, text, offsets = batch
    out = model(text,offsets)
    loss = F.cross_entropy(out, labels)   
    acc = accuracy(out, labels)
    return {'Loss': loss.detach(), 'Acc': acc}


def validation_epoch_end(model, outputs):
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()   
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()      
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

In [4]:
import torch.nn.functional as F

@torch.no_grad()
def evaluate_after_unlearning(model, val_loader):
    model.eval()
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch_end(model, outputs)

In [8]:
%%time
from tqdm import tqdm


# lrs = [12]
# clips = [22]
# ratios = [1.5]

# The following search candidates take a few minutes to run through
lrs = [7,8,9,10,11,12,13,14]
clips = [17,18,19,20,21,22,23,24]
ratios = [0.8,1,1.5]


best_retain_acc = 0

unlearned_model_name = "DBpedia_fast_unlearned.pt"

for lr in tqdm(lrs):
    for clip in clips:
        for ratio in ratios:
            ensure_deterministic()
            
            noisy_loader = build_noisy_dl(
                retain_samples,
                noisy_data,
                dbpedia_vocab,
                retain_to_forget_ratio=ratio,
            )
            unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
            unlearning_model.load_state_dict(torch.load(baseline_name))

            optimizer = torch.optim.SGD(unlearning_model.parameters(), lr = lr)


            unlearning_model.train(True)
            for epoch in range(1):
                running_loss = 0.0
                running_acc = 0
                num_batches = len(noisy_loader)
                
                for i, data in enumerate(noisy_loader):
                    labels, inputs, offsets = data

                    optimizer.zero_grad()
                    outputs = unlearning_model(inputs,offsets)
                    loss = unlearning_model.criterion(outputs, labels)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(unlearning_model.parameters(), clip)
                    optimizer.step()

                    running_loss += loss.item()
                    out = torch.argmax(outputs.detach(),dim=1)
                    running_acc += (labels==out).sum().item()/labels.size(0)
                print(f"Train loss {epoch+1}: {running_loss/num_batches},Train Acc:{running_acc*100/num_batches}%")
                print("Performance of Standard Forget Model on Forget Class")
                history = [evaluate_after_unlearning(unlearning_model, forget_valid_dl)]
                print("Accuracy: {}".format(history[0]["Acc"]*100))
                print("Loss: {}".format(history[0]["Loss"]))

                print("Performance of Standard Forget Model on Retain Class")
                history = [evaluate_after_unlearning(unlearning_model, retain_valid_dl)]
                print("Accuracy: {}".format(history[0]["Acc"]*100))
                print("Loss: {}".format(history[0]["Loss"]))
                forget_acc = evaluate_after_unlearning(unlearning_model, forget_valid_dl)["Acc"]*100
                if forget_acc == 0.0:
                    retain_acc = evaluate_after_unlearning(unlearning_model, retain_valid_dl)["Acc"]*100
                    if retain_acc > best_retain_acc:
                        best_retain_acc = retain_acc
                        best_lr = lr
                        best_clip = clip
                        best_ratio = ratio
                        torch.save(unlearning_model.state_dict(), unlearned_model_name)

print(
    f"Best hyperparams: LR={best_lr}, grad clip={best_clip}, "
    f"ratio={best_ratio} | Best retain acc: {best_retain_acc}"
)
unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
unlearning_model.load_state_dict(torch.load(unlearned_model_name))

  0%|                                                                           | 0/8 [00:00<?, ?it/s]

Train loss 1: 1954.1223998561738,Train Acc:43.84920634920635%
Performance of Standard Forget Model on Forget Class
Accuracy: 4.848108068108559
Loss: 14.142340660095215
Performance of Standard Forget Model on Retain Class
Accuracy: 80.6059718132019
Loss: 5.4084882736206055
Train loss 1: 1833.93407592067,Train Acc:47.148345153664295%
Performance of Standard Forget Model on Forget Class
Accuracy: 47.04448580741882
Loss: 8.000864028930664
Performance of Standard Forget Model on Retain Class
Accuracy: 82.2912335395813
Loss: 2.6581461429595947
Train loss 1: 1237.41339858373,Train Acc:61.17273351648352%
Performance of Standard Forget Model on Forget Class
Accuracy: 34.61853563785553
Loss: 18.675622940063477
Performance of Standard Forget Model on Retain Class
Accuracy: 87.29400634765625
Loss: 5.0146708488464355
Train loss 1: 1890.960366828101,Train Acc:47.48564514189514%
Performance of Standard Forget Model on Forget Class
Accuracy: 33.410948514938354
Loss: 18.67273712158203
Performance of St

 12%|████████▍                                                          | 1/8 [01:04<07:33, 64.73s/it]

Accuracy: 86.56262159347534
Loss: 2.717987537384033
Train loss 1: 2389.3373268369646,Train Acc:48.48879419191919%
Performance of Standard Forget Model on Forget Class
Accuracy: 49.29521977901459
Loss: 22.376758575439453
Performance of Standard Forget Model on Retain Class
Accuracy: 77.37258672714233
Loss: 16.248149871826172
Train loss 1: 1978.6324221116524,Train Acc:54.511580230496456%
Performance of Standard Forget Model on Forget Class
Accuracy: 50.485050678253174
Loss: 27.891767501831055
Performance of Standard Forget Model on Retain Class
Accuracy: 71.03227376937866
Loss: 5.684261798858643
Train loss 1: 1863.577796198073,Train Acc:49.76619734432235%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 78.37171936035156
Performance of Standard Forget Model on Retain Class
Accuracy: 73.88986349105835
Loss: 6.144255638122559
Train loss 1: 2382.788270526462,Train Acc:45.54849086099086%
Performance of Standard Forget Model on Forget Class
Accuracy: 66.11772179603577


 25%|████████████████▊                                                  | 2/8 [02:08<06:25, 64.29s/it]

Train loss 1: 2850.361792428153,Train Acc:49.69693211880712%
Performance of Standard Forget Model on Forget Class
Accuracy: 45.92708945274353
Loss: 43.03944778442383
Performance of Standard Forget Model on Retain Class
Accuracy: 43.76183748245239
Loss: 7.491744518280029
Train loss 1: 2505.0262049922235,Train Acc:52.612323926319945%
Performance of Standard Forget Model on Forget Class
Accuracy: 34.39315855503082
Loss: 41.77104949951172
Performance of Standard Forget Model on Retain Class
Accuracy: 86.53588891029358
Loss: 1.029057264328003
Train loss 1: 2232.8617223103843,Train Acc:50.84177541208791%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.8084888570010662
Loss: 57.955081939697266
Performance of Standard Forget Model on Retain Class
Accuracy: 17.66776591539383
Loss: 26.555126190185547
Train loss 1: 3204.901436941964,Train Acc:46.85733826358826%
Performance of Standard Forget Model on Forget Class
Accuracy: 16.105008125305176
Loss: 12.609291076660156
Performance o

 38%|█████████████████████████▏                                         | 3/8 [03:08<05:11, 62.34s/it]

Accuracy: 90.84141254425049
Loss: 5.910426616668701
Train loss 1: 3515.893128561595,Train Acc:46.12418831168831%
Performance of Standard Forget Model on Forget Class
Accuracy: 92.20808148384094
Loss: 0.7163382172584534
Performance of Standard Forget Model on Retain Class
Accuracy: 75.03489851951599
Loss: 8.285689353942871
Train loss 1: 3279.9450131169074,Train Acc:48.73100743695824%
Performance of Standard Forget Model on Forget Class
Accuracy: 18.869447708129883
Loss: 12.083215713500977
Performance of Standard Forget Model on Retain Class
Accuracy: 88.56369853019714
Loss: 7.515978813171387
Train loss 1: 2577.3665381357782,Train Acc:58.60634157509158%
Performance of Standard Forget Model on Forget Class
Accuracy: 41.728675365448
Loss: 20.77687644958496
Performance of Standard Forget Model on Retain Class
Accuracy: 48.15923869609833
Loss: 6.940013408660889
Train loss 1: 4114.26227281964,Train Acc:47.36257665945166%
Performance of Standard Forget Model on Forget Class
Accuracy: 5.0678346

 50%|█████████████████████████████████▌                                 | 4/8 [04:07<04:04, 61.04s/it]

Accuracy: 89.3744945526123
Loss: 11.441539764404297
Train loss 1: 4490.31599308574,Train Acc:49.12424092111592%
Performance of Standard Forget Model on Forget Class
Accuracy: 5.998393893241882
Loss: 53.48793029785156
Performance of Standard Forget Model on Retain Class
Accuracy: 64.5340621471405
Loss: 40.66432571411133
Train loss 1: 3864.805141961133,Train Acc:44.84214933018124%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.048828125
Loss: 61.89704895019531
Performance of Standard Forget Model on Retain Class
Accuracy: 75.87522864341736
Loss: 18.893972396850586
Train loss 1: 3103.674538021996,Train Acc:51.31424565018315%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.5887622479349375
Loss: 41.18579864501953
Performance of Standard Forget Model on Retain Class
Accuracy: 86.50814890861511
Loss: 16.03072738647461
Train loss 1: 4740.201330343883,Train Acc:51.5051933020683%
Performance of Standard Forget Model on Forget Class
Accuracy: 25.942060351371765

 62%|█████████████████████████████████████████▉                         | 5/8 [05:07<03:02, 60.70s/it]

Train loss 1: 5599.73670113276,Train Acc:44.36402717652717%
Performance of Standard Forget Model on Forget Class
Accuracy: 42.17761158943176
Loss: 12.327576637268066
Performance of Standard Forget Model on Retain Class
Accuracy: 80.56947588920593
Loss: 6.0359930992126465
Train loss 1: 5574.46665152797,Train Acc:45.63248005319149%
Performance of Standard Forget Model on Forget Class
Accuracy: 19.426329433918
Loss: 15.621918678283691
Performance of Standard Forget Model on Retain Class
Accuracy: 82.0682942867279
Loss: 1.6745206117630005
Train loss 1: 4509.326312837146,Train Acc:37.56982600732601%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 65.34602355957031
Performance of Standard Forget Model on Retain Class
Accuracy: 78.51492762565613
Loss: 10.307090759277344
Train loss 1: 6129.813278561547,Train Acc:39.67558772246272%
Performance of Standard Forget Model on Forget Class
Accuracy: 9.01544764637947
Loss: 29.924278259277344
Performance of Standard Forget Mode

 75%|██████████████████████████████████████████████████▎                | 6/8 [06:11<02:03, 61.58s/it]

Train loss 1: 6420.78174881708,Train Acc:43.85672198172198%
Performance of Standard Forget Model on Forget Class
Accuracy: 78.36611270904541
Loss: 3.8072409629821777
Performance of Standard Forget Model on Retain Class
Accuracy: 44.98021602630615
Loss: 68.86455535888672
Train loss 1: 5740.819993619566,Train Acc:52.32174078999211%
Performance of Standard Forget Model on Forget Class
Accuracy: 27.366548776626587
Loss: 10.820219039916992
Performance of Standard Forget Model on Retain Class
Accuracy: 84.82748866081238
Loss: 10.16382884979248
Train loss 1: 4524.617451985677,Train Acc:48.59045902014652%
Performance of Standard Forget Model on Forget Class
Accuracy: 65.4454231262207
Loss: 13.821834564208984
Performance of Standard Forget Model on Retain Class
Accuracy: 54.82828617095947
Loss: 34.61254119873047
Train loss 1: 7061.924585039654,Train Acc:44.78415103415103%
Performance of Standard Forget Model on Forget Class
Accuracy: 19.415031373500824
Loss: 66.489013671875
Performance of Stand

 88%|██████████████████████████████████████████████████████████▋        | 7/8 [07:18<01:03, 63.46s/it]

Accuracy: 86.66425943374634
Loss: 7.6049723625183105
Train loss 1: 7145.972614167228,Train Acc:45.33655002405003%
Performance of Standard Forget Model on Forget Class
Accuracy: 41.03196859359741
Loss: 23.933429718017578
Performance of Standard Forget Model on Retain Class
Accuracy: 88.63524794578552
Loss: 6.267319679260254
Train loss 1: 6801.186072314227,Train Acc:45.27617710795902%
Performance of Standard Forget Model on Forget Class
Accuracy: 3.478096053004265
Loss: 37.865692138671875
Performance of Standard Forget Model on Retain Class
Accuracy: 65.43402075767517
Loss: 32.94458770751953
Train loss 1: 5476.63759122576,Train Acc:48.211137820512825%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 62.03722381591797
Performance of Standard Forget Model on Retain Class
Accuracy: 86.08797788619995
Loss: 1.7531734704971313
Train loss 1: 8932.990634070502,Train Acc:41.57441227753728%
Performance of Standard Forget Model on Forget Class
Accuracy: 56.367027759552
Loss:

100%|███████████████████████████████████████████████████████████████████| 8/8 [08:26<00:00, 63.25s/it]


Best hyperparams: LR=14, grad clip=22, ratio=1 | Best retain acc: 93.76998543739319
CPU times: user 8min 3s, sys: 20.3 s, total: 8min 24s
Wall time: 8min 26s


<All keys matched successfully>

In [9]:
print("Performance of Standard Forget Model on Forget Class")
history = [evaluate_after_unlearning(unlearning_model, forget_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Standard Forget Model on Retain Class")
history = [evaluate_after_unlearning(unlearning_model, retain_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 134.42599487304688
Performance of Standard Forget Model on Retain Class
Accuracy: 93.76683235168457
Loss: 0.622072696685791


In [10]:
print("Test on Forget Class")
history = [evaluate_after_unlearning(unlearning_model, forget_test_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Test on Retain Class")
history = [evaluate_after_unlearning(unlearning_model, retain_test_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Test on Forget Class
Accuracy: 0.0
Loss: 134.0271453857422
Test on Retain Class
Accuracy: 93.9660906791687
Loss: 0.6111023426055908
