From bda66c515311c7393593ef3376578e80de968801 Mon Sep 17 00:00:00 2001 From: Qasim Khan Date: Sat, 3 Jun 2023 23:27:38 +0500 Subject: [PATCH 01/10] Change model architecture Model architecture was not the same as that of the one in Basic MNIST Example, so it has been changed to be the exact same --- beginner_source/fgsm_tutorial.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/beginner_source/fgsm_tutorial.py b/beginner_source/fgsm_tutorial.py index fa23680496c..b730e55a419 100644 --- a/beginner_source/fgsm_tutorial.py +++ b/beginner_source/fgsm_tutorial.py @@ -160,20 +160,27 @@ class Net(nn.Module): def __init__(self): super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 10, kernel_size=5) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.conv2_drop = nn.Dropout2d() - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, 10) + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) - x = x.view(-1, 320) - x = F.relu(self.fc1(x)) - x = F.dropout(x, training=self.training) + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) x = self.fc2(x) - return F.log_softmax(x, dim=1) + output = F.log_softmax(x, dim=1) + return output # MNIST Test dataset and dataloader declaration test_loader = torch.utils.data.DataLoader( From 8c0cdec60fc3048cfd6039b3effe243a2af79a22 Mon Sep 17 00:00:00 2001 From: Qasim Khan Date: Sat, 3 Jun 2023 23:29:10 +0500 Subject: [PATCH 02/10] Add normalization transform in dataloader The model is trained on normalized data, so it is unfair to use unnormalized data in this example. --- beginner_source/fgsm_tutorial.py | 1 + 1 file changed, 1 insertion(+) diff --git a/beginner_source/fgsm_tutorial.py b/beginner_source/fgsm_tutorial.py index b730e55a419..32e3d4473bb 100644 --- a/beginner_source/fgsm_tutorial.py +++ b/beginner_source/fgsm_tutorial.py @@ -186,6 +186,7 @@ def forward(self, x): test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=1, shuffle=True) From 35305b1dad81a784858787e21fed19fe97f712e8 Mon Sep 17 00:00:00 2001 From: Qasim Khan Date: Sat, 3 Jun 2023 23:35:47 +0500 Subject: [PATCH 03/10] Add denormalization code The MNIST model is trained with normalized data but no normalization was applied in this tutorial. Thus, a denorm function is created, which is called to denorm the data before performing FGSM. The perturbed data is again normalized before feeding it to the model. --- beginner_source/fgsm_tutorial.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/beginner_source/fgsm_tutorial.py b/beginner_source/fgsm_tutorial.py index 32e3d4473bb..90ce158071f 100644 --- a/beginner_source/fgsm_tutorial.py +++ b/beginner_source/fgsm_tutorial.py @@ -233,6 +233,27 @@ def fgsm_attack(image, epsilon, data_grad): # Return the perturbed image return perturbed_image +# denormalize the tensors before performing an FGSM attack +# because FGSM only works with the original unnormalized image +def denorm(batch, mean=[0.1307], std=[0.3081]): + """ + Denormalizes a batch of tensors. + + Args: + batch (torch.Tensor): Batch of normalized tensors. + mean (torch.Tensor or list): Mean used for normalization. + std (torch.Tensor or list): Standard deviation used for normalization. + + Returns: + torch.Tensor: Denormalized batch of tensors. + """ + if isinstance(mean, list): + mean = torch.tensor(mean).to(device) + if isinstance(std, list): + std = torch.tensor(std).to(device) + + return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1) + ###################################################################### # Testing Function @@ -287,11 +308,17 @@ def test( model, device, test_loader, epsilon ): # Collect ``datagrad`` data_grad = data.grad.data + # Denormalize the data + data_denorm = denorm(data) + # Call FGSM Attack - perturbed_data = fgsm_attack(data, epsilon, data_grad) + perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad) + + # Reapply normalization + perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data) # Re-classify the perturbed image - output = model(perturbed_data) + output = model(perturbed_data_normalized) # Check for success final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability From 538fe9f5cf3a9b8bd44afe926ec465c09745e85f Mon Sep 17 00:00:00 2001 From: Qasim Khan Date: Sat, 3 Jun 2023 23:36:34 +0500 Subject: [PATCH 04/10] fix formatting in test --- beginner_source/fgsm_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/fgsm_tutorial.py b/beginner_source/fgsm_tutorial.py index 90ce158071f..f41ee60c39b 100644 --- a/beginner_source/fgsm_tutorial.py +++ b/beginner_source/fgsm_tutorial.py @@ -308,7 +308,7 @@ def test( model, device, test_loader, epsilon ): # Collect ``datagrad`` data_grad = data.grad.data - # Denormalize the data + # Denormalize the data data_denorm = denorm(data) # Call FGSM Attack From 3d2c2bad28e5963815601993d083bbde8bb08588 Mon Sep 17 00:00:00 2001 From: Qasim Khan Date: Sat, 3 Jun 2023 23:37:27 +0500 Subject: [PATCH 05/10] load state_dict on device instead of cpu --- beginner_source/fgsm_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/fgsm_tutorial.py b/beginner_source/fgsm_tutorial.py index f41ee60c39b..2cfd682a234 100644 --- a/beginner_source/fgsm_tutorial.py +++ b/beginner_source/fgsm_tutorial.py @@ -198,7 +198,7 @@ def forward(self, x): model = Net().to(device) # Load the pretrained model -model.load_state_dict(torch.load(pretrained_model, map_location='cpu')) +model.load_state_dict(torch.load(pretrained_model, map_location=device)) # Set the model in evaluation mode. In this case this is for the Dropout layers model.eval() From a9ff79dae9554604356693918093efe04bb01eaf Mon Sep 17 00:00:00 2001 From: Qasim Khan Date: Sun, 4 Jun 2023 00:12:18 +0500 Subject: [PATCH 06/10] Fix spellings --- beginner_source/fgsm_tutorial.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/beginner_source/fgsm_tutorial.py b/beginner_source/fgsm_tutorial.py index 2cfd682a234..b58137ad738 100644 --- a/beginner_source/fgsm_tutorial.py +++ b/beginner_source/fgsm_tutorial.py @@ -233,11 +233,10 @@ def fgsm_attack(image, epsilon, data_grad): # Return the perturbed image return perturbed_image -# denormalize the tensors before performing an FGSM attack -# because FGSM only works with the original unnormalized image +# restores the tensors to their original scale def denorm(batch, mean=[0.1307], std=[0.3081]): """ - Denormalizes a batch of tensors. + Convert a batch of tensors to their original scale. Args: batch (torch.Tensor): Batch of normalized tensors. @@ -245,7 +244,7 @@ def denorm(batch, mean=[0.1307], std=[0.3081]): std (torch.Tensor or list): Standard deviation used for normalization. Returns: - torch.Tensor: Denormalized batch of tensors. + torch.Tensor: batch of tensors without normalization applied to them. """ if isinstance(mean, list): mean = torch.tensor(mean).to(device) @@ -308,7 +307,7 @@ def test( model, device, test_loader, epsilon ): # Collect ``datagrad`` data_grad = data.grad.data - # Denormalize the data + # Restore the data to its original scale data_denorm = denorm(data) # Call FGSM Attack From b13f072b4404339ee015c12c4a102e683186c73c Mon Sep 17 00:00:00 2001 From: Qasim Khan Date: Fri, 9 Jun 2023 14:49:41 +0500 Subject: [PATCH 07/10] remove deprecated argument --- beginner_source/fgsm_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/fgsm_tutorial.py b/beginner_source/fgsm_tutorial.py index d5aafbe2bca..3c9ce88a9e8 100644 --- a/beginner_source/fgsm_tutorial.py +++ b/beginner_source/fgsm_tutorial.py @@ -192,7 +192,7 @@ def forward(self, x): model = Net().to(device) # Load the pretrained model -model.load_state_dict(torch.load(pretrained_model, weights_only=True, map_location=device)) +model.load_state_dict(torch.load(pretrained_model, map_location=device)) # Set the model in evaluation mode. In this case this is for the Dropout layers model.eval() From c9f1d2ff8b3b7bbdcab966a4078c4ce9cc94c3c4 Mon Sep 17 00:00:00 2001 From: Qasim Khan Date: Fri, 9 Jun 2023 14:52:17 +0500 Subject: [PATCH 08/10] Update model weights link --- beginner_source/fgsm_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/fgsm_tutorial.py b/beginner_source/fgsm_tutorial.py index 3c9ce88a9e8..2629f3a7a00 100644 --- a/beginner_source/fgsm_tutorial.py +++ b/beginner_source/fgsm_tutorial.py @@ -123,7 +123,7 @@ # - ``pretrained_model`` - path to the pretrained MNIST model which was # trained with # `pytorch/examples/mnist `__. -# For simplicity, download the pretrained model `here `__. +# For simplicity, download the pretrained model `here `__. # # - ``use_cuda`` - boolean flag to use CUDA if desired and available. # Note, a GPU with CUDA is not critical for this tutorial as a CPU will From a341c00f6d70442cfc8695bc2d1a9e297579d9c8 Mon Sep 17 00:00:00 2001 From: Qasim Khan Date: Fri, 9 Jun 2023 14:59:29 +0500 Subject: [PATCH 09/10] Update location of fgsm weights --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index ed0ade00465..70ee12dc8b5 100644 --- a/Makefile +++ b/Makefile @@ -82,7 +82,7 @@ download: tar $(TAROPTS) -xzf $(DATADIR)/UrbanSound8K.tar.gz -C ./beginner_source/data/ # Download model for beginner_source/fgsm_tutorial.py - wget -nv -N https://s3.amazonaws.com/pytorch-tutorial-assets/lenet_mnist_model.pth -P $(DATADIR) + wget -nv -N 'https://docs.google.com/uc?export=download&id=1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl' -O lenet_mnist_model.pth -P $(DATADIR) cp $(DATADIR)/lenet_mnist_model.pth ./beginner_source/data/lenet_mnist_model.pth # Download model for advanced_source/dynamic_quantization_tutorial.py From 29c9c93ac14da2c88cc66363e8defe439520a669 Mon Sep 17 00:00:00 2001 From: Qasim Khan Date: Fri, 9 Jun 2023 15:28:50 +0500 Subject: [PATCH 10/10] Update command to download fgsm MNIST weights --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 70ee12dc8b5..bbe76125ec5 100644 --- a/Makefile +++ b/Makefile @@ -82,7 +82,7 @@ download: tar $(TAROPTS) -xzf $(DATADIR)/UrbanSound8K.tar.gz -C ./beginner_source/data/ # Download model for beginner_source/fgsm_tutorial.py - wget -nv -N 'https://docs.google.com/uc?export=download&id=1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl' -O lenet_mnist_model.pth -P $(DATADIR) + wget -nv -N 'https://docs.google.com/uc?export=download&id=1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl' -O $(DATADIR)/lenet_mnist_model.pth cp $(DATADIR)/lenet_mnist_model.pth ./beginner_source/data/lenet_mnist_model.pth # Download model for advanced_source/dynamic_quantization_tutorial.py