In [None]:
import torch
import torchvision

import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from PIL import Image
from pytorch_msssim import ssim
import torchvision.transforms as T
from PIL import Image
import datetime
import time
import os

In [None]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name())
print(torch.cuda.memory_allocated())

# Choose variants here:

In [None]:
trained_model = True
arch = 'ResNet18'

## System setup:

In [None]:
import inversefed
setup = inversefed.utils.system_startup()
defs = inversefed.training_strategy('conservative')

loss_fn, trainloader, validloader =  inversefed.construct_dataloaders('ImageNet', defs, 
                                                                      data_path=r'inversefed/data/Brain-Tumor-MRI-Dataset')

model = torchvision.models.resnet18(pretrained=trained_model)
model.to(**setup)
model.eval();

In [None]:
dm = torch.as_tensor(inversefed.consts.cifar10_mean, **setup)[:, None, None]
ds = torch.as_tensor(inversefed.consts.cifar10_std, **setup)[:, None, None]
def plot(tensor):
    tensor = tensor.clone().detach()
    tensor.mul_(ds).add_(dm).clamp_(0, 1)
    if tensor.shape[0] == 1:
        return plt.imshow(tensor[0].permute(1, 2, 0).cpu());
    else:
        fig, axes = plt.subplots(1, tensor.shape[0], figsize=(12, tensor.shape[0]*12))
        for i, im in enumerate(tensor):
            axes[i].imshow(im.permute(1, 2, 0).cpu());
        

# Reconstruct

In [None]:
len(trainloader.dataset)

### Build the input (ground-truth) gradient

In [None]:
output_folder = "Brain_MRI_input_and_recovered_images_for_Resnet_trained_100_img_v3"
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# For 100 Images

In [None]:
test_mse_final=0
sSim_final=0
whole_duration=0
loss_value=0
count=0
ASR=0
ssim_store=[]
mse_store=[]
for i in range (0, len(trainloader.dataset),57):
    print("Image No.", count+1)
    print("For image : ", i)
    
    idx = i 
    # 8112 # the beagle
    # 1200 # the owl
    # 11794 # the German shepherd
    # 19449 # the panda


    # np.random.randint(len(validloader.dataset))

    img, label = trainloader.dataset[idx]
    labels = torch.as_tensor((label,), device=setup['device'])
    ground_truth = img.to(**setup).unsqueeze(0)
   
    #plot(ground_truth);
    #print([validloader.dataset.classes[l] for l in labels]);


    ground_truth_denormalized = torch.clamp(ground_truth * ds + dm, 0, 1)
    
    
    output_path = os.path.join(output_folder, f'{idx}_{arch}_Brain_MRI_trained_input_loop_test.png')
    torchvision.utils.save_image(ground_truth_denormalized, output_path)
    
    
    #torchvision.utils.save_image(ground_truth_denormalized, f'{idx}_{arch}_covid_19_trained_input_loop_test.png')


    model.zero_grad()
    target_loss, _, _ = loss_fn(model(ground_truth), labels)
    input_gradient = torch.autograd.grad(target_loss, model.parameters())
    input_gradient = [grad.detach() for grad in input_gradient]
    full_norm = torch.stack([g.norm() for g in input_gradient]).mean()
    print(f'Full gradient norm is {full_norm:e}.')
    config = dict(signed=True,
                  boxed=True,
                  cost_fn='sim',
                  indices='def',
                  weights='equal',
                  lr=0.1,
                  optim='adam',
                  restarts=8,
                  max_iterations=24000,
                  total_variation=1e-1,
                  init='randn',
                  filter='none',
                  lr_decay=True,
                  scoring_choice='loss')
    start_time = time.time()
    rec_machine = inversefed.GradientReconstructor(model, (dm, ds), config, num_images=1)
    output, stats = rec_machine.reconstruct(input_gradient, labels, img_shape=(3, 224, 224))
    end_time=time.time()

    test_mse = (output.detach() - ground_truth).pow(2).mean()
    mse_store.append(test_mse)
    test_mse_final=test_mse_final+test_mse
    
    feat_mse = (model(output.detach())- model(ground_truth)).pow(2).mean()  
    test_psnr = inversefed.metrics.psnr(output, ground_truth, factor=1/ds)
    
    sSim=ssim(output, ground_truth).item()
    ssim_store.append(sSim)
    sSim_final=sSim_final+sSim
    
    loss_value=loss_value+stats['opt']
    plot(output)
    
    
    
    plt.title(f"Rec. loss: {stats['opt']:2.4f} | MSE: {test_mse:2.4f} | SSIM: {sSim:2.4f}"
              f"| PSNR: {test_psnr:4.2f} | FMSE: {feat_mse:2.4e} |");
    
    output_denormalized = torch.clamp(output * ds + dm, 0, 1)
    
    
    
    #torchvision.utils.save_image(output_denormalized, f'{idx}_{arch}_covid_19_trained_output_loop_test.png')
    
    output_path = os.path.join(output_folder, f'{idx}_{arch}_Brain_MRI_trained_output_loop_test.png')
    torchvision.utils.save_image(output_denormalized, output_path)
    

    
    data = inversefed.metrics.activation_errors(model, output, ground_truth)
    fig, axes = plt.subplots(2, 4, sharey=False, figsize=(14,8))
    axes[0, 0].semilogy(list(data['se'].values())[:-3])
    axes[0, 0].set_title('SE')
    axes[0, 1].semilogy(list(data['mse'].values())[:-3])
    axes[0, 1].set_title('MSE')
    axes[0, 2].plot(list(data['sim'].values())[:-3])
    axes[0, 2].set_title('Similarity')

    convs = [val for key, val in data['mse'].items() if 'conv' in key]
    axes[1, 0].semilogy(convs)
    axes[1, 0].set_title('MSE - conv layers')
    convs = [val for key, val in data['mse'].items() if 'conv1' in key]
    axes[1, 1].semilogy(convs)
    convs = [val for key, val in data['mse'].items() if 'conv2' in key]
    axes[1, 1].semilogy(convs)
    axes[1, 1].set_title('MSE - conv1 vs conv2 layers')
    bns = [val for key, val in data['mse'].items() if 'bn' in key]
    axes[1, 2].plot(bns)
    axes[1, 2].set_title('MSE - bn layers')
    fig.suptitle('Error between layers');
    
    
    
    duration=end_time-start_time
    whole_duration=whole_duration+duration
    print("Time to recover Image", i, ": ",duration)
    if(sSim>=0.9):
        ASR=ASR+1
    print()
    count=count+1

In [None]:
print("Total no. of Images: ",count)
print("Attack Success Rate: ", ASR/count)
print("Avg. MSE: ",test_mse_final.item()/count)
print("Avg. SSIM: ",sSim_final/count)
print("Avg. loss value: ",loss_value/count)
print("Avg. Duration to recover: ",whole_duration/count)


In [None]:
wwwww=0
print(ssim_store)
for d in range(len(ssim_store)):
    if(ssim_store[d]>=.8):
        wwwww=wwwww+1
print(wwwww/len(ssim_store))