In [3]:
import privacyraven as pr

from privacyraven.utils.data import get_emnist_data
from privacyraven.extraction.core import ModelExtractionAttack
from privacyraven.utils.query import get_target
from privacyraven.models.victim import train_mnist_victim
from privacyraven.models.pytorch import ImagenetTransferLearning, ThreeLayerClassifier
from pl_bolts.callbacks import PrintTableMetricsCallback

callback = PrintTableMetricsCallback()

# Create a query function for a target PyTorch Lightning model
model = train_mnist_victim()


def query_mnist(input_data):
    # PrivacyRaven provides built-in query functions
    return get_target(model, input_data, (1, 1, 28, 28))

# Obtain seed (or public) data to be used in extraction
emnist_train, emnist_test = get_emnist_data()

# Run a model extraction attack
attack = ModelExtractionAttack(
    query=query_mnist,
    query_limit=100,
    victim_input_shape=(1, 1, 28, 28),
    victim_output_targets=10,
    substitute_input_shape=(1, 3, 28, 28),
    synthesizer="copycat",
    substitute_model_arch=ImagenetTransferLearning,
    substitute_input_size=1000,
    seed_data_train=emnist_train,
    seed_data_test=emnist_test,
    callback=callback
)



GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33 K  
2 | layer_3 | Linear | 2 K   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': tensor(0.0972, device='cuda:0'),
 'test_loss': tensor(0.0972, device='cuda:0')}
--------------------------------------------------------------------------------

torch.Size([50, 1, 28, 28])
torch.Size([50, 10])
torch.Size([50, 1, 28, 28])
torch.Size([50, 10])
Synthesis complete
Synthetic Data Generated


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type        | Params
--------------------------------------------------
0 | feature_extractor | MobileNetV2 | 3 M   
1 | layer_1           | Linear      | 128 K 
2 | layer_2           | Linear      | 33 K  
3 | layer_3           | Linear      | 2 K   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

avg_val_loss│val_loss│train_loss│loss
─────────────────────────────────────
2.326599597930908│2.326599597930908│2.4186367988586426│2.4186367988586426


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

avg_val_loss│val_loss│train_loss│loss
─────────────────────────────────────
2.326599597930908│2.326599597930908│2.4186367988586426│2.4186367988586426
2.3566532135009766│2.3566532135009766│1.7656333446502686│1.7656333446502686


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

avg_val_loss│val_loss│train_loss│loss
─────────────────────────────────────
2.326599597930908│2.326599597930908│2.4186367988586426│2.4186367988586426
2.3566532135009766│2.3566532135009766│1.7656333446502686│1.7656333446502686
2.3874757289886475│2.3874757289886475│1.0166378021240234│1.0166378021240234


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

avg_val_loss│val_loss│train_loss│loss
─────────────────────────────────────
2.326599597930908│2.326599597930908│2.4186367988586426│2.4186367988586426
2.3566532135009766│2.3566532135009766│1.7656333446502686│1.7656333446502686
2.3874757289886475│2.3874757289886475│1.0166378021240234│1.0166378021240234
2.5135140419006348│2.5135140419006348│0.5877996683120728│0.5877996683120728


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

avg_val_loss│val_loss│train_loss│loss
─────────────────────────────────────
2.326599597930908│2.326599597930908│2.4186367988586426│2.4186367988586426
2.3566532135009766│2.3566532135009766│1.7656333446502686│1.7656333446502686
2.3874757289886475│2.3874757289886475│1.0166378021240234│1.0166378021240234
2.5135140419006348│2.5135140419006348│0.5877996683120728│0.5877996683120728
2.737476110458374│2.737476110458374│0.27622270584106445│0.27622270584106445


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

avg_val_loss│val_loss│train_loss│loss
─────────────────────────────────────
2.326599597930908│2.326599597930908│2.4186367988586426│2.4186367988586426
2.3566532135009766│2.3566532135009766│1.7656333446502686│1.7656333446502686
2.3874757289886475│2.3874757289886475│1.0166378021240234│1.0166378021240234
2.5135140419006348│2.5135140419006348│0.5877996683120728│0.5877996683120728
2.737476110458374│2.737476110458374│0.27622270584106445│0.27622270584106445
3.0919551849365234│3.0919551849365234│0.0814519077539444│0.0814519077539444


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

avg_val_loss│val_loss│train_loss│loss
─────────────────────────────────────
2.326599597930908│2.326599597930908│2.4186367988586426│2.4186367988586426
2.3566532135009766│2.3566532135009766│1.7656333446502686│1.7656333446502686
2.3874757289886475│2.3874757289886475│1.0166378021240234│1.0166378021240234
2.5135140419006348│2.5135140419006348│0.5877996683120728│0.5877996683120728
2.737476110458374│2.737476110458374│0.27622270584106445│0.27622270584106445
3.0919551849365234│3.0919551849365234│0.0814519077539444│0.0814519077539444
3.5008156299591064│3.5008156299591064│0.028751015663146973│0.028751015663146973


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

