In [1]:
import torch
from nlp_fast_unlearning.utils import prepare_dbpedia, build_noisy_dl, ensure_deterministic, DEVICE, BATCH_SIZE
from nlp_fast_unlearning.baseline_model import TextClassificationModel

_, _, _, dbpedia_vocab = prepare_dbpedia(for_baseline_only=True)

vocab_size = dbpedia_vocab.vocab_size

baseline_name = "DBpedia_baseline.pt"
unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
unlearning_model.load_state_dict(torch.load(baseline_name))

<All keys matched successfully>

In [2]:
%%time

(
    retain_samples,
    noisy_data,
    retain_valid_dl,
    forget_valid_dl,
    retain_test_dl,
    forget_test_dl,
    dbpedia_vocab,
) = prepare_dbpedia(
    for_baseline_only=False,
    classes_to_forget=[1, 3],
    model=unlearning_model,
    retain_percentage=0.01,
    vocab_class=dbpedia_vocab,
)

Searching for error maximizing noise for class  1
Got loss 222.51974487304688 for tensor([101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101], device='cuda:0')
Searching for error maximizing noise for class  3
Got loss 201.43650817871094 for tensor([101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 10

In [3]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


def validation_step(model, batch):
    labels, text, offsets = batch
    out = model(text,offsets)
    loss = F.cross_entropy(out, labels)   
    acc = accuracy(out, labels)
    return {'Loss': loss.detach(), 'Acc': acc}


def validation_epoch_end(model, outputs):
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()   
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()      
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

In [4]:
import torch.nn.functional as F

@torch.no_grad()
def evaluate_after_unlearning(model, val_loader):
    model.eval()
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch_end(model, outputs)

In [19]:
%%time
from tqdm import tqdm


lrs = [12]
clips = [22]
ratios = [1.5]

# The following search candidates take a few minutes to run through
# lrs = [7,8,9,10,11,12,13,14]
# clips = [17,18,19,20,21,22,23,24]
# ratios = [0.8,1,1.5]


best_retain_acc = 0

unlearned_model_name = "DBpedia_fast_unlearned.pt"

for lr in tqdm(lrs):
    for clip in clips:
        for ratio in ratios:
            ensure_deterministic()
            
            noisy_loader = build_noisy_dl(
                retain_samples,
                noisy_data,
                dbpedia_vocab,
                retain_to_forget_ratio=ratio,
            )
            unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
            unlearning_model.load_state_dict(torch.load(baseline_name))

            optimizer = torch.optim.SGD(unlearning_model.parameters(), lr = lr)


            unlearning_model.train(True)
            for epoch in range(1):
                running_loss = 0.0
                running_acc = 0
                num_batches = len(noisy_loader)
                
                for i, data in enumerate(noisy_loader):
                    labels, inputs, offsets = data

                    optimizer.zero_grad()
                    outputs = unlearning_model(inputs,offsets)
                    loss = unlearning_model.criterion(outputs, labels)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(unlearning_model.parameters(), clip)
                    optimizer.step()

                    running_loss += loss.item()
                    out = torch.argmax(outputs.detach(),dim=1)
                    running_acc += (labels==out).sum().item()/labels.size(0)
                print(f"Train loss {epoch+1}: {running_loss/num_batches},Train Acc:{running_acc*100/num_batches}%")
                forget_acc = evaluate_after_unlearning(unlearning_model, forget_valid_dl)["Acc"]*100
                if forget_acc == 0.0:
                    retain_acc = evaluate_after_unlearning(unlearning_model, retain_valid_dl)["Acc"]*100
                    if retain_acc > best_retain_acc:
                        best_retain_acc = retain_acc
                        best_lr = lr
                        best_clip = clip
                        best_ratio = ratio
                        torch.save(unlearning_model.state_dict(), unlearned_model_name)

print(
    f"Best hyperparams: LR={best_lr}, grad clip={best_clip}, "
    f"ratio={best_ratio} | Best retain acc: {best_retain_acc}"
)
unlearning_model = TextClassificationModel(vocab_size).to(DEVICE)
unlearning_model.load_state_dict(torch.load(unlearned_model_name))

  0%|                                                            | 0/1 [00:00<?, ?it/s]

Train loss 1: 7057.12420241038,Train Acc:45.16511041439477%


100%|████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.06s/it]


Best hyperparams: LR=12, grad clip=22, ratio=1.5 | Best retain acc: 97.17997312545776
CPU times: user 2.52 s, sys: 384 ms, total: 2.91 s
Wall time: 3.58 s


<All keys matched successfully>

In [9]:
print("Performance of Standard Forget Model on Forget Class")
history = [evaluate_after_unlearning(unlearning_model, forget_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Standard Forget Model on Retain Class")
history = [evaluate_after_unlearning(unlearning_model, retain_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Performance of Standard Forget Model on Forget Class
Accuracy: 0.0
Loss: 65.2414321899414
Performance of Standard Forget Model on Retain Class
Accuracy: 97.18105792999268
Loss: 0.4106481671333313


In [10]:
print("Test on Forget Class")
history = [evaluate_after_unlearning(unlearning_model, forget_test_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Test on Retain Class")
history = [evaluate_after_unlearning(unlearning_model, retain_test_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

Test on Forget Class
Accuracy: 0.0
Loss: 64.73976135253906
Test on Retain Class
Accuracy: 96.82180881500244
Loss: 0.28862181305885315
