In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
import numpy as np
from models.vision import LeNetMnist, weights_init, LeNet
from utils import label_to_onehot, cross_entropy_for_onehot
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
import lpips
import os

In [None]:
def mse(gt_data, reconstructed_data):
    return ((gt_data - reconstructed_data) ** 2).mean()

def mse_after_projection(gt_data, reconstructed_data):
    return ((gt_data - torch.clamp(reconstructed_data, min=0, max=1)) ** 2).mean()

def psnr(gt_data, reconstructed_data, max_val=1):
    reconstructed_data = torch.clamp(reconstructed_data, min=0, max=1)
    n = len(gt_data)
    mse_loss = ((gt_data.view(n, -1) - reconstructed_data.view(n, -1)) ** 2).mean(1)
    return (20 * torch.log10(max_val / torch.sqrt(mse_loss))).mean()

loss_fn_vgg = lpips.LPIPS(net='vgg')
def lpips_score(gt_data, reconstructed_data, arr=False):
    if arr:
        return loss_fn_vgg(gt_data, reconstructed_data)
    else:
        return loss_fn_vgg(gt_data, reconstructed_data).mean()
    
tt = transforms.ToPILImage()
def show_examples(gt_imgs, pred_imgs, loss, leak_mode, num=10, bias=6, name="main"):
    rank = torch.argsort(loss)
    best_id_list = rank[:num]
    worst_id_list = rank[-num:]
    torch.manual_seed(0)
    random_id_list = torch.randperm(len(rank))[bias:num+bias]

    fig, axs = plt.subplots(1, num, figsize=(8, 6 * num))
    for j, random_id in enumerate(random_id_list):
        axs[j].imshow(tt(gt_imgs[random_id].view(3, 32, 32).cpu()))
        axs[j].axis('off')
    plt.savefig(f"checkpoint/our_gt_{name}.pdf", bbox_inches='tight')
    
    fig, axs = plt.subplots(1, num, figsize=(8, 6 * num))
    for j, random_id in enumerate(random_id_list):
        axs[j].imshow(tt(pred_imgs[random_id].view(3, 32, 32).cpu()))
        axs[j].axis('off')
    plt.savefig(f"checkpoint/our_{leak_mode}_{name}.pdf", bbox_inches='tight')
    return

In [None]:
shared_model = "LeNet" 
name = "LeNet_batch1"
seed = 0
lr = 1e-4
epochs = 200
model = "MLP-3000"
batch_size = 256
dataset = "CIFAR10"
leak_mode = None
        
checkpoint_name = f"checkpoint/{dataset}_{shared_model}_{model}_{leak_mode}_{lr}_{epochs}_{batch_size}_0_version1.pt"

In [None]:
checkpoint = torch.load(checkpoint_name)
reconstructed_data = checkpoint["reconstructed_imgs"].view(-1, 3, 32, 32)
gt_data = torch.cat(checkpoint["gt_data"])
gt_data = gt_data.view(-1, 3, 32, 32).cpu()
res = []
for evaluate in [mse_after_projection, psnr, lpips_score]:
    res.append(evaluate(gt_data[:len(reconstructed_data)], reconstructed_data).item())
mse_arr = ((gt_data - torch.clamp(reconstructed_data, min=0, max=1)).view(len(gt_data), -1) ** 2).mean(1)
show_examples(gt_data[:len(reconstructed_data)], reconstructed_data, mse_arr, leak_mode, num=4, bias=0)
print(leak_mode, res)