avg_val_loss│val_loss│train_loss│loss
─────────────────────────────────────
2.326599597930908│2.326599597930908│2.4186367988586426│2.4186367988586426
2.3566532135009766│2.3566532135009766│1.7656333446502686│1.7656333446502686
2.3874757289886475│2.3874757289886475│1.0166378021240234│1.0166378021240234
2.5135140419006348│2.5135140419006348│0.5877996683120728│0.5877996683120728
2.737476110458374│2.737476110458374│0.27622270584106445│0.27622270584106445
3.0919551849365234│3.0919551849365234│0.0814519077539444│0.0814519077539444
3.5008156299591064│3.5008156299591064│0.028751015663146973│0.028751015663146973
3.940340757369995│3.940340757369995│0.008133718743920326│0.008133718743920326


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

avg_val_loss│val_loss│train_loss│loss
─────────────────────────────────────
2.326599597930908│2.326599597930908│2.4186367988586426│2.4186367988586426
2.3566532135009766│2.3566532135009766│1.7656333446502686│1.7656333446502686
2.3874757289886475│2.3874757289886475│1.0166378021240234│1.0166378021240234
2.5135140419006348│2.5135140419006348│0.5877996683120728│0.5877996683120728
2.737476110458374│2.737476110458374│0.27622270584106445│0.27622270584106445
3.0919551849365234│3.0919551849365234│0.0814519077539444│0.0814519077539444
3.5008156299591064│3.5008156299591064│0.028751015663146973│0.028751015663146973
3.940340757369995│3.940340757369995│0.008133718743920326│0.008133718743920326
4.384536266326904│4.384536266326904│0.0037335376255214214│0.0037335376255214214


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

avg_val_loss│val_loss│train_loss│loss
─────────────────────────────────────
2.326599597930908│2.326599597930908│2.4186367988586426│2.4186367988586426
2.3566532135009766│2.3566532135009766│1.7656333446502686│1.7656333446502686
2.3874757289886475│2.3874757289886475│1.0166378021240234│1.0166378021240234
2.5135140419006348│2.5135140419006348│0.5877996683120728│0.5877996683120728
2.737476110458374│2.737476110458374│0.27622270584106445│0.27622270584106445
3.0919551849365234│3.0919551849365234│0.0814519077539444│0.0814519077539444
3.5008156299591064│3.5008156299591064│0.028751015663146973│0.028751015663146973
3.940340757369995│3.940340757369995│0.008133718743920326│0.008133718743920326
4.384536266326904│4.384536266326904│0.0037335376255214214│0.0037335376255214214
4.788750648498535│4.788750648498535│0.0017482512630522251│0.0017482512630522251





HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'avg_test_loss': tensor(4.3974, device='cuda:0'),
 'test_loss': tensor(4.3974, device='cuda:0')}
--------------------------------------------------------------------------------

torch.Size([100, 3, 28, 28])
torch.Size([100, 10])
torch.Size([100, 1, 28, 28])
torch.Size([100, 10])
Out of 100 data points, the models agreed upon 7.


In [14]:
from privacyraven.utils.query import query_model

print(emnist_train.data.size())

print("test")


a, b = query_model(model, emnist_train.data, (1, 1, 28, 28))

print("test")

print(a.size())
print(b.size())


print(b[0])
print(emnist_train.targets[0])

torch.Size([240000, 28, 28])
test
torch.Size([240000, 1, 28, 28])
torch.Size([240000, 10])
test
torch.Size([240000, 10])
torch.Size([240000])
tensor(2)
tensor(8)


In [16]:
import torch
x = torch.reshape(emnist_train.data[0], (1, 1, 28, 28))

print("test")


a, b = query_model(model, x, (1, 1, 28, 28))

print("test")

print(a.size())
print(b)

test
torch.Size([1, 1, 28, 28])
torch.Size([1, 10])
Single


TypeError: only integer tensors of a single element can be converted to an index

In [10]:
import torch
torch.argmax(a, dim=1, keepdim=True)

tensor([[6],
        [8],
        [8],
        [8],
        [4],
        [0],
        [0],
        [5],
        [0],
        [0],
        [8],
        [3],
        [4],
        [4],
        [3],
        [0],
        [0],
        [7],
        [0],
        [0],
        [0],
        [3],
        [0],
        [8],
        [0],
        [0],
        [1],
        [0]], device='cuda:0')

In [4]:
subs = attack.substitute_model

In [5]:
type(subs)

privacyraven.models.pytorch.ImagenetTransferLearning

In [8]:
def query_subs(input_data):
    return get_target(subs, input_data, (1, 3, 28, 28))


In [9]:
print(query_subs(emnist_train.data[0]))

tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 5, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9])


In [10]:
print(emnist_train.targets[0])

tensor(8)


In [None]:
# Do single tensor queries work now? 