In [None]:
import numpy as np
from torch import fft
import torchvision.transforms as trans
from PIL import Image
from matplotlib import pyplot as plt
import seaborn as sns
from focal_frequency_loss import FocalFrequencyLoss as FFL
def plot_fft(image_path):
    _image = Image.open(image_path)
    to_tensor = trans.ToTensor()
    _image_tensor = to_tensor(_image)
    if _image_tensor.shape[1] >= 244 or _image_tensor.shape[2] >= 244:
        _image_tensor = trans.transforms.Resize((224, 224))(_image_tensor)
    _image_fft = fft.fft2(_image_tensor)
    _image_fft = fft.fftshift(_image_fft)
    to_PIL = trans.ToPILImage()
    _image = to_PIL(_image_tensor)
    _image_fft = _image_fft.sum(axis=0)
    # _image_fft = np.log(np.abs(_image_fft) + 1)
    _image_fft = np.abs(_image_fft)
    return _image,_image_fft,_image_tensor

In [None]:
# in order to show the difference better, we use log scale to show the fft heatmap
file_names = ['img_9.png', '9-bpp.png', '9-sig.png', '9-badnet.png', '9-blended.png', '9-trojann.png', '9-ssba.png', '9-inputAware.png', '9-wanet.png']
length = len(file_names)
fig =  plt.figure(figsize=(50, length*10))
ffl = FFL(loss_weight=1.0, alpha=1.0)
original_image_fft = 0
orignal_image_tensor = 0
column = 10
for index, file_name in enumerate(file_names):
    image, image_fft, image_tensor = plot_fft(file_name)
    if index == 0:
        original_image_fft = image_fft
        orignal_image_tensor = image_tensor.clone()
        orignal_image_tensor = orignal_image_tensor.unsqueeze(0)
    fig.add_subplot(length, column, index*column+1)
    plt.imshow(image)
    plt.title(file_name, fontsize=40)
    plt.axis('off')
    # show original image on the above row
    fig.add_subplot(length, column, index*column+2)
    sns.heatmap(np.log(original_image_fft+1), cmap='viridis', cbar=False)
    plt.axis('off')
    plt.title('log view of orgin fft', fontsize=20)
    # # show fft heatmap on the below row
    fig.add_subplot(length, column, index*column+3)
    sns.heatmap(np.log(image_fft+1), cmap='viridis', cbar=False)
    plt.axis('off')
    plt.title('log view of backdoored fft', fontsize=20)
    fig.add_subplot(length, column, index*column+4)
    fft_diff = image_fft - original_image_fft
    sns.heatmap(np.abs(fft_diff), cmap='viridis', cbar=False)
    plt.axis('off')
    image_tensor = image_tensor.unsqueeze(0)
    fflloss = ffl(image_tensor, orignal_image_tensor)
    plt.title(f'FFL: {fflloss:.8f}', fontsize=40)
    fig.add_subplot(length, column, index*column+5)
    fft_diff = image_fft - original_image_fft
    sns.heatmap(np.log(np.abs(fft_diff)+1), cmap='viridis', cbar=False)
    plt.axis('off')
    plt.title('diff log view', fontsize=20)
    fig.add_subplot(length, column, index*column+6)
    fft_diff_1 = fft_diff.detach().clone()
    fft_diff_1[fft_diff<0] = 0
    sns.heatmap(np.log(np.abs(fft_diff_1)+1), cmap='viridis', cbar=False)
    plt.axis('off')
    plt.title('Positive diff', fontsize=20)
    fig.add_subplot(length, column, index*column+7)
    fft_diff_2 = fft_diff.detach().clone()
    fft_diff_2[fft_diff>0] = 0
    sns.heatmap(np.log(np.abs(fft_diff_2)+1), cmap='viridis', cbar=False)
    plt.axis('off')
    plt.title('Negative diff', fontsize=20)
    fig.add_subplot(length, column, index*column+8)
    ffta = fft_diff.detach().clone().numpy()
    ffta = np.square(ffta)
    # get center of the fft
    h,w = ffta.shape
    center = (h-1)/2, (w-1)/2
    # get the radius of the fft
    max_radius = center[0]**2 + center[1]**2
    max_radius = int(np.ceil(np.sqrt(max_radius)))
    # create a blank array to store the fft's distribution
    fft_distribution = np.zeros(max_radius, dtype=np.float32)
    # calculate the fft's distribution
    for i in range(h):
        for j in range(w):
            radius = int(np.ceil(np.sqrt((i-center[0])**2 + (j-center[1])**2)))
            fft_distribution[radius-1] += ffta[i,j]
    # plot the fft's distribution
    plt.plot(fft_distribution)
    fig.add_subplot(length, column, index*column+9)
    plt.plot(fft_distribution[:int(np.floor(min(center[0], center[1])))])
    fig.add_subplot(length, column, index*column+10)
    plt.plot(fft_distribution[10:int(np.floor(min(center[0], center[1])))])
fig.savefig('freqcompare_all.png')