In [1]:
import torch
from nlp_fast_unlearning.utils import prepare_dbpedia, build_noisy_dl, ensure_deterministic, DEVICE, BATCH_SIZE
from nlp_fast_unlearning.baseline_model import TextClassificationModel

_, _, _, dbpedia_vocab = prepare_dbpedia(for_baseline_only=True)

vocab_size = dbpedia_vocab.vocab_size

baseline_name = "DBpedia_baseline.pt"
unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
unlearning_model.load_state_dict(torch.load(baseline_name))

<All keys matched successfully>

In [2]:
%%time

(
    retain_samples,
    noisy_data,
    retain_valid_dl,
    forget_valid_dl,
    retain_test_dl,
    forget_test_dl,
    dbpedia_vocab,
) = prepare_dbpedia(
    for_baseline_only=False,
    classes_to_forget=[2, 4],
    model=unlearning_model,
    retain_percentage=0.01,
    vocab_class=dbpedia_vocab,
)

Searching for error maximizing noise for class  2
Got loss 203.99356079101562 for tensor([101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101], device='cuda:0')
Searching for error maximizing noise for class  4
Got loss 236.83079528808594 for tensor([101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 10

In [3]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


def validation_step(model, batch):
    labels, text, offsets = batch
    out = model(text,offsets)
    loss = F.cross_entropy(out, labels)   
    acc = accuracy(out, labels)
    return {'Loss': loss.detach(), 'Acc': acc}


def validation_epoch_end(model, outputs):
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()   
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()      
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

In [4]:
import torch.nn.functional as F

@torch.no_grad()
def evaluate_after_unlearning(model, val_loader):
    model.eval()
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch_end(model, outputs)

In [5]:
%%time
from tqdm import tqdm


# lrs = [12]
# clips = [22]
# ratios = [1.5]

# The following search candidates take a few minutes to run through
lrs = [7,8,9,10,11,12,13,14]
clips = [17,18,19,20,21,22,23,24]
ratios = [1,2,2.5]


best_retain_acc = 0

unlearned_model_name = "DBpedia_fast_unlearned.pt"

for lr in tqdm(lrs):
    for clip in clips:
        for ratio in ratios:
            ensure_deterministic()
            
            noisy_loader = build_noisy_dl(
                retain_samples,
                noisy_data,
                dbpedia_vocab,
                retain_to_forget_ratio=ratio,
            )
            unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
            unlearning_model.load_state_dict(torch.load(baseline_name))

            optimizer = torch.optim.SGD(unlearning_model.parameters(), lr = lr)


            unlearning_model.train(True)
            for epoch in range(1):
                running_loss = 0.0
                running_acc = 0
                num_batches = len(noisy_loader)
                
                for i, data in enumerate(noisy_loader):
                    labels, inputs, offsets = data

                    optimizer.zero_grad()
                    outputs = unlearning_model(inputs,offsets)
                    loss = unlearning_model.criterion(outputs, labels)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(unlearning_model.parameters(), clip)
                    optimizer.step()

                    running_loss += loss.item()
                    out = torch.argmax(outputs.detach(),dim=1)
                    running_acc += (labels==out).sum().item()/labels.size(0)
                print(f"Train loss {epoch+1}: {running_loss/num_batches},Train Acc:{running_acc*100/num_batches}%")
                print("Performance of Standard Forget Model on Forget Class")
                history = [evaluate_after_unlearning(unlearning_model, forget_valid_dl)]
                print("Accuracy: {}".format(history[0]["Acc"]*100))
                print("Loss: {}".format(history[0]["Loss"]))

                print("Performance of Standard Forget Model on Retain Class")
                history = [evaluate_after_unlearning(unlearning_model, retain_valid_dl)]
                print("Accuracy: {}".format(history[0]["Acc"]*100))
                print("Loss: {}".format(history[0]["Loss"]))
                forget_acc = evaluate_after_unlearning(unlearning_model, forget_valid_dl)["Acc"]*100
                if forget_acc == 0.0:
                    retain_acc = evaluate_after_unlearning(unlearning_model, retain_valid_dl)["Acc"]*100
                    if retain_acc > best_retain_acc:
                        best_retain_acc = retain_acc
                        best_lr = lr
                        best_clip = clip
                        best_ratio = ratio
                        torch.save(unlearning_model.state_dict(), unlearned_model_name)

print(
    f"Best hyperparams: LR={best_lr}, grad clip={best_clip}, "
    f"ratio={best_ratio} | Best retain acc: {best_retain_acc}"
)
unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
unlearning_model.load_state_dict(torch.load(unlearned_model_name))

  0%|                                                                             | 0/8 [00:00<?, ?it/s]

Train loss 1: 1833.93407592067,Train Acc:47.148345153664295%
Performance of Standard Forget Model on Forget Class
Accuracy: 47.04448580741882
Loss: 8.000864028930664
Performance of Standard Forget Model on Retain Class
Accuracy: 82.2912335395813
Loss: 2.6581461429595947
Train loss 1: 997.1193838119507,Train Acc:65.16193928303304%
Performance of Standard Forget Model on Forget Class
Accuracy: 46.00961208343506
Loss: 18.482261657714844
Performance of Standard Forget Model on Retain Class
Accuracy: 89.30250406265259
Loss: 1.3878140449523926
Train loss 1: 961.0374676404577,Train Acc:55.859375%
Performance of Standard Forget Model on Forget Class
Accuracy: 36.25427782535553
Loss: 24.206205368041992
Performance of Standard Forget Model on Retain Class
Accuracy: 84.23694968223572
Loss: 1.4677088260650635
Train loss 1: 1998.4589576544586,Train Acc:56.67894872931442%
Performance of Standard Forget Model on Forget Class
Accuracy: 8.676475286483765
Loss: 10.032983779907227
Performance of Standard

 12%|████████▋                                                            | 1/8 [01:01<07:07, 61.12s/it]

Accuracy: 1.4567035250365734
Loss: 88.2271499633789
Train loss 1: 1978.6324221116524,Train Acc:54.511580230496456%
Performance of Standard Forget Model on Forget Class
Accuracy: 50.485050678253174
Loss: 27.891767501831055
Performance of Standard Forget Model on Retain Class
Accuracy: 71.03227376937866
Loss: 5.684261798858643
Train loss 1: 1538.8531637191772,Train Acc:57.42539414414414%
Performance of Standard Forget Model on Forget Class
Accuracy: 5.095073580741882
Loss: 13.866820335388184
Performance of Standard Forget Model on Retain Class
Accuracy: 89.61226344108582
Loss: 2.8266055583953857
Train loss 1: 1309.8608646970806,Train Acc:58.937026515151516%
Performance of Standard Forget Model on Forget Class
Accuracy: 14.492671191692352
Loss: 32.617244720458984
Performance of Standard Forget Model on Retain Class
Accuracy: 89.02958035469055
Loss: 8.161430358886719
Train loss 1: 2438.847771441495,Train Acc:48.55893543144207%
Performance of Standard Forget Model on Forget Class
Accuracy: 

 25%|█████████████████▎                                                   | 2/8 [02:03<06:10, 61.68s/it]

Accuracy: 89.16210532188416
Loss: 1.7363709211349487
Train loss 1: 2505.0262049922235,Train Acc:52.612323926319945%
Performance of Standard Forget Model on Forget Class
Accuracy: 34.39315855503082
Loss: 41.77104949951172
Performance of Standard Forget Model on Retain Class
Accuracy: 86.53588891029358
Loss: 1.029057264328003
Train loss 1: 1947.4724136988323,Train Acc:54.70773507882883%
Performance of Standard Forget Model on Forget Class
Accuracy: 35.068279504776
Loss: 4.676785945892334
Performance of Standard Forget Model on Retain Class
Accuracy: 87.9330575466156
Loss: 3.1867406368255615
Train loss 1: 1785.066838539008,Train Acc:49.79876893939394%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.6648292299360037
Loss: 61.81329345703125
Performance of Standard Forget Model on Retain Class
Accuracy: 77.93954014778137
Loss: 2.289846420288086
Train loss 1: 2859.739163151494,Train Acc:48.9310911643026%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss:

 38%|█████████████████████████▉                                           | 3/8 [03:14<05:31, 66.25s/it]

Train loss 1: 3279.9450131169074,Train Acc:48.73100743695824%
Performance of Standard Forget Model on Forget Class
Accuracy: 18.869447708129883
Loss: 12.083215713500977
Performance of Standard Forget Model on Retain Class
Accuracy: 88.56369853019714
Loss: 7.515978813171387
Train loss 1: 3005.757942994436,Train Acc:43.586653059309306%
Performance of Standard Forget Model on Forget Class
Accuracy: 11.330142617225647
Loss: 70.35594177246094
Performance of Standard Forget Model on Retain Class
Accuracy: 78.18236947059631
Loss: 13.064959526062012
Train loss 1: 2213.871187614672,Train Acc:48.86363636363637%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.146484375
Loss: 182.79434204101562
Performance of Standard Forget Model on Retain Class
Accuracy: 11.528169363737106
Loss: 39.01448059082031
Train loss 1: 3565.028717606156,Train Acc:53.46468060480693%
Performance of Standard Forget Model on Forget Class
Accuracy: 85.18994450569153
Loss: 0.9557057023048401
Performance of Sta

 50%|██████████████████████████████████▌                                  | 4/8 [04:17<04:18, 64.73s/it]

Train loss 1: 3864.805141961133,Train Acc:44.84214933018124%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.048828125
Loss: 61.89704895019531
Performance of Standard Forget Model on Retain Class
Accuracy: 75.87522864341736
Loss: 18.893972396850586
Train loss 1: 2545.460158665975,Train Acc:62.78974286786787%
Performance of Standard Forget Model on Forget Class
Accuracy: 15.705908834934235
Loss: 26.613309860229492
Performance of Standard Forget Model on Retain Class
Accuracy: 91.12815260887146
Loss: 3.9767048358917236
Train loss 1: 2880.303505406235,Train Acc:45.13494318181818%
Performance of Standard Forget Model on Forget Class
Accuracy: 1.5653247013688087
Loss: 69.72611999511719
Performance of Standard Forget Model on Retain Class
Accuracy: 87.40620613098145
Loss: 6.711477279663086
Train loss 1: 4282.255319100839,Train Acc:49.51102615248227%
Performance of Standard Forget Model on Forget Class
Accuracy: 18.095663189888
Loss: 14.325858116149902
Performance of Standard

 62%|███████████████████████████████████████████▏                         | 5/8 [05:21<03:13, 64.63s/it]

Train loss 1: 5574.46665152797,Train Acc:45.63248005319149%
Performance of Standard Forget Model on Forget Class
Accuracy: 19.426329433918
Loss: 15.621918678283691
Performance of Standard Forget Model on Retain Class
Accuracy: 82.0682942867279
Loss: 1.6745206117630005
Train loss 1: 4108.099779552884,Train Acc:31.5283056493994%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.2197265625
Loss: 73.30574035644531
Performance of Standard Forget Model on Retain Class
Accuracy: 66.76695942878723
Loss: 26.988149642944336
Train loss 1: 3189.4814685474744,Train Acc:50.47348484848485%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.5887622945010662
Loss: 50.09455108642578
Performance of Standard Forget Model on Retain Class
Accuracy: 89.34034705162048
Loss: 9.951517105102539
Train loss 1: 5017.76274172465,Train Acc:46.637208185579205%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.2197265625
Loss: 58.25229263305664
Performance of Standard Forget 

 75%|███████████████████████████████████████████████████▊                 | 6/8 [06:24<02:08, 64.10s/it]

Accuracy: 89.21658396720886
Loss: 30.574399948120117
Train loss 1: 5740.819993619566,Train Acc:52.32174078999211%
Performance of Standard Forget Model on Forget Class
Accuracy: 27.366548776626587
Loss: 10.820219039916992
Performance of Standard Forget Model on Retain Class
Accuracy: 84.82748866081238
Loss: 10.16382884979248
Train loss 1: 4116.666667408414,Train Acc:41.76051051051051%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.146484375
Loss: 150.11532592773438
Performance of Standard Forget Model on Retain Class
Accuracy: 18.977011740207672
Loss: 25.815444946289062
Train loss 1: 3723.6323657099047,Train Acc:43.098958333333336%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 121.6712875366211
Performance of Standard Forget Model on Retain Class
Accuracy: 42.5746887922287
Loss: 14.162519454956055
Train loss 1: 6476.552634424634,Train Acc:39.18193459416863%
Performance of Standard Forget Model on Forget Class
Accuracy: 51.876652240753174
Loss

 88%|████████████████████████████████████████████████████████████▍        | 7/8 [07:30<01:04, 64.49s/it]

Accuracy: 72.03108072280884
Loss: 12.19597339630127
Train loss 1: 6801.186072314227,Train Acc:45.27617710795902%
Performance of Standard Forget Model on Forget Class
Accuracy: 3.478096053004265
Loss: 37.865692138671875
Performance of Standard Forget Model on Retain Class
Accuracy: 65.43402075767517
Loss: 32.94458770751953
Train loss 1: 4843.911001205444,Train Acc:47.07676426426427%
Performance of Standard Forget Model on Forget Class
Accuracy: 11.983672529459
Loss: 111.47112274169922
Performance of Standard Forget Model on Retain Class
Accuracy: 88.7668788433075
Loss: 47.27890396118164
Train loss 1: 4768.860942840576,Train Acc:40.73153409090909%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 700.404541015625
Performance of Standard Forget Model on Retain Class
Accuracy: 8.111485838890076
Loss: 274.0776672363281
Train loss 1: 8338.64200394242,Train Acc:39.87530166469661%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.1220703125
Loss: 46.037406

100%|█████████████████████████████████████████████████████████████████████| 8/8 [08:34<00:00, 64.32s/it]

Accuracy: 85.87660193443298
Loss: 14.576805114746094
Best hyperparams: LR=14, grad clip=22, ratio=1 | Best retain acc: 93.76998543739319





CPU times: user 8min 9s, sys: 22.5 s, total: 8min 31s
Wall time: 8min 35s


<All keys matched successfully>

In [6]:
print("Performance of Standard Forget Model on Forget Class")
history = [evaluate_after_unlearning(unlearning_model, forget_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Standard Forget Model on Retain Class")
history = [evaluate_after_unlearning(unlearning_model, retain_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 134.45887756347656
Performance of Standard Forget Model on Retain Class
Accuracy: 93.76998543739319
Loss: 0.6215648651123047


In [7]:
print("Test on Forget Class")
history = [evaluate_after_unlearning(unlearning_model, forget_test_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Test on Retain Class")
history = [evaluate_after_unlearning(unlearning_model, retain_test_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Test on Forget Class
Accuracy: 0.0
Loss: 133.45765686035156
Test on Retain Class
Accuracy: 93.96054744720459
Loss: 0.6118340492248535
