In [1]:
import torch
from nlp_fast_unlearning.utils import prepare_dbpedia, build_noisy_dl, ensure_deterministic, DEVICE, BATCH_SIZE
from nlp_fast_unlearning.baseline_model import TextClassificationModel

_, _, _, dbpedia_vocab = prepare_dbpedia(for_baseline_only=True)

vocab_size = dbpedia_vocab.vocab_size

baseline_name = "DBpedia_baseline.pt"
unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
unlearning_model.load_state_dict(torch.load(baseline_name))

<All keys matched successfully>

In [2]:
%%time

(
    retain_samples,
    noisy_data,
    retain_valid_dl,
    forget_valid_dl,
    retain_test_dl,
    forget_test_dl,
    dbpedia_vocab,
) = prepare_dbpedia(
    for_baseline_only=False,
    classes_to_forget=[10, 11, 12, 13, 14],
    model=unlearning_model,
    retain_percentage=0.01,
    vocab_class=dbpedia_vocab,
)

Searching for error maximizing noise for class  10
Got loss 217.09841918945312 for tensor([103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,
        103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,
        103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,
        103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,
        103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,
        103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,
        103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,
        103, 103], device='cuda:0')
Searching for error maximizing noise for class  11
Got loss 209.3299560546875 for tensor([101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 1

In [3]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


def validation_step(model, batch):
    labels, text, offsets = batch
    out = model(text,offsets)
    loss = F.cross_entropy(out, labels)   
    acc = accuracy(out, labels)
    return {'Loss': loss.detach(), 'Acc': acc}


def validation_epoch_end(model, outputs):
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()   
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()      
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

In [4]:
import torch.nn.functional as F

@torch.no_grad()
def evaluate_after_unlearning(model, val_loader):
    model.eval()
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch_end(model, outputs)

In [18]:
%%time
from tqdm import tqdm


# lrs = [12]
# clips = [22]
# ratios = [1.5]

# The following search candidates take a few minutes to run through
lrs = [7,8,9,10,11,12,13,14]
clips = [17,18,19,20,21,22,23,24]
ratios = [4,8,12]


best_retain_acc = 0

unlearned_model_name = "DBpedia_fast_unlearned.pt"

for lr in tqdm(lrs):
    for clip in clips:
        for ratio in ratios:
            ensure_deterministic()
            
            noisy_loader = build_noisy_dl(
                retain_samples,
                noisy_data,
                dbpedia_vocab,
                retain_to_forget_ratio=ratio,
            )
            unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
            unlearning_model.load_state_dict(torch.load(baseline_name))

            optimizer = torch.optim.SGD(unlearning_model.parameters(), lr = lr)


            unlearning_model.train(True)
            for epoch in range(1):
                running_loss = 0.0
                running_acc = 0
                num_batches = len(noisy_loader)
                
                for i, data in enumerate(noisy_loader):
                    labels, inputs, offsets = data

                    optimizer.zero_grad()
                    outputs = unlearning_model(inputs,offsets)
                    loss = unlearning_model.criterion(outputs, labels)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(unlearning_model.parameters(), clip)
                    optimizer.step()

                    running_loss += loss.item()
                    out = torch.argmax(outputs.detach(),dim=1)
                    running_acc += (labels==out).sum().item()/labels.size(0)
                print(f"Train loss {epoch+1}: {running_loss/num_batches},Train Acc:{running_acc*100/num_batches}%")
                print("Performance of Standard Forget Model on Forget Class")
                history = [evaluate_after_unlearning(unlearning_model, forget_valid_dl)]
                print("Accuracy: {}".format(history[0]["Acc"]*100))
                print("Loss: {}".format(history[0]["Loss"]))

                print("Performance of Standard Forget Model on Retain Class")
                history = [evaluate_after_unlearning(unlearning_model, retain_valid_dl)]
                print("Accuracy: {}".format(history[0]["Acc"]*100))
                print("Loss: {}".format(history[0]["Loss"]))
                forget_acc = evaluate_after_unlearning(unlearning_model, forget_valid_dl)["Acc"]*100
                if forget_acc == 0.0:
                    retain_acc = evaluate_after_unlearning(unlearning_model, retain_valid_dl)["Acc"]*100
                    if retain_acc > best_retain_acc:
                        best_retain_acc = retain_acc
                        best_lr = lr
                        best_clip = clip
                        best_ratio = ratio
                        torch.save(unlearning_model.state_dict(), unlearned_model_name)

print(
    f"Best hyperparams: LR={best_lr}, grad clip={best_clip}, "
    f"ratio={best_ratio} | Best retain acc: {best_retain_acc}"
)
unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
unlearning_model.load_state_dict(torch.load(unlearned_model_name))

  0%|                                                                           | 0/8 [00:00<?, ?it/s]

Train loss 1: 1030.4310541817497,Train Acc:56.741431451612904%
Performance of Standard Forget Model on Forget Class
Accuracy: 14.603720605373383
Loss: 40.03542709350586
Performance of Standard Forget Model on Retain Class
Accuracy: 90.41445255279541
Loss: 2.6900699138641357
Train loss 1: 610.0576876293529,Train Acc:62.548424586776854%
Performance of Standard Forget Model on Forget Class
Accuracy: 12.449189275503159
Loss: 34.84757614135742
Performance of Standard Forget Model on Retain Class
Accuracy: 80.76491951942444
Loss: 3.8820762634277344
Train loss 1: 384.22266006469727,Train Acc:73.46739614456446%
Performance of Standard Forget Model on Forget Class
Accuracy: 19.180862605571747
Loss: 21.69687271118164
Performance of Standard Forget Model on Retain Class
Accuracy: 87.11174726486206
Loss: 5.915497779846191
Train loss 1: 1184.8233123727803,Train Acc:55.63256048387097%
Performance of Standard Forget Model on Forget Class
Accuracy: 15.055876970291138
Loss: 23.373981475830078
Performan

 12%|████████▍                                                          | 1/8 [01:09<08:03, 69.03s/it]

Train loss 1: 1277.1730982603565,Train Acc:58.59915034562212%
Performance of Standard Forget Model on Forget Class
Accuracy: 37.07266449928284
Loss: 28.735427856445312
Performance of Standard Forget Model on Retain Class
Accuracy: 77.8555154800415
Loss: 2.719019889831543
Train loss 1: 905.1896093975414,Train Acc:53.72062241735537%
Performance of Standard Forget Model on Forget Class
Accuracy: 20.985300838947296
Loss: 76.41410064697266
Performance of Standard Forget Model on Retain Class
Accuracy: 95.92437744140625
Loss: 0.45405784249305725
Train loss 1: 511.57462581835296,Train Acc:69.40403427373019%
Performance of Standard Forget Model on Forget Class
Accuracy: 16.33223593235016
Loss: 37.91170883178711
Performance of Standard Forget Model on Retain Class
Accuracy: 93.03171634674072
Loss: 0.41489988565444946
Train loss 1: 1443.1893314084698,Train Acc:57.51548099078341%
Performance of Standard Forget Model on Forget Class
Accuracy: 24.045832455158234
Loss: 24.182592391967773
Performance

 25%|████████████████▊                                                  | 2/8 [02:10<06:28, 64.76s/it]

Train loss 1: 1744.5103039427795,Train Acc:49.77318548387097%
Performance of Standard Forget Model on Forget Class
Accuracy: 37.15372383594513
Loss: 39.67557907104492
Performance of Standard Forget Model on Retain Class
Accuracy: 85.74455976486206
Loss: 11.948479652404785
Train loss 1: 1145.2359503832731,Train Acc:51.43821022727273%
Performance of Standard Forget Model on Forget Class
Accuracy: 1.6691874712705612
Loss: 124.63445281982422
Performance of Standard Forget Model on Retain Class
Accuracy: 90.983647108078
Loss: 0.46149876713752747
Train loss 1: 756.1232625057822,Train Acc:46.88433693403815%
Performance of Standard Forget Model on Forget Class
Accuracy: 19.539988040924072
Loss: 363.3919982910156
Performance of Standard Forget Model on Retain Class
Accuracy: 0.005580357174039818
Loss: 253.80152893066406
Train loss 1: 2016.078970655078,Train Acc:51.071068548387096%
Performance of Standard Forget Model on Forget Class
Accuracy: 20.674268901348114
Loss: 50.3224983215332
Performanc

 38%|█████████████████████████▏                                         | 3/8 [03:16<05:26, 65.31s/it]

Train loss 1: 2218.173499451341,Train Acc:49.06754032258065%
Performance of Standard Forget Model on Forget Class
Accuracy: 13.808080554008484
Loss: 49.82835388183594
Performance of Standard Forget Model on Retain Class
Accuracy: 86.63835525512695
Loss: 3.133941173553467
Train loss 1: 1339.1373856284401,Train Acc:55.28473657024793%
Performance of Standard Forget Model on Forget Class
Accuracy: 17.758017778396606
Loss: 77.97987365722656
Performance of Standard Forget Model on Retain Class
Accuracy: 87.96569108963013
Loss: 1.0685395002365112
Train loss 1: 1051.3480902220074,Train Acc:39.555561939783956%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.009765625145519152
Loss: 285.8570861816406
Performance of Standard Forget Model on Retain Class
Accuracy: 81.8949043750763
Loss: 5.920705318450928
Train loss 1: 2503.517289506572,Train Acc:47.46723790322581%
Performance of Standard Forget Model on Forget Class
Accuracy: 12.78122067451477
Loss: 52.96213912963867
Performance o

 50%|█████████████████████████████████▌                                 | 4/8 [04:27<04:29, 67.41s/it]

Train loss 1: 2991.115646824451,Train Acc:47.631048387096776%
Performance of Standard Forget Model on Forget Class
Accuracy: 2.0380785688757896
Loss: 51.15435028076172
Performance of Standard Forget Model on Retain Class
Accuracy: 89.99411463737488
Loss: 3.955109119415283
Train loss 1: 1604.4182359955528,Train Acc:53.06204803719008%
Performance of Standard Forget Model on Forget Class
Accuracy: 22.51703441143036
Loss: 109.66899871826172
Performance of Standard Forget Model on Retain Class
Accuracy: 85.39393544197083
Loss: 3.1435537338256836
Train loss 1: 1131.5503070228979,Train Acc:39.793653757756836%
Performance of Standard Forget Model on Forget Class
Accuracy: 14.252158999443054
Loss: 149.87962341308594
Performance of Standard Forget Model on Retain Class
Accuracy: 78.1180202960968
Loss: 6.917719841003418
Train loss 1: 2947.3254416152918,Train Acc:43.83820564516129%
Performance of Standard Forget Model on Forget Class
Accuracy: 8.996563404798508
Loss: 48.25835037231445
Performance 

 62%|█████████████████████████████████████████▉                         | 5/8 [05:36<03:23, 67.86s/it]

Train loss 1: 3143.3154396202513,Train Acc:43.44758064516129%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.42968750931322575
Loss: 92.35118103027344
Performance of Standard Forget Model on Retain Class
Accuracy: 83.56256484985352
Loss: 35.23063659667969
Train loss 1: 1949.8249461434104,Train Acc:54.204868285123965%
Performance of Standard Forget Model on Forget Class
Accuracy: 16.467486321926117
Loss: 97.59504699707031
Performance of Standard Forget Model on Retain Class
Accuracy: 83.55525135993958
Loss: 0.6877520084381104
Train loss 1: 1456.506899432132,Train Acc:36.03472980349345%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.05859375232830644
Loss: 246.3173828125
Performance of Standard Forget Model on Retain Class
Accuracy: 56.55097961425781
Loss: 12.116887092590332
Train loss 1: 3515.2703354704763,Train Acc:38.76548099078341%
Performance of Standard Forget Model on Forget Class
Accuracy: 27.466005086898804
Loss: 88.1341323852539
Performance o

 75%|██████████████████████████████████████████████████▎                | 6/8 [06:40<02:13, 66.80s/it]

Train loss 1: 3734.4291566418065,Train Acc:40.07776497695853%
Performance of Standard Forget Model on Forget Class
Accuracy: 19.539988040924072
Loss: 1699.9691162109375
Performance of Standard Forget Model on Retain Class
Accuracy: 0.022321428696159273
Loss: 1270.7373046875
Train loss 1: 2527.048370361328,Train Acc:34.28622159090909%
Performance of Standard Forget Model on Forget Class
Accuracy: 7.425032556056976
Loss: 195.05662536621094
Performance of Standard Forget Model on Retain Class
Accuracy: 83.64351391792297
Loss: 29.395435333251953
Train loss 1: 1681.01779164766,Train Acc:39.32052258101586%
Performance of Standard Forget Model on Forget Class
Accuracy: 16.94064289331436
Loss: 375.10626220703125
Performance of Standard Forget Model on Retain Class
Accuracy: 3.389238566160202
Loss: 64.957763671875
Train loss 1: 4220.201319696442,Train Acc:41.05882776497696%
Performance of Standard Forget Model on Forget Class
Accuracy: 19.61958110332489
Loss: 348.4623718261719
Performance of St

 88%|██████████████████████████████████████████████████████████▋        | 7/8 [07:43<01:05, 65.51s/it]

Train loss 1: 3970.5425999010763,Train Acc:37.5810051843318%
Performance of Standard Forget Model on Forget Class
Accuracy: 10.847626626491547
Loss: 337.9039001464844
Performance of Standard Forget Model on Retain Class
Accuracy: 5.82510270178318
Loss: 146.4703369140625
Train loss 1: 3226.011701410467,Train Acc:26.412383780991732%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.33203125931322575
Loss: 237.41514587402344
Performance of Standard Forget Model on Retain Class
Accuracy: 55.76603412628174
Loss: 27.063329696655273
Train loss 1: 1920.540557208814,Train Acc:36.77243737071937%
Performance of Standard Forget Model on Forget Class
Accuracy: 0.06835937383584678
Loss: 389.8408203125
Performance of Standard Forget Model on Retain Class
Accuracy: 49.36717450618744
Loss: 52.65946578979492
Train loss 1: 4612.472620279504,Train Acc:47.127016129032256%
Performance of Standard Forget Model on Forget Class
Accuracy: 18.28242540359497
Loss: 53.417877197265625
Performance of 

100%|███████████████████████████████████████████████████████████████████| 8/8 [08:47<00:00, 65.90s/it]


Best hyperparams: LR=11, grad clip=19, ratio=12 | Best retain acc: 88.00011873245239
CPU times: user 8min 27s, sys: 19.6 s, total: 8min 46s
Wall time: 8min 47s


<All keys matched successfully>

In [22]:
print("Performance of Standard Forget Model on Forget Class")
history = [evaluate_after_unlearning(unlearning_model, forget_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Standard Forget Model on Retain Class")
history = [evaluate_after_unlearning(unlearning_model, retain_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 256.3982849121094
Performance of Standard Forget Model on Retain Class
Accuracy: 87.99634575843811
Loss: 11.496419906616211


In [23]:
print("Test on Forget Class")
history = [evaluate_after_unlearning(unlearning_model, forget_test_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Test on Retain Class")
history = [evaluate_after_unlearning(unlearning_model, retain_test_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Test on Forget Class
Accuracy: 0.01195790828205645
Loss: 258.3566589355469
Test on Retain Class
Accuracy: 87.88360953330994
Loss: 11.101592063903809
