In [None]:
%matplotlib inline

import numpy as np
from pprint import pprint

from PIL import Image
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolder
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import models, datasets, transforms
import torchvision.transforms.functional as FF
from pytorch_msssim import ssim
torch.manual_seed(50)
import os
import time
print(torch.__version__, torchvision.__version__)

In [None]:
# dataset = ImageFolder(r'C:\Users\badha\OneDrive - Florida International University\Desktop\PhD at FIU\Solid Lab\Spring 2023\CPL Attack Paper Exp\skin11\melanoma_cancer_dataset\test')
# dataset

In [None]:
dst = ImageFolder(r'C:\Users\badha\OneDrive - Florida International University\Desktop\PhD at FIU\Solid Lab\Spring 2023\CPL Attack Paper Exp\skin11\melanoma_cancer_dataset\test')
tp = transforms.Compose([
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToTensor()
])
tt = transforms.ToPILImage()

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print("Running on %s" % device)

def label_to_onehot(target, num_classes=2):
    target = torch.unsqueeze(target, 1)
    onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
    onehot_target.scatter_(1, target, 1)
    return onehot_target

def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))

In [None]:
def weights_init(m):
    if hasattr(m, "weight"):
        m.weight.data.uniform_(-0.5, 0.5)
    if hasattr(m, "bias"):
        m.bias.data.uniform_(-0.5, 0.5)

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(3, 12, kernel_size=5, padding=5//2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(768, 2)
        )

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        # print(out.size())
        out = self.fc(out)
        return out

net = LeNet().to(device)
net.apply(weights_init)
criterion = cross_entropy_for_onehot

In [None]:
# directory=['recovered_image_skin_2024exp/With SGD Optimizer/0','recovered_image_skin_2024exp/With SGD Optimizer/benign','recovered_image_skin_2024exp/With SGD Optimizer/malignant']
# directory

In [None]:
len(dst)

In [None]:
AS=0
whole_time=0
sSim=[]
mSe=[]
count=0
for ii in range(0,len(dst),10):
    ######### honest partipant #########
    print("Image: ", count+1)
    
    img_index = ii
    gt_data = tp(dst[img_index][0]).to(device)
    gt_data = gt_data.view(1, *gt_data.size())
    gt_label = torch.Tensor([dst[img_index][1]]).long().to(device)
    gt_label = gt_label.view(1, )
    gt_onehot_label = label_to_onehot(gt_label, num_classes=2)

    #plt.imshow(tt(gt_data[0].cpu()))
    #plt.title("Ground truth image")
    #print("GT label is %d." % gt_label.item(), "\nOnehot label is %d." % torch.argmax(gt_onehot_label, dim=-1).item())

    # compute original gradient 
    out = net(gt_data)
    y = criterion(out, gt_onehot_label)
    dy_dx = torch.autograd.grad(y, net.parameters())
    
    
    # share the gradients with other clients
    original_dy_dx = list((_.detach().clone() for _ in dy_dx))
    # generate dummy data and label
    
    start=time.process_time()
    dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
    dummy_label = torch.randn(gt_onehot_label.size()).to(device).requires_grad_(True)

    #plt.imshow(tt(dummy_data[0].cpu()))
    #plt.title("Dummy data")
    #print("Dummy label is %d." % torch.argmax(dummy_label, dim=-1).item())
    optimizer = torch.optim.LBFGS([dummy_data, dummy_label] )
    
    history = []
    for iters in range(300):
        def closure():
            optimizer.zero_grad()

            pred = net(dummy_data) 
            dummy_onehot_label = F.softmax(dummy_label, dim=-1)
            dummy_loss = criterion(pred, dummy_onehot_label) # TODO: fix the gt_label to dummy_label in both code and slides.
            dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)

            grad_diff = 0
            grad_count = 0
            for gx, gy in zip(dummy_dy_dx, original_dy_dx): # TODO: fix the variablas here
                grad_diff += ((gx - gy) ** 2).sum()
                grad_count += gx.nelement()
            # grad_diff = grad_diff / grad_count * 1000
            grad_diff.backward()

            return grad_diff

        optimizer.step(closure)
        if iters % 100 == 0: 
            current_loss = closure()
            print(iters, "%.4f" % current_loss.item())
        history.append(tt(dummy_data[0].cpu()))
    end=time.process_time()
    total_time=end-start
    whole_time=whole_time+total_time
    # Load the two images (as PIL Images or tensors)
    img1 = tp(dst[ii][0]) 
    img2 = history[299]

    # Convert the images to tensors and reshape to (batch_size, channels, height, width)
    img1_tensor = img1.unsqueeze(0)
    img2_tensor = FF.to_tensor(img2).unsqueeze(0)
    
    # Calculate SSIM between the two images
    ssim_value = ssim(img1_tensor, img2_tensor, data_range=1.0, size_average=True)
    sSim.append(ssim_value)
    print(f"SSIM: {ssim_value:.4f}")
    mse = F.mse_loss(img1_tensor, img2_tensor)
    mSe.append(mse)
    print(f"MSE: {mse:.8f}")
    print("Time Needed: ",total_time)
    
    
#     lbl=[]
#     for j in range(0,10):
#         lbl.append(j)
#     for check in range(0,10):
#         if(gt_label.item()==lbl[check]):
#             plt.imshow(history[299])
#             plt.axis('off')
#             print(gt_label.item())
#             #plt.savefig(os.path.join(directory[check], f'image_lbl{ii, check}.png'), bbox_inches='tight', pad_inches=0)
#             plt.close() 
        
    print()
    
    if(ssim_value>=.90):
        AS=AS+1
        #print(AS)
    count=count+1




In [None]:
for i in range(len(sSim)):
    print(sSim[i].item())

In [None]:
for i in range(len(mSe)):
    print(mSe[i].item())

In [None]:
print("Count", count)
print("Attack Success Rate: ", (AS/(count+1)))
print("Avg. SSIM: ",(np.average(sSim)))
print("Avg. MSE: ",(np.average(mSe)))
print("Avg. Time:", (whole_time/(count+1)))